XLA train step fixes (#17973)
* Copy inputs to train and test step before modifying them, as this breaks things * Add XLA tests, fix our loss functions to be XLA-compatible * make fixup * Update loss computation test to expect vector of per-sample losses * Patch loss for TFLED * Patch loss for TFAlbert * Add a tf_legacy_loss config flag that enables old loss functions * Stop using config.get() because it's not a dict * Skip loss computation test for RAG because its loss is very strange and I'm afraid to rewrite it * make fixup * Add XLA-compatible RAG loss * Fix dtype of loss mask for TFAlbert * Fix test for XLNet too because it overrides the default one * make fixup * Fix config test * No more depending on GPU NaN behaviour * Add test, avoid potential zero division * Fix test item assignment * Fix loss computation masking test * make fixup * Fix dtype bugs
Showing
Please register or sign in to comment