Unverified Commit 9f9b4816 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Remove cudaGraph compatible trait from GroupedGemmFFI and GroupedQuantizeFFI (#2048)



* rm cudaGraph compatible trait from GroupedGEMM and groupedQuantize
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* add grouped_gemm jitting in the unit test
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent cae1c436
...@@ -673,10 +673,6 @@ class TestGroupedQuantize: ...@@ -673,10 +673,6 @@ class TestGroupedQuantize:
n_groups=n_groups, n_groups=n_groups,
) )
# grouped_quantize does not work with cudaGraph yet, so the jitting will breaks
# To test it locally, export XLA_FLAGS="--xla_gpu_enable_command_buffer= $XLA_FLAGS" to
# disable cudaGraph, then use the following jitted function
scaled_tensor = tex.grouped_quantize( scaled_tensor = tex.grouped_quantize(
x, group_sizes=group_sizes, flatten_axis=flatten_axis, quantizer=grouped_quantizer x, group_sizes=group_sizes, flatten_axis=flatten_axis, quantizer=grouped_quantizer
) )
...@@ -1312,16 +1308,14 @@ class TestGroupedDense: ...@@ -1312,16 +1308,14 @@ class TestGroupedDense:
) )
ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
# grouped_gemm does not work with cudaGraph yet, so the jitting will breaks
# To test it locally, export XLA_FLAGS="--xla_gpu_enable_command_buffer= $XLA_FLAGS" to
# disable cudaGraph, then use the following jitted function
# jitting grouped_gemm # jitting grouped_gemm
# prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))(
# lhs, rhs, group_sizes, contracting_dims, lhs,
# ) rhs,
group_sizes,
contracting_dims,
)
prim_out = tex.grouped_gemm(lhs, rhs, group_sizes, contracting_dims)
self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype) self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
...@@ -1350,12 +1344,7 @@ class TestGroupedDense: ...@@ -1350,12 +1344,7 @@ class TestGroupedDense:
) )
ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
# jitting grouped_gemm prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))(
# prim_out = jax.jit(tex.grouped_gemm, static_argnames=('contracting_dims',))(
# lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
# )
prim_out = tex.grouped_gemm(
lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
) )
...@@ -1391,9 +1380,9 @@ class TestGroupedDense: ...@@ -1391,9 +1380,9 @@ class TestGroupedDense:
value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2)) value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2))
# jitting the grouped_dense # jitting the grouped_dense
# value_n_grad_prim_func = jit(value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), value_n_grad_prim_func = jit(
# static_argnums=(4,)) value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), static_argnums=(4,)
value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)) )
ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func( ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func(
x, kernel, bias, group_sizes, contracting_dims x, kernel, bias, group_sizes, contracting_dims
...@@ -1432,9 +1421,9 @@ class TestGroupedDense: ...@@ -1432,9 +1421,9 @@ class TestGroupedDense:
value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2)) value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2))
# jitting the grouped_dense # jitting the grouped_dense
# value_n_grad_prim_func = jit(value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), value_n_grad_prim_func = jit(
# static_argnums=(4,)) value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), static_argnums=(4,)
value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)) )
ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func( ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func(
x, x,
......
...@@ -592,8 +592,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, ...@@ -592,8 +592,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI,
.Attr<bool>("rhs_is_trans") .Attr<bool>("rhs_is_trans")
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("has_bias") .Attr<bool>("has_bias")
.Attr<bool>("is_grouped_dense_wgrad"), .Attr<bool>("is_grouped_dense_wgrad"));
FFI_CudaGraph_Traits);
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -410,8 +410,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI, ...@@ -410,8 +410,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI,
.Ret<Buffer_Type>() // amax .Ret<Buffer_Type>() // amax
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("q_layout") .Attr<int64_t>("q_layout")
.Attr<int64_t>("flatten_axis"), .Attr<int64_t>("flatten_axis"));
FFI_CudaGraph_Traits);
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment