"docs/vscode:/vscode.git/clone" did not exist on "c3372e87bed990510e4ae0b39f151a34dea24f8b"
Unverified Commit 2d4644b7 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Set `precision=HIGHEST` for the ref_grouped_gemm impl in the unit test (#1967)



* set precision=HIGHEST for the ref_grouped_gemm impl in the unit test
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>


---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 657c965b
......@@ -1265,7 +1265,9 @@ class TestGroupedDense:
ref_out = []
dim_num = (contracting_dims, ((), ()))
for lhs_i, rhs_i, bias_i in zip(lhs, rhs, bias):
out_i = jax.lax.dot_general(lhs_i, rhs_i, dim_num) + jnp.expand_dims(bias_i, axis=0)
out_i = jax.lax.dot_general(
lhs_i, rhs_i, dim_num, precision=jax.lax.Precision.HIGHEST
) + jnp.expand_dims(bias_i, axis=0)
ref_out.append(jnp.squeeze(out_i))
return ref_out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment