Unverified Commit 5b3d65cc authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Fix GroupedScaledTensor creation with keyword arg (#2154)



Fix GroupedScaledTensor creation
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent c47f329b
......@@ -442,7 +442,7 @@ def _grouped_dense_fwd_rule(
ctx_kernel = ScaledTensorFactory.create_1x(
global_ctx_kernel_data.reshape(-1),
ctx_kernel.scale_inv,
ctx_kernel.scaling_mode,
scaling_mode=ctx_kernel.scaling_mode,
dq_dtype=ctx_kernel.dq_dtype,
is_colwise=False,
data_layout="N",
......@@ -459,7 +459,7 @@ def _grouped_dense_fwd_rule(
grouped_gemm_kernel = ScaledTensorFactory.create_1x(
grouped_gemm_kernel_data.reshape(-1),
ctx_kernel.scale_inv,
ctx_kernel.scaling_mode,
scaling_mode=ctx_kernel.scaling_mode,
dq_dtype=ctx_kernel.dq_dtype,
is_colwise=True,
data_layout="T",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment