Unverified Commit 5b3d65cc authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Fix GroupedScaledTensor creation with keyword arg (#2154)



Fix GroupedScaledTensor creation
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent c47f329b
...@@ -442,7 +442,7 @@ def _grouped_dense_fwd_rule( ...@@ -442,7 +442,7 @@ def _grouped_dense_fwd_rule(
ctx_kernel = ScaledTensorFactory.create_1x( ctx_kernel = ScaledTensorFactory.create_1x(
global_ctx_kernel_data.reshape(-1), global_ctx_kernel_data.reshape(-1),
ctx_kernel.scale_inv, ctx_kernel.scale_inv,
ctx_kernel.scaling_mode, scaling_mode=ctx_kernel.scaling_mode,
dq_dtype=ctx_kernel.dq_dtype, dq_dtype=ctx_kernel.dq_dtype,
is_colwise=False, is_colwise=False,
data_layout="N", data_layout="N",
...@@ -459,7 +459,7 @@ def _grouped_dense_fwd_rule( ...@@ -459,7 +459,7 @@ def _grouped_dense_fwd_rule(
grouped_gemm_kernel = ScaledTensorFactory.create_1x( grouped_gemm_kernel = ScaledTensorFactory.create_1x(
grouped_gemm_kernel_data.reshape(-1), grouped_gemm_kernel_data.reshape(-1),
ctx_kernel.scale_inv, ctx_kernel.scale_inv,
ctx_kernel.scaling_mode, scaling_mode=ctx_kernel.scaling_mode,
dq_dtype=ctx_kernel.dq_dtype, dq_dtype=ctx_kernel.dq_dtype,
is_colwise=True, is_colwise=True,
data_layout="T", data_layout="T",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment