Unverified Commit ac76d55c authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Fixes for the grouped_gemm with MXFP8 (#1945)



* memset for the mxfp8 scale padding
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 11fecc41
...@@ -261,6 +261,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty ...@@ -261,6 +261,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
bool is_delayed_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING; bool is_delayed_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING;
bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING; scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING;
bool const is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING;
size_t input_dtype_bytes = te_dtype_bytes(in_dtype); size_t input_dtype_bytes = te_dtype_bytes(in_dtype);
size_t output_dtype_bytes = te_dtype_bytes(out_dtype); size_t output_dtype_bytes = te_dtype_bytes(out_dtype);
...@@ -314,6 +315,8 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty ...@@ -314,6 +315,8 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
size_t colwise_sinv_size = 0; size_t colwise_sinv_size = 0;
size_t non_group_m = flatten_axis > 1 ? product(input_dims, 1, flatten_axis) : 1; size_t non_group_m = flatten_axis > 1 ? product(input_dims, 1, flatten_axis) : 1;
size_t num_non_empty_groups = 0; size_t num_non_empty_groups = 0;
size_t total_rowwise_sinv_size = 0;
size_t total_colwise_sinv_size = 0;
for (size_t i = 0; i < num_groups; i++) { for (size_t i = 0; i < num_groups; i++) {
size_t m_i = dim_list_host[i] * non_group_m; size_t m_i = dim_list_host[i] * non_group_m;
// Skip for zero-size input + shiff the scale ptr // Skip for zero-size input + shiff the scale ptr
...@@ -379,6 +382,12 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty ...@@ -379,6 +382,12 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
sinv_ptr += sinv_size * sinv_dtype_bytes; sinv_ptr += sinv_size * sinv_dtype_bytes;
colwise_sinv_ptr += colwise_sinv_size * colwise_sinv_dtype_bytes; colwise_sinv_ptr += colwise_sinv_size * colwise_sinv_dtype_bytes;
amax_ptr += amax_dtype_bytes; amax_ptr += amax_dtype_bytes;
total_rowwise_sinv_size += sinv_size;
total_colwise_sinv_size += colwise_sinv_size;
}
if (is_mxfp8_scaling) {
nvte_memset(scale_invs->untyped_data(), 0, total_rowwise_sinv_size, stream);
nvte_memset(colwise_scale_invs->untyped_data(), 0, total_colwise_sinv_size, stream);
} }
QuantizationConfigWrapper quant_config; QuantizationConfigWrapper quant_config;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment