/************************************************************************* * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #include "transformer_engine/gemm.h" #include #include "../extensions.h" #include "common/util/cuda_runtime.h" #include "common/util/system.h" #include "extensions.h" #include "transformer_engine/multi_stream.h" #include "xla/ffi/api/c_api.h" namespace transformer_engine { namespace jax { Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, Variadic_Result_Type output_list, int64_t num_gemms, JAXX_Scaling_Mode scaling_mode, int64_t has_bias) { // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major with size [m, k], // B: row-major with size [n, k], needs transpose, // on exiting this function, JAX expect: // C: row-major with size [m, n]. // cuBLAS uses column-major data_layout, in this view, each input matrix pair: // A: column-major with size [k, m], needs transpose, // B: column-major with size [k, n]. // If we call cuBLAS GEMM for A * B, the output will be: // C: column-major with size [m, n] --> row-major with size [n, m]. // To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call. if (num_gemms <= 0) { return ffi_with_cuda_error_check(); } size_t expected_input_size = has_bias ? 5 * num_gemms : 4 * num_gemms; size_t expected_output_size = num_gemms + 1; size_t actual_input_size = input_list.size(); size_t actual_output_size = output_list.size(); NVTE_CHECK(actual_input_size == expected_input_size, "Expected %zu input tensors, got %zu", expected_input_size, actual_input_size); NVTE_CHECK(actual_output_size == expected_output_size, "Expected %zu output tensors, got %zu", expected_output_size, actual_output_size); bool trans_lhs = true; bool trans_rhs = false; auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); bool grad = false; bool accumulate = false; bool use_split_accumulator = false; // These lists are to keep the TensorWrapper objects alive std::vector lhs_wrapper_list; std::vector rhs_wrapper_list; std::vector bias_wrapper_list; std::vector pre_gelu_wrapper_list; std::vector out_wrapper_list; std::vector workspace_wrapper_list; // These lists are the actual NVTETensor (void *) lists for multi-stream GEMM std::vector lhs_list; std::vector rhs_list; std::vector bias_list; std::vector pre_gelu_list; std::vector out_list; std::vector workspace_list; int lhs_list_offset = 0; int rhs_list_offset = num_gemms; int lhs_sinv_list_offset = 2 * num_gemms; int rhs_sinv_list_offset = 3 * num_gemms; int bias_list_offset = 4 * num_gemms; int out_list_offset = 0; for (int i = 0; i < num_gemms; i++) { Buffer_Type lhs_i = input_list.get(lhs_list_offset + i).value(); Buffer_Type rhs_i = input_list.get(rhs_list_offset + i).value(); Buffer_Type lhs_sinv_i = input_list.get(lhs_sinv_list_offset + i).value(); Buffer_Type rhs_sinv_i = input_list.get(rhs_sinv_list_offset + i).value(); Result_Type out_i = output_list.get(out_list_offset + i).value(); DType lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_i.element_type()); DType rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_i.element_type()); DType out_dtype = convert_ffi_datatype_to_te_dtype(out_i->element_type()); void *lhs_ptr = lhs_i.untyped_data(); void *rhs_ptr = rhs_i.untyped_data(); void *lhs_sinv_ptr = lhs_sinv_i.untyped_data(); void *rhs_sinv_ptr = rhs_sinv_i.untyped_data(); void *out_ptr = out_i->untyped_data(); // Placeholder for bias since it can be empty DType bias_dtype = DType::kFloat32; void *bias_ptr = nullptr; auto lhs_shape_ = lhs_i.dimensions(); auto rhs_shape_ = rhs_i.dimensions(); // lhs and rhs has shape [1, m, k] and [1, n, k] size_t m = lhs_shape_[1]; size_t n = rhs_shape_[1]; size_t k = lhs_shape_[2]; auto lhs_shape = std::vector{m, k}; auto rhs_shape = std::vector{n, k}; auto out_shape = std::vector{n, m}; auto lhs_sinv_shape = std::vector{1, 1}; auto rhs_sinv_shape = std::vector{1, 1}; if (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) { float *amax_dptr = nullptr; float *scale_dptr = nullptr; auto lhs_i_ = TensorWrapper(lhs_ptr, lhs_shape, lhs_dtype, amax_dptr, scale_dptr, reinterpret_cast(lhs_sinv_ptr)); auto rhs_i_ = TensorWrapper(rhs_ptr, rhs_shape, rhs_dtype, amax_dptr, scale_dptr, reinterpret_cast(rhs_sinv_ptr)); lhs_wrapper_list.push_back(std::move(lhs_i_)); rhs_wrapper_list.push_back(std::move(rhs_i_)); } else if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { // Note: the scale_inv array should have been swizzled in Python before lowering auto lhs_sinv_shape_ = lhs_sinv_i.dimensions(); auto rhs_sinv_shape_ = rhs_sinv_i.dimensions(); for (int i = 0; i < 2; i++) { lhs_sinv_shape[i] = lhs_sinv_shape_[i]; rhs_sinv_shape[i] = rhs_sinv_shape_[i]; } NVTEScalingMode nvte_scaling_mode = get_nvte_scaling_mode(scaling_mode); TensorWrapper lhs_i_(nvte_scaling_mode); TensorWrapper rhs_i_(nvte_scaling_mode); lhs_i_.set_rowwise_data(lhs_ptr, lhs_dtype, lhs_shape); rhs_i_.set_rowwise_data(rhs_ptr, rhs_dtype, rhs_shape); lhs_i_.set_rowwise_scale_inv(lhs_sinv_ptr, DType::kFloat8E8M0, lhs_sinv_shape); rhs_i_.set_rowwise_scale_inv(rhs_sinv_ptr, DType::kFloat8E8M0, rhs_sinv_shape); lhs_wrapper_list.push_back(std::move(lhs_i_)); rhs_wrapper_list.push_back(std::move(rhs_i_)); } else { NVTE_ERROR("Unsupported scaling mode: ", static_cast(scaling_mode)); } auto out_i_ = TensorWrapper(out_ptr, out_shape, out_dtype); void *pre_gelu_ptr = nullptr; auto bias_shape = std::vector{0}; auto pre_gelu_shape = std::vector{0}; if (has_bias) { auto bias_i_get = input_list.get(bias_list_offset + i); Buffer_Type bias_i = bias_i_get.value(); bias_ptr = bias_i.untyped_data(); bias_dtype = convert_ffi_datatype_to_te_dtype(bias_i.element_type()); bias_shape[0] = n; } auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype); auto pre_gelu_i = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, out_dtype); out_wrapper_list.push_back(std::move(out_i_)); bias_wrapper_list.push_back(std::move(bias_i)); pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i)); lhs_list.push_back(lhs_wrapper_list.back().data()); rhs_list.push_back(rhs_wrapper_list.back().data()); bias_list.push_back(bias_wrapper_list.back().data()); pre_gelu_list.push_back(pre_gelu_wrapper_list.back().data()); out_list.push_back(out_wrapper_list.back().data()); } auto workspace_get = output_list.get(num_gemms); Result_Type workspace = workspace_get.value(); uint8_t *workspace_ptr = reinterpret_cast(workspace->untyped_data()); auto num_streams = nvte_get_num_compute_streams(); size_t workspace_size = workspace->dimensions()[0] / num_streams; auto workspace_shape = std::vector{workspace_size}; for (int i = 0; i < num_streams; i++) { auto workspace_i = TensorWrapper(static_cast(workspace_ptr), workspace_shape, DType::kByte); workspace_wrapper_list.push_back(std::move(workspace_i)); workspace_list.push_back(workspace_wrapper_list.back().data()); workspace_ptr += workspace_size; } nvte_multi_stream_cublas_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), pre_gelu_list.data(), num_gemms, trans_lhs, trans_rhs, grad, workspace_list.data(), accumulate, use_split_accumulator, num_math_sm, stream); return ffi_with_cuda_error_check(); } XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, FFI::Bind() .Ctx() // stream .RemainingArgs() // input list .RemainingRets() // output list .Attr("num_gemms") .Attr("scaling_mode") .Attr("has_bias"), FFI_CudaGraph_Traits); } // namespace jax } // namespace transformer_engine