/************************************************************************* * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #include "transformer_engine/gemm.h" #include #include "../extensions.h" #include "common/util/cuda_runtime.h" #include "common/util/system.h" #include "transformer_engine/multi_stream.h" #include "transformer_engine/swizzle.h" #include "xla/ffi/api/c_api.h" #define MXFP8_BLOCK_SIZE 32 namespace transformer_engine { namespace jax { static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) { // Move the pointer to the next 256B aligned address return reinterpret_cast((reinterpret_cast(ptr) + 255) & ~static_cast(255)); } Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans, bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias, bool is_grouped_dense_wgrad) { // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major [m, k] for N - [k, m] for T // B: row-major [k, n] for N - [n, k] for T // on exiting this function, JAX expect: // C: row-major with size [m, n]. // cuBLAS uses column-major data_layout, in this view, each input matrix pair: // A: column-major with size [k, m] for T - [m, k] for N // B: column-major with size [n, k] for T - [k, n] for N // // If we call cuBLAS GEMM for A * B, the output will be: // C: column-major with size [m, n] --> row-major with size [n, m]. // To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call. int num_streams = nvte_get_num_compute_streams(); // Inputs auto lhs_ptr = reinterpret_cast(lhs_data.untyped_data()); auto rhs_ptr = reinterpret_cast(rhs_data.untyped_data()); auto lhs_sinv_ptr = reinterpret_cast(lhs_sinv.untyped_data()); auto rhs_sinv_ptr = reinterpret_cast(rhs_sinv.untyped_data()); auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_data.element_type()); auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_data.element_type()); auto lhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(lhs_sinv.element_type()); auto rhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(rhs_sinv.element_type()); auto bias_ptr = has_bias ? reinterpret_cast(bias.untyped_data()) : nullptr; auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); NVTE_CHECK(group_sizes.dimensions().size() == 1); size_t num_gemms = group_sizes.dimensions()[0]; // Outputs auto out_ptr = reinterpret_cast(output->untyped_data()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); // Here we clear the lower 8 bits of the buffer address to ensure the buffer is 256-aligned auto workspace_ptr = reinterpret_cast(workspace->untyped_data()); workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr); auto workspace_total_size = product(workspace->dimensions()); auto lhs_sinv_size = product(lhs_sinv.dimensions()); auto rhs_sinv_size = product(rhs_sinv.dimensions()); auto workspace_size = (workspace_total_size - lhs_sinv_size - rhs_sinv_size - 3 * 256) / num_streams; auto swizzled_lhs_sinv_ptr = workspace_ptr + workspace_size * num_streams; swizzled_lhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_lhs_sinv_ptr); auto swizzled_rhs_sinv_ptr = swizzled_lhs_sinv_ptr + lhs_sinv_size; swizzled_rhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_rhs_sinv_ptr); size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype); size_t lhs_sinv_dtype_bytes = te_dtype_bytes(lhs_sinv_dtype); size_t rhs_sinv_dtype_bytes = te_dtype_bytes(rhs_sinv_dtype); size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype); size_t out_dtype_bytes = te_dtype_bytes(out_dtype); NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)"); NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes, "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); size_t expected_lhs_size = m * k; size_t expected_rhs_size = is_grouped_dense_wgrad ? (k * n) : (num_gemms * k * n); size_t expected_out_size = is_grouped_dense_wgrad ? (num_gemms * m * n) : (m * n); size_t actual_lhs_size = product(lhs_data.dimensions()); size_t actual_rhs_size = product(rhs_data.dimensions()); size_t actual_out_size = product(output->dimensions()); NVTE_CHECK(expected_lhs_size == actual_lhs_size, "Unexpected lhs size! Expect ", expected_lhs_size, ", got ", actual_lhs_size); if (!is_grouped_dense_wgrad) { NVTE_CHECK(expected_rhs_size == actual_rhs_size, "Unexpected rhs size! Expect num_gemms * n * k = ", num_gemms, " * ", n, " * ", k, " = ", expected_rhs_size, ", got ", actual_rhs_size); NVTE_CHECK(expected_out_size == actual_out_size, "Unexpected output size! Expect m * n = ", m, " * ", n, " = ", expected_out_size, ", got ", actual_out_size); } else { NVTE_CHECK(expected_rhs_size == actual_rhs_size, "Unexpected rhs size! Expect k * n = ", k, " * ", n, " = ", expected_rhs_size, ", got ", actual_rhs_size); NVTE_CHECK(expected_out_size == actual_out_size, "Unexpected output size! Expect num_gemms * m * n = ", num_gemms, " * ", m, " * ", n, " = ", expected_out_size, ", got ", actual_out_size); } size_t dim_list_bytes = sizeof(int32_t) * num_gemms; std::vector dim_list_host(num_gemms); auto dim_list_ptr = reinterpret_cast(group_sizes.untyped_data()); cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, stream); // Note: This may break cudaGraph. cudaStreamSynchronize(stream); size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); if (!is_grouped_dense_wgrad) { NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, ", got sum(group_sizes)=", sum_group_sizes); } else { NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k, ", got sum(group_sizes)=", sum_group_sizes); } auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); bool grad = false; bool accumulate = false; bool use_split_accumulator = false; auto bias_shape = std::vector{has_bias ? n : 0}; const int arch = cuda::sm_arch(); // It is weird that TE/Common GEMM only use colwise for MXFP8 const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype); const bool is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING; const bool is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; const bool rhs_use_colwise = is_mxfp8_scaling && !rhs_is_trans; const bool lhs_use_colwise = is_mxfp8_scaling && lhs_is_trans; if (arch < 100 && is_fp8_gemm) { NVTE_CHECK(!lhs_is_trans && rhs_is_trans, "For SM90 or older archs and FP8 input, only NT (row-major) GEMM is supported, ", "got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans); } // These lists are to keep the TensorWrapper objects alive std::vector lhs_wrapper_list; std::vector rhs_wrapper_list; std::vector lhs_swizzle_wrapper_list; // For MXFP8 scale_inv swizzling std::vector rhs_swizzle_wrapper_list; std::vector bias_wrapper_list; std::vector pre_gelu_wrapper_list; std::vector out_wrapper_list; std::vector workspace_wrapper_list; // These lists are the actual NVTETensor (void *) lists for multi-stream GEMM std::vector lhs_list; std::vector rhs_list; std::vector lhs_swizzle_list; std::vector rhs_swizzle_list; std::vector bias_list; std::vector pre_gelu_list; std::vector out_list; std::vector workspace_list; size_t lhs_sinv_total_size = 0; size_t rhs_sinv_total_size = 0; std::vector zero_out_dptr_list; std::vector zero_out_size_list; for (size_t i = 0; i < num_gemms; i++) { // Matrix data shapes size_t m_i = dim_list_host[i]; auto lhs_shape_i = std::vector{m_i, k}; auto rhs_shape_i = std::vector{rhs_is_trans ? n : k, rhs_is_trans ? k : n}; auto out_shape_i = std::vector{m_i, n}; if (is_grouped_dense_wgrad) { size_t k_i = dim_list_host[i]; lhs_shape_i[0] = lhs_is_trans ? k_i : m; lhs_shape_i[1] = lhs_is_trans ? m : k_i; rhs_shape_i[0] = rhs_is_trans ? n : k_i; rhs_shape_i[1] = rhs_is_trans ? k_i : n; out_shape_i[0] = m; out_shape_i[1] = n; } size_t lhs_size = lhs_shape_i[0] * lhs_shape_i[1]; size_t rhs_size = rhs_shape_i[0] * rhs_shape_i[1]; size_t out_size = out_shape_i[0] * out_shape_i[1]; bool is_empty_gemm = lhs_size == 0 || rhs_size == 0; if (is_empty_gemm && out_size > 0) { zero_out_dptr_list.push_back(out_ptr); zero_out_size_list.push_back(out_size * out_dtype_bytes); } // Set matrix data pointers auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); auto out_i = TensorWrapper(static_cast(out_ptr), out_shape_i, out_dtype); void *lhs_vptr = static_cast(lhs_ptr); void *rhs_vptr = static_cast(rhs_ptr); if (rhs_use_colwise) // MatA to enter cuBLAS rhs_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); else rhs_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); if (lhs_use_colwise) // MatB to enter cuBLAS lhs_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); else lhs_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); // Set scale_inv shapes and pointers void *rhs_sinv_vptr = static_cast(rhs_sinv_ptr); void *lhs_sinv_vptr = static_cast(lhs_sinv_ptr); size_t lhs_sinv_size_i = 0; size_t rhs_sinv_size_i = 0; if (is_tensor_scaling) { auto tensor_scaling_sinv_shape = std::vector{1}; // If is_empty_gemm, scale_inv does not have the corresponding value, do not move the pointers if (!is_empty_gemm) { lhs_sinv_size_i = 1; rhs_sinv_size_i = 1; } if (rhs_use_colwise) // MatA to enter cuBLAS rhs_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, tensor_scaling_sinv_shape); else rhs_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, tensor_scaling_sinv_shape); if (lhs_use_colwise) // MatB to enter cuBLAS lhs_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, tensor_scaling_sinv_shape); else lhs_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, tensor_scaling_sinv_shape); } else if (is_mxfp8_scaling) { auto lhs_swizzle_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); auto rhs_swizzle_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); void *swizzled_lhs_sinv_vptr = static_cast(swizzled_lhs_sinv_ptr); void *swizzled_rhs_sinv_vptr = static_cast(swizzled_rhs_sinv_ptr); // {lhs, rhs}_swizzle_i point to unswizzled scale_inv data as input, while {lhs, rhs}_i // point to swizzled scale_inv data (store on workspace, only used for GEMM). // Note: even if is_empty_gemm is true, sinv are still non-empty, need to move the pointers auto lhs_sinv_shape_i = get_mxfp8_scale_shape(lhs_shape_i[0], lhs_shape_i[1], lhs_use_colwise); auto rhs_sinv_shape_i = get_mxfp8_scale_shape(rhs_shape_i[0], rhs_shape_i[1], rhs_use_colwise); lhs_sinv_size_i = lhs_sinv_shape_i[0] * lhs_sinv_shape_i[1]; rhs_sinv_size_i = rhs_sinv_shape_i[0] * rhs_sinv_shape_i[1]; if (lhs_use_colwise) { lhs_swizzle_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); lhs_swizzle_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); lhs_i.set_columnwise_scale_inv(swizzled_lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); } else { lhs_swizzle_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); lhs_swizzle_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); lhs_i.set_rowwise_scale_inv(swizzled_lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); } if (rhs_use_colwise) { rhs_swizzle_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); rhs_swizzle_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); rhs_i.set_columnwise_scale_inv(swizzled_rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); } else { rhs_swizzle_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); rhs_swizzle_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); rhs_i.set_rowwise_scale_inv(swizzled_rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); } if (!is_empty_gemm) { lhs_swizzle_wrapper_list.push_back(std::move(lhs_swizzle_i)); rhs_swizzle_wrapper_list.push_back(std::move(rhs_swizzle_i)); lhs_swizzle_list.push_back(lhs_swizzle_wrapper_list.back().data()); rhs_swizzle_list.push_back(rhs_swizzle_wrapper_list.back().data()); } } else { NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, "Unsupported scaling mode: ", static_cast(scaling_mode)); } auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype); auto pre_gelu_i = TensorWrapper(nullptr, std::vector{0}, out_dtype); // Update pointer for the next GEMM pair lhs_ptr += lhs_size * lhs_dtype_bytes; rhs_ptr += rhs_size * rhs_dtype_bytes; out_ptr += out_size * out_dtype_bytes; if (is_fp8_gemm) { lhs_sinv_ptr += lhs_sinv_size_i * lhs_sinv_dtype_bytes; rhs_sinv_ptr += rhs_sinv_size_i * rhs_sinv_dtype_bytes; lhs_sinv_total_size += lhs_sinv_size_i; rhs_sinv_total_size += rhs_sinv_size_i; if (is_mxfp8_scaling) { swizzled_lhs_sinv_ptr += lhs_sinv_size_i * lhs_sinv_dtype_bytes; swizzled_rhs_sinv_ptr += rhs_sinv_size_i * rhs_sinv_dtype_bytes; } } if (has_bias) bias_ptr += n * bias_dtype_bytes; // Move objects to the lists to keep them alive if (is_empty_gemm) continue; lhs_wrapper_list.push_back(std::move(lhs_i)); rhs_wrapper_list.push_back(std::move(rhs_i)); out_wrapper_list.push_back(std::move(out_i)); bias_wrapper_list.push_back(std::move(bias_i)); pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i)); lhs_list.push_back(lhs_wrapper_list.back().data()); rhs_list.push_back(rhs_wrapper_list.back().data()); bias_list.push_back(bias_wrapper_list.back().data()); pre_gelu_list.push_back(pre_gelu_wrapper_list.back().data()); out_list.push_back(out_wrapper_list.back().data()); } auto workspace_shape = std::vector{workspace_size}; for (int i = 0; i < num_streams; i++) { auto workspace_i = TensorWrapper(static_cast(workspace_ptr), workspace_shape, DType::kByte); workspace_wrapper_list.push_back(std::move(workspace_i)); workspace_list.push_back(workspace_wrapper_list.back().data()); workspace_ptr += workspace_size; } if (is_fp8_gemm) { NVTE_CHECK(lhs_sinv_total_size <= lhs_sinv_size, "Actual total lhs_sinv size ", lhs_sinv_total_size, " exceeds estimated upper bound ", lhs_sinv_size); NVTE_CHECK(rhs_sinv_total_size <= rhs_sinv_size, "Actual total rhs_sinv size ", rhs_sinv_total_size, " exceeds estimated upper bound ", rhs_sinv_size); } size_t num_non_empty_gemms = lhs_list.size(); if (is_mxfp8_scaling) { for (int i = 0; i < num_non_empty_gemms; i++) { // The i-th GEMM will use the (i % num_streams)-th stream to compute, // use the same stream to swizzle the scaling factors to make sure that // the swizzling is done before the GEMM computation starts. int stream_id = i % num_streams; cudaStream_t stream_i = nvte_get_compute_stream(stream_id); nvte_swizzle_scaling_factors(lhs_swizzle_list[i], lhs_list[i], stream_i); nvte_swizzle_scaling_factors(rhs_swizzle_list[i], rhs_list[i], stream_i); } } // Launch zero-out kernels before the GEMM calls to use the sync in the multi-stream GEMM size_t num_zero_outs = zero_out_dptr_list.size(); for (int i = 0; i < num_zero_outs; i++) { int stream_id = i % num_streams; cudaStream_t stream_i = nvte_get_compute_stream(stream_id); void *dptr = zero_out_dptr_list[i]; size_t count = zero_out_size_list[i]; NVTE_CHECK_CUDA(cudaMemsetAsync(dptr, 0, count, stream_i)); } nvte_multi_stream_cublas_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), pre_gelu_list.data(), num_non_empty_gemms, rhs_is_trans, lhs_is_trans, grad, workspace_list.data(), accumulate, use_split_accumulator, num_math_sm, stream); return ffi_with_cuda_error_check(); } XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, FFI::Bind() .Ctx() // stream .Arg() // lhs_data .Arg() // lhs_sinv .Arg() // rhs_data .Arg() // rhs_sinv .Arg() // bias .Arg() // group_sizes .Arg() // group_offset .Ret() // output .Ret() // workspace .Attr("M") .Attr("N") .Attr("K") .Attr("lhs_is_trans") .Attr("rhs_is_trans") .Attr("scaling_mode") .Attr("has_bias") .Attr("is_grouped_dense_wgrad"), FFI_CudaGraph_Traits); } // namespace jax } // namespace transformer_engine