Commit 8b27a2b7 authored by yuguo's avatar yuguo
Browse files

[DCU] surpport rocm gemm rocblas

parent 73f3ac47
...@@ -484,7 +484,7 @@ static void init_streams_and_events_batchgemm() { ...@@ -484,7 +484,7 @@ static void init_streams_and_events_batchgemm() {
void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias, void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
NVTETensor workspace, bool accumulate, bool use_split_accumulator, NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, cudaStream_t stream) { int math_sm_count, cudaStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas) {
NVTE_API_CALL(nvte_cublas_gemm); NVTE_API_CALL(nvte_cublas_gemm);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *inputA = reinterpret_cast<const Tensor *>(A); const Tensor *inputA = reinterpret_cast<const Tensor *>(A);
...@@ -521,15 +521,11 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons ...@@ -521,15 +521,11 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
#ifdef USE_HIPBLASLT const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM");
const char *NVTE_BLASLT_BLAS = std::getenv("NVTE_FORCE_BLASLT");
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
is_fp8_dtype(inputB->data.dtype); is_fp8_dtype(inputB->data.dtype);
if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr!=nullptr) || (use_fp8) || (NVTE_BLASLT_BLAS != nullptr && NVTE_BLASLT_BLAS[0] == '1')){ if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr!=nullptr) || (use_fp8) || (NVTE_FORCE_ROCM_GEMM != nullptr && NVTE_FORCE_ROCM_GEMM[0] == '1') || (nvte_use_hipblaslt) || (nvte_use_rocblas)){
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd, cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
#else
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
#endif //USE_HIPBLASLT
#else #else
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd, cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
...@@ -542,32 +538,33 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons ...@@ -542,32 +538,33 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
#endif //__HIP_PLATFORM_AMD__ #endif //__HIP_PLATFORM_AMD__
grad, grad,
wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator,
#ifdef __HIP_PLATFORM_AMD__
math_sm_count, 0, 0, false, nullptr, stream, nvte_use_hipblaslt, nvte_use_rocblas);
#else
math_sm_count, 0, 0, false, nullptr, stream); math_sm_count, 0, 0, false, nullptr, stream);
#endif
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
#ifdef USE_HIPBLASLT } else {
}
else{
hipblas_gemm(inputA, hipblas_gemm(inputA,
inputB, inputB,
outputD, outputD,
biasTensor, biasTensor,
outputGelu, outputGelu,
m, n, k, m, n, k,
lda, ldb, ldd, lda, ldb, ldd,
(transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N, (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
(transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N, (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
grad, wspace->data.dptr, grad, wspace->data.dptr,
wspace->data.shape[0], wspace->data.shape[0],
accumulate, use_split_accumulator, accumulate, use_split_accumulator,
math_sm_count, math_sm_count,
0, 0,
0, 0,
false, false,
nullptr, nullptr,
stream); stream);
} }
#endif //USE_HIPBLASLT
#endif //__HIP_PLATFORM_AMD__ #endif //__HIP_PLATFORM_AMD__
} }
...@@ -577,7 +574,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor ...@@ -577,7 +574,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
bool transb, bool grad, NVTETensor workspace, bool accumulate, bool transb, bool grad, NVTETensor workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count, int m_split, bool use_split_accumulator, int math_sm_count, int m_split,
int n_split, bool gemm_producer, const NVTETensor counter, int n_split, bool gemm_producer, const NVTETensor counter,
cudaStream_t stream) { cudaStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas) {
NVTE_API_CALL(nvte_cublas_atomic_gemm); NVTE_API_CALL(nvte_cublas_atomic_gemm);
#ifndef __HIP_PLATFORM_AMD__ #ifndef __HIP_PLATFORM_AMD__
...@@ -622,15 +619,11 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor ...@@ -622,15 +619,11 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
#ifdef USE_HIPBLASLT const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM");
const char *NVTE_BLASLT_BLAS = std::getenv("NVTE_FORCE_BLASLT");
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
is_fp8_dtype(inputB->data.dtype); is_fp8_dtype(inputB->data.dtype);
if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr!=nullptr) || (use_fp8) || (NVTE_BLASLT_BLAS != nullptr && NVTE_BLASLT_BLAS[0] == '1')){ if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr!=nullptr) || (use_fp8) || (NVTE_FORCE_ROCM_GEMM != nullptr && NVTE_FORCE_ROCM_GEMM[0] == '1') || (nvte_use_hipblaslt) || (nvte_use_rocblas)){
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd, cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
#else
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
#endif //USE_HIPBLASLT
#else #else
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd, cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
...@@ -643,81 +636,38 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor ...@@ -643,81 +636,38 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
#endif //__HIP_PLATFORM_AMD__ #endif //__HIP_PLATFORM_AMD__
grad, grad,
wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator,
#ifdef __HIP_PLATFORM_AMD__
math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream, nvte_use_hipblaslt, nvte_use_rocblas);
#else
math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream); math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream);
#endif
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
#ifdef USE_HIPBLASLT } else {
} hipblas_gemm(inputA,
else{ inputB,
hipblas_gemm( outputD,
inputA, biasTensor,
inputB, outputGelu,
outputD, m, n, k,
biasTensor, lda, ldb, ldd,
outputGelu, (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
m, n, k, (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
lda, ldb, ldd, grad, wspace->data.dptr,
(transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N, wspace->data.shape[0],
(transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N, accumulate, use_split_accumulator,
grad, wspace->data.dptr, math_sm_count,
wspace->data.shape[0], m_split,
accumulate, use_split_accumulator, n_split,
math_sm_count, gemm_producer,
m_split, inputCounter,
n_split, stream);
gemm_producer,
inputCounter,
stream);
} }
#endif //USE_HIPBLASLT
#endif //__HIP_PLATFORM_AMD__ #endif //__HIP_PLATFORM_AMD__
} }
void nvte_cublaslt_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, cudaStream_t stream) {
NVTE_API_CALL(nvte_cublaslt_gemm);
using namespace transformer_engine;
const Tensor *inputA = reinterpret_cast<const Tensor *>(A);
const Tensor *inputB = reinterpret_cast<const Tensor *>(B);
Tensor *outputD = reinterpret_cast<Tensor *>(D);
const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias);
Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out);
Tensor *wspace = reinterpret_cast<Tensor *>(workspace);
const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1];
const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0];
const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0];
int lda, ldb, ldd;
if (transa && !transb) { // TN
lda = k;
ldb = k;
ldd = m;
} else if (!transa && !transb) { // NN
lda = m;
ldb = k;
ldd = m;
} else if (!transa && transb) { // NT
lda = m;
ldb = n;
ldd = m;
} else { // TT
NVTE_ERROR("TT layout not allowed.");
}
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
#ifdef __HIP_PLATFORM_AMD__
transa, transb,
#else
(transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N,
#endif //__HIP_PLATFORM_AMD__
grad,
wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator,
math_sm_count, 0, 0, false, nullptr, stream);
}
void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
const NVTETensor *bias, NVTETensor *pre_gelu_out, const NVTETensor *bias, NVTETensor *pre_gelu_out,
...@@ -736,20 +686,19 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT ...@@ -736,20 +686,19 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
for (int s = 0; s < num_stream_used; s++) { for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams[s], cublas_event[0])); NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams[s], cublas_event[0]));
} }
const char *NVTE_BLAS_MULSTREAM = std::getenv("NVTE_FORCE_BLAS_MULSTREAM"); const char *NVTE_HIPBLAS_MULSTREAM = std::getenv("NVTE_FORCE_HIPBLAS_MULSTREAM");
const char *NVTE_BLASLT_BLAS = std::getenv("NVTE_FORCE_BLASLT"); const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM");
bool NVTE_FORCE_BLASLT_MULSTREAM; bool NVTE_FORCE_HIPBLAS_MULSTREAM;
if(NVTE_HIPBLAS_MULSTREAM != nullptr && NVTE_HIPBLAS_MULSTREAM[0] == '1'){
if(NVTE_BLAS_MULSTREAM==nullptr){ NVTE_FORCE_HIPBLAS_MULSTREAM = true;
NVTE_FORCE_BLASLT_MULSTREAM = true; if((NVTE_FORCE_ROCM_GEMM != nullptr && NVTE_FORCE_ROCM_GEMM[0] == '1') && (NVTE_HIPBLAS_MULSTREAM != nullptr && NVTE_HIPBLAS_MULSTREAM[0] == '1'))
} else if((NVTE_BLASLT_BLAS != nullptr && NVTE_BLASLT_BLAS[0] == '1') && (NVTE_BLAS_MULSTREAM != nullptr && NVTE_BLAS_MULSTREAM[0] == '1')){ NVTE_ERROR("NVTE_FORCE_HIPBLAS_MULSTREAM and NVTE_FORCE_ROCM_GEMM can't be set at the same time.");
NVTE_ERROR("NVTE_FORCE_BLAS_MULSTREAM and NVTE_FORCE_BLASLT can't be set at the same time.");
} else{ } else{
NVTE_FORCE_BLASLT_MULSTREAM = false; NVTE_FORCE_HIPBLAS_MULSTREAM = false;
} }
if (NVTE_FORCE_BLASLT_MULSTREAM){ if (NVTE_FORCE_HIPBLAS_MULSTREAM){
for (int i = 0; i < num_gemms; i++) { for (int i = 0; i < num_gemms; i++) {
nvte_cublaslt_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad, nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count, workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count,
compute_streams[i % num_streams]); compute_streams[i % num_streams]);
} }
...@@ -757,7 +706,7 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT ...@@ -757,7 +706,7 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
for (int i = 0; i < num_gemms; i++) { for (int i = 0; i < num_gemms; i++) {
nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad, nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count, workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count,
compute_streams[i % num_streams]); compute_streams[i % num_streams], 1, 0);
} }
} }
......
...@@ -36,26 +36,30 @@ namespace { ...@@ -36,26 +36,30 @@ namespace {
#ifdef USE_HIPBLASLT #ifdef USE_HIPBLASLT
static hipDataType get_hipblaslt_dtype(const transformer_engine::DType t) { #if HIP_VERSION >= 60000000
typedef hipDataType hipblasltDatatype_t;
typedef hipblasComputeType_t hipblasLtComputeType_t;
#define HIPBLASLT_R_16F HIP_R_16F
#define HIPBLASLT_R_32F HIP_R_32F
#define HIPBLASLT_R_16B HIP_R_16BF
#define HIPBLASLT_R_8F_E4M3 HIP_R_8F_E4M3_FNUZ
#define HIPBLASLT_R_8F_E5M2 HIP_R_8F_E5M2_FNUZ
#define HIPBLASLT_COMPUTE_F32 HIPBLAS_COMPUTE_32F
#endif // #if HIP_VERSION >= 60000000
hipblasltDatatype_t get_hipblaslt_dtype(const transformer_engine::DType t) {
using namespace transformer_engine; using namespace transformer_engine;
switch (t) { switch (t) {
case DType::kFloat16: case DType::kFloat16:
return HIP_R_16F; return HIPBLASLT_R_16F;
case DType::kFloat32: case DType::kFloat32:
return HIP_R_32F; return HIPBLASLT_R_32F;
case DType::kBFloat16: case DType::kBFloat16:
return HIP_R_16BF; return HIPBLASLT_R_16B;
#if HIP_VERSION >= 60300000
case DType::kFloat8E4M3: case DType::kFloat8E4M3:
return te_fp8_fnuz() ? HIP_R_8F_E4M3_FNUZ : HIP_R_8F_E4M3; return HIPBLASLT_R_8F_E4M3;
case DType::kFloat8E5M2: case DType::kFloat8E5M2:
return te_fp8_fnuz() ? HIP_R_8F_E5M2_FNUZ: HIP_R_8F_E5M2; return HIPBLASLT_R_8F_E5M2;
#else
case DType::kFloat8E4M3:
return HIP_R_8F_E4M3_FNUZ;
case DType::kFloat8E5M2:
return HIP_R_8F_E5M2_FNUZ;
#endif
default: default:
NVTE_ERROR("Invalid type"); NVTE_ERROR("Invalid type");
} }
...@@ -363,7 +367,11 @@ void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool ...@@ -363,7 +367,11 @@ void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool
if(! stream_order_alloc){ if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipMemset(out, 0, n*sizeof(float)) ); NVTE_CHECK_CUDA( hipMemset(out, 0, n*sizeof(float)) );
}else{ }else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipMemsetAsync(out, 0, n*sizeof(float), stream) ); NVTE_CHECK_CUDA( hipMemsetAsync(out, 0, n*sizeof(float), stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
} }
hipLaunchKernelGGL(( bias_gradient_kernel<Tin, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, stream, in, out, m, n); hipLaunchKernelGGL(( bias_gradient_kernel<Tin, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, stream, in, out, m, n);
} }
...@@ -567,11 +575,11 @@ public: ...@@ -567,11 +575,11 @@ public:
const std::string_view &getName(const T &val) { const std::string_view &getName(const T &val) {
return map.at(val); return map.at(val);
} }
T getValue(const std::string& name, const char *label="", std::function<bool(const T&)> filter = nullptr) T getValue(const std::string& name, const char *label="")
{ {
for (auto iter = map.begin(); iter != map.end(); ++iter) for (auto iter = map.begin(); iter != map.end(); ++iter)
{ {
if ((name == iter->second) && (!filter || filter(iter->first))) return iter->first; if (name == iter->second) return iter->first;
} }
NVTE_ERROR("Invalid ", label, " name: ", name); NVTE_ERROR("Invalid ", label, " name: ", name);
} }
...@@ -579,18 +587,14 @@ protected: ...@@ -579,18 +587,14 @@ protected:
const std::unordered_map<T, std::string_view> &map; const std::unordered_map<T, std::string_view> &map;
}; };
static std::unordered_map<hipDataType, std::string_view> type_name_map = { static std::unordered_map<hipblasltDatatype_t, std::string_view> type_name_map = {
{HIP_R_32F, "float32"}, {HIPBLASLT_R_32F, "float32"},
{HIP_R_16F, "float16"}, {HIPBLASLT_R_16F, "float16"},
{HIP_R_16BF, "bfloat16"}, {HIPBLASLT_R_16B, "bfloat16"},
{HIP_R_8F_E4M3_FNUZ, "float8e4m3"}, {HIPBLASLT_R_8F_E4M3, "float8e4m3"},
{HIP_R_8F_E5M2_FNUZ, "float8e5m2"}, {HIPBLASLT_R_8F_E5M2, "float8e5m2"},
#if HIP_VERSION >= 60300000
{HIP_R_8F_E4M3, "float8e4m3"},
{HIP_R_8F_E5M2, "float8e5m2"},
#endif
}; };
static NameMapper<hipDataType> typeNameMapper(type_name_map); static NameMapper<hipblasltDatatype_t> typeNameMapper(type_name_map);
static std::unordered_map<hipblasOperation_t, std::string_view> trans_name_map = { static std::unordered_map<hipblasOperation_t, std::string_view> trans_name_map = {
{HIPBLAS_OP_N, "N"}, {HIPBLAS_OP_N, "N"},
...@@ -609,24 +613,24 @@ static std::unordered_map<hipblasLtEpilogue_t, std::string_view> epi_name_map = ...@@ -609,24 +613,24 @@ static std::unordered_map<hipblasLtEpilogue_t, std::string_view> epi_name_map =
}; };
static NameMapper<hipblasLtEpilogue_t> epilogueNameMapper(epi_name_map); static NameMapper<hipblasLtEpilogue_t> epilogueNameMapper(epi_name_map);
static std::unordered_map<hipblasComputeType_t, std::string_view> comp_name_map = { static std::unordered_map<hipblasLtComputeType_t, std::string_view> comp_name_map = {
{HIPBLAS_COMPUTE_32F, "f32"} {HIPBLASLT_COMPUTE_F32, "f32"}
}; };
static NameMapper<hipblasComputeType_t> computeNameMapper(comp_name_map); static NameMapper<hipblasLtComputeType_t> computeNameMapper(comp_name_map);
static class GemmAlgoCache { static class GemmAlgoCache {
public: public:
struct Key { struct Key {
int deviceCap; int deviceCap;
hipDataType a_type, b_type, d_type, bias_type; hipblasltDatatype_t a_type, b_type, d_type, bias_type;
int m, n, k; int m, n, k;
int lda, ldb, ldd; int lda, ldb, ldd;
hipblasOperation_t transa, transb; hipblasOperation_t transa, transb;
hipblasLtEpilogue_t epilogue; hipblasLtEpilogue_t epilogue;
Key(int deviceCap_, Key(int deviceCap_,
hipDataType a_type_, hipDataType b_type_, hipblasltDatatype_t a_type_, hipblasltDatatype_t b_type_,
hipDataType d_type_, hipDataType bias_type_, hipblasltDatatype_t d_type_, hipblasltDatatype_t bias_type_,
int m_, int n_, int k_, int lda_, int ldb_, int ldd_, int m_, int n_, int k_, int lda_, int ldb_, int ldd_,
hipblasOperation_t transa_, hipblasOperation_t transb_, hipblasOperation_t transa_, hipblasOperation_t transb_,
hipblasLtEpilogue_t epilogue_): hipblasLtEpilogue_t epilogue_):
...@@ -860,32 +864,18 @@ protected: ...@@ -860,32 +864,18 @@ protected:
std::cout << "[WARNING] Invalid WS size at " << line << "\n"; std::cout << "[WARNING] Invalid WS size at " << line << "\n";
continue; continue;
} }
#if HIP_VERSION >= 60300000 cfg.a_type = typeNameMapper.getValue(type_a, "type_a");
auto fp8_filter = te_fp8_fnuz() cfg.b_type = typeNameMapper.getValue(type_b, "type_b");
? [](const hipDataType& val) cfg.d_type = typeNameMapper.getValue(type_d, "type_d");
{ return (val != HIP_R_8F_E4M3 && val != HIP_R_8F_E5M2); } cfg.bias_type = (bias_type == "-") ? (hipblasltDatatype_t)-1 : typeNameMapper.getValue(bias_type, "bias_type");
: [](const hipDataType& val) {
return (val != HIP_R_8F_E4M3_FNUZ && val != HIP_R_8F_E5M2_FNUZ);
};
#else
auto fp8_filter = nullptr;
#endif
cfg.a_type = typeNameMapper.getValue(type_a, "type_a", fp8_filter);
cfg.b_type = typeNameMapper.getValue(type_b, "type_b", fp8_filter);
cfg.d_type = typeNameMapper.getValue(type_d, "type_d", fp8_filter);
cfg.bias_type = (bias_type == "-")
? (hipDataType)-1
: typeNameMapper.getValue(bias_type, "bias_type", fp8_filter);
cfg.transa = transposeNameMapper.getValue(trans_a, "trans_a"); cfg.transa = transposeNameMapper.getValue(trans_a, "trans_a");
cfg.transb = transposeNameMapper.getValue(trans_b, "trans_b"); cfg.transb = transposeNameMapper.getValue(trans_b, "trans_b");
cfg.epilogue = epilogueNameMapper.getValue(epi, "epi"); cfg.epilogue = epilogueNameMapper.getValue(epi, "epi");
//Check and filter out compute and scale types //Check and filter out compute and scale types
if (computeNameMapper.getValue(comp, "comp") != HIPBLAS_COMPUTE_32F || if (computeNameMapper.getValue(comp, "comp") != HIPBLASLT_COMPUTE_F32 || typeNameMapper.getValue(scale, "scale") != HIPBLASLT_R_32F)
typeNameMapper.getValue(scale, "scale") != HIP_R_32F)
{ {
continue; continue;
} }
...@@ -968,9 +958,9 @@ protected: ...@@ -968,9 +958,9 @@ protected:
csv << cfg.deviceCap << cfg.m << cfg.n << cfg.k csv << cfg.deviceCap << cfg.m << cfg.n << cfg.k
<< transposeNameMapper.getName(cfg.transa) << transposeNameMapper.getName(cfg.transb) << transposeNameMapper.getName(cfg.transa) << transposeNameMapper.getName(cfg.transb)
<< typeNameMapper.getName(cfg.a_type) << typeNameMapper.getName(cfg.b_type) << typeNameMapper.getName(cfg.d_type) << typeNameMapper.getName(cfg.a_type) << typeNameMapper.getName(cfg.b_type) << typeNameMapper.getName(cfg.d_type)
<< ((cfg.bias_type == (hipDataType)-1) ? "-" : typeNameMapper.getName(cfg.bias_type)) << ((cfg.bias_type == (hipblasltDatatype_t)-1) ? "-" : typeNameMapper.getName(cfg.bias_type))
<< cfg.lda << cfg.ldb << cfg.ldd << epilogueNameMapper.getName(cfg.epilogue) << cfg.lda << cfg.ldb << cfg.ldd << epilogueNameMapper.getName(cfg.epilogue)
<< computeNameMapper.getName(HIPBLAS_COMPUTE_32F) << typeNameMapper.getName(HIP_R_32F) << computeNameMapper.getName(HIPBLASLT_COMPUTE_F32) << typeNameMapper.getName(HIPBLASLT_R_32F)
<< algo.ws_size_min << algo.ws_size_max << algo.algoId << algo.index << csv_helper::end() << "\n"; << algo.ws_size_min << algo.ws_size_max << algo.algoId << algo.index << csv_helper::end() << "\n";
} }
...@@ -1036,10 +1026,10 @@ void hipblaslt_gemm(const Tensor *inputA, ...@@ -1036,10 +1026,10 @@ void hipblaslt_gemm(const Tensor *inputA,
const bool gelu = pre_gelu_out != nullptr; const bool gelu = pre_gelu_out != nullptr;
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
is_fp8_dtype(inputB->data.dtype); is_fp8_dtype(inputB->data.dtype);
const hipDataType A_type = get_hipblaslt_dtype(inputA->data.dtype); const hipblasltDatatype_t A_type = get_hipblaslt_dtype(inputA->data.dtype);
const hipDataType B_type = get_hipblaslt_dtype(inputB->data.dtype); const hipblasltDatatype_t B_type = get_hipblaslt_dtype(inputB->data.dtype);
const hipDataType D_type = get_hipblaslt_dtype(outputD->data.dtype); const hipblasltDatatype_t D_type = get_hipblaslt_dtype(outputD->data.dtype);
const hipDataType bias_type = get_hipblaslt_dtype(inputBias->data.dtype); const hipblasltDatatype_t bias_type = get_hipblaslt_dtype(inputBias->data.dtype);
NVTE_CHECK(!is_fp8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr, NVTE_CHECK(!is_fp8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr,
"FP8 input to GEMM requires inverse of scale!"); "FP8 input to GEMM requires inverse of scale!");
...@@ -1073,7 +1063,7 @@ void hipblaslt_gemm(const Tensor *inputA, ...@@ -1073,7 +1063,7 @@ void hipblaslt_gemm(const Tensor *inputA,
int64_t ld_gelumat = (int64_t) ldd; int64_t ld_gelumat = (int64_t) ldd;
// default to tf32 except for e5m2 inputs where the config is not supported // default to tf32 except for e5m2 inputs where the config is not supported
hipblasComputeType_t gemm_compute_type = HIPBLAS_COMPUTE_32F; hipblasLtComputeType_t gemm_compute_type = HIPBLASLT_COMPUTE_F32;
// Create matrix descriptors. Not setting any extra attributes. // Create matrix descriptors. Not setting any extra attributes.
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Adesc, A_type, NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Adesc, A_type,
...@@ -1086,7 +1076,7 @@ void hipblaslt_gemm(const Tensor *inputA, ...@@ -1086,7 +1076,7 @@ void hipblaslt_gemm(const Tensor *inputA,
ldb)); ldb));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd)); NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescCreate(&operationDesc, gemm_compute_type, HIP_R_32F)); NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescCreate(&operationDesc, gemm_compute_type, HIPBLASLT_R_32F));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSA, NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSA,
&transa, sizeof(transa))); &transa, sizeof(transa)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSB, NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSB,
...@@ -1163,7 +1153,7 @@ void hipblaslt_gemm(const Tensor *inputA, ...@@ -1163,7 +1153,7 @@ void hipblaslt_gemm(const Tensor *inputA,
&epilogue, sizeof(epilogue))); &epilogue, sizeof(epilogue)));
GemmAlgoCache::Key gemm_cfg(algoCache.device_cap(device_id), A_type, B_type, D_type, GemmAlgoCache::Key gemm_cfg(algoCache.device_cap(device_id), A_type, B_type, D_type,
use_fp8 ? bias_type : (hipDataType)-1, use_fp8 ? bias_type : (hipblasltDatatype_t)-1,
m, n, k, lda, ldb, ldd, transa, transb, epilogue ); m, n, k, lda, ldb, ldd, transa, transb, epilogue );
GemmAlgoCache::Algo cached_algo; GemmAlgoCache::Algo cached_algo;
if (algoCache.find(gemm_cfg, workspaceSize, cached_algo) == 0 || !cached_algo.algo.has_value()) if (algoCache.find(gemm_cfg, workspaceSize, cached_algo) == 0 || !cached_algo.algo.has_value())
...@@ -1478,7 +1468,11 @@ void rocblas_gemm(const Tensor *inputA, ...@@ -1478,7 +1468,11 @@ void rocblas_gemm(const Tensor *inputA,
if(! stream_order_alloc){ if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipMalloc(&D_temp, sizeof(float)*m*n) ); NVTE_CHECK_CUDA( hipMalloc(&D_temp, sizeof(float)*m*n) );
}else{ }else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipMallocAsync(&D_temp, sizeof(float)*m*n, stream) ); NVTE_CHECK_CUDA( hipMallocAsync(&D_temp, sizeof(float)*m*n, stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
} }
}else { }else {
D_temp = D; D_temp = D;
...@@ -1571,7 +1565,11 @@ void rocblas_gemm(const Tensor *inputA, ...@@ -1571,7 +1565,11 @@ void rocblas_gemm(const Tensor *inputA,
if(! stream_order_alloc){ if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipMalloc(&bias_tmp, sizeof(float)*input_dim) ); // The bias gradient is for the first linear layer NVTE_CHECK_CUDA( hipMalloc(&bias_tmp, sizeof(float)*input_dim) ); // The bias gradient is for the first linear layer
}else{ }else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipMallocAsync(&bias_tmp, sizeof(float)*input_dim, stream) ); NVTE_CHECK_CUDA( hipMallocAsync(&bias_tmp, sizeof(float)*input_dim, stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
} }
}else { }else {
bias_tmp = bias_ptr; bias_tmp = bias_ptr;
...@@ -1597,7 +1595,11 @@ void rocblas_gemm(const Tensor *inputA, ...@@ -1597,7 +1595,11 @@ void rocblas_gemm(const Tensor *inputA,
if(! stream_order_alloc){ if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipFree(bias_tmp) ); NVTE_CHECK_CUDA( hipFree(bias_tmp) );
}else{ }else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipFreeAsync(bias_tmp, stream) ); NVTE_CHECK_CUDA( hipFreeAsync(bias_tmp, stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
} }
} }
...@@ -1645,7 +1647,11 @@ void rocblas_gemm(const Tensor *inputA, ...@@ -1645,7 +1647,11 @@ void rocblas_gemm(const Tensor *inputA,
if(! stream_order_alloc){ if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipMalloc(&bias_tmp, sizeof(float)*output_dim) ); NVTE_CHECK_CUDA( hipMalloc(&bias_tmp, sizeof(float)*output_dim) );
}else{ }else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipMallocAsync(&bias_tmp, sizeof(float)*output_dim, stream) ); NVTE_CHECK_CUDA( hipMallocAsync(&bias_tmp, sizeof(float)*output_dim, stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
} }
}else { }else {
bias_tmp = bias_ptr; bias_tmp = bias_ptr;
...@@ -1672,7 +1678,11 @@ void rocblas_gemm(const Tensor *inputA, ...@@ -1672,7 +1678,11 @@ void rocblas_gemm(const Tensor *inputA,
if(! stream_order_alloc){ if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipFree(bias_tmp) ); NVTE_CHECK_CUDA( hipFree(bias_tmp) );
}else{ }else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipFreeAsync(bias_tmp, stream) ); NVTE_CHECK_CUDA( hipFreeAsync(bias_tmp, stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
} }
} }
if (D_type == rocblas_datatype_f16_r || D_type == rocblas_datatype_bf16_r) { if (D_type == rocblas_datatype_f16_r || D_type == rocblas_datatype_bf16_r) {
...@@ -1773,7 +1783,11 @@ void rocblas_gemm(const Tensor *inputA, ...@@ -1773,7 +1783,11 @@ void rocblas_gemm(const Tensor *inputA,
if(! stream_order_alloc){ if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipFree(D_temp) ); NVTE_CHECK_CUDA( hipFree(D_temp) );
}else{ }else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipFreeAsync(D_temp, stream) ); NVTE_CHECK_CUDA( hipFreeAsync(D_temp, stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
} }
} }
} }
...@@ -1785,15 +1799,15 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -1785,15 +1799,15 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
int ldb, int ldd, bool transa, bool transb, bool grad, int ldb, int ldd, bool transa, bool transb, bool grad,
void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator,
int math_sm_count, int m_split, int n_split, bool gemm_producer, int math_sm_count, int m_split, int n_split, bool gemm_producer,
const Tensor *inputCounter, hipStream_t stream) const Tensor *inputCounter, hipStream_t stream, bool nvte_use_hipblaslt = 0, bool nvte_use_rocblas = 0)
{ {
/*If no backend is specified with env variable use HIPBLASLT unless it is disabled /*If no backend is specified with env variable use HIPBLASLT unless it is disabled
If HIPBLASLT backend is enabled and requested, use it despite ROCBLAS status If HIPBLASLT backend is enabled and requested, use it despite ROCBLAS status
Otherwise use ROCBLAS Otherwise use ROCBLAS
*/ */
bool use_hipblaslt = std::getenv("NVTE_USE_HIPBLASLT") != nullptr; bool use_hipblaslt = (std::getenv("NVTE_USE_HIPBLASLT") != nullptr) || nvte_use_hipblaslt;
bool use_rocblas = std::getenv("NVTE_USE_ROCBLAS") != nullptr; bool use_rocblas = (std::getenv("NVTE_USE_ROCBLAS") != nullptr) || nvte_use_rocblas;
#if !defined(USE_HIPBLASLT) && !defined(USE_ROCBLAS) #if !defined(USE_HIPBLASLT) && !defined(USE_ROCBLAS)
#error GEMM backend is not specified #error GEMM backend is not specified
...@@ -1813,12 +1827,18 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -1813,12 +1827,18 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
if (use_hipblaslt && use_rocblas) if (use_hipblaslt && use_rocblas)
{ {
use_rocblas = false; use_rocblas = false;
use_hipblaslt = true;
std::cout << "[NOTICE] Two GEMM backend are enabled, hipBLASLt will be used\n"; std::cout << "[NOTICE] Two GEMM backend are enabled, hipBLASLt will be used\n";
} else if (!use_hipblaslt && !use_rocblas)
{
use_rocblas = false;
use_hipblaslt = true;
std::cout << "[NOTICE] Two GEMM backend are disabled, hipBLASLt will be used\n";
} }
#endif #endif
#ifdef USE_HIPBLASLT #ifdef USE_HIPBLASLT
if (use_hipblaslt || !use_rocblas) if (use_hipblaslt)
{ {
hipblaslt_gemm(inputA, inputB, outputD, inputBias, outputPreGelu, hipblaslt_gemm(inputA, inputB, outputD, inputBias, outputPreGelu,
m, n, k, lda, ldb, ldd, m, n, k, lda, ldb, ldd,
...@@ -1833,6 +1853,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -1833,6 +1853,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#endif #endif
#ifdef USE_ROCBLAS #ifdef USE_ROCBLAS
if (use_rocblas)
{ {
rocblas_gemm(inputA, inputB, outputD, inputBias, outputPreGelu, rocblas_gemm(inputA, inputB, outputD, inputBias, outputPreGelu,
m, n, k, lda, ldb, ldd, m, n, k, lda, ldb, ldd,
......
...@@ -42,7 +42,7 @@ extern "C" { ...@@ -42,7 +42,7 @@ extern "C" {
void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias, void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
NVTETensor workspace, bool accumulate, bool use_split_accumulator, NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, cudaStream_t stream); int math_sm_count, cudaStream_t stream, bool nvte_use_hipblaslt = 0, bool nvte_use_rocblas = 0);
/*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters. /*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters.
* *
...@@ -77,7 +77,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor ...@@ -77,7 +77,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
bool transb, bool grad, NVTETensor workspace, bool accumulate, bool transb, bool grad, NVTETensor workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count, int m_split, bool use_split_accumulator, int math_sm_count, int m_split,
int n_split, bool gemm_producer, const NVTETensor counter, int n_split, bool gemm_producer, const NVTETensor counter,
cudaStream_t stream); cudaStream_t stream, bool nvte_use_hipblaslt = 0, bool nvte_use_rocblas = 0);
/*! \brief Compute multiple pairs of matrix multiplication, potentially fused with other operations, /*! \brief Compute multiple pairs of matrix multiplication, potentially fused with other operations,
* on multiple streams. * on multiple streams.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment