Unverified Commit 8aee1bb7 authored by alan yang's avatar alan yang Committed by GitHub
Browse files

[Pytorch] Add Cutlass Grouped GEMM Support for fine-grained MoE Model (#2045)



* feat: add cutlass group gemm support
Signed-off-by: default avatarMin Yang <min.yang@shopee.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* refactor: refactor multi tensor gemm interface
Signed-off-by: default avatarMin Yang <min.yang@shopee.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* refactor: refactor nvte_multi_stream_cublas_gemm func and add license info
Signed-off-by: default avatarMin Yang <min.yang@shopee.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* feat: add unit test for cutlass group gemm
Signed-off-by: default avatarMin Yang <min.yang@shopee.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* feat: add cutlass support type protect
Signed-off-by: default avatarMin Yang <min.yang@shopee.com>

* add tests and fix lint
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* feat: fix unit tests error
Signed-off-by: default avatarMin Yang <min.yang@shopee.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* feat: refactor host workspace malloc
Signed-off-by: default avatarMin Yang <min.yang@shopee.com>

* update cutlass
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* update cutlass
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* further relex threshold and add a env var to warn fall back
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarMin Yang <min.yang@shopee.com>
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Signed-off-by: default avataralan yang <89962857+cassiewilliam@users.noreply.github.com>
Co-authored-by: default avatarMin Yang <min.yang@shopee.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent eb69fad7
......@@ -4,3 +4,6 @@
[submodule "3rdparty/cudnn-frontend"]
path = 3rdparty/cudnn-frontend
url = https://github.com/NVIDIA/cudnn-frontend.git
[submodule "3rdparty/cutlass"]
path = 3rdparty/cutlass
url = https://github.com/NVIDIA/cutlass.git
Subproject commit 57e3cfb47a2d9e0d46eb6335c3dc411498efa198
......@@ -125,6 +125,11 @@ if fp8_available:
fp8_recipes.append(recipe.Float8CurrentScaling())
fp8_recipes.append(recipe.DelayedScaling())
use_cutlass_grouped_gemm = [False]
# Only enable cutlass grouped gemm on Hopper
if torch.cuda.get_device_capability() == (9, 0):
use_cutlass_grouped_gemm.append(True)
def is_fused_attn_available(
config: ModelConfig,
......@@ -1805,6 +1810,7 @@ def test_grouped_linear_accuracy(
bias,
delay_wgrad_compute,
parallel_mode=None,
use_cutlass=False,
):
fp8 = recipe is not None
if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
......@@ -1876,11 +1882,49 @@ def test_grouped_linear_accuracy(
delay_wgrad_compute,
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
for o, o_ref in zip(outputs, outputs_ref):
if use_cutlass:
torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3)
else:
# cuBLAS implementation should be bit-wise match
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.skipif(
torch.cuda.get_device_capability() != (9, 0),
reason="Only enable CUTLASS grouped gemm on Hopper",
)
@pytest.mark.parametrize("dtype", param_types, ids=str)
@pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
@pytest.mark.parametrize("delay_wgrad_compute", all_boolean)
def test_grouped_linear_accuracy_cutlass(
dtype,
num_gemms,
bs,
model,
fuse_wgrad_accumulation,
delay_wgrad_compute,
):
os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1"
test_grouped_linear_accuracy(
dtype,
num_gemms,
bs,
model,
None,
False,
fuse_wgrad_accumulation,
False,
delay_wgrad_compute,
None,
use_cutlass=True,
)
os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None)
@pytest.mark.parametrize("dtype", param_types, ids=str)
@pytest.mark.parametrize("num_gemms", [3])
@pytest.mark.parametrize("bs", [1])
......@@ -2542,10 +2586,11 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
(16, 10027, 128, 512),
],
)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("dtype", param_types, ids=str)
@pytest.mark.parametrize("layout", ["TN", "NN", "NT"])
@pytest.mark.parametrize("accumulate", [False, True])
def test_grouped_gemm(shape, dtype, layout, accumulate):
@pytest.mark.parametrize("use_cutlass", use_cutlass_grouped_gemm)
def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
torch.manual_seed(0)
z, m, k, n = shape
......@@ -2580,6 +2625,9 @@ def test_grouped_gemm(shape, dtype, layout, accumulate):
grad = True
single_output = False
if use_cutlass:
os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1"
for i in range(z):
general_gemm(
A[i],
......@@ -2607,9 +2655,15 @@ def test_grouped_gemm(shape, dtype, layout, accumulate):
single_output=single_output,
)
# should be bit-wise match
for o, o_ref in zip(out, out_ref):
if not use_cutlass:
# cublas implementation should be bit-wise match
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
else:
torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2)
if use_cutlass:
os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None)
@pytest.mark.parametrize("N", [32])
......
......@@ -45,6 +45,11 @@ if(NOT EXISTS "${CUDNN_FRONTEND_INCLUDE_DIR}")
endif()
include(${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake)
set(CUTLASS_INCLUDE_DIR
"${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cutlass/include")
set(CUTLASS_TOOLS_INCLUDE_DIR
"${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cutlass/tools/util/include")
# Python
find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
......@@ -81,6 +86,7 @@ list(APPEND transformer_engine_SOURCES
fused_attn/fused_attn.cpp
fused_attn/utils.cu
gemm/cublaslt_gemm.cu
gemm/cutlass_grouped_gemm.cu
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
......@@ -121,18 +127,30 @@ add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
target_include_directories(transformer_engine PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/include")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
set_source_files_properties(
"gemm/cutlass_grouped_gemm.cu"
PROPERTIES
COMPILE_FLAGS
"-gencode arch=compute_90a,code=sm_90a")
else()
message(FATAL_ERROR "cutlass gemm/cutlass_grouped_gemm.cu kernel required sm 90a")
endif()
# Configure dependencies
target_link_libraries(transformer_engine PUBLIC
CUDA::cublas
CUDA::cudart
CUDNN::cudnn_all)
target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine SYSTEM PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl)
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
target_include_directories(transformer_engine PRIVATE
${CUTLASS_INCLUDE_DIR}
${CUTLASS_TOOLS_INCLUDE_DIR})
# Compiling Userbuffers with native MPI bootstrapping requires linking against MPI
option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF)
......
......@@ -19,6 +19,7 @@
#include "../util/logging.h"
#include "../util/multi_stream.h"
#include "common/util/cuda_runtime.h"
#include "cutlass_grouped_gemm.cuh"
namespace {
......@@ -650,9 +651,10 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
CUBLAS_VERSION);
#endif
NVTE_CHECK(
cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
transformer_engine::cuda::cudart_version() >= 12020 &&
transformer_engine::cuda::cudart_version() < 13000,
"Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA version is ",
cuda::cudart_version());
transformer_engine::cuda::cudart_version());
NVTE_CHECK(
cublas_version() >= 120205 && cublas_version() < 130000,
"Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is ",
......@@ -675,13 +677,11 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
n_split, gemm_producer, inputCounter, stream);
}
void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
const NVTETensor *bias, NVTETensor *pre_gelu_out,
const int num_gemms, bool transa, bool transb, bool grad,
NVTETensor *workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count,
void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
const NVTETensor *bias, NVTETensor *pre_gelu_out, const int num_gemms,
bool transa, bool transb, bool grad, NVTETensor *workspace,
bool accumulate, bool use_split_accumulator, int math_sm_count,
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_stream_cublas_gemm);
using namespace transformer_engine;
int num_streams = nvte_get_num_compute_streams();
......@@ -711,6 +711,25 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
}
}
void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
const NVTETensor *bias, NVTETensor *pre_gelu_out,
const int num_gemms, bool transa, bool transb, bool grad,
NVTETensor *workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count,
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_stream_cublas_gemm);
using namespace transformer_engine;
// Deprecation warning
NVTE_WARN(
"nvte_multi_stream_cublas_gemm is deprecated and will be removed in a future release. "
"Please migrate to nvte_multi_tensor_gemm (with CUTLASS Grouped GEMM support when "
"applicable).");
multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad, workspace,
accumulate, use_split_accumulator, math_sm_count, stream);
}
namespace transformer_engine {
using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;
......@@ -718,3 +737,85 @@ using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublas
void nvte_cublas_handle_init() { auto _ = cublasHandleManager::Instance().GetHandle(); }
} // namespace transformer_engine
void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
const NVTETensor *bias, NVTETensor *pre_gelu_out, const int num_gemms,
bool transa, bool transb, bool grad, NVTETensor *workspace,
bool accumulate, bool use_split_accumulator, int math_sm_count,
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_gemm);
const int current_device = transformer_engine::cuda::current_device();
const bool is_hopper = (transformer_engine::cuda::sm_arch(current_device) == 90);
const bool use_cutlass = transformer_engine::getenv<bool>("NVTE_USE_CUTLASS_GROUPED_GEMM", false);
const bool warn_fallback =
transformer_engine::getenv<bool>("NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK", false);
auto cublas_path = [&]() {
multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad,
workspace, accumulate, use_split_accumulator, math_sm_count, stream);
};
// Currently only support cutlass group gemm on Hopper Arch
if (!(is_hopper && use_cutlass)) {
cublas_path();
return;
}
auto is_empty_arr = [&](const NVTETensor *p) -> bool {
if (p == nullptr) return true;
for (int i = 0; i < num_gemms; ++i) {
if (transformer_engine::convertNVTETensor(p[i])->has_data()) return false;
}
return true;
};
auto all_groups_uniform_k128 = [&](const NVTETensor *p, bool trans) -> bool {
int64_t ref_k = -1;
for (size_t i = 0; i < num_gemms; i++) {
const auto tensor = transformer_engine::convertNVTETensorCheck(p[i]);
const int k = trans ? tensor->data.shape[0] : tensor->data.shape[1];
if ((k & 127) != 0) return false;
if (ref_k < 0)
ref_k = k;
else if (k != ref_k)
return false;
}
return true;
};
auto is_supported_dtype = [&]() -> bool {
auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]);
auto *inputB = transformer_engine::convertNVTETensorCheck(B[0]);
auto *OutputD = transformer_engine::convertNVTETensorCheck(D[0]);
auto A_type = get_cuda_dtype(inputA->data.dtype);
auto B_type = get_cuda_dtype(inputB->data.dtype);
auto D_type = get_cuda_dtype(OutputD->data.dtype);
return (A_type == B_type) && (A_type == D_type) &&
((A_type == CUDA_R_16BF) || (A_type == CUDA_R_16F));
};
// CUTLASS Grouped GEMM fast path (SM90/TMA)
// Conditions:
// - No fused epilogue: both bias and pre_gelu_out are empty.
// - Supported dtypes only: FP16/BF16 (FP32 accumulate).
// - Uniform K across groups and K % 128 == 0.
// - use_split_accumulator is ignored for FP16/BF16.
// - grad is irrelevant when bias/pre_gelu_out are empty.
//
// Otherwise, fall back to cuBLAS.
if (is_empty_arr(bias) && is_empty_arr(pre_gelu_out) && is_supported_dtype() &&
all_groups_uniform_k128(B, transb)) {
cutlass_grouped_gemm(A, B, D, num_gemms, transa, transb, grad, workspace, accumulate,
current_device, math_sm_count, stream);
} else {
if (warn_fallback) {
NVTE_WARN("Fallback to cuBLAS grouped GEMM.");
}
cublas_path();
}
}
/***************************************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
**************************************************************************************************/
#include "cutlass/bfloat16.h"
#include "cutlass/cutlass.h"
#include "cutlass_grouped_gemm.cuh"
namespace transformer_engine {
namespace grouped_gemm {
// Explicit template instantiation to match the template declarations in the .cuh
template void CutlassGroupedGemm<false, false, cutlass::half_t>(const NVTETensor*,
const NVTETensor*, NVTETensor*,
NVTETensor*, float, float, int,
cudaStream_t, int, int);
template void CutlassGroupedGemm<true, false, cutlass::half_t>(const NVTETensor*, const NVTETensor*,
NVTETensor*, NVTETensor*, float,
float, int, cudaStream_t, int, int);
template void CutlassGroupedGemm<false, true, cutlass::half_t>(const NVTETensor*, const NVTETensor*,
NVTETensor*, NVTETensor*, float,
float, int, cudaStream_t, int, int);
template void CutlassGroupedGemm<false, false, cutlass::bfloat16_t>(const NVTETensor*,
const NVTETensor*, NVTETensor*,
NVTETensor*, float, float, int,
cudaStream_t, int, int);
template void CutlassGroupedGemm<true, false, cutlass::bfloat16_t>(const NVTETensor*,
const NVTETensor*, NVTETensor*,
NVTETensor*, float, float, int,
cudaStream_t, int, int);
template void CutlassGroupedGemm<false, true, cutlass::bfloat16_t>(const NVTETensor*,
const NVTETensor*, NVTETensor*,
NVTETensor*, float, float, int,
cudaStream_t, int, int);
} // namespace grouped_gemm
} // namespace transformer_engine
void cutlass_grouped_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, int num_gemms,
bool transa, bool transb, bool grad, NVTETensor* workspace,
bool accumulate, int device, int math_sm_count, cudaStream_t stream) {
using namespace transformer_engine;
auto* inputA = convertNVTETensorCheck(A[0]);
auto* inputB = convertNVTETensorCheck(B[0]);
float one = 1.0;
float zero = 0.0;
float alpha = one;
float beta = (accumulate) ? one : zero;
auto dispatch = [&](auto tag) {
using T = decltype(tag);
if (!transa && !transb) {
grouped_gemm::CutlassGroupedGemm<false, false, T>(B, A, D, workspace, alpha, beta, num_gemms,
stream, device, math_sm_count);
} else if (!transb && transa) {
grouped_gemm::CutlassGroupedGemm<false, true, T>(B, A, D, workspace, alpha, beta, num_gemms,
stream, device, math_sm_count);
} else if (transb && !transa) {
grouped_gemm::CutlassGroupedGemm<true, false, T>(B, A, D, workspace, alpha, beta, num_gemms,
stream, device, math_sm_count);
} else {
NVTE_ERROR("Layout 'TT' is not supported by cutlass_grouped_gemm.");
}
};
if (inputA->data.dtype == DType::kBFloat16) {
dispatch(cutlass::bfloat16_t{});
} else if (inputA->data.dtype == DType::kFloat16) {
dispatch(cutlass::half_t{});
} else {
NVTE_ERROR("Unsupported dtype: only BF16(FP16) are supported.");
}
}
/***************************************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
**************************************************************************************************/
//
// Copyright (c) 2025 Shopee Inc. All Rights Reserved.
//
/**
* @file: cutlass_grouped_gemm.cuh
* @author: min.yang@shopee.com, yangfan.bai@shopee.com, finch.li@shopee.com
* @date: 2025-08-08 16:20:00
* @brief: cutlass group gemm kernel.
**/
#pragma once
#include <transformer_engine/transformer_engine.h>
#include <cub/cub.cuh>
#include <type_traits>
#include "../common.h"
#include "../util/logging.h"
#include "common/util/system.h"
#include "cute/tensor.hpp"
#include "cutlass/bfloat16.h"
#include "cutlass/complex.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/device_memory.h"
#include "cutlass/util/packed_stride.hpp"
namespace transformer_engine {
namespace grouped_gemm {
template <bool trans_a>
using GroupedGemmInputALayout =
std::conditional_t<trans_a, ::cutlass::layout::ColumnMajor, ::cutlass::layout::RowMajor>;
template <bool trans_b>
using GroupedGemmInputBLayout =
std::conditional_t<trans_b, ::cutlass::layout::ColumnMajor, ::cutlass::layout::RowMajor>;
using ProblemShapeType = cute::Shape<int, int, int>;
using ProblemShape = cutlass::gemm::GroupProblemShape<ProblemShapeType>; // <M,N,K> per group
template <typename ScheduleConfig>
struct GemmGivenSchedule {
using ElementA = typename ScheduleConfig::DataType; // Element type for A matrix operand
using ElementB = typename ScheduleConfig::DataType; // Element type for B matrix operand
using ElementC = typename ScheduleConfig::DataType; // Element type for C and D matrix operands
// A matrix configuration
using LayoutA = typename ScheduleConfig::LayoutA; // Layout type for A matrix operand
static constexpr int AlignmentA =
128 / cutlass::sizeof_bits<
ElementA>::value; // Alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using LayoutB = typename ScheduleConfig::LayoutB; // Layout type for B matrix operand
static constexpr int AlignmentB =
128 / cutlass::sizeof_bits<
ElementB>::value; // Alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using LayoutC = typename ScheduleConfig::LayoutC; // Layout type for C and D matrix operands
static constexpr int AlignmentC =
128 / cutlass::sizeof_bits<
ElementC>::value; // Alignment of C matrix in units of elements (up to 16 bytes)
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag =
cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using StageCountType =
cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using TileShape = typename ScheduleConfig::TileShape; // Threadblock-level tile size
using ClusterShape =
typename ScheduleConfig::ClusterShape; // Shape of the threadblocks in a cluster
using KernelSchedule = typename ScheduleConfig::KernelSchedule; // Kernel to launch
using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; // Epilogue to launch
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator,
ElementC, LayoutC*, AlignmentC, ElementC, LayoutC*, AlignmentC, EpilogueSchedule,
cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB, LayoutB*, AlignmentB,
ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernel =
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
};
template <typename DataType_, bool trans_a, bool trans_b>
struct ScheduleConfig {
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using TileShape = cute::Shape<cute::_128, cute::_128, cute::_128>;
using ClusterShape = cute::Shape<cute::_1, cute::_2, cute::_1>;
// TODO(Alan): Add tuning for different scenarios to select the optimal configuration,
// as the current configuration may not be the best.
// using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
// using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
// using TileShape = Shape<cute::_256, cute::_128, cute::_128>;
// using ClusterShape = Shape<cute::_1, cute::_2, cute::_1>;
using LayoutA = GroupedGemmInputALayout<trans_a>;
using LayoutB = GroupedGemmInputBLayout<trans_b>;
using LayoutC = cutlass::layout::RowMajor;
using DataType = DataType_;
};
template <typename DataType_, bool trans_a, bool trans_b>
using GemmGrouped = typename GemmGivenSchedule<ScheduleConfig<DataType_, trans_a, trans_b>>::Gemm;
template <typename GemmT, typename ElementA, typename ElementB, typename ElementC, typename StrideA,
typename StrideB, typename StrideC>
typename GemmT::Arguments MakeArguments(int num_experts, void* problem_sizes_host,
void* problem_sizes, const ElementA** ptr_A,
StrideA* stride_A, const ElementB** ptr_B,
StrideB* stride_B, ElementC** ptr_C, StrideC* stride_C,
float alpha, float beta, int device, int math_sm_count) {
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
// to use a GPU other than that with device ID 0.
cutlass::KernelHardwareInfo kernel_hw_info =
cutlass::KernelHardwareInfo::make_kernel_hardware_info<typename GemmT::GemmKernel>(
device, math_sm_count);
typename GemmT::Arguments arguments;
decltype(arguments.epilogue.thread) fusion_args;
fusion_args.alpha = alpha;
fusion_args.beta = beta;
fusion_args.alpha_ptr = nullptr;
fusion_args.beta_ptr = nullptr;
fusion_args.alpha_ptr_array = nullptr;
fusion_args.beta_ptr_array = nullptr;
// Single alpha and beta for all groups
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
arguments =
typename GemmT::Arguments{cutlass::gemm::GemmUniversalMode::kGrouped,
{num_experts, reinterpret_cast<ProblemShapeType*>(problem_sizes),
reinterpret_cast<ProblemShapeType const*>(problem_sizes_host)},
{ptr_A, stride_A, ptr_B, stride_B},
{
fusion_args,
(beta > 0.0) ? (const ElementC**)ptr_C : nullptr, // NOLINT(*)
stride_C,
ptr_C,
stride_C,
},
kernel_hw_info};
return arguments;
}
template <typename T>
inline __device__ __host__ T ROUND_UP(T m, T n) {
return (m + n - 1) / n * n;
}
template <typename T>
void debug_type() {
std::cout << typeid(T).name() << std::endl;
}
int64_t inline getGemmCoordSize(int64_t num_gemms) {
return (int64_t)(ROUND_UP(num_gemms * sizeof(ProblemShapeType), 128UL));
}
int64_t inline getPtrSize(int64_t num_gemms) {
return (int64_t)(ROUND_UP(num_gemms * sizeof(half*), 128UL));
}
int64_t inline getLddSize(int64_t num_gemms) {
return (int64_t)(ROUND_UP(num_gemms * sizeof(int64_t), 128UL));
}
// cpu workspace size is 4MB
static constexpr size_t kCPUWorkSpaceSize = 4 * 1024 * 1024;
static char* getHostWorkspace() {
static std::once_flag flag;
static std::shared_ptr<char> workspace;
std::call_once(flag, [&]() {
workspace =
std::shared_ptr<char>(reinterpret_cast<char*>(std::malloc(kCPUWorkSpaceSize)), [](char* p) {
if (p) std::free(p);
});
if (!workspace) {
throw std::bad_alloc();
}
});
return workspace.get();
}
template <bool trans_a, bool trans_b, typename Element>
void CutlassGroupedGemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D,
NVTETensor* workspace, float alpha, float beta, int num_gemms,
cudaStream_t stream, int device, int math_sm_count) {
using Gemm = GemmGrouped<Element, trans_a, trans_b>;
using LayoutA = typename Gemm::LayoutA;
using LayoutB = typename Gemm::LayoutB;
using LayoutC = typename Gemm::LayoutC;
using ElementA = typename Gemm::ElementA;
using ElementB = typename Gemm::ElementB;
using ElementC = typename Gemm::ElementC;
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
typename Gemm::Arguments arguments;
size_t kernel_workspace_size = Gemm::get_workspace_size(arguments);
auto gemm_coord_size = getGemmCoordSize(num_gemms);
auto ptr_size = getPtrSize(num_gemms);
auto ldd_size = getLddSize(num_gemms);
auto param_workspace_size = 3 * ptr_size + 3 * ldd_size + gemm_coord_size;
NVTE_CHECK(
param_workspace_size < kCPUWorkSpaceSize,
"Insufficient kCPUWorkSpaceSize size: required=", static_cast<int64_t>(param_workspace_size),
", available=", static_cast<int64_t>(kCPUWorkSpaceSize), " for CUTLASS grouped GEMM.");
auto total_workspace_size = param_workspace_size + kernel_workspace_size;
transformer_engine::Tensor* wspace = transformer_engine::convertNVTETensor(workspace[0]);
NVTE_CHECK(total_workspace_size < wspace->numel(), "Insufficient workspace[0] size: required=",
static_cast<int64_t>(total_workspace_size),
", available=", static_cast<int64_t>(wspace->numel()), " for CUTLASS grouped GEMM.");
char* workspace_ptr = reinterpret_cast<char*>(wspace->data.dptr);
char* kernel_workspace_ptr = nullptr;
char* host_workspace = getHostWorkspace();
ProblemShapeType* problem_sizes_host = reinterpret_cast<ProblemShapeType*>(host_workspace);
ElementA** ptr_A_host = reinterpret_cast<ElementA**>(host_workspace + gemm_coord_size);
ElementB** ptr_B_host = reinterpret_cast<ElementB**>(host_workspace + gemm_coord_size + ptr_size);
ElementC** ptr_C_host =
reinterpret_cast<ElementC**>(host_workspace + gemm_coord_size + 2 * ptr_size);
int64_t* lda_host =
reinterpret_cast<int64_t*>(host_workspace + gemm_coord_size + 3 * ptr_size + 0 * ldd_size);
int64_t* ldb_host =
reinterpret_cast<int64_t*>(host_workspace + gemm_coord_size + 3 * ptr_size + 1 * ldd_size);
int64_t* ldc_host =
reinterpret_cast<int64_t*>(host_workspace + gemm_coord_size + 3 * ptr_size + 2 * ldd_size);
for (size_t i = 0; i < num_gemms; i++) {
const transformer_engine::Tensor* inputA = transformer_engine::convertNVTETensorCheck(A[i]);
const transformer_engine::Tensor* inputB = transformer_engine::convertNVTETensorCheck(B[i]);
transformer_engine::Tensor* outputD = transformer_engine::convertNVTETensor(D[i]);
const int m = trans_a ? inputA->data.shape[1] : inputA->data.shape[0];
const int k = trans_a ? inputA->data.shape[0] : inputA->data.shape[1];
const int n = trans_b ? inputB->data.shape[0] : inputB->data.shape[1];
auto problem = ProblemShapeType(m, n, k);
problem_sizes_host[i] = problem;
ptr_A_host[i] = reinterpret_cast<ElementA*>(inputA->data.dptr);
ptr_B_host[i] = reinterpret_cast<ElementB*>(inputB->data.dptr);
ptr_C_host[i] = reinterpret_cast<ElementC*>(outputD->data.dptr);
lda_host[i] = LayoutA::packed({m, k}).stride(0);
ldb_host[i] = LayoutB::packed({k, n}).stride(0);
ldc_host[i] = LayoutC::packed({m, n}).stride(0);
}
cudaMemcpyAsync(workspace_ptr, host_workspace, param_workspace_size, cudaMemcpyHostToDevice,
stream);
char* param_workspace_ptr = workspace_ptr;
ProblemShapeType* problem_sizes_device = reinterpret_cast<ProblemShapeType*>(param_workspace_ptr);
const ElementA** ptr_A = reinterpret_cast<const ElementA**>(
reinterpret_cast<char*>(param_workspace_ptr) + gemm_coord_size);
const ElementB** ptr_B = reinterpret_cast<const ElementB**>(
reinterpret_cast<char*>(param_workspace_ptr) + gemm_coord_size + 1 * ptr_size);
ElementC** ptr_C = reinterpret_cast<ElementC**>(reinterpret_cast<char*>(param_workspace_ptr) +
gemm_coord_size + 2 * ptr_size);
StrideA* lda = reinterpret_cast<StrideA*>(reinterpret_cast<char*>(param_workspace_ptr) +
gemm_coord_size + 3 * ptr_size + 0 * ldd_size);
StrideB* ldb = reinterpret_cast<StrideB*>(reinterpret_cast<char*>(param_workspace_ptr) +
gemm_coord_size + 3 * ptr_size + 1 * ldd_size);
StrideC* ldc = reinterpret_cast<StrideC*>(reinterpret_cast<char*>(param_workspace_ptr) +
gemm_coord_size + 3 * ptr_size + 2 * ldd_size);
kernel_workspace_ptr = workspace_ptr + param_workspace_size;
arguments = MakeArguments<Gemm, ElementA, ElementB, ElementC, StrideA, StrideB, StrideC>(
num_gemms, problem_sizes_host, problem_sizes_device, ptr_A, lda, ptr_B, ldb, ptr_C, ldc,
alpha, beta, device, math_sm_count);
Gemm gemm;
// Check can implement the kernel.
if (gemm.can_implement(arguments) != cutlass::Status::kSuccess) {
NVTE_CHECK(false, "Failed to implement CUTLASS Grouped GEMM");
}
// Initialize the kernel.
if (gemm.initialize(arguments, kernel_workspace_ptr) != cutlass::Status::kSuccess) {
NVTE_CHECK(false, "Failed to initialize CUTLASS Grouped GEMM");
}
// Execute the kernel in the current stream.
if (gemm.run(stream) != cutlass::Status::kSuccess) {
NVTE_CHECK(false, "Failed to run CUTLASS Grouped GEMM");
}
}
} // namespace grouped_gemm
} // namespace transformer_engine
void cutlass_grouped_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, int num_gemms,
bool transa, bool transb, bool grad, NVTETensor* workspace,
bool accumulate, int device, int math_sm_count, cudaStream_t stream);
......@@ -133,11 +133,10 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
* \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics)
* \param[in] stream CUDA stream to wait on.
*/
void nvte_multi_stream_cublas_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D,
const NVTETensor* bias, NVTETensor* pre_gelu_out,
const int num_gemms, bool transa, bool transb, bool grad,
NVTETensor* workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count,
void nvte_multi_tensor_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D,
const NVTETensor* bias, NVTETensor* pre_gelu_out, const int num_gemms,
bool transa, bool transb, bool grad, NVTETensor* workspace,
bool accumulate, bool use_split_accumulator, int math_sm_count,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
......
......@@ -526,10 +526,10 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
NVTE_CHECK_CUDA(cudaMemsetAsync(dptr, 0, count, stream_i));
}
nvte_multi_stream_cublas_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(),
pre_gelu_list.data(), num_non_empty_gemms, rhs_is_trans,
lhs_is_trans, grad, workspace_list.data(), accumulate,
use_split_accumulator, num_math_sm, stream);
nvte_multi_tensor_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(),
pre_gelu_list.data(), num_non_empty_gemms, rhs_is_trans, lhs_is_trans,
grad, workspace_list.data(), accumulate, use_split_accumulator,
num_math_sm, stream);
return ffi_with_cuda_error_check();
}
......
......@@ -477,11 +477,10 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
// For now, we only have multi-stream cublas backend.
NVTE_SCOPED_GIL_RELEASE({
nvte_multi_stream_cublas_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(),
te_bias_vector.data(), te_pre_gelu_out_vector.data(),
te_A_vector.size(), transa, transb, grad,
te_workspace_vector.data(), accumulate, use_split_accumulator,
math_sm_count, at::cuda::getCurrentCUDAStream());
nvte_multi_tensor_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(),
te_bias_vector.data(), te_pre_gelu_out_vector.data(), te_A_vector.size(),
transa, transb, grad, te_workspace_vector.data(), accumulate,
use_split_accumulator, math_sm_count, at::cuda::getCurrentCUDAStream());
});
return bias;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment