[JAX] Flax with compute dtype inferred from input dtype. (#1485)
flax module with compute dtype inferred from the inputs
Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
Showing
Please register or sign in to comment
flax module with compute dtype inferred from the inputs
Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>