• Matt's avatar
    XLA train step fixes (#17973) · d6cec458
    Matt authored
    * Copy inputs to train and test step before modifying them, as this breaks things
    
    * Add XLA tests, fix our loss functions to be XLA-compatible
    
    * make fixup
    
    * Update loss computation test to expect vector of per-sample losses
    
    * Patch loss for TFLED
    
    * Patch loss for TFAlbert
    
    * Add a tf_legacy_loss config flag that enables old loss functions
    
    * Stop using config.get() because it's not a dict
    
    * Skip loss computation test for RAG because its loss is very strange and I'm afraid to rewrite it
    
    * make fixup
    
    * Add XLA-compatible RAG loss
    
    * Fix dtype of loss mask for TFAlbert
    
    * Fix test for XLNet too because it overrides the default one
    
    * make fixup
    
    * Fix config test
    
    * No more depending on GPU NaN behaviour
    
    * Add test, avoid potential zero division
    
    * Fix test item assignment
    
    * Fix loss computation masking test
    
    * make fixup
    
    * Fix dtype bugs
    d6cec458
test_modeling_tf_common.py 100 KB