Commit 229be5e8 authored by yuguo's avatar yuguo
Browse files

[DCU] new rocm gemm

parent 388ac735
......@@ -451,7 +451,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[0]);
use_split_accumulator, _math_sms, _stream_compute[0], 1, 0, 0);
for (int i = 1; i < _num_splits; i++) {
input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k});
......@@ -462,7 +462,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()], 1, 0);
_stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size());
NVTE_CHECK_CUDA(
cudaEventRecord(_start_comm, _stream_compute[(i - 1) % _stream_compute.size()]));
......@@ -510,7 +510,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()], 1, 0);
_stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size());
NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[i % _stream_compute.size()]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0));
......@@ -821,7 +821,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()], 1, 0);
_stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size());
if (i < num_steps - 1) {
// P2P communication
......@@ -865,7 +865,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()], 1, 0);
_stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size());
if (i < _tp_size - 1) {
// P2P communication
......@@ -1010,7 +1010,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[stream_id], 1, 0);
use_split_accumulator, _math_sms, _stream_compute[stream_id], 1, 0, stream_id);
if (i > 0) {
// P2P communication chunk
......
......@@ -163,7 +163,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
int ldb, int ldd, bool transa, bool transb, bool grad,
void* workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator,
int math_sm_count, int m_split, int n_split, bool gemm_producer,
const Tensor *inputCounter, hipStream_t stream);
const Tensor *inputCounter, hipStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas, int compute_stream_offset);
#else // Use cublasLt
using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
......@@ -484,7 +484,7 @@ static void init_streams_and_events_batchgemm() {
void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, cudaStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas) {
int math_sm_count, cudaStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas, int compute_stream_offset) {
NVTE_API_CALL(nvte_cublas_gemm);
using namespace transformer_engine;
const Tensor *inputA = reinterpret_cast<const Tensor *>(A);
......@@ -539,7 +539,7 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
grad,
wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator,
#ifdef __HIP_PLATFORM_AMD__
math_sm_count, 0, 0, false, nullptr, stream, nvte_use_hipblaslt, nvte_use_rocblas);
math_sm_count, 0, 0, false, nullptr, stream, nvte_use_hipblaslt, nvte_use_rocblas, compute_stream_offset);
#else
math_sm_count, 0, 0, false, nullptr, stream);
#endif
......@@ -574,7 +574,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
bool transb, bool grad, NVTETensor workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count, int m_split,
int n_split, bool gemm_producer, const NVTETensor counter,
cudaStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas) {
cudaStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas, int compute_stream_offset) {
NVTE_API_CALL(nvte_cublas_atomic_gemm);
#ifndef __HIP_PLATFORM_AMD__
......@@ -637,7 +637,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
grad,
wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator,
#ifdef __HIP_PLATFORM_AMD__
math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream, nvte_use_hipblaslt, nvte_use_rocblas);
math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream, nvte_use_hipblaslt, nvte_use_rocblas, compute_stream_offset);
#else
math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream);
#endif
......@@ -706,7 +706,7 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
for (int i = 0; i < num_gemms; i++) {
nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count,
compute_streams[i % num_streams], 1, 0);
compute_streams[i % num_streams], 1, 0, i % num_streams);
}
}
......
......@@ -37,30 +37,26 @@ namespace {
#ifdef USE_HIPBLASLT
#if HIP_VERSION >= 60000000
typedef hipDataType hipblasltDatatype_t;
typedef hipblasComputeType_t hipblasLtComputeType_t;
#define HIPBLASLT_R_16F HIP_R_16F
#define HIPBLASLT_R_32F HIP_R_32F
#define HIPBLASLT_R_16B HIP_R_16BF
#define HIPBLASLT_R_8F_E4M3 HIP_R_8F_E4M3_FNUZ
#define HIPBLASLT_R_8F_E5M2 HIP_R_8F_E5M2_FNUZ
#define HIPBLASLT_COMPUTE_F32 HIPBLAS_COMPUTE_32F
#endif // #if HIP_VERSION >= 60000000
hipblasltDatatype_t get_hipblaslt_dtype(const transformer_engine::DType t) {
static hipDataType get_hipblaslt_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
switch (t) {
case DType::kFloat16:
return HIPBLASLT_R_16F;
return HIP_R_16F;
case DType::kFloat32:
return HIPBLASLT_R_32F;
return HIP_R_32F;
case DType::kBFloat16:
return HIPBLASLT_R_16B;
return HIP_R_16BF;
#if HIP_VERSION >= 60300000
case DType::kFloat8E4M3:
return HIPBLASLT_R_8F_E4M3;
return te_fp8_fnuz() ? HIP_R_8F_E4M3_FNUZ : HIP_R_8F_E4M3;
case DType::kFloat8E5M2:
return HIPBLASLT_R_8F_E5M2;
return te_fp8_fnuz() ? HIP_R_8F_E5M2_FNUZ: HIP_R_8F_E5M2;
#else
case DType::kFloat8E4M3:
return HIP_R_8F_E4M3_FNUZ;
case DType::kFloat8E5M2:
return HIP_R_8F_E5M2_FNUZ;
#endif
default:
NVTE_ERROR("Invalid type");
}
......@@ -368,11 +364,7 @@ void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool
if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipMemset(out, 0, n*sizeof(float)) );
}else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipMemsetAsync(out, 0, n*sizeof(float), stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
}
hipLaunchKernelGGL(( bias_gradient_kernel<Tin, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, stream, in, out, m, n);
}
......@@ -576,11 +568,11 @@ public:
const std::string_view &getName(const T &val) {
return map.at(val);
}
T getValue(const std::string& name, const char *label="")
T getValue(const std::string& name, const char *label="", std::function<bool(const T&)> filter = nullptr)
{
for (auto iter = map.begin(); iter != map.end(); ++iter)
{
if (name == iter->second) return iter->first;
if ((name == iter->second) && (!filter || filter(iter->first))) return iter->first;
}
NVTE_ERROR("Invalid ", label, " name: ", name);
}
......@@ -588,14 +580,18 @@ protected:
const std::unordered_map<T, std::string_view> &map;
};
static std::unordered_map<hipblasltDatatype_t, std::string_view> type_name_map = {
{HIPBLASLT_R_32F, "float32"},
{HIPBLASLT_R_16F, "float16"},
{HIPBLASLT_R_16B, "bfloat16"},
{HIPBLASLT_R_8F_E4M3, "float8e4m3"},
{HIPBLASLT_R_8F_E5M2, "float8e5m2"},
static std::unordered_map<hipDataType, std::string_view> type_name_map = {
{HIP_R_32F, "float32"},
{HIP_R_16F, "float16"},
{HIP_R_16BF, "bfloat16"},
{HIP_R_8F_E4M3_FNUZ, "float8e4m3"},
{HIP_R_8F_E5M2_FNUZ, "float8e5m2"},
#if HIP_VERSION >= 60300000
{HIP_R_8F_E4M3, "float8e4m3"},
{HIP_R_8F_E5M2, "float8e5m2"},
#endif
};
static NameMapper<hipblasltDatatype_t> typeNameMapper(type_name_map);
static NameMapper<hipDataType> typeNameMapper(type_name_map);
static std::unordered_map<hipblasOperation_t, std::string_view> trans_name_map = {
{HIPBLAS_OP_N, "N"},
......@@ -614,24 +610,24 @@ static std::unordered_map<hipblasLtEpilogue_t, std::string_view> epi_name_map =
};
static NameMapper<hipblasLtEpilogue_t> epilogueNameMapper(epi_name_map);
static std::unordered_map<hipblasLtComputeType_t, std::string_view> comp_name_map = {
{HIPBLASLT_COMPUTE_F32, "f32"}
static std::unordered_map<hipblasComputeType_t, std::string_view> comp_name_map = {
{HIPBLAS_COMPUTE_32F, "f32"}
};
static NameMapper<hipblasLtComputeType_t> computeNameMapper(comp_name_map);
static NameMapper<hipblasComputeType_t> computeNameMapper(comp_name_map);
static class GemmAlgoCache {
public:
struct Key {
int deviceCap;
hipblasltDatatype_t a_type, b_type, d_type, bias_type;
hipDataType a_type, b_type, d_type, bias_type;
int m, n, k;
int lda, ldb, ldd;
hipblasOperation_t transa, transb;
hipblasLtEpilogue_t epilogue;
Key(int deviceCap_,
hipblasltDatatype_t a_type_, hipblasltDatatype_t b_type_,
hipblasltDatatype_t d_type_, hipblasltDatatype_t bias_type_,
hipDataType a_type_, hipDataType b_type_,
hipDataType d_type_, hipDataType bias_type_,
int m_, int n_, int k_, int lda_, int ldb_, int ldd_,
hipblasOperation_t transa_, hipblasOperation_t transb_,
hipblasLtEpilogue_t epilogue_):
......@@ -865,18 +861,32 @@ protected:
std::cout << "[WARNING] Invalid WS size at " << line << "\n";
continue;
}
cfg.a_type = typeNameMapper.getValue(type_a, "type_a");
cfg.b_type = typeNameMapper.getValue(type_b, "type_b");
cfg.d_type = typeNameMapper.getValue(type_d, "type_d");
cfg.bias_type = (bias_type == "-") ? (hipblasltDatatype_t)-1 : typeNameMapper.getValue(bias_type, "bias_type");
#if HIP_VERSION >= 60300000
auto fp8_filter = te_fp8_fnuz()
? [](const hipDataType& val)
{ return (val != HIP_R_8F_E4M3 && val != HIP_R_8F_E5M2); }
: [](const hipDataType& val) {
return (val != HIP_R_8F_E4M3_FNUZ && val != HIP_R_8F_E5M2_FNUZ);
};
#else
auto fp8_filter = nullptr;
#endif
cfg.a_type = typeNameMapper.getValue(type_a, "type_a", fp8_filter);
cfg.b_type = typeNameMapper.getValue(type_b, "type_b", fp8_filter);
cfg.d_type = typeNameMapper.getValue(type_d, "type_d", fp8_filter);
cfg.bias_type = (bias_type == "-")
? (hipDataType)-1
: typeNameMapper.getValue(bias_type, "bias_type", fp8_filter);
cfg.transa = transposeNameMapper.getValue(trans_a, "trans_a");
cfg.transb = transposeNameMapper.getValue(trans_b, "trans_b");
cfg.epilogue = epilogueNameMapper.getValue(epi, "epi");
//Check and filter out compute and scale types
if (computeNameMapper.getValue(comp, "comp") != HIPBLASLT_COMPUTE_F32 || typeNameMapper.getValue(scale, "scale") != HIPBLASLT_R_32F)
if (computeNameMapper.getValue(comp, "comp") != HIPBLAS_COMPUTE_32F ||
typeNameMapper.getValue(scale, "scale") != HIP_R_32F)
{
continue;
}
......@@ -959,9 +969,9 @@ protected:
csv << cfg.deviceCap << cfg.m << cfg.n << cfg.k
<< transposeNameMapper.getName(cfg.transa) << transposeNameMapper.getName(cfg.transb)
<< typeNameMapper.getName(cfg.a_type) << typeNameMapper.getName(cfg.b_type) << typeNameMapper.getName(cfg.d_type)
<< ((cfg.bias_type == (hipblasltDatatype_t)-1) ? "-" : typeNameMapper.getName(cfg.bias_type))
<< ((cfg.bias_type == (hipDataType)-1) ? "-" : typeNameMapper.getName(cfg.bias_type))
<< cfg.lda << cfg.ldb << cfg.ldd << epilogueNameMapper.getName(cfg.epilogue)
<< computeNameMapper.getName(HIPBLASLT_COMPUTE_F32) << typeNameMapper.getName(HIPBLASLT_R_32F)
<< computeNameMapper.getName(HIPBLAS_COMPUTE_32F) << typeNameMapper.getName(HIP_R_32F)
<< algo.ws_size_min << algo.ws_size_max << algo.algoId << algo.index << csv_helper::end() << "\n";
}
......@@ -995,6 +1005,19 @@ static inline int getIntEnv(const char *name, int defval, int minval)
} //namespace
/* Warning: only call once per device!
* When calling nvte_multi_stream_cublas_gemm with hipblaslt backend
* need to create multiple handles corresponding to compute_streams
* to avoid a handle be used by multi-streams concurrently.
*/
static void init_hipblaslt_handles(hipblasLtHandle_t* hipblaslt_handles) {
NVTE_CHECK(hipblaslt_handles != nullptr);
for (int i = 0; i < num_streams; i++) {
NVTE_CHECK_HIPBLASLT(hipblasLtCreate(&hipblaslt_handles[i]));
}
}
void hipblaslt_gemm(const Tensor *inputA,
const Tensor *inputB,
Tensor *outputD,
......@@ -1014,7 +1037,8 @@ void hipblaslt_gemm(const Tensor *inputA,
int n_split,
bool gemm_producer,
const Tensor *inputCounter,
hipStream_t stream
hipStream_t stream,
hipblasLtHandle_t handle
) {
void *A = inputA->data.dptr;
void *A_scale_inverse = inputA->scale_inv.dptr;
......@@ -1027,10 +1051,10 @@ void hipblaslt_gemm(const Tensor *inputA,
const bool gelu = pre_gelu_out != nullptr;
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
is_fp8_dtype(inputB->data.dtype);
const hipblasltDatatype_t A_type = get_hipblaslt_dtype(inputA->data.dtype);
const hipblasltDatatype_t B_type = get_hipblaslt_dtype(inputB->data.dtype);
const hipblasltDatatype_t D_type = get_hipblaslt_dtype(outputD->data.dtype);
const hipblasltDatatype_t bias_type = get_hipblaslt_dtype(inputBias->data.dtype);
const hipDataType A_type = get_hipblaslt_dtype(inputA->data.dtype);
const hipDataType B_type = get_hipblaslt_dtype(inputB->data.dtype);
const hipDataType D_type = get_hipblaslt_dtype(outputD->data.dtype);
const hipDataType bias_type = get_hipblaslt_dtype(inputBias->data.dtype);
NVTE_CHECK(!is_fp8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr,
"FP8 input to GEMM requires inverse of scale!");
......@@ -1050,10 +1074,12 @@ void hipblaslt_gemm(const Tensor *inputA,
int device_id;
NVTE_CHECK_CUDA(hipGetDevice(&device_id));
hipblasLtHandle_t handle = cached_handles.get(device_id);
if (handle == nullptr)
{
handle = cached_handles.obtain(device_id);
if (handle == nullptr) {
handle = cached_handles.get(device_id);
if (handle == nullptr)
{
handle = cached_handles.obtain(device_id);
}
}
hipblasLtMatmulDesc_t operationDesc = nullptr;
......@@ -1064,7 +1090,7 @@ void hipblaslt_gemm(const Tensor *inputA,
int64_t ld_gelumat = (int64_t) ldd;
// default to tf32 except for e5m2 inputs where the config is not supported
hipblasLtComputeType_t gemm_compute_type = HIPBLASLT_COMPUTE_F32;
hipblasComputeType_t gemm_compute_type = HIPBLAS_COMPUTE_32F;
// Create matrix descriptors. Not setting any extra attributes.
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Adesc, A_type,
......@@ -1077,7 +1103,7 @@ void hipblaslt_gemm(const Tensor *inputA,
ldb));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescCreate(&operationDesc, gemm_compute_type, HIPBLASLT_R_32F));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescCreate(&operationDesc, gemm_compute_type, HIP_R_32F));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSA,
&transa, sizeof(transa)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSB,
......@@ -1154,7 +1180,7 @@ void hipblaslt_gemm(const Tensor *inputA,
&epilogue, sizeof(epilogue)));
GemmAlgoCache::Key gemm_cfg(algoCache.device_cap(device_id), A_type, B_type, D_type,
use_fp8 ? bias_type : (hipblasltDatatype_t)-1,
use_fp8 ? bias_type : (hipDataType)-1,
m, n, k, lda, ldb, ldd, transa, transb, epilogue );
GemmAlgoCache::Algo cached_algo;
if (algoCache.find(gemm_cfg, workspaceSize, cached_algo) == 0 || !cached_algo.algo.has_value())
......@@ -1231,6 +1257,7 @@ void hipblaslt_gemm(const Tensor *inputA,
<< " in range [" << firstAlgo << "-" << (algoTuneCount - 1) << "] with "
<< tuneLoopCount << " loops " << std::endl;
NVTE_CHECK_CUDA(hipStreamSynchronize(stream));
hipStream_t profilingStream;
NVTE_CHECK_CUDA(hipStreamCreateWithFlags(&profilingStream, hipStreamNonBlocking));
using tuning_clock = std::chrono::steady_clock;
......@@ -1475,11 +1502,7 @@ void rocblas_gemm(const Tensor *inputA,
if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipMalloc(&D_temp, sizeof(float)*m*n) );
}else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipMallocAsync(&D_temp, sizeof(float)*m*n, stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
}
}else {
D_temp = D;
......@@ -1570,11 +1593,7 @@ void rocblas_gemm(const Tensor *inputA,
if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipMalloc(&bias_tmp, sizeof(float)*input_dim) ); // The bias gradient is for the first linear layer
}else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipMallocAsync(&bias_tmp, sizeof(float)*input_dim, stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
}
}else {
bias_tmp = bias_ptr;
......@@ -1600,11 +1619,7 @@ void rocblas_gemm(const Tensor *inputA,
if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipFree(bias_tmp) );
}else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipFreeAsync(bias_tmp, stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
}
}
......@@ -1652,11 +1667,7 @@ void rocblas_gemm(const Tensor *inputA,
if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipMalloc(&bias_tmp, sizeof(float)*output_dim) );
}else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipMallocAsync(&bias_tmp, sizeof(float)*output_dim, stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
}
}else {
bias_tmp = bias_ptr;
......@@ -1683,11 +1694,7 @@ void rocblas_gemm(const Tensor *inputA,
if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipFree(bias_tmp) );
}else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipFreeAsync(bias_tmp, stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
}
}
if (D_type == rocblas_datatype_f16_r || D_type == rocblas_datatype_bf16_r) {
......@@ -1788,11 +1795,7 @@ void rocblas_gemm(const Tensor *inputA,
if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipFree(D_temp) );
}else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipFreeAsync(D_temp, stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
}
}
}
......@@ -1804,7 +1807,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
int ldb, int ldd, bool transa, bool transb, bool grad,
void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator,
int math_sm_count, int m_split, int n_split, bool gemm_producer,
const Tensor *inputCounter, hipStream_t stream, bool nvte_use_hipblaslt = 0, bool nvte_use_rocblas = 0)
const Tensor *inputCounter, hipStream_t stream, bool nvte_use_hipblaslt = 0, bool nvte_use_rocblas = 0, int compute_stream_offset = -1)
{
/*If no backend is specified with env variable use HIPBLASLT unless it is disabled
If HIPBLASLT backend is enabled and requested, use it despite ROCBLAS status
......@@ -1845,16 +1848,31 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#endif
#ifdef USE_HIPBLASLT
if (use_hipblaslt)
if (use_hipblaslt || !use_rocblas)
{
// Check compute_stream_offset valid.
NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < num_streams);
hipblasLtHandle_t handle = nullptr;
if (compute_stream_offset != -1) {
// Init hipblaslt handles (once, globally)
static std::once_flag init_flag;
static hipblasLtHandle_t hipblaslt_handles[num_streams];
std::call_once(init_flag, init_hipblaslt_handles, hipblaslt_handles);
handle = hipblaslt_handles[compute_stream_offset];
}
hipblaslt_gemm(inputA, inputB, outputD, inputBias, outputPreGelu,
m, n, k, lda, ldb, ldd,
(transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
(transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
grad,
workspace, workspaceSize, accumulate, use_split_accumulator,
math_sm_count, m_split, n_split, gemm_producer,
inputCounter, stream);
m, n, k, lda, ldb, ldd,
(transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
(transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
grad,
workspace, workspaceSize, accumulate, use_split_accumulator,
math_sm_count, m_split, n_split, gemm_producer,
inputCounter, stream,
handle);
return;
}
#endif
......
......@@ -42,7 +42,7 @@ extern "C" {
void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, cudaStream_t stream, bool nvte_use_hipblaslt = 0, bool nvte_use_rocblas = 0);
int math_sm_count, cudaStream_t stream, bool nvte_use_hipblaslt = 0, bool nvte_use_rocblas = 0, int compute_stream_offset = -1);
/*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters.
*
......@@ -77,7 +77,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
bool transb, bool grad, NVTETensor workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count, int m_split,
int n_split, bool gemm_producer, const NVTETensor counter,
cudaStream_t stream, bool nvte_use_hipblaslt = 0, bool nvte_use_rocblas = 0);
cudaStream_t stream, bool nvte_use_hipblaslt = 0, bool nvte_use_rocblas = 0, int compute_stream_offset = -1);
/*! \brief Compute multiple pairs of matrix multiplication, potentially fused with other operations,
* on multiple streams.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment