Commit 519aae87 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by wenjh
Browse files

[common] Add support for cuBLASLt GEMM for GroupedTensor (#2502)



* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add FP8 scale support and fix alignment for grouped GEMM

- Add FP8 scale_inv pointer handling in nvte_grouped_gemm for proper FP8 GEMM
- Fix random padding in tests to ensure 16-byte alignment for all dtypes
- Reorder GroupedGemmSetupWorkspace members for natural alignment
- Remove debug prints
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Grouped GEMM: code cleanup and NULL C support

- Remove unused alignment parameter from GroupedGemmSetupWorkspace::from_buffers
- Simplify select_grouped_operand by removing dead code branches
- Add GroupedOperandSelection.tensor field to avoid passing tensor separately
- Extract set_fp8_scale_pointers and init_matrix_layouts helpers
- Add safety check for FP8 on Hopper column-wise fallback
- Support NULL C tensor when beta=0 (uses D as placeholder)
- Remove unused get_scale_inv() from test
- Add use_null_c test parameter and test case
- Fix documentation: alpha/beta are single element tensors only
Signed-off-by: default avatarPiotr Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Grouped GEMM: per-matrix alpha/beta support

- Change alpha/beta from single values to per-matrix arrays
- Validate alpha/beta have exactly num_tensors elements
- Update kernel to index alpha_ptr[idx] and beta_ptr[idx]
- Move alpha/beta validation to validate_grouped_gemm_inputs
- Update tests to use per-matrix alpha/beta arrays
- Update documentation
Signed-off-by: default avatarPiotr Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix alpha/beta numel - use SimpleTensor::numel()
Signed-off-by: default avatarPiotr Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Refactor: move grouped GEMM to separate file and cleanup API
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Require Blackwell (SM100) and cuBLAS 13.1+ for grouped GEMM
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/common/gemm/config.h
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* changed
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* suggestions
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* refactored hopper tensor selection
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarPiotr Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
parent fd061211
......@@ -29,6 +29,7 @@ list(APPEND test_cuda_sources
test_causal_softmax.cu
test_swizzle.cu
test_swap_first_dims.cu
test_grouped_gemm.cu
../test_common.cu)
if(USE_ROCM)
list(APPEND test_cuda_sources
......
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cublasLt.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <algorithm>
#include <memory>
#include <numeric>
#include <optional>
#include <random>
#include <tuple>
#include <vector>
#include <transformer_engine/cast.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/recipe.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"
using namespace transformer_engine;
using namespace test;
namespace {
enum class InputCase {
kFP8Current,
kBF16,
};
enum class ShapeCase {
kAllSame,
kSameFirst,
kSameLast,
kAllDifferent,
};
size_t grouped_setup_workspace_size(const size_t num_tensors) {
const size_t ptr_bytes = num_tensors * sizeof(void*);
const size_t int_bytes = num_tensors * sizeof(int);
// Layout: 6 pointer arrays (A, B, C, D, alpha, beta) + 6 int arrays (a_rows, a_cols, b_rows, b_cols, d_rows, d_cols)
size_t size = 6 * ptr_bytes + 6 * int_bytes;
const size_t alignment = 256;
size = ((size + alignment - 1) / alignment) * alignment;
return size;
}
Tensor make_fp8_operand(const std::string& name, const std::vector<size_t>& shape) {
Tensor input_fp32(name + "_fp32", shape, DType::kFloat32);
fillUniform(&input_fp32);
Tensor fp8(name, shape, TypeInfo<fp8e4m3>::dtype, true, true, NVTE_DELAYED_TENSOR_SCALING);
nvte_compute_amax(input_fp32.data(), fp8.data(), 0);
QuantizationConfigWrapper config;
nvte_compute_scale_from_amax(fp8.data(), config, 0);
nvte_quantize(input_fp32.data(), fp8.data(), 0);
return fp8;
}
Tensor make_bf16_operand(const std::string& name, const std::vector<size_t>& shape) {
Tensor t(name, shape, DType::kBFloat16);
const size_t numel = shape[0] * shape[1];
std::vector<__nv_bfloat16> ones(numel, __float2bfloat16(1.0f));
NVTE_CHECK_CUDA(cudaMemcpy(t.rowwise_dptr(), ones.data(),
numel * sizeof(__nv_bfloat16), cudaMemcpyHostToDevice));
return t;
}
struct TestParams {
InputCase input_case;
bool transa;
bool transb;
ShapeCase shape_case;
bool use_null_c = false; // When true, pass nullptr for C (valid when beta=0)
};
// Returns a vector of (M, N, K) tuples for each GEMM in the group.
// M - number of rows in output D
// N - number of columns in output D
// K - reduction dimension shared between A and B
std::vector<std::tuple<size_t, size_t, size_t>> make_shapes(ShapeCase scase) {
switch (scase) {
case ShapeCase::kAllSame:
return {{64, 64, 32}, {64, 64, 32}, {64, 64, 32}};
case ShapeCase::kSameFirst:
// Same M (first dim), varying N and K
return {{64, 80, 32}, {64, 96, 48}, {64, 112, 64}};
case ShapeCase::kSameLast:
// Same N (last dim), varying M and K
return {{64, 80, 32}, {80, 80, 48}, {96, 80, 64}};
case ShapeCase::kAllDifferent:
default:
return {{64, 96, 32}, {80, 112, 48}, {96, 128, 64}};
}
}
void run_grouped_gemm_case(const TestParams& params) {
#if CUBLAS_VERSION < 130100
GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.1+, but compile-time cuBLAS version is "
<< CUBLAS_VERSION << ".";
#else
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP() << "Grouped GEMM requires Blackwell (SM100) or newer.";
}
const std::vector<std::tuple<size_t, size_t, size_t>> shapes = make_shapes(params.shape_case);
const size_t num_gemms = shapes.size();
std::vector<Tensor> A_tensors;
std::vector<Tensor> B_tensors;
std::vector<Tensor> D_multi;
A_tensors.reserve(num_gemms);
B_tensors.reserve(num_gemms);
D_multi.reserve(num_gemms);
for (size_t i = 0; i < num_gemms; ++i) {
const auto [M, N, K] = shapes[i];
const std::vector<size_t> a_shape = params.transa ? std::vector<size_t>{M, K}
: std::vector<size_t>{K, M};
const std::vector<size_t> b_shape = params.transb ? std::vector<size_t>{K, N}
: std::vector<size_t>{N, K};
switch (params.input_case) {
case InputCase::kFP8Current: {
A_tensors.emplace_back(make_fp8_operand("A" + std::to_string(i), a_shape));
B_tensors.emplace_back(make_fp8_operand("B" + std::to_string(i), b_shape));
break;
}
case InputCase::kBF16: {
A_tensors.emplace_back(make_bf16_operand("A" + std::to_string(i), a_shape));
B_tensors.emplace_back(make_bf16_operand("B" + std::to_string(i), b_shape));
break;
}
}
D_multi.emplace_back(Tensor("D_multi" + std::to_string(i),
std::vector<size_t>{M, N},
DType::kBFloat16));
}
std::vector<NVTETensor> A_ptrs(num_gemms);
std::vector<NVTETensor> B_ptrs(num_gemms);
std::vector<NVTETensor> D_ptrs(num_gemms);
std::vector<Tensor> workspaces(num_gemms);
std::vector<NVTETensor> workspace_ptrs(num_gemms, nullptr);
std::vector<Tensor*> A_views;
std::vector<Tensor*> B_views;
A_views.reserve(num_gemms);
B_views.reserve(num_gemms);
// Empty bias/gelu arrays for nvte_multi_tensor_gemm (no epilogues)
std::vector<NVTETensor> bias_ptrs(num_gemms, nullptr);
std::vector<NVTETensor> gelu_ptrs(num_gemms, nullptr);
const size_t cublas_ws_bytes = 32ull * 1024 * 1024;
for (size_t i = 0; i < num_gemms; ++i) {
A_ptrs[i] = A_tensors[i].data();
B_ptrs[i] = B_tensors[i].data();
D_ptrs[i] = D_multi[i].data();
workspaces[i] = Tensor("workspace" + std::to_string(i), std::vector<size_t>{cublas_ws_bytes}, DType::kByte);
workspace_ptrs[i] = workspaces[i].data();
A_views.push_back(&A_tensors[i]);
B_views.push_back(&B_tensors[i]);
}
nvte_multi_tensor_gemm(A_ptrs.data(),
B_ptrs.data(),
D_ptrs.data(),
bias_ptrs.data(),
gelu_ptrs.data(),
static_cast<int>(num_gemms),
params.transa,
params.transb,
false, // grad
workspace_ptrs.data(),
false, // accumulate
false, // use_split_accumulator
0, // sm_count
0);
GroupedBuffers grouped_A = build_grouped_tensor(A_views, A_tensors[0].scaling_mode());
GroupedBuffers grouped_B = build_grouped_tensor(B_views, B_tensors[0].scaling_mode());
std::vector<Tensor> C_tensors;
std::vector<Tensor> D_group_tensors;
C_tensors.reserve(num_gemms);
D_group_tensors.reserve(num_gemms);
for (size_t i = 0; i < num_gemms; ++i) {
const auto [M, N, K] = shapes[i];
(void)K;
if (!params.use_null_c) {
C_tensors.emplace_back(Tensor("C" + std::to_string(i),
std::vector<size_t>{static_cast<size_t>(M), static_cast<size_t>(N)},
DType::kBFloat16));
}
D_group_tensors.emplace_back(Tensor("D_group" + std::to_string(i),
std::vector<size_t>{static_cast<size_t>(M), static_cast<size_t>(N)},
DType::kBFloat16));
NVTE_CHECK_CUDA(cudaMemset(D_group_tensors.back().rowwise_dptr(), 0, bytes(D_group_tensors.back().rowwise_shape(), D_group_tensors.back().dtype())));
}
std::vector<Tensor*> C_views, D_views;
for (size_t i = 0; i < num_gemms; ++i) {
if (!params.use_null_c) {
C_views.push_back(&C_tensors[i]);
}
D_views.push_back(&D_group_tensors[i]);
}
std::optional<GroupedBuffers> grouped_C;
if (!params.use_null_c) {
grouped_C = build_grouped_tensor(C_views, NVTE_DELAYED_TENSOR_SCALING);
}
GroupedBuffers grouped_D = build_grouped_tensor(D_views, NVTE_DELAYED_TENSOR_SCALING);
// Per-matrix alpha/beta (all 1.0 and 0.0 respectively)
Tensor alpha_tensor("alpha", std::vector<size_t>{num_gemms}, DType::kFloat32);
Tensor beta_tensor("beta", std::vector<size_t>{num_gemms}, DType::kFloat32);
std::vector<float> alpha_vals(num_gemms, 1.f);
std::vector<float> beta_vals(num_gemms, 0.f);
NVTE_CHECK_CUDA(cudaMemcpy(alpha_tensor.rowwise_dptr(), alpha_vals.data(),
num_gemms * sizeof(float), cudaMemcpyHostToDevice));
NVTE_CHECK_CUDA(cudaMemcpy(beta_tensor.rowwise_dptr(), beta_vals.data(),
num_gemms * sizeof(float), cudaMemcpyHostToDevice));
const size_t setup_ws_bytes = grouped_setup_workspace_size(num_gemms);
Tensor setup_ws("setup_ws", std::vector<size_t>{setup_ws_bytes}, DType::kByte);
Tensor cublas_ws("cublas_ws", std::vector<size_t>{cublas_ws_bytes}, DType::kByte);
nvte_grouped_gemm(grouped_A.get_handle(),
params.transa,
grouped_B.get_handle(),
params.transb,
params.use_null_c ? nullptr : grouped_C->get_handle(),
grouped_D.get_handle(),
alpha_tensor.data(),
beta_tensor.data(),
setup_ws.data(),
cublas_ws.data(),
nullptr, // config (use defaults)
0);
for (size_t i = 0; i < num_gemms; ++i) {
Tensor grouped_split("grouped_D" + std::to_string(i),
std::vector<size_t>{static_cast<size_t>(std::get<0>(shapes[i])),
static_cast<size_t>(std::get<1>(shapes[i]))},
D_multi[i].dtype());
const size_t offset_bytes = static_cast<size_t>(grouped_D.offsets_host[i]) * grouped_D.elem_size;
NVTE_CHECK_CUDA(cudaMemcpy(grouped_split.rowwise_dptr(),
static_cast<char*>(grouped_D.get_data()) + offset_bytes,
grouped_D.tensor_bytes[i],
cudaMemcpyDeviceToDevice));
grouped_split.to_cpu();
D_multi[i].to_cpu();
auto [atol, rtol] = getTolerances(D_multi[i].dtype());
compareResults("grouped_vs_multi",
grouped_split,
D_multi[i].rowwise_cpu_dptr<bf16>(),
true,
atol,
rtol);
}
#endif // CUBLAS_VERSION >= 130100
}
class GroupedGemmTest : public ::testing::TestWithParam<TestParams> {};
TEST_P(GroupedGemmTest, CompareWithMultiTensorGemm) {
run_grouped_gemm_case(GetParam());
}
std::string MakeGroupedGemmTestName(const testing::TestParamInfo<GroupedGemmTest::ParamType>& info) {
constexpr const char* kInputNames[] = {"FP8Current", "BF16"};
constexpr const char* kShapeNames[] = {"AllSame", "SameM", "SameN", "AllDiff"};
const std::string layout = std::string("ta") + (info.param.transa ? "T" : "N") +
"tb" + (info.param.transb ? "T" : "N");
const std::string null_c = info.param.use_null_c ? "_NullC" : "";
return std::string(kInputNames[static_cast<int>(info.param.input_case)]) + "_" +
kShapeNames[static_cast<int>(info.param.shape_case)] + "_" + layout + null_c;
}
// TestParams: {input_case, transa, transb, shape_case, use_null_c}
const std::vector<TestParams> kTestParams = {
// Basic tests
{InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent, false},
{InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent, false},
{InputCase::kFP8Current, false, false, ShapeCase::kAllSame, false},
{InputCase::kBF16, true, false, ShapeCase::kSameFirst, false},
{InputCase::kBF16, false, true, ShapeCase::kSameLast, false},
{InputCase::kBF16, false, false, ShapeCase::kAllSame, false},
{InputCase::kBF16, true, true, ShapeCase::kAllDifferent, false},
// Test NULL C (valid when beta=0)
{InputCase::kBF16, false, false, ShapeCase::kAllSame, true},
};
INSTANTIATE_TEST_SUITE_P(OperatorTest,
GroupedGemmTest,
::testing::ValuesIn(kTestParams),
MakeGroupedGemmTestName);
} // namespace
......@@ -9,6 +9,7 @@
#include <algorithm>
#include <memory>
#include <numeric>
#include <random>
#include <iostream>
#include <cassert>
......@@ -1116,4 +1117,166 @@ std::array<size_t, 4> get_scale_tensor_dims(const size_t rows,
return {unpadded_blocks_Y, unpadded_blocks_X, blocks_Y, blocks_X};
}
GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
const NVTEScalingMode scaling_mode) {
NVTE_CHECK(!tensors.empty(), "No tensors provided for grouped tensor build.");
const NVTEShape shape = tensors[0]->rowwise_shape();
const DType dtype = tensors[0]->dtype();
const size_t num_tensors = tensors.size();
const size_t elem_size = typeToNumBits(dtype) / 8;
GroupedBuffers grouped;
grouped.elem_size = elem_size;
grouped.num_tensors = num_tensors;
grouped.dtype = dtype;
grouped.scaling_mode = scaling_mode;
grouped.tensor_bytes.resize(num_tensors);
grouped.offsets_host.resize(num_tensors, 0);
std::vector<int64_t> first_dims(num_tensors);
std::vector<int64_t> last_dims(num_tensors);
for (size_t i = 0; i < num_tensors; ++i) {
const auto s = tensors[i]->rowwise_shape();
NVTE_CHECK(s.ndim == 2, "Grouped tensor build expects 2D tensors.");
first_dims[i] = static_cast<int64_t>(s.data[0]);
last_dims[i] = static_cast<int64_t>(s.data[1]);
grouped.tensor_bytes[i] = bytes(s, dtype);
}
const bool same_first = std::all_of(first_dims.begin(), first_dims.end(),
[&](int64_t v) { return v == first_dims[0]; });
const bool same_last = std::all_of(last_dims.begin(), last_dims.end(),
[&](int64_t v) { return v == last_dims[0]; });
std::vector<int64_t> offsets(num_tensors, 0);
auto random_padding = [&]() -> int64_t {
// Random padding ensuring 16-byte alignment regardless of element size
// cuBLAS requires aligned pointers for vectorized loads
static std::mt19937 gen(12345);
std::uniform_int_distribution<int64_t> dist(0, 3);
// Calculate elements needed for 16-byte alignment in bytes, rounded up
const size_t align_elements =
std::max<size_t>(1, (16 + elem_size - 1) / elem_size); // 16 bytes / element_size
return dist(gen) * static_cast<int64_t>(align_elements);
};
auto numel = [&](size_t idx) -> int64_t {
return first_dims[idx] * last_dims[idx];
};
const bool need_offsets = !same_first || !same_last;
if (need_offsets) {
offsets[0] = 0;
for (size_t i = 1; i < num_tensors; ++i) {
offsets[i] = offsets[i - 1] + numel(i - 1) + random_padding();
}
} else {
for (size_t i = 0; i < num_tensors; ++i) {
offsets[i] = static_cast<int64_t>(i) * numel(0);
}
}
grouped.offsets_host = offsets;
int64_t logical_first = 0;
int64_t logical_last = 0;
if (same_first && same_last) {
logical_first = first_dims[0] * static_cast<int64_t>(num_tensors);
logical_last = last_dims[0];
} else if (same_first && !same_last) {
logical_first = first_dims[0];
logical_last = std::accumulate(last_dims.begin(), last_dims.end(), int64_t{0});
} else if (!same_first && same_last) {
logical_first = std::accumulate(first_dims.begin(), first_dims.end(), int64_t{0});
logical_last = last_dims[0];
} else {
logical_first = 1;
logical_last = 0;
for (size_t i = 0; i < num_tensors; ++i) {
logical_last += first_dims[i] * last_dims[i];
}
}
size_t logical_data[2] = {static_cast<size_t>(logical_first),
static_cast<size_t>(logical_last)};
grouped.logical_shape = nvte_make_shape(logical_data, 2);
grouped.handle.reset(nvte_create_grouped_tensor(scaling_mode, num_tensors, grouped.logical_shape));
const int64_t last_idx = static_cast<int64_t>(num_tensors - 1);
const int64_t total_elems = need_offsets
? (offsets[last_idx] + numel(last_idx))
: (logical_first * logical_last);
const size_t total_bytes = static_cast<size_t>(total_elems) * elem_size;
grouped.data = cuda_alloc(total_bytes);
for (size_t i = 0; i < num_tensors; ++i) {
const size_t offset_bytes = static_cast<size_t>(offsets[i]) * elem_size;
NVTE_CHECK_CUDA(cudaMemcpy(static_cast<char*>(grouped.data.get()) + offset_bytes,
tensors[i]->rowwise_dptr(),
grouped.tensor_bytes[i],
cudaMemcpyDeviceToDevice));
}
NVTEBasicTensor data_tensor{grouped.data.get(), static_cast<NVTEDType>(dtype), grouped.logical_shape};
NVTEGroupedTensor h = grouped.handle.get();
nvte_set_grouped_tensor_param(&h, kNVTEGroupedRowwiseData, &data_tensor);
const bool include_columnwise = isFp8Type(dtype) || isFp4Type(dtype);
if (include_columnwise) {
grouped.columnwise_data = cuda_alloc(total_bytes);
for (size_t i = 0; i < num_tensors; ++i) {
const size_t offset_bytes = static_cast<size_t>(offsets[i]) * elem_size;
NVTE_CHECK_CUDA(cudaMemcpy(static_cast<char*>(grouped.columnwise_data.get()) + offset_bytes,
tensors[i]->columnwise_dptr(),
grouped.tensor_bytes[i],
cudaMemcpyDeviceToDevice));
}
NVTEBasicTensor col_tensor{grouped.columnwise_data.get(),
static_cast<NVTEDType>(dtype),
grouped.logical_shape};
nvte_set_grouped_tensor_param(&h, kNVTEGroupedColumnwiseData, &col_tensor);
}
if (!same_first) {
grouped.first_dims_dev = cuda_alloc<int64_t>(num_tensors * sizeof(int64_t));
NVTE_CHECK_CUDA(cudaMemcpy(grouped.first_dims_dev.get(), first_dims.data(),
num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice));
NVTEShape fd_shape = nvte_make_shape(&num_tensors, 1);
NVTEBasicTensor fd_tensor{grouped.first_dims_dev.get(), kNVTEInt64, fd_shape};
nvte_set_grouped_tensor_param(&h, kNVTEGroupedFirstDims, &fd_tensor);
}
if (!same_last) {
grouped.last_dims_dev = cuda_alloc<int64_t>(num_tensors * sizeof(int64_t));
NVTE_CHECK_CUDA(cudaMemcpy(grouped.last_dims_dev.get(), last_dims.data(),
num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice));
NVTEShape ld_shape = nvte_make_shape(&num_tensors, 1);
NVTEBasicTensor ld_tensor{grouped.last_dims_dev.get(), kNVTEInt64, ld_shape};
nvte_set_grouped_tensor_param(&h, kNVTEGroupedLastDims, &ld_tensor);
}
if (!same_first || !same_last) {
grouped.offsets_dev = cuda_alloc<int64_t>(num_tensors * sizeof(int64_t));
NVTE_CHECK_CUDA(cudaMemcpy(grouped.offsets_dev.get(), offsets.data(),
num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice));
NVTEShape off_shape = nvte_make_shape(&num_tensors, 1);
NVTEBasicTensor off_tensor{grouped.offsets_dev.get(), kNVTEInt64, off_shape};
nvte_set_grouped_tensor_param(&h, kNVTEGroupedTensorOffsets, &off_tensor);
}
if (isFp8Type(dtype)) {
std::vector<float> scale_inv_cpu(num_tensors, 1.f);
for (size_t i = 0; i < num_tensors; ++i) {
tensors[i]->to_cpu();
scale_inv_cpu[i] = tensors[i]->rowwise_cpu_scale_inv_ptr<float>()[0];
}
grouped.scale_inv = cuda_alloc(sizeof(float) * num_tensors);
NVTE_CHECK_CUDA(cudaMemcpy(grouped.scale_inv.get(), scale_inv_cpu.data(),
sizeof(float) * num_tensors, cudaMemcpyHostToDevice));
NVTEShape scale_shape = nvte_make_shape(&num_tensors, 1);
NVTEBasicTensor scale_tensor{grouped.scale_inv.get(), kNVTEFloat32, scale_shape};
nvte_set_grouped_tensor_param(&h, kNVTEGroupedRowwiseScaleInv, &scale_tensor);
nvte_set_grouped_tensor_param(&h, kNVTEGroupedColumnwiseScaleInv, &scale_tensor);
}
return grouped;
}
} // namespace test
......@@ -525,6 +525,60 @@ int32_t getDeviceComputeCapability();
constexpr int32_t hopperComputeCapability = 90;
constexpr int32_t blackwellComputeCapability = 100;
// Custom deleters for RAII
struct CudaDeleter {
void operator()(void* p) const { if (p) cudaFree(p); }
};
struct GroupedTensorDeleter {
void operator()(NVTEGroupedTensor h) const { if (h) nvte_destroy_grouped_tensor(h); }
};
template <typename T = void>
using CudaPtr = std::unique_ptr<T, CudaDeleter>;
using GroupedTensorHandle = std::unique_ptr<std::remove_pointer_t<NVTEGroupedTensor>, GroupedTensorDeleter>;
// Helper to allocate CUDA memory into a CudaPtr
template <typename T = void>
CudaPtr<T> cuda_alloc(size_t bytes) {
void* ptr = nullptr;
NVTE_CHECK_CUDA(cudaMalloc(&ptr, bytes));
return CudaPtr<T>(static_cast<T*>(ptr));
}
// Helper owning GPU buffers that back NVTEGroupedTensor.
// NVTEGroupedTensor does not own memory; data/offsets/scales
// must be allocated and freed by the test.
struct GroupedBuffers {
GroupedTensorHandle handle;
CudaPtr<> data;
CudaPtr<> scale_inv;
CudaPtr<int64_t> first_dims_dev;
CudaPtr<int64_t> last_dims_dev;
CudaPtr<int64_t> offsets_dev;
CudaPtr<> columnwise_data;
NVTEShape logical_shape{};
std::vector<int64_t> offsets_host;
std::vector<size_t> tensor_bytes;
size_t num_tensors{0};
size_t elem_size{0};
DType dtype{DType::kFloat32};
NVTEScalingMode scaling_mode{NVTE_DELAYED_TENSOR_SCALING};
GroupedBuffers() = default;
GroupedBuffers(const GroupedBuffers&) = delete;
GroupedBuffers& operator=(const GroupedBuffers&) = delete;
GroupedBuffers(GroupedBuffers&&) = default;
GroupedBuffers& operator=(GroupedBuffers&&) = default;
~GroupedBuffers() = default;
// Convenience accessors for raw pointers
NVTEGroupedTensor get_handle() const { return handle.get(); }
void* get_data() const { return data.get(); }
};
GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
const NVTEScalingMode scaling_mode);
} // namespace test
#if FP4_TYPE_SUPPORTED
......
......@@ -202,6 +202,7 @@ if(USE_CUDA)
fused_attn/fused_attn_fp8.cu
fused_attn/utils.cu
gemm/cublaslt_gemm.cu
gemm/cublaslt_grouped_gemm.cu
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
......@@ -358,6 +359,7 @@ else()
fused_attn/utils.cu
gemm/cublaslt_gemm.cu
gemm/hipblas_gemm.cu
gemm/cublaslt_grouped_gemm.cu
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
......
......@@ -126,3 +126,106 @@ void nvte_destroy_matmul_config(NVTEMatmulConfig config) {
delete reinterpret_cast<transformer_engine::MatmulConfig *>(config);
}
}
NVTEGroupedMatmulConfig nvte_create_grouped_matmul_config() {
return new transformer_engine::GroupedMatmulConfig;
}
void nvte_get_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config,
NVTEGroupedMatmulConfigAttribute attr, void *buf,
size_t size_in_bytes, size_t *size_written) {
// Write attribute size
NVTE_CHECK(attr < kNVTEGroupedMatmulConfigNumAttributes,
"Invalid NVTEGroupedMatmulConfigAttribute (got ", static_cast<int>(attr), ")");
NVTE_CHECK(size_written != nullptr, "Invalid size_written (got NULL)");
const auto &attr_size = transformer_engine::GroupedMatmulConfig::attr_sizes[attr];
*size_written = attr_size;
// Return immediately if buffer is not provided
if (buf == nullptr) {
return;
}
// Check buffer size
NVTE_CHECK(size_in_bytes >= attr_size,
"Buffer is too small for grouped matmul config attribute "
"(attribute ",
static_cast<int>(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes,
" bytes)");
// Write to buffer
NVTE_CHECK(config != nullptr, "Invalid NVTEGroupedMatmulConfig (got NULL)");
const auto &config_ = *reinterpret_cast<const transformer_engine::GroupedMatmulConfig *>(config);
switch (attr) {
case kNVTEGroupedMatmulConfigAvgM: {
int64_t val = config_.avg_m.value_or(0);
std::memcpy(buf, &val, attr_size);
break;
}
case kNVTEGroupedMatmulConfigAvgN: {
int64_t val = config_.avg_n.value_or(0);
std::memcpy(buf, &val, attr_size);
break;
}
case kNVTEGroupedMatmulConfigAvgK: {
int64_t val = config_.avg_k.value_or(0);
std::memcpy(buf, &val, attr_size);
break;
}
case kNVTEGroupedMatmulConfigSMCount:
std::memcpy(buf, &config_.sm_count, attr_size);
break;
default:
NVTE_ERROR("Unsupported NVTEGroupedMatmulConfigAttribute (got ", static_cast<int>(attr), ")");
}
}
void nvte_set_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config,
NVTEGroupedMatmulConfigAttribute attr,
const void *buf, size_t size_in_bytes) {
// Check attribute and buffer
NVTE_CHECK(attr < kNVTEGroupedMatmulConfigNumAttributes,
"Invalid NVTEGroupedMatmulConfigAttribute (got ", static_cast<int>(attr), ")");
const auto &attr_size = transformer_engine::GroupedMatmulConfig::attr_sizes[attr];
NVTE_CHECK(size_in_bytes >= attr_size,
"Buffer is too small for grouped matmul config attribute "
"(attribute ",
static_cast<int>(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes,
" bytes)");
NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)");
// Read from buffer
NVTE_CHECK(config != nullptr, "Invalid NVTEGroupedMatmulConfig (got NULL)");
auto &config_ = *reinterpret_cast<transformer_engine::GroupedMatmulConfig *>(config);
switch (attr) {
case kNVTEGroupedMatmulConfigAvgM: {
int64_t val;
std::memcpy(&val, buf, attr_size);
config_.avg_m = val;
break;
}
case kNVTEGroupedMatmulConfigAvgN: {
int64_t val;
std::memcpy(&val, buf, attr_size);
config_.avg_n = val;
break;
}
case kNVTEGroupedMatmulConfigAvgK: {
int64_t val;
std::memcpy(&val, buf, attr_size);
config_.avg_k = val;
break;
}
case kNVTEGroupedMatmulConfigSMCount:
std::memcpy(&config_.sm_count, buf, attr_size);
break;
default:
NVTE_ERROR("Unsupported NVTEGroupedMatmulConfigAttribute (got ", static_cast<int>(attr), ")");
}
}
void nvte_destroy_grouped_matmul_config(NVTEGroupedMatmulConfig config) {
if (config != nullptr) {
delete reinterpret_cast<transformer_engine::GroupedMatmulConfig *>(config);
}
}
......@@ -9,6 +9,9 @@
#include <transformer_engine/transformer_engine.h>
#include <cstdint>
#include <optional>
namespace transformer_engine {
struct MatmulConfig {
......@@ -31,6 +34,22 @@ struct MatmulConfig {
};
};
struct GroupedMatmulConfig {
// Average dimension hints for cuBLASLt algorithm selection heuristics.
// nullopt means "not set" - compute automatically from tensor shapes.
std::optional<int64_t> avg_m;
std::optional<int64_t> avg_n;
std::optional<int64_t> avg_k;
// Number of streaming multiprocessors to use in GEMM kernel
int sm_count = 0;
// Note: API transfers the value type, not std::optional
static constexpr size_t attr_sizes[] = {sizeof(decltype(avg_m)::value_type),
sizeof(decltype(avg_n)::value_type),
sizeof(decltype(avg_k)::value_type), sizeof(sm_count)};
};
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_GEMM_CONFIG_H_
......@@ -311,13 +311,6 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
return ret;
}
/* cuBLAS version number at run-time */
size_t cublas_version() {
// Cache version to avoid cuBLAS logging overhead
static size_t version = cublasLtGetVersion();
return version;
}
} // namespace
#endif // __HIP_PLATFORM_AMD__
......@@ -518,8 +511,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#endif // CUBLAS_VERSION >= 120800
} else if (mxfp8_gemm) {
#if CUBLAS_VERSION >= 120800
NVTE_CHECK(cublas_version() >= 120800,
"MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version());
NVTE_CHECK(cuda::cublas_version() >= 120800,
"MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ",
cuda::cublas_version());
// Check that scales are in expected format
NVTE_CHECK(inputA->with_gemm_swizzled_scales,
......@@ -541,7 +535,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
// Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling.
// CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set.
if (cublas_version() <= 120803) {
if (cuda::cublas_version() <= 120803) {
const int64_t dummy_a_vec_stride = 1;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride,
......@@ -553,8 +547,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#endif // CUBLAS_VERSION >= 120800
} else if (use_fp4) { // NVFP4 GEMM
#if CUBLAS_VERSION >= 120800
NVTE_CHECK(cublas_version() >= 120800,
"FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version());
NVTE_CHECK(cuda::cublas_version() >= 120800,
"FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is ",
cuda::cublas_version());
// Check that scales are in expected format
NVTE_CHECK(inputA->with_gemm_swizzled_scales,
......@@ -589,9 +584,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
(inputB->scaling_mode == NVTE_BLOCK_SCALING_1D ||
inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) {
#if CUBLAS_VERSION >= 120900
NVTE_CHECK(cublas_version() >= 120900,
NVTE_CHECK(cuda::cublas_version() >= 120900,
"FP8 block scaling requires cuBLAS 12.9+, but run-time cuBLAS version is ",
cublas_version());
cuda::cublas_version());
// Check that matrix formats are valid
NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D &&
......@@ -624,7 +619,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
}
#if CUBLAS_VERSION >= 120800
if (cublas_version() >= 120800) {
if (cuda::cublas_version() >= 120800) {
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_A_SCALE_MODE,
&scaling_mode_a, sizeof(scaling_mode_a)));
......@@ -641,7 +636,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax)));
#if CUBLAS_VERSION >= 120800
if (cublas_version() >= 120800) {
if (cuda::cublas_version() >= 120800) {
// NOTE: In all current cases where FP8 output is supported, the input is
// scaled identically to the output.
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
......@@ -728,9 +723,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA version is ",
cuda::cudart_version());
NVTE_CHECK(cublas_version() >= 120205 && cublas_version() < 130000,
NVTE_CHECK(cuda::cublas_version() >= 120205 && cuda::cublas_version() < 130000,
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS version is ",
cublas_version());
cuda::cublas_version());
if (m_split == 0) m_split = 1;
if (n_split == 0) n_split = 1;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
......@@ -1201,9 +1196,9 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
"Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA version is ",
transformer_engine::cuda::cudart_version());
NVTE_CHECK(
cublas_version() >= 120205 && cublas_version() < 130000,
cuda::cublas_version() >= 120205 && cuda::cublas_version() < 130000,
"Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is ",
cublas_version());
cuda::cublas_version());
#endif
#else
#define NVTE_CUBLAS_ATOMIC_GEMM_COMPILE 1
......
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/transformer_engine.h>
#include <cstdint>
#include "../common.h"
#include "../util/cuda_runtime.h"
#include "../util/handle_manager.h"
#include "../util/logging.h"
#include "./config.h"
namespace {
inline void CreateCublasHandle(cublasLtHandle_t *handle) {
NVTE_CHECK_CUBLAS(cublasLtCreate(handle));
}
} // namespace
#if CUBLAS_VERSION >= 130100
namespace {
// Helper struct to pass per-tensor shape/offset info (pointer or uniform value)
struct TensorShapeInfo {
const int64_t *first_dims; // nullptr if uniform
const int64_t *last_dims; // nullptr if uniform
const int64_t *offsets; // nullptr if need to compute
int64_t uniform_first; // used if first_dims == nullptr
int64_t uniform_last; // used if last_dims == nullptr
// Create from GroupedTensor
static TensorShapeInfo from_tensor(const transformer_engine::GroupedTensor *t) {
const bool has_first = t->first_dims.has_data();
const bool has_last = t->last_dims.has_data();
// When per-tensor dims are not provided, we must be in the uniform-shape case.
NVTE_CHECK(has_first || t->all_same_first_dim(),
"GroupedTensor is missing first_dims for varying shapes");
NVTE_CHECK(has_last || t->all_same_last_dim(),
"GroupedTensor is missing last_dims for varying shapes");
const int64_t *first_ptr =
has_first ? static_cast<const int64_t *>(t->first_dims.dptr) : nullptr;
const int64_t *last_ptr = has_last ? static_cast<const int64_t *>(t->last_dims.dptr) : nullptr;
const int64_t uniform_first = has_first ? 0 : static_cast<int64_t>(t->get_common_first_dim());
const int64_t uniform_last = has_last ? 0 : static_cast<int64_t>(t->get_common_last_dim());
return {first_ptr, last_ptr,
t->tensor_offsets.has_data() ? static_cast<const int64_t *>(t->tensor_offsets.dptr)
: nullptr,
uniform_first, uniform_last};
}
// Create for C tensor (uses D's dimensions, only has offsets)
static TensorShapeInfo create_shape_info_for_C(const transformer_engine::GroupedTensor *C,
const transformer_engine::GroupedTensor *D) {
const bool has_first = D->first_dims.has_data();
const bool has_last = D->last_dims.has_data();
NVTE_CHECK(has_first || D->all_same_first_dim(),
"GroupedTensor D is missing first_dims for varying shapes");
NVTE_CHECK(has_last || D->all_same_last_dim(),
"GroupedTensor D is missing last_dims for varying shapes");
const int64_t *first_ptr =
has_first ? static_cast<const int64_t *>(D->first_dims.dptr) : nullptr;
const int64_t *last_ptr = has_last ? static_cast<const int64_t *>(D->last_dims.dptr) : nullptr;
const int64_t uniform_first = has_first ? 0 : static_cast<int64_t>(D->get_common_first_dim());
const int64_t uniform_last = has_last ? 0 : static_cast<int64_t>(D->get_common_last_dim());
return {first_ptr, last_ptr,
C->tensor_offsets.has_data() ? static_cast<const int64_t *>(C->tensor_offsets.dptr)
: nullptr,
uniform_first, uniform_last};
}
};
// Helper functions to compute average dimensions from logical_shape for heuristics
// These are hints for cuBLASLt algorithm selection, don't need to be exact
inline int64_t compute_avg_first_dim(const transformer_engine::GroupedTensor *t) {
// logical_shape[0] is either num_tensors*M (uniform) or sum_of_M (varying first)
// In both cases, dividing by num_tensors gives the average
return static_cast<int64_t>(t->logical_shape.data[0]) / static_cast<int64_t>(t->num_tensors);
}
inline int64_t compute_avg_last_dim(const transformer_engine::GroupedTensor *t) {
if (t->all_same_last_dim()) {
// logical_shape[1] is the common N
return static_cast<int64_t>(t->logical_shape.data[1]);
}
// When varying, logical_shape[1] should be sum of last dims if provided; otherwise fallback to avg via division.
return static_cast<int64_t>(t->logical_shape.data[1]) / static_cast<int64_t>(t->num_tensors);
}
// Workspace layout for grouped GEMM
struct GroupedGemmSetupWorkspace {
void **A_ptrs;
void **B_ptrs;
void **C_ptrs;
void **D_ptrs;
float **alpha_ptrs;
float **beta_ptrs;
// Storage dimensions for cuBLAS matrix layouts
int *a_rows;
int *a_cols;
int *b_rows;
int *b_cols;
int *d_rows; // M (first dim) - also used for C
int *d_cols; // N (last dim) - also used for C
// Initialize from workspace buffer
// Layout: all pointer arrays first (8-byte aligned), then int arrays (4-byte aligned)
static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors) {
GroupedGemmSetupWorkspace ws;
size_t offset = 0;
const size_t ptr_size = num_tensors * sizeof(void *);
const size_t int_size = num_tensors * sizeof(int);
// Pointer arrays first (all 8-byte aligned)
ws.A_ptrs = reinterpret_cast<void **>(setup_ws_ptr + offset);
offset += ptr_size;
ws.B_ptrs = reinterpret_cast<void **>(setup_ws_ptr + offset);
offset += ptr_size;
ws.C_ptrs = reinterpret_cast<void **>(setup_ws_ptr + offset);
offset += ptr_size;
ws.D_ptrs = reinterpret_cast<void **>(setup_ws_ptr + offset);
offset += ptr_size;
ws.alpha_ptrs = reinterpret_cast<float **>(setup_ws_ptr + offset);
offset += ptr_size;
ws.beta_ptrs = reinterpret_cast<float **>(setup_ws_ptr + offset);
offset += ptr_size;
// Int arrays for storage dimensions (4-byte aligned)
ws.a_rows = reinterpret_cast<int *>(setup_ws_ptr + offset);
offset += int_size;
ws.a_cols = reinterpret_cast<int *>(setup_ws_ptr + offset);
offset += int_size;
ws.b_rows = reinterpret_cast<int *>(setup_ws_ptr + offset);
offset += int_size;
ws.b_cols = reinterpret_cast<int *>(setup_ws_ptr + offset);
offset += int_size;
ws.d_rows = reinterpret_cast<int *>(setup_ws_ptr + offset);
offset += int_size;
ws.d_cols = reinterpret_cast<int *>(setup_ws_ptr + offset);
return ws;
}
// Calculate required size for setup workspace
static size_t required_setup_size(size_t num_tensors, size_t alignment) {
const size_t ptr_size = num_tensors * sizeof(void *);
const size_t int_size = num_tensors * sizeof(int);
// Layout: 6 ptr arrays, then 6 int arrays
size_t size = 6 * ptr_size + 6 * int_size;
size = ((size + alignment - 1) / alignment) * alignment;
return size;
}
};
// -----------------------------------------------------------------------------
// Helper routines to keep nvte_grouped_gemm readable
// -----------------------------------------------------------------------------
inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor *inputA,
const transformer_engine::GroupedTensor *inputB,
const transformer_engine::GroupedTensor *inputC,
const transformer_engine::GroupedTensor *outputD,
const transformer_engine::Tensor *alpha_tensor,
const transformer_engine::Tensor *beta_tensor) {
const size_t num_tensors = inputA->num_tensors;
NVTE_CHECK(num_tensors >= 1, "Grouped GEMM: number of tensors must be at least 1");
NVTE_CHECK(inputB->num_tensors == num_tensors,
"Grouped GEMM: A and B must have the same number of tensors");
// C can be NULL (will use D as C when beta=0)
if (inputC != nullptr) {
NVTE_CHECK(inputC->num_tensors == num_tensors,
"Grouped GEMM: A and C must have the same number of tensors");
}
NVTE_CHECK(outputD->num_tensors == num_tensors,
"Grouped GEMM: A and D must have the same number of tensors");
// Validate alpha/beta have per-matrix values
const size_t alpha_numel = alpha_tensor->data.numel();
const size_t beta_numel = beta_tensor->data.numel();
NVTE_CHECK(alpha_numel == num_tensors, "Grouped GEMM: alpha must have num_tensors (", num_tensors,
") elements, got ", alpha_numel);
NVTE_CHECK(beta_numel == num_tensors, "Grouped GEMM: beta must have num_tensors (", num_tensors,
") elements, got ", beta_numel);
auto is_fp8_or_16bit = [](transformer_engine::DType dtype) {
return dtype == transformer_engine::DType::kFloat8E4M3 ||
dtype == transformer_engine::DType::kFloat8E5M2 ||
dtype == transformer_engine::DType::kBFloat16 ||
dtype == transformer_engine::DType::kFloat16;
};
auto is_output_dtype = [](transformer_engine::DType dtype) {
return dtype == transformer_engine::DType::kBFloat16 ||
dtype == transformer_engine::DType::kFloat16 ||
dtype == transformer_engine::DType::kFloat32;
};
NVTE_CHECK(is_fp8_or_16bit(inputA->dtype()) && is_fp8_or_16bit(inputB->dtype()),
"Grouped GEMM inputs must be FP8, BF16, or FP16.");
// Only check C dtype if C is provided
if (inputC != nullptr) {
NVTE_CHECK(is_output_dtype(inputC->dtype()), "Grouped GEMM: C must be BF16, FP16, or FP32.");
}
NVTE_CHECK(is_output_dtype(outputD->dtype()), "Grouped GEMM: D must be BF16, FP16, or FP32.");
NVTE_CHECK(inputA->has_data() || inputA->has_columnwise_data(),
"Grouped GEMM: A tensor is missing both row-wise and column-wise data");
NVTE_CHECK(inputB->has_data() || inputB->has_columnwise_data(),
"Grouped GEMM: B tensor is missing both row-wise and column-wise data");
}
// Select row-wise vs column-wise storage and adjust transpose flag for grouped GEMM.
// Mirrors the non-grouped GEMM logic for FP8 layout handling (TN-only on Hopper) and
// fallback to column-wise data when row-wise is absent.
// Contains all information needed for GEMM setup - shape already accounts for storage layout.
struct GroupedOperandSelection {
TensorShapeInfo shape; // Shape info with dims already swapped for columnwise if needed
char *dptr = nullptr;
void *scale_inv = nullptr;
transformer_engine::DType dtype = transformer_engine::DType::kNumTypes;
bool trans = false;
};
// Helper to create TensorShapeInfo from a GroupedTensor, optionally swapping first/last dims.
// When swap_dims=true, first_dims and last_dims are swapped to account for columnwise storage.
// Note: tensor_offsets are the same for rowwise and columnwise data (same element count per tensor).
inline TensorShapeInfo create_shape_info(const transformer_engine::GroupedTensor *t,
bool swap_dims) {
const bool has_first = t->first_dims.has_data();
const bool has_last = t->last_dims.has_data();
NVTE_CHECK(has_first || t->all_same_first_dim(),
"GroupedTensor is missing first_dims for varying shapes");
NVTE_CHECK(has_last || t->all_same_last_dim(),
"GroupedTensor is missing last_dims for varying shapes");
const int64_t *first_ptr = has_first ? static_cast<const int64_t *>(t->first_dims.dptr) : nullptr;
const int64_t *last_ptr = has_last ? static_cast<const int64_t *>(t->last_dims.dptr) : nullptr;
const int64_t uniform_first = has_first ? 0 : static_cast<int64_t>(t->get_common_first_dim());
const int64_t uniform_last = has_last ? 0 : static_cast<int64_t>(t->get_common_last_dim());
const int64_t *offsets_ptr =
t->tensor_offsets.has_data() ? static_cast<const int64_t *>(t->tensor_offsets.dptr) : nullptr;
if (swap_dims) {
// Swap first/last to account for columnwise (transposed) storage
return {last_ptr, first_ptr, offsets_ptr, uniform_last, uniform_first};
}
return {first_ptr, last_ptr, offsets_ptr, uniform_first, uniform_last};
}
inline GroupedOperandSelection select_grouped_operand(const transformer_engine::GroupedTensor *t,
bool trans, bool is_A) {
using namespace transformer_engine;
const bool has_row = t->has_data();
const bool has_col = t->has_columnwise_data();
NVTE_CHECK(has_row || has_col,
"Grouped GEMM operand is missing both row-wise and column-wise data");
// Currently only unquantized data and tensor-scaled FP8 are supported.
const auto sm = t->scaling_mode;
NVTE_CHECK(sm == NVTE_DELAYED_TENSOR_SCALING,
"Grouped GEMM is only supported with unquantized data and tensor-scaled FP8 data");
const DType row_dtype = t->data.dtype;
const DType col_dtype = t->columnwise_data.dtype;
GroupedOperandSelection sel;
sel.trans = trans;
const DType rep_dtype = has_row ? row_dtype : col_dtype;
const bool is_fp8 = is_fp8_dtype(rep_dtype);
const bool non_tn_fp8_ok = nvte_is_non_tn_fp8_gemm_supported();
// Helper to select columnwise storage (swaps dims in shape)
auto use_columnwise = [&]() {
sel.dptr = static_cast<char *>(t->columnwise_data.dptr);
sel.scale_inv = t->columnwise_scale_inv.dptr;
sel.dtype = col_dtype;
sel.shape = create_shape_info(t, /*swap_dims=*/true);
};
// Helper to select row-wise storage
auto use_rowwise = [&]() {
sel.dptr = static_cast<char *>(t->data.dptr);
sel.scale_inv = t->scale_inv.dptr;
sel.dtype = row_dtype;
sel.shape = create_shape_info(t, /*swap_dims=*/false);
};
// Hopper-style TN-only FP8: force TN by switching layout and flipping transpose when needed.
if (is_fp8 && !non_tn_fp8_ok) {
if (is_A) {
if (!sel.trans) {
NVTE_CHECK(has_col, "Grouped GEMM: A is missing column-wise data needed for FP8 TN layout");
use_columnwise();
sel.trans = true; // using pre-transposed storage
return sel;
}
} else { // B
if (sel.trans) {
NVTE_CHECK(has_col, "Grouped GEMM: B is missing column-wise data needed for FP8 TN layout");
use_columnwise();
sel.trans = false; // using pre-transposed storage
return sel;
}
}
}
// If only column-wise data is available, mirror the transpose flag (pre-transposed storage).
if (!has_row && has_col) {
// On Hopper FP8, this would break TN requirement - should have been handled above
NVTE_CHECK(
!is_fp8 || non_tn_fp8_ok,
"Grouped GEMM: FP8 on Hopper requires row-wise data for this transpose configuration");
use_columnwise();
sel.trans = !trans; // flip transpose for pre-transposed storage
return sel;
}
// Default: use row-wise data
use_rowwise();
return sel;
}
inline void *validate_and_get_workspace_ptr(transformer_engine::Tensor *ws, size_t required_size,
const char *workspace_name) {
NVTE_CHECK(ws != nullptr, workspace_name, " tensor is null.");
const size_t provided_size = get_buffer_size_bytes(ws->data.numel(), ws->data.dtype);
NVTE_CHECK(provided_size >= required_size, "Grouped GEMM: Insufficient ", workspace_name,
". Required: ", required_size, " bytes, Available: ", provided_size, " bytes.");
return ws->data.dptr;
}
inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA,
cublasLtMatrixLayoutOpaque_t &descB,
cublasLtMatrixLayoutOpaque_t &descC,
cublasLtMatrixLayoutOpaque_t &descD,
const GroupedGemmSetupWorkspace &ws,
const GroupedOperandSelection &A_sel,
const GroupedOperandSelection &B_sel,
const transformer_engine::GroupedTensor *D, size_t num_tensors) {
const cudaDataType_t A_type = get_cuda_dtype(A_sel.dtype);
const cudaDataType_t B_type = get_cuda_dtype(B_sel.dtype);
const cudaDataType_t D_type = get_cuda_dtype(D->dtype());
// Storage dimensions computed by kernel, leading dimension = rows
NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, ws.a_rows,
ws.a_cols, ws.a_rows));
NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, ws.b_rows,
ws.b_cols, ws.b_rows));
NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, ws.d_rows,
ws.d_cols, ws.d_rows));
NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, ws.d_rows,
ws.d_cols, ws.d_rows));
}
inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOperation_t op_A,
cublasOperation_t op_B) {
NVTE_CHECK_CUBLAS(cublasLtMatmulDescInit(&matmulDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &op_A,
sizeof(op_A)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &op_B,
sizeof(op_B)));
cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE,
&pointer_mode, sizeof(pointer_mode)));
int64_t alphabeta_batch_stride = 1;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc,
CUBLASLT_MATMUL_DESC_ALPHA_BATCH_STRIDE,
&alphabeta_batch_stride, sizeof(int64_t)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc,
CUBLASLT_MATMUL_DESC_BETA_BATCH_STRIDE,
&alphabeta_batch_stride, sizeof(int64_t)));
}
inline void set_fp8_scale_pointers(cublasLtMatmulDescOpaque_t &matmulDesc,
const GroupedOperandSelection &A_sel,
const GroupedOperandSelection &B_sel) {
const bool is_fp8_a = is_fp8_dtype(A_sel.dtype);
const bool is_fp8_b = is_fp8_dtype(B_sel.dtype);
if (!is_fp8_a && !is_fp8_b) return;
if (is_fp8_a) {
void *a_scale_inv = A_sel.scale_inv;
NVTE_CHECK(a_scale_inv != nullptr, "FP8 grouped GEMM: A scale_inv is required");
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
&matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_scale_inv, sizeof(a_scale_inv)));
}
if (is_fp8_b) {
void *b_scale_inv = B_sel.scale_inv;
NVTE_CHECK(b_scale_inv != nullptr, "FP8 grouped GEMM: B scale_inv is required");
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
&matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_scale_inv, sizeof(b_scale_inv)));
}
}
// Constants for grouped GEMM workspace (declared early for use in heuristics)
static constexpr size_t kGroupedGemmAlignment = 256;
static constexpr size_t kGroupedGemmCublasWorkspaceSize = 32ull * 1024 * 1024; // 32 MiB
inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle,
cublasLtMatmulDescOpaque_t &matmulDesc,
cublasLtMatrixLayoutOpaque_t &descA,
cublasLtMatrixLayoutOpaque_t &descB,
cublasLtMatrixLayoutOpaque_t &descC,
cublasLtMatrixLayoutOpaque_t &descD,
int64_t avg_m, int64_t avg_n, int64_t avg_k) {
cublasLtMatmulPreferenceOpaque_t preference;
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceInit(&preference));
NVTE_CHECK_CUBLAS(
cublasLtMatmulPreferenceSetAttribute(&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&kGroupedGemmCublasWorkspaceSize, sizeof(size_t)));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
&preference, CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_ROWS, &avg_m, sizeof(int64_t)));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
&preference, CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_COLS, &avg_n, sizeof(int64_t)));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
&preference, CUBLASLT_MATMUL_PREF_GROUPED_AVERAGE_REDUCTION_DIM, &avg_k, sizeof(int64_t)));
cublasLtMatmulHeuristicResult_t heuristicResult;
int returnedResults = 0;
auto status = cublasLtMatmulAlgoGetHeuristic(handle, &matmulDesc, &descA, &descB, &descC, &descD,
&preference, 1, &heuristicResult, &returnedResults);
NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED,
"Unable to find suitable cuBLAS grouped GEMM algorithm");
NVTE_CHECK_CUBLAS(status);
NVTE_CHECK(returnedResults > 0, "No suitable algorithm found for grouped GEMM");
return heuristicResult.algo;
}
// Single kernel that sets up all GEMM parameters.
// Rationale: cuBLASLt grouped matmul API needs flat arrays of pointers and per-matrix dimensions,
// but NVTEGroupedTensor stores a single contiguous buffer + optional per-tensor offsets/shapes.
// We bridge the mismatch on GPU by computing per-group pointers and storage dims in one kernel.
__global__ void setup_grouped_gemm_kernel(
// Output arrays
void **A_ptrs, void **B_ptrs, void **C_ptrs, void **D_ptrs, int *a_rows, int *a_cols,
int *b_rows, int *b_cols, int *d_rows, int *d_cols, float **alpha_ptrs, float **beta_ptrs,
// Inputs
char *a_base, char *b_base, char *c_base, char *d_base, TensorShapeInfo A_meta,
TensorShapeInfo B_meta, TensorShapeInfo C_meta, TensorShapeInfo D_meta, size_t a_elem_size,
size_t b_elem_size, size_t c_elem_size, size_t d_elem_size, float *alpha_ptr, float *beta_ptr,
size_t num_tensors) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= num_tensors) return;
// Get dimensions for this tensor (from array or uniform value)
int64_t a_first = A_meta.first_dims ? A_meta.first_dims[idx] : A_meta.uniform_first;
int64_t a_last = A_meta.last_dims ? A_meta.last_dims[idx] : A_meta.uniform_last;
int64_t b_first = B_meta.first_dims ? B_meta.first_dims[idx] : B_meta.uniform_first;
int64_t b_last = B_meta.last_dims ? B_meta.last_dims[idx] : B_meta.uniform_last;
int64_t d_first = D_meta.first_dims ? D_meta.first_dims[idx] : D_meta.uniform_first;
int64_t d_last = D_meta.last_dims ? D_meta.last_dims[idx] : D_meta.uniform_last;
// Compute offsets (from array or compute from uniform dims)
int64_t a_offset =
A_meta.offsets ? A_meta.offsets[idx] : (idx * A_meta.uniform_first * A_meta.uniform_last);
int64_t b_offset =
B_meta.offsets ? B_meta.offsets[idx] : (idx * B_meta.uniform_first * B_meta.uniform_last);
int64_t c_offset =
C_meta.offsets ? C_meta.offsets[idx] : (idx * C_meta.uniform_first * C_meta.uniform_last);
int64_t d_offset =
D_meta.offsets ? D_meta.offsets[idx] : (idx * D_meta.uniform_first * D_meta.uniform_last);
// Compute data pointers
A_ptrs[idx] = a_base + a_offset * a_elem_size;
B_ptrs[idx] = b_base + b_offset * b_elem_size;
C_ptrs[idx] = c_base + c_offset * c_elem_size;
D_ptrs[idx] = d_base + d_offset * d_elem_size;
// Compute storage dimensions for cuBLAS matrix layouts.
// For INPUTS (A, B): Row-wise storage is seen as transposed column-major by cuBLAS,
// so rows=last, cols=first. For columnwise, dims are already swapped.
a_rows[idx] = static_cast<int>(a_last);
a_cols[idx] = static_cast<int>(a_first);
b_rows[idx] = static_cast<int>(b_last);
b_cols[idx] = static_cast<int>(b_first);
// For OUTPUTS (D, C): cuBLAS writes in column-major, so rows=first (M), cols=last (N).
d_rows[idx] = static_cast<int>(d_first);
d_cols[idx] = static_cast<int>(d_last);
// Fill alpha/beta pointers (per-matrix)
alpha_ptrs[idx] = alpha_ptr + idx;
beta_ptrs[idx] = beta_ptr + idx;
}
// Launch the setup kernel to populate workspace arrays
inline void launch_grouped_gemm_setup(
const GroupedGemmSetupWorkspace &ws, const GroupedOperandSelection &A_sel,
const GroupedOperandSelection &B_sel, const transformer_engine::GroupedTensor *C,
const transformer_engine::GroupedTensor *D, const transformer_engine::Tensor *alpha_tensor,
const transformer_engine::Tensor *beta_tensor, size_t num_tensors, cudaStream_t stream) {
// Use shape info from selection (already accounts for columnwise dimension swap)
TensorShapeInfo A_meta = A_sel.shape;
TensorShapeInfo B_meta = B_sel.shape;
TensorShapeInfo C_meta = TensorShapeInfo::create_shape_info_for_C(C, D);
TensorShapeInfo D_meta = TensorShapeInfo::from_tensor(D);
char *c_base = static_cast<char *>(C->data.dptr);
char *d_base = static_cast<char *>(D->data.dptr);
const size_t a_elem_size = transformer_engine::typeToSize(A_sel.dtype);
const size_t b_elem_size = transformer_engine::typeToSize(B_sel.dtype);
const size_t c_elem_size = transformer_engine::typeToSize(C->dtype());
const size_t d_elem_size = transformer_engine::typeToSize(D->dtype());
const int threads_per_block = 256;
const int num_blocks = (num_tensors + threads_per_block - 1) / threads_per_block;
setup_grouped_gemm_kernel<<<num_blocks, threads_per_block, 0, stream>>>(
ws.A_ptrs, ws.B_ptrs, ws.C_ptrs, ws.D_ptrs, ws.a_rows, ws.a_cols, ws.b_rows, ws.b_cols,
ws.d_rows, ws.d_cols, ws.alpha_ptrs, ws.beta_ptrs, A_sel.dptr, B_sel.dptr, c_base, d_base,
A_meta, B_meta, C_meta, D_meta, a_elem_size, b_elem_size, c_elem_size, d_elem_size,
static_cast<float *>(alpha_tensor->data.dptr), static_cast<float *>(beta_tensor->data.dptr),
num_tensors);
NVTE_CHECK_CUDA(cudaGetLastError());
}
inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) {
return GroupedGemmSetupWorkspace::required_setup_size(num_tensors, kGroupedGemmAlignment);
}
} // namespace
void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb,
const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha,
const NVTETensor beta, NVTETensor workspace_setup,
NVTETensor workspace_cublas, NVTEGroupedMatmulConfig config,
cudaStream_t stream) {
NVTE_API_CALL(nvte_grouped_gemm);
using namespace transformer_engine;
// Grouped GEMM requires Blackwell (SM100) or newer and cuBLAS 13.1+
const int current_device = cuda::current_device();
NVTE_CHECK(cuda::sm_arch(current_device) >= 100,
"nvte_grouped_gemm requires Blackwell (SM100) or newer architecture.");
NVTE_CHECK(cuda::cublas_version() >= 130100,
"nvte_grouped_gemm requires cuBLAS 13.1+, but run-time cuBLAS version is ",
cuda::cublas_version());
// Convert to internal types
const GroupedTensor *inputA = convertNVTEGroupedTensorCheck(A);
const GroupedTensor *inputB = convertNVTEGroupedTensorCheck(B);
const GroupedTensor *inputC_raw = convertNVTEGroupedTensor(C); // Can be NULL
GroupedTensor *outputD = convertNVTEGroupedTensorCheck(D);
const Tensor *alpha_tensor = convertNVTETensorCheck(alpha);
const Tensor *beta_tensor = convertNVTETensorCheck(beta);
Tensor *wspace_setup = convertNVTETensor(workspace_setup);
Tensor *wspace_cublas = convertNVTETensor(workspace_cublas);
// Parse config (if provided)
GroupedMatmulConfig config_;
if (config != nullptr) {
config_ = *reinterpret_cast<GroupedMatmulConfig *>(config);
}
// Validate inputs and num_tensors
validate_grouped_gemm_inputs(inputA, inputB, inputC_raw, outputD, alpha_tensor, beta_tensor);
// If C is NULL, use D as C (valid when beta=0, cuBLAS won't read C data)
const GroupedTensor *inputC = (inputC_raw != nullptr) ? inputC_raw : outputD;
const size_t num_tensors = inputA->num_tensors;
// Select operand storage (row-wise vs column-wise) and adjust transpose flags to
// mirror the non-grouped GEMM logic for FP8 layout constraints.
const auto A_sel = select_grouped_operand(inputA, static_cast<bool>(transa), /*is_A=*/true);
const auto B_sel = select_grouped_operand(inputB, static_cast<bool>(transb), /*is_A=*/false);
// Workspaces: setup (pointer arrays) and cuBLAS
const size_t setup_workspace_size = grouped_gemm_setup_workspace_size(num_tensors);
const size_t cublas_workspace_size = kGroupedGemmCublasWorkspaceSize;
void *setup_workspace_ptr = validate_and_get_workspace_ptr(wspace_setup, setup_workspace_size,
"Grouped GEMM setup workspace");
void *cublas_workspace_ptr = validate_and_get_workspace_ptr(wspace_cublas, cublas_workspace_size,
"Grouped GEMM cuBLAS workspace");
auto setup_workspace = GroupedGemmSetupWorkspace::from_buffers(
static_cast<char *>(setup_workspace_ptr), num_tensors);
launch_grouped_gemm_setup(setup_workspace, A_sel, B_sel, inputC, outputD, alpha_tensor,
beta_tensor, num_tensors, stream);
// Get cuBLAS handle
using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;
cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle();
// Setup cuBLAS operations
cublasOperation_t op_A = A_sel.trans ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t op_B = B_sel.trans ? CUBLAS_OP_T : CUBLAS_OP_N;
// Create grouped matrix layouts
cublasLtMatrixLayoutOpaque_t descA, descB, descC, descD;
init_matrix_layouts(descA, descB, descC, descD, setup_workspace, A_sel, B_sel, outputD,
num_tensors);
// Create matmul descriptor
cublasLtMatmulDescOpaque_t matmulDesc;
init_matmul_desc(matmulDesc, op_A, op_B);
set_fp8_scale_pointers(matmulDesc, A_sel, B_sel);
// Compute average dimensions for heuristics
// K dimension: if transa, K is A's first dim; if not, K is A's last dim
// Use original inputA and transa for heuristics (not modified A_sel.trans)
int64_t avg_m_val = config_.avg_m.value_or(compute_avg_first_dim(outputD));
int64_t avg_n_val = config_.avg_n.value_or(compute_avg_last_dim(outputD));
int64_t avg_k_val =
config_.avg_k.value_or(transa ? compute_avg_first_dim(inputA) : compute_avg_last_dim(inputA));
// Heuristic selection
cublasLtMatmulAlgo_t algo = select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC,
descD, avg_m_val, avg_n_val, avg_k_val);
// Execute the grouped GEMM
NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, &matmulDesc, setup_workspace.alpha_ptrs,
setup_workspace.A_ptrs, &descA, setup_workspace.B_ptrs, &descB,
setup_workspace.beta_ptrs, setup_workspace.C_ptrs, &descC,
setup_workspace.D_ptrs, &descD, &algo, cublas_workspace_ptr,
kGroupedGemmCublasWorkspaceSize, stream));
}
#else // CUBLAS_VERSION < 130100
void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb,
const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha,
const NVTETensor beta, NVTETensor workspace_setup,
NVTETensor workspace_cublas, NVTEGroupedMatmulConfig config,
cudaStream_t stream) {
NVTE_ERROR("nvte_grouped_gemm requires cuBLAS 13.1+, but compile-time cuBLAS version is ",
CUBLAS_VERSION, ". Please upgrade to CUDA 13.1 or newer.");
}
#endif // CUBLAS_VERSION >= 130100
......@@ -11,6 +11,8 @@
#ifndef TRANSFORMER_ENGINE_GEMM_H_
#define TRANSFORMER_ENGINE_GEMM_H_
#include <stdint.h>
#include "transformer_engine.h"
#ifdef __cplusplus
......@@ -20,6 +22,9 @@ extern "C" {
/*! \brief Configuration for matrix multiplication. */
typedef void *NVTEMatmulConfig;
/*! \brief Configuration for grouped matrix multiplication. */
typedef void *NVTEGroupedMatmulConfig;
/*! \enum NVTEMatmulConfigAttribute
* \brief Type of option for matrix multiplication.
*/
......@@ -52,6 +57,36 @@ enum NVTEMatmulConfigAttribute {
kNVTEMatmulConfigNumAttributes
};
/*! \enum NVTEGroupedMatmulConfigAttribute
* \brief Type of option for grouped matrix multiplication.
*/
enum NVTEGroupedMatmulConfigAttribute {
/*! Average M dimension hint
*
* Optional hint for average M dimension across all matrices in the group.
* Used by cuBLASLt for algorithm selection heuristics. If not set,
* computed automatically from D's logical shape.
*/
kNVTEGroupedMatmulConfigAvgM = 0,
/*! Average N dimension hint
*
* Optional hint for average N dimension across all matrices in the group.
* Used by cuBLASLt for algorithm selection heuristics. If not set,
* computed automatically from D's logical shape.
*/
kNVTEGroupedMatmulConfigAvgN = 1,
/*! Average K (reduction) dimension hint
*
* Optional hint for average K dimension across all matrices in the group.
* Used by cuBLASLt for algorithm selection heuristics. If not set,
* computed automatically from A's logical shape.
*/
kNVTEGroupedMatmulConfigAvgK = 2,
/*! Number of streaming multiprocessors to use in GEMM kernel. */
kNVTEGroupedMatmulConfigSMCount = 3,
kNVTEGroupedMatmulConfigNumAttributes
};
/*! \brief Create a matrix multiplication configuration. */
NVTEMatmulConfig nvte_create_matmul_config();
......@@ -82,6 +117,38 @@ void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigA
/*! \brief Destroy a matrix multiplication configuration. */
void nvte_destroy_matmul_config(NVTEMatmulConfig config);
/*! \brief Create a grouped matrix multiplication configuration. */
NVTEGroupedMatmulConfig nvte_create_grouped_matmul_config();
/*! \brief Query an option in grouped matrix multiplication configuration.
*
* \param[in] config Grouped matrix multiplication configuration.
* \param[in] attr Option type.
* \param[out] buf Memory address to write option value. Ignored if
* NULL.
* \param[in] size_in_bytes Size of buf.
* \param[out] size_written Number of bytes that have been written to
* buf. If buf is NULL, then the number of
* bytes that would have been written.
*/
void nvte_get_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config,
NVTEGroupedMatmulConfigAttribute attr, void *buf,
size_t size_in_bytes, size_t *size_written);
/*! \brief Set an option in grouped matrix multiplication configuration.
*
* \param[in] config Grouped matrix multiplication configuration.
* \param[in] attr Option type.
* \param[out] buf Memory address to read option value.
* \param[in] size_in_bytes Size of buf.
*/
void nvte_set_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config,
NVTEGroupedMatmulConfigAttribute attr,
const void *buf, size_t size_in_bytes);
/*! \brief Destroy a grouped matrix multiplication configuration. */
void nvte_destroy_grouped_matmul_config(NVTEGroupedMatmulConfig config);
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations (deprecated).
*
* This has been deprecated in favor of nvte_cublas_gemm_v2.
......@@ -229,6 +296,46 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor
bool accumulate, bool use_split_accumulator, int math_sm_count,
cudaStream_t stream);
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Grouped matrix multiplication: D = alpha * op(A) @ op(B) + beta * C
*
* \note Requires cuBLAS 13.1+ (CUDA 13.1+) and Blackwell (SM100) or newer GPU architecture.
* Will error at runtime if compiled with an older cuBLAS version or run on
* a pre-Blackwell GPU.
*
* Performs batched GEMM on a collection of matrices with potentially different shapes.
* All tensors in the group must have compatible dimensions for matrix multiplication.
* Uses NVTEGroupedTensor to efficiently handle collections of tensors with contiguous
* memory layout and shape metadata.
*
* \param[in] A Input grouped tensor A.
* \param[in] transa Whether to transpose A matrices.
* \param[in] B Input grouped tensor B.
* \param[in] transb Whether to transpose B matrices.
* \param[in] C Input grouped tensor C (can be NULL for beta=0).
* \param[out] D Output grouped tensor D.
* \param[in] alpha Scale multipliers for A @ B (NVTETensor with num_tensors elements).
* \param[in] beta Scale multipliers for C (NVTETensor with num_tensors elements).
* \param[in] workspace_setup Workspace tensor for pointer array setup.
* \param[in] workspace_cublas Workspace tensor for cuBLAS operations.
* \param[in] config Additional configuration (can be NULL for defaults).
* \param[in] stream CUDA stream for the operation.
*
* Requirements:
* - cuBLAS 13.1+ (CUDA 13.1+)
* - Blackwell (SM100) or newer GPU architecture
* - A, B, C (if provided), D must have the same num_tensors
* - For each i: D[i] = alpha[i] * op(A[i]) @ op(B[i]) + beta[i] * C[i]
* - Shape compatibility: if transa=false, transb=false:
* - A[i]: (M[i], K[i]), B[i]: (K[i], N[i]), D[i]: (M[i], N[i])
*/
void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb,
const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha,
const NVTETensor beta, NVTETensor workspace_setup,
NVTETensor workspace_cublas, NVTEGroupedMatmulConfig config,
cudaStream_t stream);
#ifdef __HIP_PLATFORM_AMD__
void nvte_multi_stream_cublas_batchgemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D,
const NVTETensor* bias, NVTETensor* pre_gelu_out,
......@@ -356,6 +463,70 @@ class MatmulConfigWrapper {
NVTEMatmulConfig config_ = nullptr;
};
/*! \struct GroupedMatmulConfigWrapper
* \brief C++ wrapper for NVTEGroupedMatmulConfig.
*/
class GroupedMatmulConfigWrapper {
public:
GroupedMatmulConfigWrapper() : config_{nvte_create_grouped_matmul_config()} {}
GroupedMatmulConfigWrapper(const GroupedMatmulConfigWrapper &) = delete;
GroupedMatmulConfigWrapper &operator=(const GroupedMatmulConfigWrapper &) = delete;
GroupedMatmulConfigWrapper(GroupedMatmulConfigWrapper &&other) : config_{other.config_} {
other.config_ = nullptr;
}
GroupedMatmulConfigWrapper &operator=(GroupedMatmulConfigWrapper &&other) {
if (config_ != nullptr) {
nvte_destroy_grouped_matmul_config(config_);
}
config_ = other.config_;
other.config_ = nullptr;
return *this;
}
~GroupedMatmulConfigWrapper() {
if (config_ != nullptr) {
nvte_destroy_grouped_matmul_config(config_);
config_ = nullptr;
}
}
/*! \brief Get the underlying NVTEGroupedMatmulConfig.
*
* \return NVTEGroupedMatmulConfig held by this GroupedMatmulConfigWrapper.
*/
operator NVTEGroupedMatmulConfig() const noexcept { return config_; }
/*! \brief Set average M dimension hint for algorithm selection. */
void set_avg_m(int64_t avg_m) {
nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigAvgM, &avg_m,
sizeof(int64_t));
}
/*! \brief Set average N dimension hint for algorithm selection. */
void set_avg_n(int64_t avg_n) {
nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigAvgN, &avg_n,
sizeof(int64_t));
}
/*! \brief Set average K dimension hint for algorithm selection. */
void set_avg_k(int64_t avg_k) {
nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigAvgK, &avg_k,
sizeof(int64_t));
}
/*! \brief Set number of streaming multiprocessors to use. */
void set_sm_count(int sm_count) {
nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigSMCount, &sm_count,
sizeof(int));
}
private:
/*! \brief Wrapped NVTEGroupedMatmulConfig. */
NVTEGroupedMatmulConfig config_ = nullptr;
};
} // namespace transformer_engine
#endif // __cplusplus
......
......@@ -6,6 +6,8 @@
#include "../util/cuda_runtime.h"
#include <cublasLt.h>
#include <filesystem>
#include <mutex>
......@@ -232,6 +234,12 @@ int cudart_version() {
return version;
}
size_t cublas_version() {
// Cache version to avoid cuBLAS logging overhead
static size_t version = cublasLtGetVersion();
return version;
}
} // namespace cuda
} // namespace transformer_engine
......@@ -85,6 +85,12 @@ const std::string &include_directory(bool required = false);
*/
int cudart_version();
/* \brief cuBLAS version number at run-time
*
* Versions may differ between compile-time and run-time.
*/
size_t cublas_version();
} // namespace cuda
} // namespace transformer_engine
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment