• Kristian Holsheimer's avatar
    [FlaxBert] Fix non-broadcastable attention mask for batched forward-passes (#8791) · f8eda599
    Kristian Holsheimer authored
    * [FlaxBert] Fix non-broadcastable attention mask for batched forward-passes
    
    * [FlaxRoberta] Fix non-broadcastable attention mask
    
    * Use jax.numpy instead of ordinary numpy (otherwise not jit-able)
    
    * Partially revert "Use jax.numpy ..."
    
    * Add tests for batched forward passes
    
    * Avoid unnecessary OOMs due to preallocation of GPU memory by XLA
    
    * Auto-fix style
    
    * Re-enable GPU memory preallocation but with mem fraction < 1/paralleism
    f8eda599
test_modeling_flax_roberta.py 2.57 KB