"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "a8f0fe03b4ac801ee50b35ff24ec2998eaa301ac"
Commit 229be5e8 authored by yuguo's avatar yuguo
Browse files

[DCU] new rocm gemm

parent 388ac735
...@@ -451,7 +451,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -451,7 +451,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[0]); use_split_accumulator, _math_sms, _stream_compute[0], 1, 0, 0);
for (int i = 1; i < _num_splits; i++) { for (int i = 1; i < _num_splits; i++) {
input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k}); input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k});
...@@ -462,7 +462,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -462,7 +462,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms, accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()], 1, 0); _stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size());
NVTE_CHECK_CUDA( NVTE_CHECK_CUDA(
cudaEventRecord(_start_comm, _stream_compute[(i - 1) % _stream_compute.size()])); cudaEventRecord(_start_comm, _stream_compute[(i - 1) % _stream_compute.size()]));
...@@ -510,7 +510,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -510,7 +510,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms, accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()], 1, 0); _stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size());
NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[i % _stream_compute.size()])); NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[i % _stream_compute.size()]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0));
...@@ -821,7 +821,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -821,7 +821,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()], 1, 0); _stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size());
if (i < num_steps - 1) { if (i < num_steps - 1) {
// P2P communication // P2P communication
...@@ -865,7 +865,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -865,7 +865,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()], 1, 0); _stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size());
if (i < _tp_size - 1) { if (i < _tp_size - 1) {
// P2P communication // P2P communication
...@@ -1010,7 +1010,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa, ...@@ -1010,7 +1010,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[stream_id], 1, 0); use_split_accumulator, _math_sms, _stream_compute[stream_id], 1, 0, stream_id);
if (i > 0) { if (i > 0) {
// P2P communication chunk // P2P communication chunk
......
...@@ -163,7 +163,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -163,7 +163,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
int ldb, int ldd, bool transa, bool transb, bool grad, int ldb, int ldd, bool transa, bool transb, bool grad,
void* workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, void* workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator,
int math_sm_count, int m_split, int n_split, bool gemm_producer, int math_sm_count, int m_split, int n_split, bool gemm_producer,
const Tensor *inputCounter, hipStream_t stream); const Tensor *inputCounter, hipStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas, int compute_stream_offset);
#else // Use cublasLt #else // Use cublasLt
using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>; using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
...@@ -484,7 +484,7 @@ static void init_streams_and_events_batchgemm() { ...@@ -484,7 +484,7 @@ static void init_streams_and_events_batchgemm() {
void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias, void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
NVTETensor workspace, bool accumulate, bool use_split_accumulator, NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, cudaStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas) { int math_sm_count, cudaStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas, int compute_stream_offset) {
NVTE_API_CALL(nvte_cublas_gemm); NVTE_API_CALL(nvte_cublas_gemm);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *inputA = reinterpret_cast<const Tensor *>(A); const Tensor *inputA = reinterpret_cast<const Tensor *>(A);
...@@ -539,7 +539,7 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons ...@@ -539,7 +539,7 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
grad, grad,
wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator,
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
math_sm_count, 0, 0, false, nullptr, stream, nvte_use_hipblaslt, nvte_use_rocblas); math_sm_count, 0, 0, false, nullptr, stream, nvte_use_hipblaslt, nvte_use_rocblas, compute_stream_offset);
#else #else
math_sm_count, 0, 0, false, nullptr, stream); math_sm_count, 0, 0, false, nullptr, stream);
#endif #endif
...@@ -574,7 +574,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor ...@@ -574,7 +574,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
bool transb, bool grad, NVTETensor workspace, bool accumulate, bool transb, bool grad, NVTETensor workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count, int m_split, bool use_split_accumulator, int math_sm_count, int m_split,
int n_split, bool gemm_producer, const NVTETensor counter, int n_split, bool gemm_producer, const NVTETensor counter,
cudaStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas) { cudaStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas, int compute_stream_offset) {
NVTE_API_CALL(nvte_cublas_atomic_gemm); NVTE_API_CALL(nvte_cublas_atomic_gemm);
#ifndef __HIP_PLATFORM_AMD__ #ifndef __HIP_PLATFORM_AMD__
...@@ -637,7 +637,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor ...@@ -637,7 +637,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
grad, grad,
wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator,
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream, nvte_use_hipblaslt, nvte_use_rocblas); math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream, nvte_use_hipblaslt, nvte_use_rocblas, compute_stream_offset);
#else #else
math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream); math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream);
#endif #endif
...@@ -706,7 +706,7 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT ...@@ -706,7 +706,7 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
for (int i = 0; i < num_gemms; i++) { for (int i = 0; i < num_gemms; i++) {
nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad, nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count, workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count,
compute_streams[i % num_streams], 1, 0); compute_streams[i % num_streams], 1, 0, i % num_streams);
} }
} }
......
...@@ -37,30 +37,26 @@ namespace { ...@@ -37,30 +37,26 @@ namespace {
#ifdef USE_HIPBLASLT #ifdef USE_HIPBLASLT
#if HIP_VERSION >= 60000000 static hipDataType get_hipblaslt_dtype(const transformer_engine::DType t) {
typedef hipDataType hipblasltDatatype_t;
typedef hipblasComputeType_t hipblasLtComputeType_t;
#define HIPBLASLT_R_16F HIP_R_16F
#define HIPBLASLT_R_32F HIP_R_32F
#define HIPBLASLT_R_16B HIP_R_16BF
#define HIPBLASLT_R_8F_E4M3 HIP_R_8F_E4M3_FNUZ
#define HIPBLASLT_R_8F_E5M2 HIP_R_8F_E5M2_FNUZ
#define HIPBLASLT_COMPUTE_F32 HIPBLAS_COMPUTE_32F
#endif // #if HIP_VERSION >= 60000000
hipblasltDatatype_t get_hipblaslt_dtype(const transformer_engine::DType t) {
using namespace transformer_engine; using namespace transformer_engine;
switch (t) { switch (t) {
case DType::kFloat16: case DType::kFloat16:
return HIPBLASLT_R_16F; return HIP_R_16F;
case DType::kFloat32: case DType::kFloat32:
return HIPBLASLT_R_32F; return HIP_R_32F;
case DType::kBFloat16: case DType::kBFloat16:
return HIPBLASLT_R_16B; return HIP_R_16BF;
#if HIP_VERSION >= 60300000
case DType::kFloat8E4M3: case DType::kFloat8E4M3:
return HIPBLASLT_R_8F_E4M3; return te_fp8_fnuz() ? HIP_R_8F_E4M3_FNUZ : HIP_R_8F_E4M3;
case DType::kFloat8E5M2: case DType::kFloat8E5M2:
return HIPBLASLT_R_8F_E5M2; return te_fp8_fnuz() ? HIP_R_8F_E5M2_FNUZ: HIP_R_8F_E5M2;
#else
case DType::kFloat8E4M3:
return HIP_R_8F_E4M3_FNUZ;
case DType::kFloat8E5M2:
return HIP_R_8F_E5M2_FNUZ;
#endif
default: default:
NVTE_ERROR("Invalid type"); NVTE_ERROR("Invalid type");
} }
...@@ -368,11 +364,7 @@ void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool ...@@ -368,11 +364,7 @@ void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool
if(! stream_order_alloc){ if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipMemset(out, 0, n*sizeof(float)) ); NVTE_CHECK_CUDA( hipMemset(out, 0, n*sizeof(float)) );
}else{ }else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipMemsetAsync(out, 0, n*sizeof(float), stream) ); NVTE_CHECK_CUDA( hipMemsetAsync(out, 0, n*sizeof(float), stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
} }
hipLaunchKernelGGL(( bias_gradient_kernel<Tin, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, stream, in, out, m, n); hipLaunchKernelGGL(( bias_gradient_kernel<Tin, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, stream, in, out, m, n);
} }
...@@ -576,11 +568,11 @@ public: ...@@ -576,11 +568,11 @@ public:
const std::string_view &getName(const T &val) { const std::string_view &getName(const T &val) {
return map.at(val); return map.at(val);
} }
T getValue(const std::string& name, const char *label="") T getValue(const std::string& name, const char *label="", std::function<bool(const T&)> filter = nullptr)
{ {
for (auto iter = map.begin(); iter != map.end(); ++iter) for (auto iter = map.begin(); iter != map.end(); ++iter)
{ {
if (name == iter->second) return iter->first; if ((name == iter->second) && (!filter || filter(iter->first))) return iter->first;
} }
NVTE_ERROR("Invalid ", label, " name: ", name); NVTE_ERROR("Invalid ", label, " name: ", name);
} }
...@@ -588,14 +580,18 @@ protected: ...@@ -588,14 +580,18 @@ protected:
const std::unordered_map<T, std::string_view> &map; const std::unordered_map<T, std::string_view> &map;
}; };
static std::unordered_map<hipblasltDatatype_t, std::string_view> type_name_map = { static std::unordered_map<hipDataType, std::string_view> type_name_map = {
{HIPBLASLT_R_32F, "float32"}, {HIP_R_32F, "float32"},
{HIPBLASLT_R_16F, "float16"}, {HIP_R_16F, "float16"},
{HIPBLASLT_R_16B, "bfloat16"}, {HIP_R_16BF, "bfloat16"},
{HIPBLASLT_R_8F_E4M3, "float8e4m3"}, {HIP_R_8F_E4M3_FNUZ, "float8e4m3"},
{HIPBLASLT_R_8F_E5M2, "float8e5m2"}, {HIP_R_8F_E5M2_FNUZ, "float8e5m2"},
#if HIP_VERSION >= 60300000
{HIP_R_8F_E4M3, "float8e4m3"},
{HIP_R_8F_E5M2, "float8e5m2"},
#endif
}; };
static NameMapper<hipblasltDatatype_t> typeNameMapper(type_name_map); static NameMapper<hipDataType> typeNameMapper(type_name_map);
static std::unordered_map<hipblasOperation_t, std::string_view> trans_name_map = { static std::unordered_map<hipblasOperation_t, std::string_view> trans_name_map = {
{HIPBLAS_OP_N, "N"}, {HIPBLAS_OP_N, "N"},
...@@ -614,24 +610,24 @@ static std::unordered_map<hipblasLtEpilogue_t, std::string_view> epi_name_map = ...@@ -614,24 +610,24 @@ static std::unordered_map<hipblasLtEpilogue_t, std::string_view> epi_name_map =
}; };
static NameMapper<hipblasLtEpilogue_t> epilogueNameMapper(epi_name_map); static NameMapper<hipblasLtEpilogue_t> epilogueNameMapper(epi_name_map);
static std::unordered_map<hipblasLtComputeType_t, std::string_view> comp_name_map = { static std::unordered_map<hipblasComputeType_t, std::string_view> comp_name_map = {
{HIPBLASLT_COMPUTE_F32, "f32"} {HIPBLAS_COMPUTE_32F, "f32"}
}; };
static NameMapper<hipblasLtComputeType_t> computeNameMapper(comp_name_map); static NameMapper<hipblasComputeType_t> computeNameMapper(comp_name_map);
static class GemmAlgoCache { static class GemmAlgoCache {
public: public:
struct Key { struct Key {
int deviceCap; int deviceCap;
hipblasltDatatype_t a_type, b_type, d_type, bias_type; hipDataType a_type, b_type, d_type, bias_type;
int m, n, k; int m, n, k;
int lda, ldb, ldd; int lda, ldb, ldd;
hipblasOperation_t transa, transb; hipblasOperation_t transa, transb;
hipblasLtEpilogue_t epilogue; hipblasLtEpilogue_t epilogue;
Key(int deviceCap_, Key(int deviceCap_,
hipblasltDatatype_t a_type_, hipblasltDatatype_t b_type_, hipDataType a_type_, hipDataType b_type_,
hipblasltDatatype_t d_type_, hipblasltDatatype_t bias_type_, hipDataType d_type_, hipDataType bias_type_,
int m_, int n_, int k_, int lda_, int ldb_, int ldd_, int m_, int n_, int k_, int lda_, int ldb_, int ldd_,
hipblasOperation_t transa_, hipblasOperation_t transb_, hipblasOperation_t transa_, hipblasOperation_t transb_,
hipblasLtEpilogue_t epilogue_): hipblasLtEpilogue_t epilogue_):
...@@ -865,18 +861,32 @@ protected: ...@@ -865,18 +861,32 @@ protected:
std::cout << "[WARNING] Invalid WS size at " << line << "\n"; std::cout << "[WARNING] Invalid WS size at " << line << "\n";
continue; continue;
} }
cfg.a_type = typeNameMapper.getValue(type_a, "type_a"); #if HIP_VERSION >= 60300000
cfg.b_type = typeNameMapper.getValue(type_b, "type_b"); auto fp8_filter = te_fp8_fnuz()
cfg.d_type = typeNameMapper.getValue(type_d, "type_d"); ? [](const hipDataType& val)
cfg.bias_type = (bias_type == "-") ? (hipblasltDatatype_t)-1 : typeNameMapper.getValue(bias_type, "bias_type"); { return (val != HIP_R_8F_E4M3 && val != HIP_R_8F_E5M2); }
: [](const hipDataType& val) {
return (val != HIP_R_8F_E4M3_FNUZ && val != HIP_R_8F_E5M2_FNUZ);
};
#else
auto fp8_filter = nullptr;
#endif
cfg.a_type = typeNameMapper.getValue(type_a, "type_a", fp8_filter);
cfg.b_type = typeNameMapper.getValue(type_b, "type_b", fp8_filter);
cfg.d_type = typeNameMapper.getValue(type_d, "type_d", fp8_filter);
cfg.bias_type = (bias_type == "-")
? (hipDataType)-1
: typeNameMapper.getValue(bias_type, "bias_type", fp8_filter);
cfg.transa = transposeNameMapper.getValue(trans_a, "trans_a"); cfg.transa = transposeNameMapper.getValue(trans_a, "trans_a");
cfg.transb = transposeNameMapper.getValue(trans_b, "trans_b"); cfg.transb = transposeNameMapper.getValue(trans_b, "trans_b");
cfg.epilogue = epilogueNameMapper.getValue(epi, "epi"); cfg.epilogue = epilogueNameMapper.getValue(epi, "epi");
//Check and filter out compute and scale types //Check and filter out compute and scale types
if (computeNameMapper.getValue(comp, "comp") != HIPBLASLT_COMPUTE_F32 || typeNameMapper.getValue(scale, "scale") != HIPBLASLT_R_32F) if (computeNameMapper.getValue(comp, "comp") != HIPBLAS_COMPUTE_32F ||
typeNameMapper.getValue(scale, "scale") != HIP_R_32F)
{ {
continue; continue;
} }
...@@ -959,9 +969,9 @@ protected: ...@@ -959,9 +969,9 @@ protected:
csv << cfg.deviceCap << cfg.m << cfg.n << cfg.k csv << cfg.deviceCap << cfg.m << cfg.n << cfg.k
<< transposeNameMapper.getName(cfg.transa) << transposeNameMapper.getName(cfg.transb) << transposeNameMapper.getName(cfg.transa) << transposeNameMapper.getName(cfg.transb)
<< typeNameMapper.getName(cfg.a_type) << typeNameMapper.getName(cfg.b_type) << typeNameMapper.getName(cfg.d_type) << typeNameMapper.getName(cfg.a_type) << typeNameMapper.getName(cfg.b_type) << typeNameMapper.getName(cfg.d_type)
<< ((cfg.bias_type == (hipblasltDatatype_t)-1) ? "-" : typeNameMapper.getName(cfg.bias_type)) << ((cfg.bias_type == (hipDataType)-1) ? "-" : typeNameMapper.getName(cfg.bias_type))
<< cfg.lda << cfg.ldb << cfg.ldd << epilogueNameMapper.getName(cfg.epilogue) << cfg.lda << cfg.ldb << cfg.ldd << epilogueNameMapper.getName(cfg.epilogue)
<< computeNameMapper.getName(HIPBLASLT_COMPUTE_F32) << typeNameMapper.getName(HIPBLASLT_R_32F) << computeNameMapper.getName(HIPBLAS_COMPUTE_32F) << typeNameMapper.getName(HIP_R_32F)
<< algo.ws_size_min << algo.ws_size_max << algo.algoId << algo.index << csv_helper::end() << "\n"; << algo.ws_size_min << algo.ws_size_max << algo.algoId << algo.index << csv_helper::end() << "\n";
} }
...@@ -995,6 +1005,19 @@ static inline int getIntEnv(const char *name, int defval, int minval) ...@@ -995,6 +1005,19 @@ static inline int getIntEnv(const char *name, int defval, int minval)
} //namespace } //namespace
/* Warning: only call once per device!
* When calling nvte_multi_stream_cublas_gemm with hipblaslt backend
* need to create multiple handles corresponding to compute_streams
* to avoid a handle be used by multi-streams concurrently.
*/
static void init_hipblaslt_handles(hipblasLtHandle_t* hipblaslt_handles) {
NVTE_CHECK(hipblaslt_handles != nullptr);
for (int i = 0; i < num_streams; i++) {
NVTE_CHECK_HIPBLASLT(hipblasLtCreate(&hipblaslt_handles[i]));
}
}
void hipblaslt_gemm(const Tensor *inputA, void hipblaslt_gemm(const Tensor *inputA,
const Tensor *inputB, const Tensor *inputB,
Tensor *outputD, Tensor *outputD,
...@@ -1014,7 +1037,8 @@ void hipblaslt_gemm(const Tensor *inputA, ...@@ -1014,7 +1037,8 @@ void hipblaslt_gemm(const Tensor *inputA,
int n_split, int n_split,
bool gemm_producer, bool gemm_producer,
const Tensor *inputCounter, const Tensor *inputCounter,
hipStream_t stream hipStream_t stream,
hipblasLtHandle_t handle
) { ) {
void *A = inputA->data.dptr; void *A = inputA->data.dptr;
void *A_scale_inverse = inputA->scale_inv.dptr; void *A_scale_inverse = inputA->scale_inv.dptr;
...@@ -1027,10 +1051,10 @@ void hipblaslt_gemm(const Tensor *inputA, ...@@ -1027,10 +1051,10 @@ void hipblaslt_gemm(const Tensor *inputA,
const bool gelu = pre_gelu_out != nullptr; const bool gelu = pre_gelu_out != nullptr;
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
is_fp8_dtype(inputB->data.dtype); is_fp8_dtype(inputB->data.dtype);
const hipblasltDatatype_t A_type = get_hipblaslt_dtype(inputA->data.dtype); const hipDataType A_type = get_hipblaslt_dtype(inputA->data.dtype);
const hipblasltDatatype_t B_type = get_hipblaslt_dtype(inputB->data.dtype); const hipDataType B_type = get_hipblaslt_dtype(inputB->data.dtype);
const hipblasltDatatype_t D_type = get_hipblaslt_dtype(outputD->data.dtype); const hipDataType D_type = get_hipblaslt_dtype(outputD->data.dtype);
const hipblasltDatatype_t bias_type = get_hipblaslt_dtype(inputBias->data.dtype); const hipDataType bias_type = get_hipblaslt_dtype(inputBias->data.dtype);
NVTE_CHECK(!is_fp8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr, NVTE_CHECK(!is_fp8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr,
"FP8 input to GEMM requires inverse of scale!"); "FP8 input to GEMM requires inverse of scale!");
...@@ -1050,10 +1074,12 @@ void hipblaslt_gemm(const Tensor *inputA, ...@@ -1050,10 +1074,12 @@ void hipblaslt_gemm(const Tensor *inputA,
int device_id; int device_id;
NVTE_CHECK_CUDA(hipGetDevice(&device_id)); NVTE_CHECK_CUDA(hipGetDevice(&device_id));
hipblasLtHandle_t handle = cached_handles.get(device_id); if (handle == nullptr) {
if (handle == nullptr) handle = cached_handles.get(device_id);
{ if (handle == nullptr)
handle = cached_handles.obtain(device_id); {
handle = cached_handles.obtain(device_id);
}
} }
hipblasLtMatmulDesc_t operationDesc = nullptr; hipblasLtMatmulDesc_t operationDesc = nullptr;
...@@ -1064,7 +1090,7 @@ void hipblaslt_gemm(const Tensor *inputA, ...@@ -1064,7 +1090,7 @@ void hipblaslt_gemm(const Tensor *inputA,
int64_t ld_gelumat = (int64_t) ldd; int64_t ld_gelumat = (int64_t) ldd;
// default to tf32 except for e5m2 inputs where the config is not supported // default to tf32 except for e5m2 inputs where the config is not supported
hipblasLtComputeType_t gemm_compute_type = HIPBLASLT_COMPUTE_F32; hipblasComputeType_t gemm_compute_type = HIPBLAS_COMPUTE_32F;
// Create matrix descriptors. Not setting any extra attributes. // Create matrix descriptors. Not setting any extra attributes.
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Adesc, A_type, NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Adesc, A_type,
...@@ -1077,7 +1103,7 @@ void hipblaslt_gemm(const Tensor *inputA, ...@@ -1077,7 +1103,7 @@ void hipblaslt_gemm(const Tensor *inputA,
ldb)); ldb));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd)); NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescCreate(&operationDesc, gemm_compute_type, HIPBLASLT_R_32F)); NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescCreate(&operationDesc, gemm_compute_type, HIP_R_32F));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSA, NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSA,
&transa, sizeof(transa))); &transa, sizeof(transa)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSB, NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSB,
...@@ -1154,7 +1180,7 @@ void hipblaslt_gemm(const Tensor *inputA, ...@@ -1154,7 +1180,7 @@ void hipblaslt_gemm(const Tensor *inputA,
&epilogue, sizeof(epilogue))); &epilogue, sizeof(epilogue)));
GemmAlgoCache::Key gemm_cfg(algoCache.device_cap(device_id), A_type, B_type, D_type, GemmAlgoCache::Key gemm_cfg(algoCache.device_cap(device_id), A_type, B_type, D_type,
use_fp8 ? bias_type : (hipblasltDatatype_t)-1, use_fp8 ? bias_type : (hipDataType)-1,
m, n, k, lda, ldb, ldd, transa, transb, epilogue ); m, n, k, lda, ldb, ldd, transa, transb, epilogue );
GemmAlgoCache::Algo cached_algo; GemmAlgoCache::Algo cached_algo;
if (algoCache.find(gemm_cfg, workspaceSize, cached_algo) == 0 || !cached_algo.algo.has_value()) if (algoCache.find(gemm_cfg, workspaceSize, cached_algo) == 0 || !cached_algo.algo.has_value())
...@@ -1231,6 +1257,7 @@ void hipblaslt_gemm(const Tensor *inputA, ...@@ -1231,6 +1257,7 @@ void hipblaslt_gemm(const Tensor *inputA,
<< " in range [" << firstAlgo << "-" << (algoTuneCount - 1) << "] with " << " in range [" << firstAlgo << "-" << (algoTuneCount - 1) << "] with "
<< tuneLoopCount << " loops " << std::endl; << tuneLoopCount << " loops " << std::endl;
NVTE_CHECK_CUDA(hipStreamSynchronize(stream));
hipStream_t profilingStream; hipStream_t profilingStream;
NVTE_CHECK_CUDA(hipStreamCreateWithFlags(&profilingStream, hipStreamNonBlocking)); NVTE_CHECK_CUDA(hipStreamCreateWithFlags(&profilingStream, hipStreamNonBlocking));
using tuning_clock = std::chrono::steady_clock; using tuning_clock = std::chrono::steady_clock;
...@@ -1475,11 +1502,7 @@ void rocblas_gemm(const Tensor *inputA, ...@@ -1475,11 +1502,7 @@ void rocblas_gemm(const Tensor *inputA,
if(! stream_order_alloc){ if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipMalloc(&D_temp, sizeof(float)*m*n) ); NVTE_CHECK_CUDA( hipMalloc(&D_temp, sizeof(float)*m*n) );
}else{ }else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipMallocAsync(&D_temp, sizeof(float)*m*n, stream) ); NVTE_CHECK_CUDA( hipMallocAsync(&D_temp, sizeof(float)*m*n, stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
} }
}else { }else {
D_temp = D; D_temp = D;
...@@ -1570,11 +1593,7 @@ void rocblas_gemm(const Tensor *inputA, ...@@ -1570,11 +1593,7 @@ void rocblas_gemm(const Tensor *inputA,
if(! stream_order_alloc){ if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipMalloc(&bias_tmp, sizeof(float)*input_dim) ); // The bias gradient is for the first linear layer NVTE_CHECK_CUDA( hipMalloc(&bias_tmp, sizeof(float)*input_dim) ); // The bias gradient is for the first linear layer
}else{ }else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipMallocAsync(&bias_tmp, sizeof(float)*input_dim, stream) ); NVTE_CHECK_CUDA( hipMallocAsync(&bias_tmp, sizeof(float)*input_dim, stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
} }
}else { }else {
bias_tmp = bias_ptr; bias_tmp = bias_ptr;
...@@ -1600,11 +1619,7 @@ void rocblas_gemm(const Tensor *inputA, ...@@ -1600,11 +1619,7 @@ void rocblas_gemm(const Tensor *inputA,
if(! stream_order_alloc){ if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipFree(bias_tmp) ); NVTE_CHECK_CUDA( hipFree(bias_tmp) );
}else{ }else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipFreeAsync(bias_tmp, stream) ); NVTE_CHECK_CUDA( hipFreeAsync(bias_tmp, stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
} }
} }
...@@ -1652,11 +1667,7 @@ void rocblas_gemm(const Tensor *inputA, ...@@ -1652,11 +1667,7 @@ void rocblas_gemm(const Tensor *inputA,
if(! stream_order_alloc){ if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipMalloc(&bias_tmp, sizeof(float)*output_dim) ); NVTE_CHECK_CUDA( hipMalloc(&bias_tmp, sizeof(float)*output_dim) );
}else{ }else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipMallocAsync(&bias_tmp, sizeof(float)*output_dim, stream) ); NVTE_CHECK_CUDA( hipMallocAsync(&bias_tmp, sizeof(float)*output_dim, stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
} }
}else { }else {
bias_tmp = bias_ptr; bias_tmp = bias_ptr;
...@@ -1683,11 +1694,7 @@ void rocblas_gemm(const Tensor *inputA, ...@@ -1683,11 +1694,7 @@ void rocblas_gemm(const Tensor *inputA,
if(! stream_order_alloc){ if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipFree(bias_tmp) ); NVTE_CHECK_CUDA( hipFree(bias_tmp) );
}else{ }else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipFreeAsync(bias_tmp, stream) ); NVTE_CHECK_CUDA( hipFreeAsync(bias_tmp, stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
} }
} }
if (D_type == rocblas_datatype_f16_r || D_type == rocblas_datatype_bf16_r) { if (D_type == rocblas_datatype_f16_r || D_type == rocblas_datatype_bf16_r) {
...@@ -1788,11 +1795,7 @@ void rocblas_gemm(const Tensor *inputA, ...@@ -1788,11 +1795,7 @@ void rocblas_gemm(const Tensor *inputA,
if(! stream_order_alloc){ if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipFree(D_temp) ); NVTE_CHECK_CUDA( hipFree(D_temp) );
}else{ }else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipFreeAsync(D_temp, stream) ); NVTE_CHECK_CUDA( hipFreeAsync(D_temp, stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
} }
} }
} }
...@@ -1804,7 +1807,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -1804,7 +1807,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
int ldb, int ldd, bool transa, bool transb, bool grad, int ldb, int ldd, bool transa, bool transb, bool grad,
void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator,
int math_sm_count, int m_split, int n_split, bool gemm_producer, int math_sm_count, int m_split, int n_split, bool gemm_producer,
const Tensor *inputCounter, hipStream_t stream, bool nvte_use_hipblaslt = 0, bool nvte_use_rocblas = 0) const Tensor *inputCounter, hipStream_t stream, bool nvte_use_hipblaslt = 0, bool nvte_use_rocblas = 0, int compute_stream_offset = -1)
{ {
/*If no backend is specified with env variable use HIPBLASLT unless it is disabled /*If no backend is specified with env variable use HIPBLASLT unless it is disabled
If HIPBLASLT backend is enabled and requested, use it despite ROCBLAS status If HIPBLASLT backend is enabled and requested, use it despite ROCBLAS status
...@@ -1845,16 +1848,31 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -1845,16 +1848,31 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#endif #endif
#ifdef USE_HIPBLASLT #ifdef USE_HIPBLASLT
if (use_hipblaslt) if (use_hipblaslt || !use_rocblas)
{ {
// Check compute_stream_offset valid.
NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < num_streams);
hipblasLtHandle_t handle = nullptr;
if (compute_stream_offset != -1) {
// Init hipblaslt handles (once, globally)
static std::once_flag init_flag;
static hipblasLtHandle_t hipblaslt_handles[num_streams];
std::call_once(init_flag, init_hipblaslt_handles, hipblaslt_handles);
handle = hipblaslt_handles[compute_stream_offset];
}
hipblaslt_gemm(inputA, inputB, outputD, inputBias, outputPreGelu, hipblaslt_gemm(inputA, inputB, outputD, inputBias, outputPreGelu,
m, n, k, lda, ldb, ldd, m, n, k, lda, ldb, ldd,
(transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N, (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
(transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N, (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
grad, grad,
workspace, workspaceSize, accumulate, use_split_accumulator, workspace, workspaceSize, accumulate, use_split_accumulator,
math_sm_count, m_split, n_split, gemm_producer, math_sm_count, m_split, n_split, gemm_producer,
inputCounter, stream); inputCounter, stream,
handle);
return; return;
} }
#endif #endif
......
...@@ -42,7 +42,7 @@ extern "C" { ...@@ -42,7 +42,7 @@ extern "C" {
void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias, void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
NVTETensor workspace, bool accumulate, bool use_split_accumulator, NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, cudaStream_t stream, bool nvte_use_hipblaslt = 0, bool nvte_use_rocblas = 0); int math_sm_count, cudaStream_t stream, bool nvte_use_hipblaslt = 0, bool nvte_use_rocblas = 0, int compute_stream_offset = -1);
/*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters. /*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters.
* *
...@@ -77,7 +77,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor ...@@ -77,7 +77,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
bool transb, bool grad, NVTETensor workspace, bool accumulate, bool transb, bool grad, NVTETensor workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count, int m_split, bool use_split_accumulator, int math_sm_count, int m_split,
int n_split, bool gemm_producer, const NVTETensor counter, int n_split, bool gemm_producer, const NVTETensor counter,
cudaStream_t stream, bool nvte_use_hipblaslt = 0, bool nvte_use_rocblas = 0); cudaStream_t stream, bool nvte_use_hipblaslt = 0, bool nvte_use_rocblas = 0, int compute_stream_offset = -1);
/*! \brief Compute multiple pairs of matrix multiplication, potentially fused with other operations, /*! \brief Compute multiple pairs of matrix multiplication, potentially fused with other operations,
* on multiple streams. * on multiple streams.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment