"test/vscode:/vscode.git/clone" did not exist on "0513330a853fd2bb3196b89a7727f95c36d7335f"
[FlaxBert] Fix non-broadcastable attention mask for batched forward-passes (#8791)
* [FlaxBert] Fix non-broadcastable attention mask for batched forward-passes * [FlaxRoberta] Fix non-broadcastable attention mask * Use jax.numpy instead of ordinary numpy (otherwise not jit-able) * Partially revert "Use jax.numpy ..." * Add tests for batched forward passes * Avoid unnecessary OOMs due to preallocation of GPU memory by XLA * Auto-fix style * Re-enable GPU memory preallocation but with mem fraction < 1/paralleism
Showing
Please register or sign in to comment