[JAX] Fixes for CI failures with the latest JAX (#1469)
* fixes L1 test
* fix test_multigpu_encoder
* fixes for other multi-encoder tests
* jax.extend.ffi to jax.ffi
* initialization with float32
* add init_dtype as an optional arg to all modules
* update use_scan query from xla flags
* relax threshold for test_encoder fp8
* relax the tols
---------
Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
Showing
Please register or sign in to comment