Unverified Commit 3c4dfffb authored by Hua Huang's avatar Hua Huang Committed by GitHub
Browse files

[JAX] Fix grouped GEMM error on CUDA 12.9.1 & later (#1925)



* Fix JAX grouped gemm error on CUDA 12.9.1 & later by using 16B alignment for scale ptr
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* Pad MXFP8 scales with 2*-127 instead of NaNs
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarHua Huang <huah@nvidia.com>
parent 637faccb
...@@ -1343,9 +1343,10 @@ class TestGroupedDense: ...@@ -1343,9 +1343,10 @@ class TestGroupedDense:
def _ref_sum_grouped_dense(self, x, kernel, bias, group_sizes, contracting_dims): def _ref_sum_grouped_dense(self, x, kernel, bias, group_sizes, contracting_dims):
out_list = self._ref_grouped_dense(x, kernel, bias, group_sizes, contracting_dims) out_list = self._ref_grouped_dense(x, kernel, bias, group_sizes, contracting_dims)
# Note: we use jnp.sum instead of jnp.mean to make the gradient larger # Note: we use jnp.sum instead of jnp.mean to make the gradient larger
# and prevent them from being clamp to zero # and prevent them from being clamp to zero in FP8. / sqrt(x.size) is used to
# normalize the output and prevent the gradient from being too large for FP8.
out_sum_list = [jnp.sum(out) for out in out_list] out_sum_list = [jnp.sum(out) for out in out_list]
return jnp.sum(jnp.asarray(out_sum_list)) return jnp.sum(jnp.asarray(out_sum_list)) / jnp.sqrt(x.size)
def _primitive_sum_grouped_dense( def _primitive_sum_grouped_dense(
self, x, kernel, bias, group_sizes, contracting_dims, quantizer_set=noop_quantizer_set self, x, kernel, bias, group_sizes, contracting_dims, quantizer_set=noop_quantizer_set
...@@ -1353,7 +1354,7 @@ class TestGroupedDense: ...@@ -1353,7 +1354,7 @@ class TestGroupedDense:
out = grouped_dense( out = grouped_dense(
x, kernel, group_sizes, contracting_dims, bias=bias, quantizer_set=quantizer_set x, kernel, group_sizes, contracting_dims, bias=bias, quantizer_set=quantizer_set
) )
return jnp.sum(jnp.asarray(out)) return jnp.sum(jnp.asarray(out)) / jnp.sqrt(x.size)
@pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
def test_grouped_dense_grad_fp16(self, dtype, input_shape): def test_grouped_dense_grad_fp16(self, dtype, input_shape):
......
...@@ -98,15 +98,27 @@ class GroupedGemmPrimitive(BasePrimitive): ...@@ -98,15 +98,27 @@ class GroupedGemmPrimitive(BasePrimitive):
A jnp.ndarray containing the result of the grouped GEMM operation A jnp.ndarray containing the result of the grouped GEMM operation
""" """
del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval
del K, lhs_is_trans, rhs_is_trans, scaling_mode, has_bias del K, lhs_is_trans, rhs_is_trans, has_bias
# TODO(Phuong): move some shape checks from Cpp to here # TODO(Phuong): move some shape checks from Cpp to here
workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams
workspace_alignment_padding = 256
tensor_scaling_sinv_aligment = 16
mxfp8_scaling_sinv_alignment_padding = 256
# cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not # cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not
# necessarily 256 bytes aligned, we add some padding to ensure alignment. # necessarily 256 bytes aligned, we add some padding to ensure alignment.
workspace_size += workspace_alignment_padding
if scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING.value,
ScalingMode.CURRENT_TENSOR_SCALING.value,
):
# For tensor scaling, each matrix has a single scale value, but it
# needs to be aligned to 16 bytes for CUDA 12.9.1 and later.
workspace_size += lhs_scale_inv_aval.size * tensor_scaling_sinv_aligment
workspace_size += rhs_scale_inv_aval.size * tensor_scaling_sinv_aligment
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
# We also pad scale_inv swizzle buffers size for 256 bytes alignment. # We also pad scale_inv swizzle buffers size for 256 bytes alignment.
workspace_size += 256 workspace_size += lhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding
workspace_size += lhs_scale_inv_aval.size + 256 workspace_size += rhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding
workspace_size += rhs_scale_inv_aval.size + 256
workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8)
out_shape = (M, N) out_shape = (M, N)
......
...@@ -62,6 +62,14 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type ...@@ -62,6 +62,14 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
NVTE_CHECK(group_sizes.dimensions().size() == 1); NVTE_CHECK(group_sizes.dimensions().size() == 1);
size_t num_gemms = group_sizes.dimensions()[0]; size_t num_gemms = group_sizes.dimensions()[0];
// It is weird that TE/Common GEMM only use colwise for MXFP8
const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype);
const bool is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING;
const bool is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING;
const bool rhs_use_colwise = is_mxfp8_scaling && !rhs_is_trans;
const bool lhs_use_colwise = is_mxfp8_scaling && lhs_is_trans;
// Outputs // Outputs
auto out_ptr = reinterpret_cast<uint8_t *>(output->untyped_data()); auto out_ptr = reinterpret_cast<uint8_t *>(output->untyped_data());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type());
...@@ -72,12 +80,25 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type ...@@ -72,12 +80,25 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
auto lhs_sinv_size = product(lhs_sinv.dimensions()); auto lhs_sinv_size = product(lhs_sinv.dimensions());
auto rhs_sinv_size = product(rhs_sinv.dimensions()); auto rhs_sinv_size = product(rhs_sinv.dimensions());
auto workspace_size = const size_t workspace_alignment_padding = 256;
(workspace_total_size - lhs_sinv_size - rhs_sinv_size - 3 * 256) / num_streams; const size_t tensor_scaling_sinv_aligment = 16;
const size_t mxfp8_scaling_sinv_alignment_padding = 256;
auto workspace_size = workspace_total_size - workspace_alignment_padding;
if (is_mxfp8_scaling) {
// For MXFP8 swizzled scale_inv buffers, only the first pointer needs to be with 256B alignment padding. Later pointers are guaranteed to be 256-aligned as the scale_inv shapes are padded by 128x4.
workspace_size -= (lhs_sinv_size + rhs_sinv_size + 2 * mxfp8_scaling_sinv_alignment_padding);
} else if (is_tensor_scaling) {
// For tensor scaling, each matrix has a single scale value, and all scales need to be aligned
// by 16 bytes to meet the requirement of CUDA 12.9.1 and later.
workspace_size -= tensor_scaling_sinv_aligment * (lhs_sinv_size + rhs_sinv_size);
}
workspace_size = workspace_size / num_streams;
auto swizzled_lhs_sinv_ptr = workspace_ptr + workspace_size * num_streams; auto swizzled_lhs_sinv_ptr = workspace_ptr + workspace_size * num_streams;
swizzled_lhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_lhs_sinv_ptr); swizzled_lhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_lhs_sinv_ptr);
auto swizzled_rhs_sinv_ptr = swizzled_lhs_sinv_ptr + lhs_sinv_size; auto swizzled_rhs_sinv_ptr = swizzled_lhs_sinv_ptr + lhs_sinv_size;
swizzled_rhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_rhs_sinv_ptr); swizzled_rhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_rhs_sinv_ptr);
auto lhs_scatter_aligned_ptr = swizzled_lhs_sinv_ptr; // Already 256B aligned
auto rhs_scatter_aligned_ptr = lhs_scatter_aligned_ptr + num_gemms * tensor_scaling_sinv_aligment;
size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype);
size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype); size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype);
...@@ -86,6 +107,23 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type ...@@ -86,6 +107,23 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype); size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype);
size_t out_dtype_bytes = te_dtype_bytes(out_dtype); size_t out_dtype_bytes = te_dtype_bytes(out_dtype);
if (is_tensor_scaling) {
cudaStream_t stream_0 = nvte_get_compute_stream(0);
size_t dpitch = tensor_scaling_sinv_aligment;
size_t spitch = lhs_sinv_dtype_bytes;
size_t width = lhs_sinv_dtype_bytes;
size_t height = lhs_sinv_size;
cudaMemcpy2DAsync(lhs_scatter_aligned_ptr, dpitch, lhs_sinv_ptr, spitch, width, height,
cudaMemcpyDeviceToDevice, stream_0);
spitch = rhs_sinv_dtype_bytes;
width = rhs_sinv_dtype_bytes;
height = rhs_sinv_size;
cudaMemcpy2DAsync(rhs_scatter_aligned_ptr, dpitch, rhs_sinv_ptr, spitch, width, height,
cudaMemcpyDeviceToDevice, stream_0);
lhs_sinv_ptr = lhs_scatter_aligned_ptr;
rhs_sinv_ptr = rhs_scatter_aligned_ptr;
}
NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)"); NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)");
NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes, NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes,
"sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)");
...@@ -135,14 +173,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type ...@@ -135,14 +173,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
auto bias_shape = std::vector<size_t>{has_bias ? n : 0}; auto bias_shape = std::vector<size_t>{has_bias ? n : 0};
const int arch = cuda::sm_arch(); const int arch = cuda::sm_arch();
// It is weird that TE/Common GEMM only use colwise for MXFP8
const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype);
const bool is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING;
const bool is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING;
const bool rhs_use_colwise = is_mxfp8_scaling && !rhs_is_trans;
const bool lhs_use_colwise = is_mxfp8_scaling && lhs_is_trans;
if (arch < 100 && is_fp8_gemm) { if (arch < 100 && is_fp8_gemm) {
NVTE_CHECK(!lhs_is_trans && rhs_is_trans, NVTE_CHECK(!lhs_is_trans && rhs_is_trans,
"For SM90 or older archs and FP8 input, only NT (row-major) GEMM is supported, ", "For SM90 or older archs and FP8 input, only NT (row-major) GEMM is supported, ",
...@@ -224,8 +254,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type ...@@ -224,8 +254,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
auto tensor_scaling_sinv_shape = std::vector<size_t>{1}; auto tensor_scaling_sinv_shape = std::vector<size_t>{1};
// If is_empty_gemm, scale_inv does not have the corresponding value, do not move the pointers // If is_empty_gemm, scale_inv does not have the corresponding value, do not move the pointers
if (!is_empty_gemm) { if (!is_empty_gemm) {
lhs_sinv_size_i = 1; lhs_sinv_size_i = tensor_scaling_sinv_aligment / lhs_sinv_dtype_bytes;
rhs_sinv_size_i = 1; rhs_sinv_size_i = tensor_scaling_sinv_aligment / rhs_sinv_dtype_bytes;
} }
if (rhs_use_colwise) // MatA to enter cuBLAS if (rhs_use_colwise) // MatA to enter cuBLAS
rhs_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, tensor_scaling_sinv_shape); rhs_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, tensor_scaling_sinv_shape);
...@@ -324,6 +354,10 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type ...@@ -324,6 +354,10 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
} }
if (is_fp8_gemm) { if (is_fp8_gemm) {
if (is_tensor_scaling) {
lhs_sinv_size *= tensor_scaling_sinv_aligment;
rhs_sinv_size *= tensor_scaling_sinv_aligment;
}
NVTE_CHECK(lhs_sinv_total_size <= lhs_sinv_size, "Actual total lhs_sinv size ", NVTE_CHECK(lhs_sinv_total_size <= lhs_sinv_size, "Actual total lhs_sinv size ",
lhs_sinv_total_size, " exceeds estimated upper bound ", lhs_sinv_size); lhs_sinv_total_size, " exceeds estimated upper bound ", lhs_sinv_size);
NVTE_CHECK(rhs_sinv_total_size <= rhs_sinv_size, "Actual total rhs_sinv size ", NVTE_CHECK(rhs_sinv_total_size <= rhs_sinv_size, "Actual total rhs_sinv size ",
......
...@@ -142,9 +142,9 @@ class ScaledTensor1x(ScaledTensor): ...@@ -142,9 +142,9 @@ class ScaledTensor1x(ScaledTensor):
pad_width = tuple( pad_width = tuple(
(0, a - b) for a, b in zip(expected_scale_shape, expected_unpadded_scale_shape) (0, a - b) for a, b in zip(expected_scale_shape, expected_unpadded_scale_shape)
) )
# This actually pad scale_inv with nan, should we pad it with 127 directly instead? # padding with the smallest number it can present
self.scale_inv = jnp.pad( self.scale_inv = jnp.pad(
self.scale_inv, pad_width=pad_width, mode="constant", constant_values=0 self.scale_inv, pad_width=pad_width, mode="constant", constant_values=2**-127
) )
def tree_flatten(self): def tree_flatten(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment