Unverified Commit b9f40131 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[common] Add support for cuBLASLt GEMM for GroupedTensor (#2502)



* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add FP8 scale support and fix alignment for grouped GEMM

- Add FP8 scale_inv pointer handling in nvte_grouped_gemm for proper FP8 GEMM
- Fix random padding in tests to ensure 16-byte alignment for all dtypes
- Reorder GroupedGemmSetupWorkspace members for natural alignment
- Remove debug prints
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Grouped GEMM: code cleanup and NULL C support

- Remove unused alignment parameter from GroupedGemmSetupWorkspace::from_buffers
- Simplify select_grouped_operand by removing dead code branches
- Add GroupedOperandSelection.tensor field to avoid passing tensor separately
- Extract set_fp8_scale_pointers and init_matrix_layouts helpers
- Add safety check for FP8 on Hopper column-wise fallback
- Support NULL C tensor when beta=0 (uses D as placeholder)
- Remove unused get_scale_inv() from test
- Add use_null_c test parameter and test case
- Fix documentation: alpha/beta are single element tensors only
Signed-off-by: default avatarPiotr Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Grouped GEMM: per-matrix alpha/beta support

- Change alpha/beta from single values to per-matrix arrays
- Validate alpha/beta have exactly num_tensors elements
- Update kernel to index alpha_ptr[idx] and beta_ptr[idx]
- Move alpha/beta validation to validate_grouped_gemm_inputs
- Update tests to use per-matrix alpha/beta arrays
- Update documentation
Signed-off-by: default avatarPiotr Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix alpha/beta numel - use SimpleTensor::numel()
Signed-off-by: default avatarPiotr Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Refactor: move grouped GEMM to separate file and cleanup API
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Require Blackwell (SM100) and cuBLAS 13.1+ for grouped GEMM
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/common/gemm/config.h
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* changed
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* suggestions
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* refactored hopper tensor selection
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarPiotr Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
parent f04b094c
......@@ -30,6 +30,7 @@ add_executable(test_operator
test_causal_softmax.cu
test_swizzle.cu
test_swap_first_dims.cu
test_grouped_gemm.cu
../test_common.cu)
# Find required packages
......
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cublasLt.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <algorithm>
#include <memory>
#include <numeric>
#include <optional>
#include <random>
#include <tuple>
#include <vector>
#include <transformer_engine/cast.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/recipe.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"
using namespace transformer_engine;
using namespace test;
namespace {
enum class InputCase {
kFP8Current,
kBF16,
};
enum class ShapeCase {
kAllSame,
kSameFirst,
kSameLast,
kAllDifferent,
};
size_t grouped_setup_workspace_size(const size_t num_tensors) {
const size_t ptr_bytes = num_tensors * sizeof(void*);
const size_t int_bytes = num_tensors * sizeof(int);
// Layout: 6 pointer arrays (A, B, C, D, alpha, beta) + 6 int arrays (a_rows, a_cols, b_rows, b_cols, d_rows, d_cols)
size_t size = 6 * ptr_bytes + 6 * int_bytes;
const size_t alignment = 256;
size = ((size + alignment - 1) / alignment) * alignment;
return size;
}
Tensor make_fp8_operand(const std::string& name, const std::vector<size_t>& shape) {
Tensor input_fp32(name + "_fp32", shape, DType::kFloat32);
fillUniform(&input_fp32);
Tensor fp8(name, shape, TypeInfo<fp8e4m3>::dtype, true, true, NVTE_DELAYED_TENSOR_SCALING);
nvte_compute_amax(input_fp32.data(), fp8.data(), 0);
QuantizationConfigWrapper config;
nvte_compute_scale_from_amax(fp8.data(), config, 0);
nvte_quantize(input_fp32.data(), fp8.data(), 0);
return fp8;
}
Tensor make_bf16_operand(const std::string& name, const std::vector<size_t>& shape) {
Tensor t(name, shape, DType::kBFloat16);
const size_t numel = shape[0] * shape[1];
std::vector<__nv_bfloat16> ones(numel, __float2bfloat16(1.0f));
NVTE_CHECK_CUDA(cudaMemcpy(t.rowwise_dptr(), ones.data(),
numel * sizeof(__nv_bfloat16), cudaMemcpyHostToDevice));
return t;
}
struct TestParams {
InputCase input_case;
bool transa;
bool transb;
ShapeCase shape_case;
bool use_null_c = false; // When true, pass nullptr for C (valid when beta=0)
};
// Returns a vector of (M, N, K) tuples for each GEMM in the group.
// M - number of rows in output D
// N - number of columns in output D
// K - reduction dimension shared between A and B
std::vector<std::tuple<size_t, size_t, size_t>> make_shapes(ShapeCase scase) {
switch (scase) {
case ShapeCase::kAllSame:
return {{64, 64, 32}, {64, 64, 32}, {64, 64, 32}};
case ShapeCase::kSameFirst:
// Same M (first dim), varying N and K
return {{64, 80, 32}, {64, 96, 48}, {64, 112, 64}};
case ShapeCase::kSameLast:
// Same N (last dim), varying M and K
return {{64, 80, 32}, {80, 80, 48}, {96, 80, 64}};
case ShapeCase::kAllDifferent:
default:
return {{64, 96, 32}, {80, 112, 48}, {96, 128, 64}};
}
}
void run_grouped_gemm_case(const TestParams& params) {
#if CUBLAS_VERSION < 130100
GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.1+, but compile-time cuBLAS version is "
<< CUBLAS_VERSION << ".";
#else
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP() << "Grouped GEMM requires Blackwell (SM100) or newer.";
}
const std::vector<std::tuple<size_t, size_t, size_t>> shapes = make_shapes(params.shape_case);
const size_t num_gemms = shapes.size();
std::vector<Tensor> A_tensors;
std::vector<Tensor> B_tensors;
std::vector<Tensor> D_multi;
A_tensors.reserve(num_gemms);
B_tensors.reserve(num_gemms);
D_multi.reserve(num_gemms);
for (size_t i = 0; i < num_gemms; ++i) {
const auto [M, N, K] = shapes[i];
const std::vector<size_t> a_shape = params.transa ? std::vector<size_t>{M, K}
: std::vector<size_t>{K, M};
const std::vector<size_t> b_shape = params.transb ? std::vector<size_t>{K, N}
: std::vector<size_t>{N, K};
switch (params.input_case) {
case InputCase::kFP8Current: {
A_tensors.emplace_back(make_fp8_operand("A" + std::to_string(i), a_shape));
B_tensors.emplace_back(make_fp8_operand("B" + std::to_string(i), b_shape));
break;
}
case InputCase::kBF16: {
A_tensors.emplace_back(make_bf16_operand("A" + std::to_string(i), a_shape));
B_tensors.emplace_back(make_bf16_operand("B" + std::to_string(i), b_shape));
break;
}
}
D_multi.emplace_back(Tensor("D_multi" + std::to_string(i),
std::vector<size_t>{M, N},
DType::kBFloat16));
}
std::vector<NVTETensor> A_ptrs(num_gemms);
std::vector<NVTETensor> B_ptrs(num_gemms);
std::vector<NVTETensor> D_ptrs(num_gemms);
std::vector<Tensor> workspaces(num_gemms);
std::vector<NVTETensor> workspace_ptrs(num_gemms, nullptr);
std::vector<Tensor*> A_views;
std::vector<Tensor*> B_views;
A_views.reserve(num_gemms);
B_views.reserve(num_gemms);
// Empty bias/gelu arrays for nvte_multi_tensor_gemm (no epilogues)
std::vector<NVTETensor> bias_ptrs(num_gemms, nullptr);
std::vector<NVTETensor> gelu_ptrs(num_gemms, nullptr);
const size_t cublas_ws_bytes = 32ull * 1024 * 1024;
for (size_t i = 0; i < num_gemms; ++i) {
A_ptrs[i] = A_tensors[i].data();
B_ptrs[i] = B_tensors[i].data();
D_ptrs[i] = D_multi[i].data();
workspaces[i] = Tensor("workspace" + std::to_string(i), std::vector<size_t>{cublas_ws_bytes}, DType::kByte);
workspace_ptrs[i] = workspaces[i].data();
A_views.push_back(&A_tensors[i]);
B_views.push_back(&B_tensors[i]);
}
nvte_multi_tensor_gemm(A_ptrs.data(),
B_ptrs.data(),
D_ptrs.data(),
bias_ptrs.data(),
gelu_ptrs.data(),
static_cast<int>(num_gemms),
params.transa,
params.transb,
false, // grad
workspace_ptrs.data(),
false, // accumulate
false, // use_split_accumulator
0, // sm_count
0);
GroupedBuffers grouped_A = build_grouped_tensor(A_views, A_tensors[0].scaling_mode());
GroupedBuffers grouped_B = build_grouped_tensor(B_views, B_tensors[0].scaling_mode());
std::vector<Tensor> C_tensors;
std::vector<Tensor> D_group_tensors;
C_tensors.reserve(num_gemms);
D_group_tensors.reserve(num_gemms);
for (size_t i = 0; i < num_gemms; ++i) {
const auto [M, N, K] = shapes[i];
(void)K;
if (!params.use_null_c) {
C_tensors.emplace_back(Tensor("C" + std::to_string(i),
std::vector<size_t>{static_cast<size_t>(M), static_cast<size_t>(N)},
DType::kBFloat16));
}
D_group_tensors.emplace_back(Tensor("D_group" + std::to_string(i),
std::vector<size_t>{static_cast<size_t>(M), static_cast<size_t>(N)},
DType::kBFloat16));
NVTE_CHECK_CUDA(cudaMemset(D_group_tensors.back().rowwise_dptr(), 0, bytes(D_group_tensors.back().rowwise_shape(), D_group_tensors.back().dtype())));
}
std::vector<Tensor*> C_views, D_views;
for (size_t i = 0; i < num_gemms; ++i) {
if (!params.use_null_c) {
C_views.push_back(&C_tensors[i]);
}
D_views.push_back(&D_group_tensors[i]);
}
std::optional<GroupedBuffers> grouped_C;
if (!params.use_null_c) {
grouped_C = build_grouped_tensor(C_views, NVTE_DELAYED_TENSOR_SCALING);
}
GroupedBuffers grouped_D = build_grouped_tensor(D_views, NVTE_DELAYED_TENSOR_SCALING);
// Per-matrix alpha/beta (all 1.0 and 0.0 respectively)
Tensor alpha_tensor("alpha", std::vector<size_t>{num_gemms}, DType::kFloat32);
Tensor beta_tensor("beta", std::vector<size_t>{num_gemms}, DType::kFloat32);
std::vector<float> alpha_vals(num_gemms, 1.f);
std::vector<float> beta_vals(num_gemms, 0.f);
NVTE_CHECK_CUDA(cudaMemcpy(alpha_tensor.rowwise_dptr(), alpha_vals.data(),
num_gemms * sizeof(float), cudaMemcpyHostToDevice));
NVTE_CHECK_CUDA(cudaMemcpy(beta_tensor.rowwise_dptr(), beta_vals.data(),
num_gemms * sizeof(float), cudaMemcpyHostToDevice));
const size_t setup_ws_bytes = grouped_setup_workspace_size(num_gemms);
Tensor setup_ws("setup_ws", std::vector<size_t>{setup_ws_bytes}, DType::kByte);
Tensor cublas_ws("cublas_ws", std::vector<size_t>{cublas_ws_bytes}, DType::kByte);
nvte_grouped_gemm(grouped_A.get_handle(),
params.transa,
grouped_B.get_handle(),
params.transb,
params.use_null_c ? nullptr : grouped_C->get_handle(),
grouped_D.get_handle(),
alpha_tensor.data(),
beta_tensor.data(),
setup_ws.data(),
cublas_ws.data(),
nullptr, // config (use defaults)
0);
for (size_t i = 0; i < num_gemms; ++i) {
Tensor grouped_split("grouped_D" + std::to_string(i),
std::vector<size_t>{static_cast<size_t>(std::get<0>(shapes[i])),
static_cast<size_t>(std::get<1>(shapes[i]))},
D_multi[i].dtype());
const size_t offset_bytes = static_cast<size_t>(grouped_D.offsets_host[i]) * grouped_D.elem_size;
NVTE_CHECK_CUDA(cudaMemcpy(grouped_split.rowwise_dptr(),
static_cast<char*>(grouped_D.get_data()) + offset_bytes,
grouped_D.tensor_bytes[i],
cudaMemcpyDeviceToDevice));
grouped_split.to_cpu();
D_multi[i].to_cpu();
auto [atol, rtol] = getTolerances(D_multi[i].dtype());
compareResults("grouped_vs_multi",
grouped_split,
D_multi[i].rowwise_cpu_dptr<bf16>(),
true,
atol,
rtol);
}
#endif // CUBLAS_VERSION >= 130100
}
class GroupedGemmTest : public ::testing::TestWithParam<TestParams> {};
TEST_P(GroupedGemmTest, CompareWithMultiTensorGemm) {
run_grouped_gemm_case(GetParam());
}
std::string MakeGroupedGemmTestName(const testing::TestParamInfo<GroupedGemmTest::ParamType>& info) {
constexpr const char* kInputNames[] = {"FP8Current", "BF16"};
constexpr const char* kShapeNames[] = {"AllSame", "SameM", "SameN", "AllDiff"};
const std::string layout = std::string("ta") + (info.param.transa ? "T" : "N") +
"tb" + (info.param.transb ? "T" : "N");
const std::string null_c = info.param.use_null_c ? "_NullC" : "";
return std::string(kInputNames[static_cast<int>(info.param.input_case)]) + "_" +
kShapeNames[static_cast<int>(info.param.shape_case)] + "_" + layout + null_c;
}
// TestParams: {input_case, transa, transb, shape_case, use_null_c}
const std::vector<TestParams> kTestParams = {
// Basic tests
{InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent, false},
{InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent, false},
{InputCase::kFP8Current, false, false, ShapeCase::kAllSame, false},
{InputCase::kBF16, true, false, ShapeCase::kSameFirst, false},
{InputCase::kBF16, false, true, ShapeCase::kSameLast, false},
{InputCase::kBF16, false, false, ShapeCase::kAllSame, false},
{InputCase::kBF16, true, true, ShapeCase::kAllDifferent, false},
// Test NULL C (valid when beta=0)
{InputCase::kBF16, false, false, ShapeCase::kAllSame, true},
};
INSTANTIATE_TEST_SUITE_P(OperatorTest,
GroupedGemmTest,
::testing::ValuesIn(kTestParams),
MakeGroupedGemmTestName);
} // namespace
......@@ -9,6 +9,7 @@
#include <algorithm>
#include <memory>
#include <numeric>
#include <random>
#include <iostream>
#include <cassert>
......@@ -1057,4 +1058,166 @@ std::array<size_t, 4> get_scale_tensor_dims(const size_t rows,
return {unpadded_blocks_Y, unpadded_blocks_X, blocks_Y, blocks_X};
}
GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
const NVTEScalingMode scaling_mode) {
NVTE_CHECK(!tensors.empty(), "No tensors provided for grouped tensor build.");
const NVTEShape shape = tensors[0]->rowwise_shape();
const DType dtype = tensors[0]->dtype();
const size_t num_tensors = tensors.size();
const size_t elem_size = typeToNumBits(dtype) / 8;
GroupedBuffers grouped;
grouped.elem_size = elem_size;
grouped.num_tensors = num_tensors;
grouped.dtype = dtype;
grouped.scaling_mode = scaling_mode;
grouped.tensor_bytes.resize(num_tensors);
grouped.offsets_host.resize(num_tensors, 0);
std::vector<int64_t> first_dims(num_tensors);
std::vector<int64_t> last_dims(num_tensors);
for (size_t i = 0; i < num_tensors; ++i) {
const auto s = tensors[i]->rowwise_shape();
NVTE_CHECK(s.ndim == 2, "Grouped tensor build expects 2D tensors.");
first_dims[i] = static_cast<int64_t>(s.data[0]);
last_dims[i] = static_cast<int64_t>(s.data[1]);
grouped.tensor_bytes[i] = bytes(s, dtype);
}
const bool same_first = std::all_of(first_dims.begin(), first_dims.end(),
[&](int64_t v) { return v == first_dims[0]; });
const bool same_last = std::all_of(last_dims.begin(), last_dims.end(),
[&](int64_t v) { return v == last_dims[0]; });
std::vector<int64_t> offsets(num_tensors, 0);
auto random_padding = [&]() -> int64_t {
// Random padding ensuring 16-byte alignment regardless of element size
// cuBLAS requires aligned pointers for vectorized loads
static std::mt19937 gen(12345);
std::uniform_int_distribution<int64_t> dist(0, 3);
// Calculate elements needed for 16-byte alignment in bytes, rounded up
const size_t align_elements =
std::max<size_t>(1, (16 + elem_size - 1) / elem_size); // 16 bytes / element_size
return dist(gen) * static_cast<int64_t>(align_elements);
};
auto numel = [&](size_t idx) -> int64_t {
return first_dims[idx] * last_dims[idx];
};
const bool need_offsets = !same_first || !same_last;
if (need_offsets) {
offsets[0] = 0;
for (size_t i = 1; i < num_tensors; ++i) {
offsets[i] = offsets[i - 1] + numel(i - 1) + random_padding();
}
} else {
for (size_t i = 0; i < num_tensors; ++i) {
offsets[i] = static_cast<int64_t>(i) * numel(0);
}
}
grouped.offsets_host = offsets;
int64_t logical_first = 0;
int64_t logical_last = 0;
if (same_first && same_last) {
logical_first = first_dims[0] * static_cast<int64_t>(num_tensors);
logical_last = last_dims[0];
} else if (same_first && !same_last) {
logical_first = first_dims[0];
logical_last = std::accumulate(last_dims.begin(), last_dims.end(), int64_t{0});
} else if (!same_first && same_last) {
logical_first = std::accumulate(first_dims.begin(), first_dims.end(), int64_t{0});
logical_last = last_dims[0];
} else {
logical_first = 1;
logical_last = 0;
for (size_t i = 0; i < num_tensors; ++i) {
logical_last += first_dims[i] * last_dims[i];
}
}
size_t logical_data[2] = {static_cast<size_t>(logical_first),
static_cast<size_t>(logical_last)};
grouped.logical_shape = nvte_make_shape(logical_data, 2);
grouped.handle.reset(nvte_create_grouped_tensor(scaling_mode, num_tensors, grouped.logical_shape));
const int64_t last_idx = static_cast<int64_t>(num_tensors - 1);
const int64_t total_elems = need_offsets
? (offsets[last_idx] + numel(last_idx))
: (logical_first * logical_last);
const size_t total_bytes = static_cast<size_t>(total_elems) * elem_size;
grouped.data = cuda_alloc(total_bytes);
for (size_t i = 0; i < num_tensors; ++i) {
const size_t offset_bytes = static_cast<size_t>(offsets[i]) * elem_size;
NVTE_CHECK_CUDA(cudaMemcpy(static_cast<char*>(grouped.data.get()) + offset_bytes,
tensors[i]->rowwise_dptr(),
grouped.tensor_bytes[i],
cudaMemcpyDeviceToDevice));
}
NVTEBasicTensor data_tensor{grouped.data.get(), static_cast<NVTEDType>(dtype), grouped.logical_shape};
NVTEGroupedTensor h = grouped.handle.get();
nvte_set_grouped_tensor_param(&h, kNVTEGroupedRowwiseData, &data_tensor);
const bool include_columnwise = isFp8Type(dtype) || isFp4Type(dtype);
if (include_columnwise) {
grouped.columnwise_data = cuda_alloc(total_bytes);
for (size_t i = 0; i < num_tensors; ++i) {
const size_t offset_bytes = static_cast<size_t>(offsets[i]) * elem_size;
NVTE_CHECK_CUDA(cudaMemcpy(static_cast<char*>(grouped.columnwise_data.get()) + offset_bytes,
tensors[i]->columnwise_dptr(),
grouped.tensor_bytes[i],
cudaMemcpyDeviceToDevice));
}
NVTEBasicTensor col_tensor{grouped.columnwise_data.get(),
static_cast<NVTEDType>(dtype),
grouped.logical_shape};
nvte_set_grouped_tensor_param(&h, kNVTEGroupedColumnwiseData, &col_tensor);
}
if (!same_first) {
grouped.first_dims_dev = cuda_alloc<int64_t>(num_tensors * sizeof(int64_t));
NVTE_CHECK_CUDA(cudaMemcpy(grouped.first_dims_dev.get(), first_dims.data(),
num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice));
NVTEShape fd_shape = nvte_make_shape(&num_tensors, 1);
NVTEBasicTensor fd_tensor{grouped.first_dims_dev.get(), kNVTEInt64, fd_shape};
nvte_set_grouped_tensor_param(&h, kNVTEGroupedFirstDims, &fd_tensor);
}
if (!same_last) {
grouped.last_dims_dev = cuda_alloc<int64_t>(num_tensors * sizeof(int64_t));
NVTE_CHECK_CUDA(cudaMemcpy(grouped.last_dims_dev.get(), last_dims.data(),
num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice));
NVTEShape ld_shape = nvte_make_shape(&num_tensors, 1);
NVTEBasicTensor ld_tensor{grouped.last_dims_dev.get(), kNVTEInt64, ld_shape};
nvte_set_grouped_tensor_param(&h, kNVTEGroupedLastDims, &ld_tensor);
}
if (!same_first || !same_last) {
grouped.offsets_dev = cuda_alloc<int64_t>(num_tensors * sizeof(int64_t));
NVTE_CHECK_CUDA(cudaMemcpy(grouped.offsets_dev.get(), offsets.data(),
num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice));
NVTEShape off_shape = nvte_make_shape(&num_tensors, 1);
NVTEBasicTensor off_tensor{grouped.offsets_dev.get(), kNVTEInt64, off_shape};
nvte_set_grouped_tensor_param(&h, kNVTEGroupedTensorOffsets, &off_tensor);
}
if (isFp8Type(dtype)) {
std::vector<float> scale_inv_cpu(num_tensors, 1.f);
for (size_t i = 0; i < num_tensors; ++i) {
tensors[i]->to_cpu();
scale_inv_cpu[i] = tensors[i]->rowwise_cpu_scale_inv_ptr<float>()[0];
}
grouped.scale_inv = cuda_alloc(sizeof(float) * num_tensors);
NVTE_CHECK_CUDA(cudaMemcpy(grouped.scale_inv.get(), scale_inv_cpu.data(),
sizeof(float) * num_tensors, cudaMemcpyHostToDevice));
NVTEShape scale_shape = nvte_make_shape(&num_tensors, 1);
NVTEBasicTensor scale_tensor{grouped.scale_inv.get(), kNVTEFloat32, scale_shape};
nvte_set_grouped_tensor_param(&h, kNVTEGroupedRowwiseScaleInv, &scale_tensor);
nvte_set_grouped_tensor_param(&h, kNVTEGroupedColumnwiseScaleInv, &scale_tensor);
}
return grouped;
}
} // namespace test
......@@ -504,6 +504,60 @@ int32_t getDeviceComputeCapability();
constexpr int32_t hopperComputeCapability = 90;
constexpr int32_t blackwellComputeCapability = 100;
// Custom deleters for RAII
struct CudaDeleter {
void operator()(void* p) const { if (p) cudaFree(p); }
};
struct GroupedTensorDeleter {
void operator()(NVTEGroupedTensor h) const { if (h) nvte_destroy_grouped_tensor(h); }
};
template <typename T = void>
using CudaPtr = std::unique_ptr<T, CudaDeleter>;
using GroupedTensorHandle = std::unique_ptr<std::remove_pointer_t<NVTEGroupedTensor>, GroupedTensorDeleter>;
// Helper to allocate CUDA memory into a CudaPtr
template <typename T = void>
CudaPtr<T> cuda_alloc(size_t bytes) {
void* ptr = nullptr;
NVTE_CHECK_CUDA(cudaMalloc(&ptr, bytes));
return CudaPtr<T>(static_cast<T*>(ptr));
}
// Helper owning GPU buffers that back NVTEGroupedTensor.
// NVTEGroupedTensor does not own memory; data/offsets/scales
// must be allocated and freed by the test.
struct GroupedBuffers {
GroupedTensorHandle handle;
CudaPtr<> data;
CudaPtr<> scale_inv;
CudaPtr<int64_t> first_dims_dev;
CudaPtr<int64_t> last_dims_dev;
CudaPtr<int64_t> offsets_dev;
CudaPtr<> columnwise_data;
NVTEShape logical_shape{};
std::vector<int64_t> offsets_host;
std::vector<size_t> tensor_bytes;
size_t num_tensors{0};
size_t elem_size{0};
DType dtype{DType::kFloat32};
NVTEScalingMode scaling_mode{NVTE_DELAYED_TENSOR_SCALING};
GroupedBuffers() = default;
GroupedBuffers(const GroupedBuffers&) = delete;
GroupedBuffers& operator=(const GroupedBuffers&) = delete;
GroupedBuffers(GroupedBuffers&&) = default;
GroupedBuffers& operator=(GroupedBuffers&&) = default;
~GroupedBuffers() = default;
// Convenience accessors for raw pointers
NVTEGroupedTensor get_handle() const { return handle.get(); }
void* get_data() const { return data.get(); }
};
GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
const NVTEScalingMode scaling_mode);
} // namespace test
#if FP4_TYPE_SUPPORTED
......
......@@ -144,6 +144,7 @@ list(APPEND transformer_engine_cuda_sources
fused_attn/fused_attn_fp8.cu
fused_attn/utils.cu
gemm/cublaslt_gemm.cu
gemm/cublaslt_grouped_gemm.cu
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
......
......@@ -126,3 +126,106 @@ void nvte_destroy_matmul_config(NVTEMatmulConfig config) {
delete reinterpret_cast<transformer_engine::MatmulConfig *>(config);
}
}
NVTEGroupedMatmulConfig nvte_create_grouped_matmul_config() {
return new transformer_engine::GroupedMatmulConfig;
}
void nvte_get_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config,
NVTEGroupedMatmulConfigAttribute attr, void *buf,
size_t size_in_bytes, size_t *size_written) {
// Write attribute size
NVTE_CHECK(attr < kNVTEGroupedMatmulConfigNumAttributes,
"Invalid NVTEGroupedMatmulConfigAttribute (got ", static_cast<int>(attr), ")");
NVTE_CHECK(size_written != nullptr, "Invalid size_written (got NULL)");
const auto &attr_size = transformer_engine::GroupedMatmulConfig::attr_sizes[attr];
*size_written = attr_size;
// Return immediately if buffer is not provided
if (buf == nullptr) {
return;
}
// Check buffer size
NVTE_CHECK(size_in_bytes >= attr_size,
"Buffer is too small for grouped matmul config attribute "
"(attribute ",
static_cast<int>(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes,
" bytes)");
// Write to buffer
NVTE_CHECK(config != nullptr, "Invalid NVTEGroupedMatmulConfig (got NULL)");
const auto &config_ = *reinterpret_cast<const transformer_engine::GroupedMatmulConfig *>(config);
switch (attr) {
case kNVTEGroupedMatmulConfigAvgM: {
int64_t val = config_.avg_m.value_or(0);
std::memcpy(buf, &val, attr_size);
break;
}
case kNVTEGroupedMatmulConfigAvgN: {
int64_t val = config_.avg_n.value_or(0);
std::memcpy(buf, &val, attr_size);
break;
}
case kNVTEGroupedMatmulConfigAvgK: {
int64_t val = config_.avg_k.value_or(0);
std::memcpy(buf, &val, attr_size);
break;
}
case kNVTEGroupedMatmulConfigSMCount:
std::memcpy(buf, &config_.sm_count, attr_size);
break;
default:
NVTE_ERROR("Unsupported NVTEGroupedMatmulConfigAttribute (got ", static_cast<int>(attr), ")");
}
}
void nvte_set_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config,
NVTEGroupedMatmulConfigAttribute attr,
const void *buf, size_t size_in_bytes) {
// Check attribute and buffer
NVTE_CHECK(attr < kNVTEGroupedMatmulConfigNumAttributes,
"Invalid NVTEGroupedMatmulConfigAttribute (got ", static_cast<int>(attr), ")");
const auto &attr_size = transformer_engine::GroupedMatmulConfig::attr_sizes[attr];
NVTE_CHECK(size_in_bytes >= attr_size,
"Buffer is too small for grouped matmul config attribute "
"(attribute ",
static_cast<int>(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes,
" bytes)");
NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)");
// Read from buffer
NVTE_CHECK(config != nullptr, "Invalid NVTEGroupedMatmulConfig (got NULL)");
auto &config_ = *reinterpret_cast<transformer_engine::GroupedMatmulConfig *>(config);
switch (attr) {
case kNVTEGroupedMatmulConfigAvgM: {
int64_t val;
std::memcpy(&val, buf, attr_size);
config_.avg_m = val;
break;
}
case kNVTEGroupedMatmulConfigAvgN: {
int64_t val;
std::memcpy(&val, buf, attr_size);
config_.avg_n = val;
break;
}
case kNVTEGroupedMatmulConfigAvgK: {
int64_t val;
std::memcpy(&val, buf, attr_size);
config_.avg_k = val;
break;
}
case kNVTEGroupedMatmulConfigSMCount:
std::memcpy(&config_.sm_count, buf, attr_size);
break;
default:
NVTE_ERROR("Unsupported NVTEGroupedMatmulConfigAttribute (got ", static_cast<int>(attr), ")");
}
}
void nvte_destroy_grouped_matmul_config(NVTEGroupedMatmulConfig config) {
if (config != nullptr) {
delete reinterpret_cast<transformer_engine::GroupedMatmulConfig *>(config);
}
}
......@@ -9,6 +9,9 @@
#include <transformer_engine/transformer_engine.h>
#include <cstdint>
#include <optional>
namespace transformer_engine {
struct MatmulConfig {
......@@ -31,6 +34,22 @@ struct MatmulConfig {
};
};
struct GroupedMatmulConfig {
// Average dimension hints for cuBLASLt algorithm selection heuristics.
// nullopt means "not set" - compute automatically from tensor shapes.
std::optional<int64_t> avg_m;
std::optional<int64_t> avg_n;
std::optional<int64_t> avg_k;
// Number of streaming multiprocessors to use in GEMM kernel
int sm_count = 0;
// Note: API transfers the value type, not std::optional
static constexpr size_t attr_sizes[] = {sizeof(decltype(avg_m)::value_type),
sizeof(decltype(avg_n)::value_type),
sizeof(decltype(avg_k)::value_type), sizeof(sm_count)};
};
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_GEMM_CONFIG_H_
......@@ -302,13 +302,6 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
return ret;
}
/* cuBLAS version number at run-time */
size_t cublas_version() {
// Cache version to avoid cuBLAS logging overhead
static size_t version = cublasLtGetVersion();
return version;
}
} // namespace
namespace transformer_engine {
......@@ -501,8 +494,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#endif // CUBLAS_VERSION >= 120800
} else if (mxfp8_gemm) {
#if CUBLAS_VERSION >= 120800
NVTE_CHECK(cublas_version() >= 120800,
"MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version());
NVTE_CHECK(cuda::cublas_version() >= 120800,
"MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ",
cuda::cublas_version());
// Check that scales are in expected format
NVTE_CHECK(inputA->with_gemm_swizzled_scales,
......@@ -524,7 +518,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
// Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling.
// CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set.
if (cublas_version() <= 120803) {
if (cuda::cublas_version() <= 120803) {
const int64_t dummy_a_vec_stride = 1;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride,
......@@ -536,8 +530,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#endif // CUBLAS_VERSION >= 120800
} else if (use_fp4) { // NVFP4 GEMM
#if CUBLAS_VERSION >= 120800
NVTE_CHECK(cublas_version() >= 120800,
"FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version());
NVTE_CHECK(cuda::cublas_version() >= 120800,
"FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is ",
cuda::cublas_version());
// Check that scales are in expected format
NVTE_CHECK(inputA->with_gemm_swizzled_scales,
......@@ -572,9 +567,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
(inputB->scaling_mode == NVTE_BLOCK_SCALING_1D ||
inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) {
#if CUBLAS_VERSION >= 120900
NVTE_CHECK(cublas_version() >= 120900,
NVTE_CHECK(cuda::cublas_version() >= 120900,
"FP8 block scaling requires cuBLAS 12.9+, but run-time cuBLAS version is ",
cublas_version());
cuda::cublas_version());
// Check that matrix formats are valid
NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D &&
......@@ -607,7 +602,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
}
#if CUBLAS_VERSION >= 120800
if (cublas_version() >= 120800) {
if (cuda::cublas_version() >= 120800) {
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_A_SCALE_MODE,
&scaling_mode_a, sizeof(scaling_mode_a)));
......@@ -624,7 +619,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax)));
#if CUBLAS_VERSION >= 120800
if (cublas_version() >= 120800) {
if (cuda::cublas_version() >= 120800) {
// NOTE: In all current cases where FP8 output is supported, the input is
// scaled identically to the output.
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
......@@ -711,9 +706,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA version is ",
cuda::cudart_version());
NVTE_CHECK(cublas_version() >= 120205 && cublas_version() < 130000,
NVTE_CHECK(cuda::cublas_version() >= 120205 && cuda::cublas_version() < 130000,
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS version is ",
cublas_version());
cuda::cublas_version());
if (m_split == 0) m_split = 1;
if (n_split == 0) n_split = 1;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
......@@ -939,9 +934,9 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
"Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA version is ",
transformer_engine::cuda::cudart_version());
NVTE_CHECK(
cublas_version() >= 120205 && cublas_version() < 130000,
cuda::cublas_version() >= 120205 && cuda::cublas_version() < 130000,
"Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is ",
cublas_version());
cuda::cublas_version());
const Tensor *inputA = convertNVTETensorCheck(A);
const Tensor *inputB = convertNVTETensorCheck(B);
......
This diff is collapsed.
......@@ -11,6 +11,8 @@
#ifndef TRANSFORMER_ENGINE_GEMM_H_
#define TRANSFORMER_ENGINE_GEMM_H_
#include <stdint.h>
#include "transformer_engine.h"
#ifdef __cplusplus
......@@ -20,6 +22,9 @@ extern "C" {
/*! \brief Configuration for matrix multiplication. */
typedef void *NVTEMatmulConfig;
/*! \brief Configuration for grouped matrix multiplication. */
typedef void *NVTEGroupedMatmulConfig;
/*! \enum NVTEMatmulConfigAttribute
* \brief Type of option for matrix multiplication.
*/
......@@ -52,6 +57,36 @@ enum NVTEMatmulConfigAttribute {
kNVTEMatmulConfigNumAttributes
};
/*! \enum NVTEGroupedMatmulConfigAttribute
* \brief Type of option for grouped matrix multiplication.
*/
enum NVTEGroupedMatmulConfigAttribute {
/*! Average M dimension hint
*
* Optional hint for average M dimension across all matrices in the group.
* Used by cuBLASLt for algorithm selection heuristics. If not set,
* computed automatically from D's logical shape.
*/
kNVTEGroupedMatmulConfigAvgM = 0,
/*! Average N dimension hint
*
* Optional hint for average N dimension across all matrices in the group.
* Used by cuBLASLt for algorithm selection heuristics. If not set,
* computed automatically from D's logical shape.
*/
kNVTEGroupedMatmulConfigAvgN = 1,
/*! Average K (reduction) dimension hint
*
* Optional hint for average K dimension across all matrices in the group.
* Used by cuBLASLt for algorithm selection heuristics. If not set,
* computed automatically from A's logical shape.
*/
kNVTEGroupedMatmulConfigAvgK = 2,
/*! Number of streaming multiprocessors to use in GEMM kernel. */
kNVTEGroupedMatmulConfigSMCount = 3,
kNVTEGroupedMatmulConfigNumAttributes
};
/*! \brief Create a matrix multiplication configuration. */
NVTEMatmulConfig nvte_create_matmul_config();
......@@ -82,6 +117,38 @@ void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigA
/*! \brief Destroy a matrix multiplication configuration. */
void nvte_destroy_matmul_config(NVTEMatmulConfig config);
/*! \brief Create a grouped matrix multiplication configuration. */
NVTEGroupedMatmulConfig nvte_create_grouped_matmul_config();
/*! \brief Query an option in grouped matrix multiplication configuration.
*
* \param[in] config Grouped matrix multiplication configuration.
* \param[in] attr Option type.
* \param[out] buf Memory address to write option value. Ignored if
* NULL.
* \param[in] size_in_bytes Size of buf.
* \param[out] size_written Number of bytes that have been written to
* buf. If buf is NULL, then the number of
* bytes that would have been written.
*/
void nvte_get_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config,
NVTEGroupedMatmulConfigAttribute attr, void *buf,
size_t size_in_bytes, size_t *size_written);
/*! \brief Set an option in grouped matrix multiplication configuration.
*
* \param[in] config Grouped matrix multiplication configuration.
* \param[in] attr Option type.
* \param[out] buf Memory address to read option value.
* \param[in] size_in_bytes Size of buf.
*/
void nvte_set_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config,
NVTEGroupedMatmulConfigAttribute attr,
const void *buf, size_t size_in_bytes);
/*! \brief Destroy a grouped matrix multiplication configuration. */
void nvte_destroy_grouped_matmul_config(NVTEGroupedMatmulConfig config);
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations (deprecated).
*
* This has been deprecated in favor of nvte_cublas_gemm_v2.
......@@ -228,6 +295,46 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor
bool transa, bool transb, bool grad, NVTETensor *workspace,
bool accumulate, bool use_split_accumulator, int math_sm_count,
cudaStream_t stream);
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Grouped matrix multiplication: D = alpha * op(A) @ op(B) + beta * C
*
* \note Requires cuBLAS 13.1+ (CUDA 13.1+) and Blackwell (SM100) or newer GPU architecture.
* Will error at runtime if compiled with an older cuBLAS version or run on
* a pre-Blackwell GPU.
*
* Performs batched GEMM on a collection of matrices with potentially different shapes.
* All tensors in the group must have compatible dimensions for matrix multiplication.
* Uses NVTEGroupedTensor to efficiently handle collections of tensors with contiguous
* memory layout and shape metadata.
*
* \param[in] A Input grouped tensor A.
* \param[in] transa Whether to transpose A matrices.
* \param[in] B Input grouped tensor B.
* \param[in] transb Whether to transpose B matrices.
* \param[in] C Input grouped tensor C (can be NULL for beta=0).
* \param[out] D Output grouped tensor D.
* \param[in] alpha Scale multipliers for A @ B (NVTETensor with num_tensors elements).
* \param[in] beta Scale multipliers for C (NVTETensor with num_tensors elements).
* \param[in] workspace_setup Workspace tensor for pointer array setup.
* \param[in] workspace_cublas Workspace tensor for cuBLAS operations.
* \param[in] config Additional configuration (can be NULL for defaults).
* \param[in] stream CUDA stream for the operation.
*
* Requirements:
* - cuBLAS 13.1+ (CUDA 13.1+)
* - Blackwell (SM100) or newer GPU architecture
* - A, B, C (if provided), D must have the same num_tensors
* - For each i: D[i] = alpha[i] * op(A[i]) @ op(B[i]) + beta[i] * C[i]
* - Shape compatibility: if transa=false, transb=false:
* - A[i]: (M[i], K[i]), B[i]: (K[i], N[i]), D[i]: (M[i], N[i])
*/
void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb,
const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha,
const NVTETensor beta, NVTETensor workspace_setup,
NVTETensor workspace_cublas, NVTEGroupedMatmulConfig config,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
......@@ -331,6 +438,70 @@ class MatmulConfigWrapper {
NVTEMatmulConfig config_ = nullptr;
};
/*! \struct GroupedMatmulConfigWrapper
* \brief C++ wrapper for NVTEGroupedMatmulConfig.
*/
class GroupedMatmulConfigWrapper {
public:
GroupedMatmulConfigWrapper() : config_{nvte_create_grouped_matmul_config()} {}
GroupedMatmulConfigWrapper(const GroupedMatmulConfigWrapper &) = delete;
GroupedMatmulConfigWrapper &operator=(const GroupedMatmulConfigWrapper &) = delete;
GroupedMatmulConfigWrapper(GroupedMatmulConfigWrapper &&other) : config_{other.config_} {
other.config_ = nullptr;
}
GroupedMatmulConfigWrapper &operator=(GroupedMatmulConfigWrapper &&other) {
if (config_ != nullptr) {
nvte_destroy_grouped_matmul_config(config_);
}
config_ = other.config_;
other.config_ = nullptr;
return *this;
}
~GroupedMatmulConfigWrapper() {
if (config_ != nullptr) {
nvte_destroy_grouped_matmul_config(config_);
config_ = nullptr;
}
}
/*! \brief Get the underlying NVTEGroupedMatmulConfig.
*
* \return NVTEGroupedMatmulConfig held by this GroupedMatmulConfigWrapper.
*/
operator NVTEGroupedMatmulConfig() const noexcept { return config_; }
/*! \brief Set average M dimension hint for algorithm selection. */
void set_avg_m(int64_t avg_m) {
nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigAvgM, &avg_m,
sizeof(int64_t));
}
/*! \brief Set average N dimension hint for algorithm selection. */
void set_avg_n(int64_t avg_n) {
nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigAvgN, &avg_n,
sizeof(int64_t));
}
/*! \brief Set average K dimension hint for algorithm selection. */
void set_avg_k(int64_t avg_k) {
nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigAvgK, &avg_k,
sizeof(int64_t));
}
/*! \brief Set number of streaming multiprocessors to use. */
void set_sm_count(int sm_count) {
nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigSMCount, &sm_count,
sizeof(int));
}
private:
/*! \brief Wrapped NVTEGroupedMatmulConfig. */
NVTEGroupedMatmulConfig config_ = nullptr;
};
} // namespace transformer_engine
#endif // __cplusplus
......
......@@ -6,6 +6,8 @@
#include "../util/cuda_runtime.h"
#include <cublasLt.h>
#include <filesystem>
#include <mutex>
......@@ -210,6 +212,12 @@ int cudart_version() {
return version;
}
size_t cublas_version() {
// Cache version to avoid cuBLAS logging overhead
static size_t version = cublasLtGetVersion();
return version;
}
} // namespace cuda
} // namespace transformer_engine
......@@ -73,6 +73,12 @@ const std::string &include_directory(bool required = false);
*/
int cudart_version();
/* \brief cuBLAS version number at run-time
*
* Versions may differ between compile-time and run-time.
*/
size_t cublas_version();
} // namespace cuda
} // namespace transformer_engine
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment