"git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "36873ec87292c9fbd16096fa2bf509b116ef449e"
Unverified Commit 8dba2963 authored by Vladimir Cherepanov's avatar Vladimir Cherepanov Committed by GitHub
Browse files

Add cuBLASMp-backed GEMM-like API to TE common (#1824)



* Pick up cuBLASMp during build
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Change lib order to fix link error
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Context creation, incomplete...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Test fixure
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* A sanity AgGemm test, failing...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix axes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Take care of uneven distribution
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use MPI to get position of local matrices
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Refactor
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Refactor & fixes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Gemm-RS
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Gemm-AR, not working...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fixes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Setting all-reduce epilogue for gemm-ar
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use supported shapes for GEMM-AR
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak tolerance
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* First shot at fp8
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use TensorHolder in tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More test configs
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Support comm_sm_count
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Parametrize dtypes for A, B and D separately
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak scaling
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Amax ptr
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Flags parity with cublas_gemm, saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Cleanup
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Bias tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix bias test
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Aux, saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* aux_ld
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* A fix
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use test::Tensor
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Set scale inv
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove unsupported test configs
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Replace libcal with NCCL
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Add NVTX markers to API functions
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak GemmAr tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More test config
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix merge fallout
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove MPI dependency, comment API, add algo parameter
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix nvshmem dependency
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix nvshmem build
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Excluse CommGemm tests from L0_cppunittest
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Add cpp_distributed sh file for CI
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Adapt tp TensorAllocator
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Skip GemmAr test on unsupported HW
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Oversibscribe is needed on some clusters
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix incomplete libcal removal
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Move CI tests to L1
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Rename context to include NVTE prefix
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove leftover code
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* NVTE_WITH_CUBLASMP off by default
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More detailed NVTE_CHECK diag
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Comment API
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Include stdbool header for legacy C compilers
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove now unused argument
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Abstract away cuBLASMp algo behind our own enum
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More detailed shape diag messages
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/common/include/transformer_engine/comm_gemm.h
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarVladimir Cherepanov <56651474+mk-61@users.noreply.github.com>

* Add license
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

---------
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: default avatarVladimir Cherepanov <56651474+mk-61@users.noreply.github.com>
Co-authored-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
parent 1398fa5f
...@@ -17,4 +17,4 @@ cd $TE_PATH/tests/cpp ...@@ -17,4 +17,4 @@ cd $TE_PATH/tests/cpp
cmake -GNinja -Bbuild . cmake -GNinja -Bbuild .
cmake --build build cmake --build build
export OMP_NUM_THREADS=$((NUM_PHYSICAL_CORES / NUM_PARALLEL_JOBS)) export OMP_NUM_THREADS=$((NUM_PHYSICAL_CORES / NUM_PARALLEL_JOBS))
ctest --test-dir build -j$NUM_PARALLEL_JOBS ctest --test-dir build -j$NUM_PARALLEL_JOBS -E '(AgGemm|GemmRs|GemmAr)'
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
# Find TE
: ${TE_PATH:=/opt/transformerengine}
TE_LIB_PATH=$(pip3 show transformer-engine | grep -E "Location:|Editable project location:" | tail -n 1 | awk '{print $NF}')
export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH
cd $TE_PATH/tests/cpp
cmake -GNinja -S. -Bbuild
cmake --build build
mpirun --allow-run-as-root --np 4 --oversubscribe ./build/comm_gemm/test_comm_gemm
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""Installation script.""" """Installation script."""
from importlib import metadata
import os import os
import time import time
from pathlib import Path from pathlib import Path
...@@ -66,6 +67,18 @@ def setup_common_extension() -> CMakeExtension: ...@@ -66,6 +67,18 @@ def setup_common_extension() -> CMakeExtension:
if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))): if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))):
cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON") cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON")
if bool(int(os.getenv("NVTE_WITH_CUBLASMP", "0"))):
cmake_flags.append("-DNVTE_WITH_CUBLASMP=ON")
cublasmp_dir = os.getenv("CUBLASMP_HOME") or metadata.distribution(
"nvidia-cublasmp-cu12"
).locate_file("nvidia/cublasmp/cu12")
cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}")
nvshmem_dir = os.getenv("NVSHMEM_HOME") or metadata.distribution(
"nvidia-nvshmem-cu12"
).locate_file("nvidia/nvshmem")
cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}")
print("CMAKE_FLAGS:", cmake_flags[-2:])
# Add custom CMake arguments from environment variable # Add custom CMake arguments from environment variable
nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS") nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS")
if nvte_cmake_extra_args: if nvte_cmake_extra_args:
......
...@@ -37,10 +37,12 @@ find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_ ...@@ -37,10 +37,12 @@ find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_
message(STATUS "Found transformer_engine library: ${TE_LIB}") message(STATUS "Found transformer_engine library: ${TE_LIB}")
include_directories(../../transformer_engine/common/include) include_directories(../../transformer_engine/common/include)
include_directories(../../transformer_engine/common) include_directories(../../transformer_engine/common)
include_directories(../../transformer_engine)
include_directories(${CMAKE_SOURCE_DIR}) include_directories(${CMAKE_SOURCE_DIR})
find_package(CUDAToolkit REQUIRED) find_package(CUDAToolkit REQUIRED)
include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake) include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake)
add_subdirectory(comm_gemm)
add_subdirectory(operator) add_subdirectory(operator)
add_subdirectory(util) add_subdirectory(util)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
add_executable(test_comm_gemm
test_comm_gemm.cu
../test_common.cu)
find_package(OpenMP REQUIRED)
find_package(MPI REQUIRED)
find_library(NCCL_LIB
NAMES nccl libnccl
PATH_SUFFIXES lib
REQUIRED)
target_include_directories(test_comm_gemm PRIVATE ${MPI_CXX_INCLUDE_PATH} $ENV{CUBLASMP_HOME}/include)
target_link_libraries(test_comm_gemm PUBLIC CUDA::cuda_driver CUDA::cudart GTest::gtest ${TE_LIB} CUDA::nvrtc CUDNN::cudnn MPI::MPI_CXX ${NCCL_LIB} OpenMP::OpenMP_CXX)
include(GoogleTest)
gtest_discover_tests(test_comm_gemm DISCOVERY_TIMEOUT 600)
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <gtest/gtest.h>
#include <mpi.h>
#include <nccl.h>
#include <transformer_engine/comm_gemm.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/transformer_engine.h>
#include <iostream>
#include <limits>
#include <random>
#include <sstream>
#include <string>
#include <vector>
#include "../test_common.h"
#include "common.h"
using transformer_engine::DType;
using transformer_engine::TypeInfo;
#define CHECK_MPI(expr) \
do { \
int err = (expr); \
if (err != MPI_SUCCESS) { \
char err_str[MPI_MAX_ERROR_STRING + 1]{}; \
int _len{}; \
MPI_Error_string(err, err_str, &_len); \
EXPECT_TRUE(false) << "MPI error: " << err << ": " << err_str; \
} \
} while (false)
#define CHECK_NCCL(expr) \
do { \
ncclResult_t err = (expr); \
if (err != ncclSuccess) { \
EXPECT_TRUE(false) << "NCCL error: " << err << ": " << ncclGetErrorString(err); \
} \
} while (false)
#define CHECK_CU(expr) \
do { \
CUresult err = (expr); \
if (err != CUDA_SUCCESS) { \
const char* str{}; \
CUresult e_str = cuGetErrorString(err, &str); \
if (e_str != CUDA_SUCCESS) str = "(unknown)"; \
EXPECT_TRUE(false) << "CU error: " << err << ": " << str; \
} \
} while (false)
int main(int argc, char* argv[]) {
::testing::InitGoogleTest(&argc, argv);
CHECK_MPI(MPI_Init(&argc, &argv));
auto ret = RUN_ALL_TESTS();
CHECK_MPI(MPI_Finalize());
return ret;
}
bool IsMulticastSupported(int device_id) {
int supported = 0;
CHECK_CU(cuDeviceGetAttribute(&supported, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, device_id));
return supported;
}
template <typename T>
std::vector<T> CopyMatrix(const std::vector<T>& data, size_t mstart, size_t nstart, size_t msize,
size_t nsize, size_t ld) {
std::vector<T> ret(msize * nsize);
size_t dst = 0;
for (size_t j = nstart; j < nstart + nsize; ++j) {
for (size_t i = mstart; i < mstart + msize; ++i) {
ret[dst++] = data[j * ld + i];
}
}
return ret;
}
template <typename T>
test::Tensor Make(size_t m, size_t n, float scale) {
test::Tensor ret("", std::vector{n, m}, TypeInfo<T>::dtype);
ret.set_scale(scale);
ret.set_scale_inv(1.0 / scale);
return ret;
}
template <typename T>
test::Tensor MakeFromData(const std::vector<T>& data, size_t mstart, size_t nstart, size_t msize,
size_t nsize, size_t ld, float scale) {
test::Tensor ret("", std::vector{nsize, msize}, TypeInfo<T>::dtype);
ret.set_scale(scale);
ret.set_scale_inv(1.0 / scale);
auto local = CopyMatrix(data, mstart, nstart, msize, nsize, ld);
NVTE_CHECK_CUDA(cudaMemcpy(ret.rowwise_dptr(), local.data(), local.size() * sizeof local[0],
cudaMemcpyDefault));
return ret;
}
template <typename T>
float GetScale(float amax) {
if constexpr (sizeof(T) > 1) return 1.0;
return static_cast<float>(static_cast<T>(std::numeric_limits<float>::max())) / amax;
}
struct Params {
DType a_type;
DType b_type;
DType d_type;
bool transa;
bool transb;
size_t m;
size_t n;
size_t k;
float tol;
};
class CommGemmFixure : public ::testing::TestWithParam<Params> {
protected:
CommGemmFixure() {
CHECK_MPI(MPI_Comm_size(MPI_COMM_WORLD, &nranks_));
CHECK_MPI(MPI_Comm_rank(MPI_COMM_WORLD, &rank_));
NVTE_CHECK_CUDA(cudaSetDevice(rank_));
ncclUniqueId id{};
if (rank_ == 0) CHECK_NCCL(ncclGetUniqueId(&id));
CHECK_MPI(MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD));
CHECK_NCCL(ncclCommInitRank(&comm_, nranks_, id, rank_));
ctx_ = nvte_comm_gemm_ctx_create(comm_, nranks_, rank_);
}
~CommGemmFixure() {
nvte_comm_gemm_ctx_destroy(ctx_);
ncclCommDestroy(comm_);
}
struct PatternDims {
int64_t a_rows_start;
int64_t a_rows_num;
int64_t a_cols_start;
int64_t a_cols_num;
int64_t b_rows_start;
int64_t b_rows_num;
int64_t b_cols_start;
int64_t b_cols_num;
int64_t d_rows_start;
int64_t d_rows_num;
int64_t d_cols_start;
int64_t d_cols_num;
};
virtual PatternDims DistributeTensors(int64_t m, int64_t n, int64_t k) = 0;
virtual void CommGemm(int64_t m, int64_t n, int64_t k, const NVTETensor a, const NVTETensor b,
const NVTETensor d, const NVTETensor bias, const NVTETensor pre_act_out,
bool transa, bool transb, bool grad, bool accumulate, int comm_sm_count,
cudaStream_t stream) = 0;
template <typename AType, typename BType, typename DType, typename BiasType>
void Run(bool transa, bool transb, size_t m, size_t n, size_t k, float tol) {
cudaStream_t stream{};
NVTE_CHECK_CUDA(cudaStreamCreate(&stream));
constexpr float MAX_IN = 1.0;
std::mt19937 rng(12);
std::uniform_real_distribution<float> dist(0.0, MAX_IN);
float a_scale = GetScale<AType>(MAX_IN);
float b_scale = GetScale<BType>(MAX_IN);
float d_scale = GetScale<DType>(MAX_IN * MAX_IN * k);
float bias_scale = GetScale<BiasType>(MAX_IN);
std::vector<AType> adata(m * k);
std::generate(adata.begin(), adata.end(),
[&rng, &dist, a_scale] { return static_cast<AType>(dist(rng) * a_scale); });
std::vector<BType> bdata(k * n);
std::generate(bdata.begin(), bdata.end(),
[&rng, &dist, b_scale] { return static_cast<BType>(dist(rng) * b_scale); });
std::vector<BiasType> biasdata(m * n);
std::generate(biasdata.begin(), biasdata.end(), [&rng, &dist, bias_scale] {
return static_cast<BiasType>(dist(rng) * bias_scale);
});
auto ga = transa ? MakeFromData<AType>(adata, 0, 0, k, m, k, a_scale)
: MakeFromData<AType>(adata, 0, 0, m, k, m, a_scale);
auto gb = transb ? MakeFromData<BType>(bdata, 0, 0, n, k, n, b_scale)
: MakeFromData<BType>(bdata, 0, 0, k, n, k, b_scale);
auto gbias = MakeFromData<BiasType>(biasdata, 0, 0, m, n, m, bias_scale);
auto gd = Make<DType>(m, n, d_scale);
auto gaux = Make<DType>(m, n, d_scale);
auto dims = DistributeTensors(m, n, k);
auto a = transa ? MakeFromData<AType>(adata, dims.a_rows_start, dims.a_cols_start,
dims.a_rows_num, dims.a_cols_num, k, a_scale)
: MakeFromData<AType>(adata, dims.a_cols_start, dims.a_rows_start,
dims.a_cols_num, dims.a_rows_num, m, a_scale);
auto b = transb ? MakeFromData<BType>(bdata, dims.b_cols_start, dims.b_rows_start,
dims.b_cols_num, dims.b_rows_num, n, b_scale)
: MakeFromData<BType>(bdata, dims.b_rows_start, dims.b_cols_start,
dims.b_rows_num, dims.b_cols_num, k, b_scale);
auto bias = MakeFromData<BiasType>(biasdata, dims.d_rows_start, dims.d_cols_start,
dims.d_rows_num, dims.d_cols_num, m, bias_scale);
auto d = Make<DType>(dims.d_rows_num, dims.d_cols_num, d_scale);
auto aux = Make<DType>(dims.d_rows_num, dims.d_cols_num, d_scale);
bool grad = false;
bool accumulate = false;
CommGemm(m, n, k, a.data(), b.data(), d.data(), bias.data(), aux.data(), transa, transb, grad,
accumulate, 0 /*comm_sm_count*/, stream);
auto workspace = Make<uint8_t>(1, 32 << 20, 1.0);
nvte_cublas_gemm(ga.data(), gb.data(), gd.data(), gbias.data(), gaux.data(), transa, transb,
grad, workspace.data(), accumulate, false /* use_split_accumulator */,
0 /* math_sm_count */, stream);
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
NVTE_CHECK_CUDA(cudaStreamDestroy(stream));
std::vector<DType> out(dims.d_rows_num * dims.d_cols_num);
NVTE_CHECK_CUDA(
cudaMemcpy(out.data(), d.rowwise_dptr(), out.size() * sizeof out[0], cudaMemcpyDefault));
std::vector<DType> out_golden_global(m * n);
NVTE_CHECK_CUDA(cudaMemcpy(out_golden_global.data(), gd.rowwise_dptr(),
out_golden_global.size() * sizeof out_golden_global[0],
cudaMemcpyDefault));
auto out_golden = CopyMatrix(out_golden_global, dims.d_rows_start, dims.d_cols_start,
dims.d_rows_num, dims.d_cols_num, m);
NVTE_CHECK(out.size() == out_golden.size());
for (size_t i = 0; i < out.size(); ++i) {
EXPECT_NEAR(static_cast<float>(out[i]), static_cast<float>(out_golden[i]), tol * k);
}
}
NVTECommGemmCtx* ctx_{};
int nranks_{};
int rank_{};
ncclComm_t comm_{};
};
struct AgGemm : public CommGemmFixure {
PatternDims DistributeTensors(int64_t m, int64_t n, int64_t k) override {
auto a_cols_num = nvte_comm_gemm_numroc(ctx_, m);
auto b_cols_num = nvte_comm_gemm_numroc(ctx_, n);
int64_t a_cols_start{};
int64_t b_cols_start{};
MPI_Exscan(&a_cols_num, &a_cols_start, 1, MPI_INT64_T, MPI_SUM, MPI_COMM_WORLD);
MPI_Exscan(&b_cols_num, &b_cols_start, 1, MPI_INT64_T, MPI_SUM, MPI_COMM_WORLD);
return PatternDims{
.a_rows_start = 0,
.a_rows_num = k,
.a_cols_start = a_cols_start,
.a_cols_num = a_cols_num,
.b_rows_start = 0,
.b_rows_num = k,
.b_cols_start = b_cols_start,
.b_cols_num = b_cols_num,
.d_rows_start = a_cols_start,
.d_rows_num = a_cols_num,
.d_cols_start = 0,
.d_cols_num = n,
};
}
void CommGemm(int64_t m, int64_t n, int64_t k, const NVTETensor a, const NVTETensor b,
const NVTETensor d, const NVTETensor bias, const NVTETensor pre_act_out,
bool transa, bool transb, bool grad, bool accumulate, int comm_sm_count,
cudaStream_t stream) override {
nvte_all_gather_gemm(ctx_, m, n, k, a, b, d, bias, pre_act_out, transa, transb, grad,
accumulate, comm_sm_count, stream, kNVTECommGemmAlgoDefault);
}
};
struct GemmRs : public CommGemmFixure {
PatternDims DistributeTensors(int64_t m, int64_t n, int64_t k) override {
auto rows_num = nvte_comm_gemm_numroc(ctx_, k);
auto d_cols_num = nvte_comm_gemm_numroc(ctx_, n);
int64_t rows_start{};
int64_t d_cols_start{};
MPI_Exscan(&rows_num, &rows_start, 1, MPI_INT64_T, MPI_SUM, MPI_COMM_WORLD);
MPI_Exscan(&d_cols_num, &d_cols_start, 1, MPI_INT64_T, MPI_SUM, MPI_COMM_WORLD);
return PatternDims{
.a_rows_start = rows_start,
.a_rows_num = rows_num,
.a_cols_start = 0,
.a_cols_num = m,
.b_rows_start = rows_start,
.b_rows_num = rows_num,
.b_cols_start = 0,
.b_cols_num = n,
.d_rows_start = 0,
.d_rows_num = m,
.d_cols_start = d_cols_start,
.d_cols_num = d_cols_num,
};
}
void CommGemm(int64_t m, int64_t n, int64_t k, const NVTETensor a, const NVTETensor b,
const NVTETensor d, const NVTETensor bias, const NVTETensor pre_act_out,
bool transa, bool transb, bool grad, bool accumulate, int comm_sm_count,
cudaStream_t stream) override {
nvte_gemm_reduce_scatter(ctx_, m, n, k, a, b, d, bias, pre_act_out, transa, transb, grad,
accumulate, comm_sm_count, stream, kNVTECommGemmAlgoDefault);
}
};
struct GemmAr : public CommGemmFixure {
PatternDims DistributeTensors(int64_t m, int64_t n, int64_t k) override {
auto rows_num = nvte_comm_gemm_numroc(ctx_, k);
int64_t rows_start{};
MPI_Exscan(&rows_num, &rows_start, 1, MPI_INT64_T, MPI_SUM, MPI_COMM_WORLD);
return PatternDims{
.a_rows_start = rows_start,
.a_rows_num = rows_num,
.a_cols_start = 0,
.a_cols_num = m,
.b_rows_start = rows_start,
.b_rows_num = rows_num,
.b_cols_start = 0,
.b_cols_num = n,
.d_rows_start = 0,
.d_rows_num = m,
.d_cols_start = 0,
.d_cols_num = n,
};
}
void CommGemm(int64_t m, int64_t n, int64_t k, const NVTETensor a, const NVTETensor b,
const NVTETensor d, const NVTETensor bias, const NVTETensor pre_act_out,
bool transa, bool transb, bool grad, bool accumulate, int comm_sm_count,
cudaStream_t stream) override {
nvte_gemm_all_reduce(ctx_, m, n, k, a, b, d, bias, pre_act_out, transa, transb, grad,
accumulate, comm_sm_count, stream, kNVTECommGemmAlgoDefault);
}
void SetUp() override {
if (!IsMulticastSupported(rank_))
GTEST_SKIP() << "Multicast is not supported on device " << rank_;
}
};
TEST_P(AgGemm, Gemm) {
auto [a_type, b_type, d_type, transa, transb, m, n, k, tol] = GetParam();
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
a_type, AType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
b_type, BType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
d_type, DType, Run<AType, BType, DType, DType>(transa, transb, m, n, k, tol);)));
}
TEST_P(GemmRs, Gemm) {
auto [a_type, b_type, d_type, transa, transb, m, n, k, tol] = GetParam();
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
a_type, AType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
b_type, BType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
d_type, DType, Run<AType, BType, DType, DType>(transa, transb, m, n, k, tol);)));
}
TEST_P(GemmAr, Gemm) {
auto [a_type, b_type, d_type, transa, transb, m, n, k, tol] = GetParam();
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
a_type, AType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
b_type, BType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
d_type, DType, Run<AType, BType, DType, DType>(transa, transb, m, n, k, tol);)));
}
std::string ParamSuffix(const testing::TestParamInfo<Params>& info) {
const auto [a_type, b_type, d_type, transa, transb, m, n, k, _tol] = info.param;
std::ostringstream ss;
ss << static_cast<int>(a_type) << "_" << static_cast<int>(b_type) << "_"
<< static_cast<int>(d_type) << "_" << (transa ? "T" : "N") << (transb ? "T" : "N") << "_" << m
<< "x" << n << "x" << k;
return ss.str();
}
INSTANTIATE_TEST_SUITE_P(AgGemm, AgGemm,
testing::Values(Params{DType::kFloat16, DType::kFloat16, DType::kFloat16,
false, false, 256, 128, 64, 1e-3},
Params{DType::kFloat16, DType::kFloat16, DType::kFloat16,
false, true, 256, 128, 64, 1e-3},
Params{DType::kFloat16, DType::kFloat16, DType::kFloat16,
true, false, 256, 128, 64, 1e-3},
Params{DType::kBFloat16, DType::kBFloat16,
DType::kBFloat16, false, false, 256, 128, 64, 1e-3},
Params{DType::kBFloat16, DType::kBFloat16,
DType::kBFloat16, false, true, 256, 128, 64, 1e-3},
Params{DType::kBFloat16, DType::kBFloat16,
DType::kBFloat16, true, false, 256, 128, 64, 1e-3},
Params{DType::kFloat8E4M3, DType::kFloat8E4M3,
DType::kFloat16, true, false, 256, 128, 64, 1e-3},
Params{DType::kFloat8E4M3, DType::kFloat8E5M2,
DType::kFloat16, true, false, 256, 128, 64, 1e-3},
Params{DType::kFloat8E5M2, DType::kFloat8E4M3,
DType::kFloat16, true, false, 256, 128, 64, 1e-3}),
&ParamSuffix);
INSTANTIATE_TEST_SUITE_P(GemmRs, GemmRs,
testing::Values(Params{DType::kFloat16, DType::kFloat16, DType::kFloat16,
false, false, 64, 128, 256, 5e-2},
Params{DType::kFloat16, DType::kFloat16, DType::kFloat16,
false, true, 64, 128, 256, 5e-2},
Params{DType::kFloat16, DType::kFloat16, DType::kFloat16,
true, false, 64, 128, 256, 5e-2},
Params{DType::kBFloat16, DType::kBFloat16,
DType::kBFloat16, false, false, 64, 128, 256, 5e-2},
Params{DType::kBFloat16, DType::kBFloat16,
DType::kBFloat16, false, true, 64, 128, 256, 5e-2},
Params{DType::kBFloat16, DType::kBFloat16,
DType::kBFloat16, true, false, 64, 128, 256, 5e-2},
Params{DType::kFloat8E4M3, DType::kFloat8E4M3,
DType::kFloat16, true, false, 64, 128, 256, 5e-2},
Params{DType::kFloat8E4M3, DType::kFloat8E5M2,
DType::kFloat16, true, false, 64, 128, 256, 5e-2},
Params{DType::kFloat8E5M2, DType::kFloat8E4M3,
DType::kFloat16, true, false, 64, 128, 256, 5e-2}),
&ParamSuffix);
INSTANTIATE_TEST_SUITE_P(
GemmAr, GemmAr,
testing::Values(Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, true, false, 64,
64 * 4, 64 * 4, 5e-2},
Params{DType::kBFloat16, DType::kBFloat16, DType::kBFloat16, true, false, 64,
64 * 4, 64 * 4, 5e-2},
Params{DType::kFloat8E5M2, DType::kFloat8E4M3, DType::kFloat16, true, false,
128, 128 * 4, 128 * 4, 5e-2},
Params{DType::kFloat8E4M3, DType::kFloat8E5M2, DType::kFloat16, true, false,
128, 128 * 4, 128 * 4, 5e-2},
Params{DType::kFloat8E4M3, DType::kFloat8E4M3, DType::kFloat16, true, false,
128, 128 * 4, 128 * 4, 5e-2}),
&ParamSuffix);
...@@ -110,6 +110,12 @@ list(APPEND transformer_engine_SOURCES ...@@ -110,6 +110,12 @@ list(APPEND transformer_engine_SOURCES
comm_gemm_overlap/userbuffers/userbuffers-host.cpp comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu comm_gemm_overlap/userbuffers/userbuffers.cu
comm_gemm_overlap/comm_gemm_overlap.cpp) comm_gemm_overlap/comm_gemm_overlap.cpp)
if (NVTE_WITH_CUBLASMP)
list(APPEND transformer_engine_SOURCES
comm_gemm/comm_gemm.cpp)
endif()
add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
target_include_directories(transformer_engine PUBLIC target_include_directories(transformer_engine PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/include") "${CMAKE_CURRENT_SOURCE_DIR}/include")
...@@ -123,6 +129,8 @@ target_link_libraries(transformer_engine PUBLIC ...@@ -123,6 +129,8 @@ target_link_libraries(transformer_engine PUBLIC
CUDNN::cudnn_all) CUDNN::cudnn_all)
target_include_directories(transformer_engine PRIVATE target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine SYSTEM PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl)
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
# Compiling Userbuffers with native MPI bootstrapping requires linking against MPI # Compiling Userbuffers with native MPI bootstrapping requires linking against MPI
...@@ -141,6 +149,25 @@ if (NVTE_ENABLE_NVSHMEM) ...@@ -141,6 +149,25 @@ if (NVTE_ENABLE_NVSHMEM)
target_include_directories(transformer_engine PUBLIC ${NVSHMEMAPI_INCLUDE_DIR}) target_include_directories(transformer_engine PUBLIC ${NVSHMEMAPI_INCLUDE_DIR})
endif() endif()
option(NVTE_WITH_CUBLASMP "Use cuBLASMp for tensor parallel GEMMs" OFF)
if (NVTE_WITH_CUBLASMP)
target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUBLASMP)
target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include ${NVSHMEM_DIR}/include)
find_library(CUBLASMP_LIB
NAMES cublasmp libcublasmp
PATHS ${CUBLASMP_DIR}
PATH_SUFFIXES lib
REQUIRED)
find_library(NVSHMEM_HOST_LIB
NAMES nvshmem_host libnvshmem_host.so.3
PATHS ${NVSHMEM_DIR}
PATH_SUFFIXES lib
REQUIRED)
target_link_libraries(transformer_engine PUBLIC ${CUBLASMP_LIB} ${NVSHMEM_HOST_LIB})
message(STATUS "Using cuBLASMp at: ${CUBLASMP_DIR}")
message(STATUS "Using nvshmem at: ${NVSHMEM_DIR}")
endif()
# Hack to enable dynamic loading in cuDNN frontend # Hack to enable dynamic loading in cuDNN frontend
target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING) target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING)
......
This diff is collapsed.
...@@ -26,6 +26,24 @@ __global__ void __launch_bounds__(1) ...@@ -26,6 +26,24 @@ __global__ void __launch_bounds__(1)
} // namespace } // namespace
cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
switch (t) {
case DType::kFloat16:
return CUDA_R_16F;
case DType::kFloat32:
return CUDA_R_32F;
case DType::kBFloat16:
return CUDA_R_16BF;
case DType::kFloat8E4M3:
return CUDA_R_8F_E4M3;
case DType::kFloat8E5M2:
return CUDA_R_8F_E5M2;
default:
NVTE_ERROR("Invalid type");
}
}
void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) { void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) {
if (is_fp8_dtype(t->data.dtype) && is_tensor_scaling(t->scaling_mode)) { if (is_fp8_dtype(t->data.dtype) && is_tensor_scaling(t->scaling_mode)) {
NVTE_CHECK(t->scale_inv.dptr != nullptr, "Tensor should have allocated scale_inv."); NVTE_CHECK(t->scale_inv.dptr != nullptr, "Tensor should have allocated scale_inv.");
......
...@@ -270,6 +270,8 @@ struct QuantizationConfig { ...@@ -270,6 +270,8 @@ struct QuantizationConfig {
}; };
}; };
cudaDataType_t get_cuda_dtype(const transformer_engine::DType t);
template <typename T> template <typename T>
constexpr T DIVUP(const T &x, const T &y) { constexpr T DIVUP(const T &x, const T &y) {
return (((x) + ((y)-1)) / (y)); return (((x) + ((y)-1)) / (y));
...@@ -382,9 +384,19 @@ struct BitsNumber { ...@@ -382,9 +384,19 @@ struct BitsNumber {
template <typename T> template <typename T>
struct TypeInfo { struct TypeInfo {
#if FP4_TYPE_SUPPORTED #if FP4_TYPE_SUPPORTED
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp4e2m1>; using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp4e2m1
#if CUDA_VERSION >= 12080
,
fp8e8m0
#endif
>;
#else #else
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2>; using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2
#if CUDA_VERSION >= 12080
,
fp8e8m0
#endif
>;
#endif #endif
template <typename U, DType current> template <typename U, DType current>
......
...@@ -22,24 +22,6 @@ ...@@ -22,24 +22,6 @@
namespace { namespace {
cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
switch (t) {
case DType::kFloat16:
return CUDA_R_16F;
case DType::kFloat32:
return CUDA_R_32F;
case DType::kBFloat16:
return CUDA_R_16BF;
case DType::kFloat8E4M3:
return CUDA_R_8F_E4M3;
case DType::kFloat8E5M2:
return CUDA_R_8F_E5M2;
default:
NVTE_ERROR("Invalid type");
}
}
uint32_t _getAlignment(uintptr_t address) { uint32_t _getAlignment(uintptr_t address) {
// alignment are in bytes // alignment are in bytes
uint32_t alignment = 256; uint32_t alignment = 256;
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file comm_gemm.h
* \brief Functions for distributed (multi-GPU) matrix multiplication.
*
* This API is a TE-native binding to cuBLASMp library.
* Refer here: https://docs.nvidia.com/cuda/cublasmp/usage/tp.html for specific
* patterns, which allow communication-computation overlap.
*
* All GEMM functions here have the same computation semantic, as expressed
* on global matrices, similar to nvte_cublas_gemm call:
* - `D = AB` if both `bias` and `pre_gelu_out` are empty tensors
* - `D = AB + bias` if `pre_gelu_out` is empty and `bias` is not empty
* - `D = GELU(AB + bias)` if both `bias` and `pre_gelu_out` are not empty tensors
*
* Functions differ in matrix distribution patterns
*/
#ifndef TRANSFORMER_ENGINE_COMMON_COMM_GEMM_H_
#define TRANSFORMER_ENGINE_COMMON_COMM_GEMM_H_
#include <nccl.h>
#include <stdint.h>
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#else
#include <stdbool.h>
#endif
typedef struct NVTECommGemmCtx NVTECommGemmCtx;
enum NVTECommGemmAlgoType {
kNVTECommGemmAlgoDefault = 0,
kNVTECommGemmAlgoSplitP2P = 1,
kNVTECommGemmAlgoSplitMulticast = 2,
kNVTECommGemmAlgoAtomicP2P = 3,
kNVTECommGemmAlgoAtomicMulticast = 4
};
/*! \brief Create a comm-gemm context.
*
* \param[in] comm NCCL communicator.
* \param[in] nranks Number of ranks.
* \param[in] rank Local rank.
*/
NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank);
/*! \brief Destroy a comm-gemm context.
*
* \param[in] ctx Context to destroy.
*/
void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx);
/*! \brief Perform AllGather communication followed by GEMM
*
* Gathers distributed data from all ranks, then computes matrix multiplication.
*
* \param[in] ctx Comm-GEMM context.
* \param[in] m Global m dimension.
* \param[in] n Global n dimension.
* \param[in] k Global k dimension.
* \param[in] a Local part of A matrix.
* \param[in] b Local part of B matrix.
* \param[in,out] d Local part of D matrix.
* \param[in] bias Bias tensor.
* \param[in,out] pre_act_out Local part of output matrix before GELU activation.
* \param[in] transa Whether A matrix is transposed.
* \param[in] transb Whether B matrix is transposed.
* \param[in] grad Whether this operation is part of gradient computation.
* \param[in] accumulate Whether to accumulate the result into the D matrix.
* \param[in] comm_sm_count Number of GPU SMs to use for communication (default=0: use heuristics)
* \param[in] main_stream CUDA stream used for computation.
* \param[in] algo Algorithm to use.
*/
void nvte_all_gather_gemm(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, const NVTETensor a,
const NVTETensor b, const NVTETensor d, const NVTETensor bias,
const NVTETensor pre_act_out, bool transa, bool transb, bool grad,
bool accumulate, int comm_sm_count, cudaStream_t main_stream,
NVTECommGemmAlgoType algo);
/*! \brief Perform GEMM followed by ReduceScatter communication
*
* Computes matrix multiplication, then distributes results across ranks with reduction.
*
* \param[in] ctx Comm-GEMM context.
* \param[in] m Global m dimension.
* \param[in] n Global n dimension.
* \param[in] k Global k dimension.
* \param[in] a Local part of A matrix.
* \param[in] b Local part of B matrix.
* \param[in,out] d Local part of D matrix.
* \param[in] bias Bias tensor.
* \param[in,out] pre_act_out Local part of output matrix before GELU activation.
* \param[in] transa Whether A matrix is transposed.
* \param[in] transb Whether B matrix is transposed.
* \param[in] grad Whether this operation is part of gradient computation.
* \param[in] accumulate Whether to accumulate the result into the D matrix.
* \param[in] comm_sm_count Number of GPU SMs to use for communication (default=0: use heuristics)
* \param[in] main_stream CUDA stream used for computation.
* \param[in] algo Algorithm to use.
*/
void nvte_gemm_reduce_scatter(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k,
const NVTETensor a, const NVTETensor b, const NVTETensor d,
const NVTETensor bias, const NVTETensor pre_act_out, bool transa,
bool transb, bool grad, bool accumulate, int comm_sm_count,
cudaStream_t main_stream, NVTECommGemmAlgoType algo);
/*! \brief Perform GEMM followed by AllReduce communication
*
* Computes matrix multiplication, then reduces results across all ranks.
*
* \param[in] ctx Comm-GEMM context.
* \param[in] m Global m dimension.
* \param[in] n Global n dimension.
* \param[in] k Global k dimension.
* \param[in] a Local part of A matrix.
* \param[in] b Local part of B matrix.
* \param[in,out] d Local part of D matrix.
* \param[in] bias Bias tensor.
* \param[in,out] pre_act_out Local part of output matrix before GELU activation.
* \param[in] transa Whether A matrix is transposed.
* \param[in] transb Whether B matrix is transposed.
* \param[in] grad Whether this operation is part of gradient computation.
* \param[in] accumulate Whether to accumulate the result into the D matrix.
* \param[in] comm_sm_count Number of GPU SMs to use for communication (default=0: use heuristics)
* \param[in] main_stream CUDA stream used for computation.
* \param[in] algo Algorithm to use.
*/
void nvte_gemm_all_reduce(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, const NVTETensor a,
const NVTETensor b, const NVTETensor d, const NVTETensor bias,
const NVTETensor pre_act_out, bool transa, bool transb, bool grad,
bool accumulate, int comm_sm_count, cudaStream_t main_stream,
NVTECommGemmAlgoType algo);
/*! \brief Get local number of rows or columns.
*
* Utility function to get local dimension.
* Block size, nranks and local rank is derived from the context ctx.
*
* \param[in] ctx Comm-GEMM context.
* \param[in] global_size Global dimension.
*/
int64_t nvte_comm_gemm_numroc(NVTECommGemmCtx* ctx, int64_t global_size);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_COMM_GEMM_H_
...@@ -12,8 +12,13 @@ ...@@ -12,8 +12,13 @@
#include <cudnn.h> #include <cudnn.h>
#include <nvrtc.h> #include <nvrtc.h>
#ifdef NVTE_WITH_CUBLASMP
#include <cublasmp.h>
#endif // NVTE_WITH_CUBLASMP
#include <iostream> #include <iostream>
#include <stdexcept> #include <stdexcept>
#include <string>
#include "../util/string.h" #include "../util/string.h"
...@@ -87,4 +92,16 @@ ...@@ -87,4 +92,16 @@
} \ } \
} while (false) } while (false)
#ifdef NVTE_WITH_CUBLASMP
#define NVTE_CHECK_CUBLASMP(expr) \
do { \
const cublasMpStatus_t status = (expr); \
if (status != CUBLASMP_STATUS_SUCCESS) { \
NVTE_ERROR("cuBLASMp Error: ", std::to_string(status)); \
} \
} while (false)
#endif // NVTE_WITH_CUBLASMP
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ #endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment