Unverified Commit ef28c865 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] NVFP4 scale swizzling via nvte kernel (#2350)



* swizzle via nvte
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent d0d40631
...@@ -533,6 +533,9 @@ class GemmPrimitive(BasePrimitive): ...@@ -533,6 +533,9 @@ class GemmPrimitive(BasePrimitive):
# Declare cuBLAS workspace # Declare cuBLAS workspace
workspace_size = get_cublas_workspace_size_bytes() workspace_size = get_cublas_workspace_size_bytes()
# NVFP4 swizzling happen in via nvte kernel instead of JAX transposes
if scaling_mode.is_nvfp4_scaling:
workspace_size += lhs_scale_inv.size + rhs_scale_inv.size
if not collective_op.is_none: if not collective_op.is_none:
workspace_size *= get_cgemm_num_max_streams() workspace_size *= get_cgemm_num_max_streams()
# cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not # cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not
...@@ -662,6 +665,8 @@ class GemmPrimitive(BasePrimitive): ...@@ -662,6 +665,8 @@ class GemmPrimitive(BasePrimitive):
rhs_scale_inv = apply_padding_to_scale_inv( rhs_scale_inv = apply_padding_to_scale_inv(
rhs_scale_inv, scaling_mode, rhs.shape, not rhs_transposed, rhs_flatten_axis rhs_scale_inv, scaling_mode, rhs.shape, not rhs_transposed, rhs_flatten_axis
) )
# Only perform JAX-based swizzle for MXFP8, NVFP4 swizzle will go though nvte kernel
if scaling_mode.is_mxfp8_scaling:
lhs_scale_inv = swizzled_scale(lhs_scale_inv, lhs_flatten_axis, lhs_transposed) lhs_scale_inv = swizzled_scale(lhs_scale_inv, lhs_flatten_axis, lhs_transposed)
rhs_scale_inv = swizzled_scale(rhs_scale_inv, rhs_flatten_axis, not rhs_transposed) rhs_scale_inv = swizzled_scale(rhs_scale_inv, rhs_flatten_axis, not rhs_transposed)
......
...@@ -34,8 +34,8 @@ static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) { ...@@ -34,8 +34,8 @@ static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) {
} }
std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand( std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, JAXX_Scaling_Mode scaling_mode, cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, uint8_t *swizzle_scale_ptr,
size_t axis_boundary, bool rowwise) { JAXX_Scaling_Mode scaling_mode, size_t axis_boundary, bool rowwise) {
// Set tensor data with collapsed 2D shape // Set tensor data with collapsed 2D shape
auto buffer_dims = buffer.dimensions(); auto buffer_dims = buffer.dimensions();
std::vector<size_t> input_shape = {product(buffer_dims, 0, axis_boundary), std::vector<size_t> input_shape = {product(buffer_dims, 0, axis_boundary),
...@@ -56,18 +56,33 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand( ...@@ -56,18 +56,33 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
NVTE_CHECK(scale_inv.element_count() > 0, "Missing inverse scaling factor for quantized GEMM."); NVTE_CHECK(scale_inv.element_count() > 0, "Missing inverse scaling factor for quantized GEMM.");
std::vector<size_t> scale_shape = {1}; std::vector<size_t> scale_shape = {1};
if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { auto is_nvfp4 = is_nvfp4_scaling(scaling_mode);
auto scale_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type());
if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING || is_nvfp4) {
// Block scaling also needs to be collapsed to match 2D data // Block scaling also needs to be collapsed to match 2D data
scale_shape = {product(scale_inv.dimensions(), 0, axis_boundary), scale_shape = {product(scale_inv.dimensions(), 0, axis_boundary),
product(scale_inv.dimensions(), axis_boundary, scale_inv.dimensions().size())}; product(scale_inv.dimensions(), axis_boundary, scale_inv.dimensions().size())};
NVTE_CHECK(typeToSize(scale_dtype) == 1,
"Inverse scale factors need to have an 8-bit data type.");
} }
if (!is_nvfp4) {
auto scale_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type());
if (rowwise) { if (rowwise) {
input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape);
} else { } else {
input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape);
} }
} else { // Swizzle for NVFP4
NVTE_CHECK(rowwise, "NVFP4 GEMM expects rowwise for both LHS and RHS");
input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape);
// Create tensor to hold swizzled scale factor
TensorWrapper output(get_nvte_scaling_mode(scaling_mode));
output.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape);
output.set_rowwise_scale_inv(swizzle_scale_ptr, scale_dtype, scale_shape);
// Launch swizzle kernel
nvte_swizzle_scaling_factors(input.data(), output.data(), stream);
// Set swizzled scales into the input tensor
input.set_rowwise_scale_inv(swizzle_scale_ptr, scale_dtype, scale_shape);
}
} }
return std::make_tuple(std::move(input), input_shape); return std::make_tuple(std::move(input), input_shape);
...@@ -145,16 +160,34 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i ...@@ -145,16 +160,34 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i
int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed, int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed,
bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad, bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad,
bool use_split_accumulator, JAXX_Collective_Op collective_op) { bool use_split_accumulator, JAXX_Collective_Op collective_op) {
// cuBLAS workspace + 256 alignment enforcement (+ swizzle scales)
uint8_t *lhs_swizzle_scale_ptr = nullptr, *rhs_swizzle_scale_ptr = nullptr;
auto workspace_ptr = reinterpret_cast<uint8_t *>(workspace->untyped_data());
workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr);
size_t workspace_size = static_cast<size_t>(workspace->element_count()) - 256;
if (is_nvfp4_scaling(scaling_mode)) {
auto lhs_scale_size = product(lhs_scale_inv.dimensions());
auto rhs_scale_size = product(rhs_scale_inv.dimensions());
workspace_size = workspace_size - lhs_scale_size - rhs_scale_size;
lhs_swizzle_scale_ptr = workspace_ptr;
rhs_swizzle_scale_ptr = workspace_ptr + lhs_scale_size;
workspace_ptr = rhs_swizzle_scale_ptr + rhs_scale_size;
}
auto workspace_ = TensorWrapper(workspace_ptr, std::vector<size_t>{workspace_size}, DType::kByte);
// NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when // NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when
// device supports non-TN layouts (compute capability >= 10.0, excluding 12.x) // device supports non-TN layouts (compute capability >= 10.0, excluding 12.x)
bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING ||
(is_tensor_scaling(scaling_mode) && nvte_is_non_tn_fp8_gemm_supported())); (is_tensor_scaling(scaling_mode) && nvte_is_non_tn_fp8_gemm_supported()));
bool make_lhs_rowwise = (always_rowwise) ? true : !lhs_transposed; bool make_lhs_rowwise = (always_rowwise) ? true : !lhs_transposed;
bool make_rhs_rowwise = (always_rowwise) ? true : rhs_transposed; bool make_rhs_rowwise = (always_rowwise) ? true : rhs_transposed;
auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, lhs, lhs_scale_inv, scaling_mode,
lhs_axis_boundary, make_lhs_rowwise); auto [lhs_, lhs_shape] =
auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, rhs, rhs_scale_inv, scaling_mode, xla_buffer_to_nvte_gemm_operand(stream, lhs, lhs_scale_inv, lhs_swizzle_scale_ptr,
rhs_axis_boundary, make_rhs_rowwise); scaling_mode, lhs_axis_boundary, make_lhs_rowwise);
auto [rhs_, rhs_shape] =
xla_buffer_to_nvte_gemm_operand(stream, rhs, rhs_scale_inv, rhs_swizzle_scale_ptr,
scaling_mode, rhs_axis_boundary, make_rhs_rowwise);
std::vector<size_t> out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0], std::vector<size_t> out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0],
(rhs_transposed) ? rhs_shape[0] : rhs_shape[1]}; (rhs_transposed) ? rhs_shape[0] : rhs_shape[1]};
...@@ -191,11 +224,6 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i ...@@ -191,11 +224,6 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i
} }
auto pre_gelu_ = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, pre_gelu_dtype); auto pre_gelu_ = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, pre_gelu_dtype);
// cuBLAS workspace + 256 alignment enforcement
auto workspace_ptr = reinterpret_cast<uint8_t *>(workspace->untyped_data());
workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr);
std::vector<size_t> workspace_shape = {static_cast<size_t>(workspace->element_count()) - 256};
auto workspace_ = TensorWrapper(workspace_ptr, workspace_shape, DType::kByte);
auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0); auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0);
float one = 1.; float one = 1.;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment