/************************************************************************* * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #include "transformer_engine/gemm.h" #include #include "../extensions.h" #include "common/util/cuda_runtime.h" #include "common/util/system.h" #include "xla/ffi/api/c_api.h" #define MXFP8_BLOCK_SIZE 32 namespace transformer_engine { namespace jax { Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans, bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias, bool is_grouped_dense_wgrad) { // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major [m, k] for N - [k, m] for T // B: row-major [k, n] for N - [n, k] for T // on exiting this function, JAX expect: // C: row-major with size [m, n]. // cuBLAS uses column-major data_layout, in this view, each input matrix pair: // A: column-major with size [k, m] for T - [m, k] for N // B: column-major with size [n, k] for T - [k, n] for N // // If we call cuBLAS GEMM for A * B, the output will be: // C: column-major with size [m, n] --> row-major with size [n, m]. // To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call. int num_streams = nvte_get_num_compute_streams(); // Inputs auto lhs_ptr = reinterpret_cast(lhs_data.untyped_data()); auto rhs_ptr = reinterpret_cast(rhs_data.untyped_data()); auto lhs_sinv_ptr = reinterpret_cast(lhs_sinv.untyped_data()); auto rhs_sinv_ptr = reinterpret_cast(rhs_sinv.untyped_data()); auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_data.element_type()); auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_data.element_type()); auto lhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(lhs_sinv.element_type()); auto rhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(rhs_sinv.element_type()); auto bias_ptr = has_bias ? reinterpret_cast(bias.untyped_data()) : nullptr; auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); NVTE_CHECK(group_sizes.dimensions().size() == 1); size_t num_gemms = group_sizes.dimensions()[0]; // Outputs auto out_ptr = reinterpret_cast(output->untyped_data()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); // Here we clear the lower 8 bits of the buffer address to ensure the buffer is 256-aligned auto workspace_ptr = reinterpret_cast((reinterpret_cast(workspace->untyped_data()) + 255) & ~static_cast(255)); auto workspace_total_size = product(workspace->dimensions()) - 255; auto workspace_size = workspace_total_size / num_streams; size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype); size_t lhs_sinv_dtype_bytes = te_dtype_bytes(lhs_sinv_dtype); size_t rhs_sinv_dtype_bytes = te_dtype_bytes(rhs_sinv_dtype); size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype); size_t out_dtype_bytes = te_dtype_bytes(out_dtype); NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)"); NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes, "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); size_t expected_lhs_size = m * k; size_t expected_rhs_size = is_grouped_dense_wgrad ? (k * n) : (num_gemms * k * n); size_t expected_out_size = is_grouped_dense_wgrad ? (num_gemms * m * n) : (m * n); size_t actual_lhs_size = product(lhs_data.dimensions()); size_t actual_rhs_size = product(rhs_data.dimensions()); size_t actual_out_size = product(output->dimensions()); NVTE_CHECK(expected_lhs_size == actual_lhs_size, "Unexpected lhs size! Expect ", expected_lhs_size, ", got ", actual_lhs_size); if (!is_grouped_dense_wgrad) { NVTE_CHECK(expected_rhs_size == actual_rhs_size, "Unexpected rhs size! Expect num_gemms * n * k = ", num_gemms, " * ", n, " * ", k, " = ", expected_rhs_size, ", got ", actual_rhs_size); NVTE_CHECK(expected_out_size == actual_out_size, "Unexpected output size! Expect m * n = ", m, " * ", n, " = ", expected_out_size, ", got ", actual_out_size); } else { NVTE_CHECK(expected_rhs_size == actual_rhs_size, "Unexpected rhs size! Expect k * n = ", k, " * ", n, " = ", expected_rhs_size, ", got ", actual_rhs_size); NVTE_CHECK(expected_out_size == actual_out_size, "Unexpected output size! Expect num_gemms * m * n = ", num_gemms, " * ", m, " * ", n, " = ", expected_out_size, ", got ", actual_out_size); } size_t dim_list_bytes = sizeof(int32_t) * num_gemms; std::vector dim_list_host(num_gemms); auto dim_list_ptr = reinterpret_cast(group_sizes.untyped_data()); cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, stream); // Note: This may break cudaGraph. cudaStreamSynchronize(stream); size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); if (!is_grouped_dense_wgrad) { NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, ", got sum(group_sizes)=", sum_group_sizes); } else { NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k, ", got sum(group_sizes)=", sum_group_sizes); } auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); bool grad = false; bool accumulate = false; bool use_split_accumulator = false; auto bias_shape = std::vector{has_bias ? n : 0}; const int arch = cuda::sm_arch(); // It is weird that TE/Common GEMM only use colwise for MXFP8 const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype); const bool is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; const bool rhs_use_colwise = is_mxfp8_scaling && !rhs_is_trans; const bool lhs_use_colwise = is_mxfp8_scaling && lhs_is_trans; if (arch < 100 && is_fp8_gemm) { NVTE_CHECK(!lhs_is_trans && rhs_is_trans, "For SM90 or older archs and FP8 input, only NT (row-major) GEMM is supported, ", "got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans); } // These lists are to keep the TensorWrapper objects alive std::vector lhs_wrapper_list; std::vector rhs_wrapper_list; std::vector bias_wrapper_list; std::vector pre_gelu_wrapper_list; std::vector out_wrapper_list; std::vector workspace_wrapper_list; // These lists are the actual NVTETensor (void *) lists for multi-stream GEMM std::vector lhs_list; std::vector rhs_list; std::vector bias_list; std::vector pre_gelu_list; std::vector out_list; std::vector workspace_list; for (size_t i = 0; i < num_gemms; i++) { // Matrix data shapes size_t m_i = dim_list_host[i]; auto lhs_shape = std::vector{m_i, k}; auto rhs_shape = std::vector{rhs_is_trans ? n : k, rhs_is_trans ? k : n}; auto out_shape = std::vector{m_i, n}; if (is_grouped_dense_wgrad) { size_t k_i = dim_list_host[i]; lhs_shape[0] = lhs_is_trans ? k_i : m; lhs_shape[1] = lhs_is_trans ? m : k_i; rhs_shape[0] = rhs_is_trans ? n : k_i; rhs_shape[1] = rhs_is_trans ? k_i : n; out_shape[0] = m; out_shape[1] = n; } // Set matrix data pointers auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); auto out_i = TensorWrapper(static_cast(out_ptr), out_shape, out_dtype); void *lhs_vptr = static_cast(lhs_ptr); void *rhs_vptr = static_cast(rhs_ptr); if (rhs_use_colwise) // MatA to enter cuBLAS rhs_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape); else rhs_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape); if (lhs_use_colwise) // MatB to enter cuBLAS lhs_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape); else lhs_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape); // Scale_inv shapes auto lhs_sinv_size = std::vector{1}; auto rhs_sinv_size = std::vector{1}; if (is_mxfp8_scaling) { NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)", MXFP8_BLOCK_SIZE, k); size_t scale_k = k / MXFP8_BLOCK_SIZE; lhs_sinv_size[0] = m_i * scale_k; rhs_sinv_size[0] = n * scale_k; // Need to add swizzle here } // Set scale_inv pointers void *rhs_sinv_vptr = static_cast(rhs_sinv_ptr); void *lhs_sinv_vptr = static_cast(lhs_sinv_ptr); if (is_fp8_gemm) { if (rhs_use_colwise) // MatA to enter cuBLAS rhs_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_size); else rhs_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_size); if (lhs_use_colwise) // MatB to enter cuBLAS lhs_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_size); else lhs_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_size); } else { NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, "Unsupported scaling mode: ", static_cast(scaling_mode)); } auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype); auto pre_gelu_i = TensorWrapper(nullptr, std::vector{0}, out_dtype); // Update pointer for the next GEMM pair lhs_ptr += lhs_shape[0] * lhs_shape[1] * lhs_dtype_bytes; rhs_ptr += rhs_shape[0] * rhs_shape[1] * rhs_dtype_bytes; out_ptr += out_shape[0] * out_shape[1] * out_dtype_bytes; if (is_fp8_gemm) { lhs_sinv_ptr += lhs_sinv_size[0] * lhs_sinv_dtype_bytes; rhs_sinv_ptr += rhs_sinv_size[0] * rhs_sinv_dtype_bytes; } if (has_bias) bias_ptr += n * bias_dtype_bytes; // Move objects to the lists to keep them alive lhs_wrapper_list.push_back(std::move(lhs_i)); rhs_wrapper_list.push_back(std::move(rhs_i)); out_wrapper_list.push_back(std::move(out_i)); bias_wrapper_list.push_back(std::move(bias_i)); pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i)); lhs_list.push_back(lhs_wrapper_list.back().data()); rhs_list.push_back(rhs_wrapper_list.back().data()); bias_list.push_back(bias_wrapper_list.back().data()); pre_gelu_list.push_back(pre_gelu_wrapper_list.back().data()); out_list.push_back(out_wrapper_list.back().data()); } auto workspace_shape = std::vector{workspace_size}; for (int i = 0; i < num_streams; i++) { auto workspace_i = TensorWrapper(static_cast(workspace_ptr), workspace_shape, DType::kByte); workspace_wrapper_list.push_back(std::move(workspace_i)); workspace_list.push_back(workspace_wrapper_list.back().data()); workspace_ptr += workspace_size; } nvte_multi_stream_cublas_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), pre_gelu_list.data(), num_gemms, rhs_is_trans, lhs_is_trans, grad, workspace_list.data(), accumulate, use_split_accumulator, num_math_sm, stream); return ffi_with_cuda_error_check(); } XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, FFI::Bind() .Ctx() // stream .Arg() // lhs_data .Arg() // lhs_sinv .Arg() // rhs_data .Arg() // rhs_sinv .Arg() // bias .Arg() // group_sizes .Arg() // group_offset .Ret() // output .Ret() // workspace .Attr("M") .Attr("N") .Attr("K") .Attr("lhs_is_trans") .Attr("rhs_is_trans") .Attr("scaling_mode") .Attr("has_bias") .Attr("is_grouped_dense_wgrad"), FFI_CudaGraph_Traits); } // namespace jax } // namespace transformer_engine