Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
8b27a2b7
Commit
8b27a2b7
authored
Apr 23, 2025
by
yuguo
Browse files
[DCU] surpport rocm gemm rocblas
parent
73f3ac47
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
150 additions
and
180 deletions
+150
-180
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+62
-113
transformer_engine/common/gemm/rocm_gemm.cu
transformer_engine/common/gemm/rocm_gemm.cu
+86
-65
transformer_engine/common/include/transformer_engine/gemm.h
transformer_engine/common/include/transformer_engine/gemm.h
+2
-2
No files found.
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
8b27a2b7
...
...
@@ -484,7 +484,7 @@ static void init_streams_and_events_batchgemm() {
void
nvte_cublas_gemm
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
)
{
int
math_sm_count
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
,
bool
nvte_use_rocblas
)
{
NVTE_API_CALL
(
nvte_cublas_gemm
);
using
namespace
transformer_engine
;
const
Tensor
*
inputA
=
reinterpret_cast
<
const
Tensor
*>
(
A
);
...
...
@@ -521,15 +521,11 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
#ifdef __HIP_PLATFORM_AMD__
#ifdef USE_HIPBLASLT
const
char
*
NVTE_BLASLT_BLAS
=
std
::
getenv
(
"NVTE_FORCE_BLASLT"
);
const
char
*
NVTE_FORCE_ROCM_GEMM
=
std
::
getenv
(
"NVTE_FORCE_ROCM_GEMM"
);
const
bool
use_fp8
=
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
is_fp8_dtype
(
inputB
->
data
.
dtype
);
if
((
biasTensor
->
data
.
dptr
!=
nullptr
)
||
(
outputGelu
->
data
.
dptr
!=
nullptr
)
||
(
use_fp8
)
||
(
NVTE_
BLASLT_BLAS
!=
nullptr
&&
NVTE_
BLASLT_BLAS
[
0
]
==
'1'
)){
if
((
biasTensor
->
data
.
dptr
!=
nullptr
)
||
(
outputGelu
->
data
.
dptr
!=
nullptr
)
||
(
use_fp8
)
||
(
NVTE_
FORCE_ROCM_GEMM
!=
nullptr
&&
NVTE_
FORCE_ROCM_GEMM
[
0
]
==
'1'
)
||
(
nvte_use_hipblaslt
)
||
(
nvte_use_rocblas
)){
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
#else
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
#endif //USE_HIPBLASLT
#else
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
...
...
@@ -542,32 +538,33 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
#endif //__HIP_PLATFORM_AMD__
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
#ifdef __HIP_PLATFORM_AMD__
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
,
nvte_use_hipblaslt
,
nvte_use_rocblas
);
#else
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
);
#endif
#ifdef __HIP_PLATFORM_AMD__
#ifdef USE_HIPBLASLT
}
else
{
}
else
{
hipblas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
(
transa
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transb
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
);
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
(
transa
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transb
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
);
}
#endif //USE_HIPBLASLT
#endif //__HIP_PLATFORM_AMD__
}
...
...
@@ -577,7 +574,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
const
NVTETensor
counter
,
cudaStream_t
stream
)
{
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
,
bool
nvte_use_rocblas
)
{
NVTE_API_CALL
(
nvte_cublas_atomic_gemm
);
#ifndef __HIP_PLATFORM_AMD__
...
...
@@ -622,15 +619,11 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
#ifdef __HIP_PLATFORM_AMD__
#ifdef USE_HIPBLASLT
const
char
*
NVTE_BLASLT_BLAS
=
std
::
getenv
(
"NVTE_FORCE_BLASLT"
);
const
char
*
NVTE_FORCE_ROCM_GEMM
=
std
::
getenv
(
"NVTE_FORCE_ROCM_GEMM"
);
const
bool
use_fp8
=
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
is_fp8_dtype
(
inputB
->
data
.
dtype
);
if
((
biasTensor
->
data
.
dptr
!=
nullptr
)
||
(
outputGelu
->
data
.
dptr
!=
nullptr
)
||
(
use_fp8
)
||
(
NVTE_
BLASLT_BLAS
!=
nullptr
&&
NVTE_
BLASLT_BLAS
[
0
]
==
'1'
)){
if
((
biasTensor
->
data
.
dptr
!=
nullptr
)
||
(
outputGelu
->
data
.
dptr
!=
nullptr
)
||
(
use_fp8
)
||
(
NVTE_
FORCE_ROCM_GEMM
!=
nullptr
&&
NVTE_
FORCE_ROCM_GEMM
[
0
]
==
'1'
)
||
(
nvte_use_hipblaslt
)
||
(
nvte_use_rocblas
)){
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
#else
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
#endif //USE_HIPBLASLT
#else
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
...
...
@@ -643,81 +636,38 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
#endif //__HIP_PLATFORM_AMD__
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
#ifdef __HIP_PLATFORM_AMD__
math_sm_count
,
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
,
nvte_use_hipblaslt
,
nvte_use_rocblas
);
#else
math_sm_count
,
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
);
#endif
#ifdef __HIP_PLATFORM_AMD__
#ifdef USE_HIPBLASLT
}
else
{
hipblas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
(
transa
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transb
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
);
}
else
{
hipblas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
(
transa
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transb
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
);
}
#endif //USE_HIPBLASLT
#endif //__HIP_PLATFORM_AMD__
}
void
nvte_cublaslt_gemm
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_cublaslt_gemm
);
using
namespace
transformer_engine
;
const
Tensor
*
inputA
=
reinterpret_cast
<
const
Tensor
*>
(
A
);
const
Tensor
*
inputB
=
reinterpret_cast
<
const
Tensor
*>
(
B
);
Tensor
*
outputD
=
reinterpret_cast
<
Tensor
*>
(
D
);
const
Tensor
*
biasTensor
=
reinterpret_cast
<
const
Tensor
*>
(
bias
);
Tensor
*
outputGelu
=
reinterpret_cast
<
Tensor
*>
(
pre_gelu_out
);
Tensor
*
wspace
=
reinterpret_cast
<
Tensor
*>
(
workspace
);
const
int
m
=
transa
?
inputA
->
data
.
shape
[
0
]
:
inputA
->
data
.
shape
[
1
];
const
int
k
=
transa
?
inputA
->
data
.
shape
[
1
]
:
inputA
->
data
.
shape
[
0
];
const
int
n
=
transb
?
inputB
->
data
.
shape
[
1
]
:
inputB
->
data
.
shape
[
0
];
int
lda
,
ldb
,
ldd
;
if
(
transa
&&
!
transb
)
{
// TN
lda
=
k
;
ldb
=
k
;
ldd
=
m
;
}
else
if
(
!
transa
&&
!
transb
)
{
// NN
lda
=
m
;
ldb
=
k
;
ldd
=
m
;
}
else
if
(
!
transa
&&
transb
)
{
// NT
lda
=
m
;
ldb
=
n
;
ldd
=
m
;
}
else
{
// TT
NVTE_ERROR
(
"TT layout not allowed."
);
}
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
#ifdef __HIP_PLATFORM_AMD__
transa
,
transb
,
#else
(
transa
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
(
transb
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
#endif //__HIP_PLATFORM_AMD__
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
);
}
void
nvte_multi_stream_cublas_gemm
(
const
NVTETensor
*
A
,
const
NVTETensor
*
B
,
NVTETensor
*
D
,
const
NVTETensor
*
bias
,
NVTETensor
*
pre_gelu_out
,
...
...
@@ -736,20 +686,19 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
for
(
int
s
=
0
;
s
<
num_stream_used
;
s
++
)
{
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
compute_streams
[
s
],
cublas_event
[
0
]));
}
const
char
*
NVTE_BLAS_MULSTREAM
=
std
::
getenv
(
"NVTE_FORCE_BLAS_MULSTREAM"
);
const
char
*
NVTE_BLASLT_BLAS
=
std
::
getenv
(
"NVTE_FORCE_BLASLT"
);
bool
NVTE_FORCE_BLASLT_MULSTREAM
;
if
(
NVTE_BLAS_MULSTREAM
==
nullptr
){
NVTE_FORCE_BLASLT_MULSTREAM
=
true
;
}
else
if
((
NVTE_BLASLT_BLAS
!=
nullptr
&&
NVTE_BLASLT_BLAS
[
0
]
==
'1'
)
&&
(
NVTE_BLAS_MULSTREAM
!=
nullptr
&&
NVTE_BLAS_MULSTREAM
[
0
]
==
'1'
)){
NVTE_ERROR
(
"NVTE_FORCE_BLAS_MULSTREAM and NVTE_FORCE_BLASLT can't be set at the same time."
);
const
char
*
NVTE_HIPBLAS_MULSTREAM
=
std
::
getenv
(
"NVTE_FORCE_HIPBLAS_MULSTREAM"
);
const
char
*
NVTE_FORCE_ROCM_GEMM
=
std
::
getenv
(
"NVTE_FORCE_ROCM_GEMM"
);
bool
NVTE_FORCE_HIPBLAS_MULSTREAM
;
if
(
NVTE_HIPBLAS_MULSTREAM
!=
nullptr
&&
NVTE_HIPBLAS_MULSTREAM
[
0
]
==
'1'
){
NVTE_FORCE_HIPBLAS_MULSTREAM
=
true
;
if
((
NVTE_FORCE_ROCM_GEMM
!=
nullptr
&&
NVTE_FORCE_ROCM_GEMM
[
0
]
==
'1'
)
&&
(
NVTE_HIPBLAS_MULSTREAM
!=
nullptr
&&
NVTE_HIPBLAS_MULSTREAM
[
0
]
==
'1'
))
NVTE_ERROR
(
"NVTE_FORCE_HIPBLAS_MULSTREAM and NVTE_FORCE_ROCM_GEMM can't be set at the same time."
);
}
else
{
NVTE_FORCE_BLAS
LT
_MULSTREAM
=
false
;
NVTE_FORCE_
HIP
BLAS_MULSTREAM
=
false
;
}
if
(
NVTE_FORCE_BLAS
LT
_MULSTREAM
){
if
(
NVTE_FORCE_
HIP
BLAS_MULSTREAM
){
for
(
int
i
=
0
;
i
<
num_gemms
;
i
++
)
{
nvte_cublas
lt
_gemm
(
A
[
i
],
B
[
i
],
D
[
i
],
bias
[
i
],
pre_gelu_out
[
i
],
transa
,
transb
,
grad
,
nvte_cublas_gemm
(
A
[
i
],
B
[
i
],
D
[
i
],
bias
[
i
],
pre_gelu_out
[
i
],
transa
,
transb
,
grad
,
workspace
[
i
%
num_streams
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
compute_streams
[
i
%
num_streams
]);
}
...
...
@@ -757,7 +706,7 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
for
(
int
i
=
0
;
i
<
num_gemms
;
i
++
)
{
nvte_cublas_gemm
(
A
[
i
],
B
[
i
],
D
[
i
],
bias
[
i
],
pre_gelu_out
[
i
],
transa
,
transb
,
grad
,
workspace
[
i
%
num_streams
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
compute_streams
[
i
%
num_streams
]);
compute_streams
[
i
%
num_streams
]
,
1
,
0
);
}
}
...
...
transformer_engine/common/gemm/rocm_gemm.cu
View file @
8b27a2b7
...
...
@@ -36,26 +36,30 @@ namespace {
#ifdef USE_HIPBLASLT
static
hipDataType
get_hipblaslt_dtype
(
const
transformer_engine
::
DType
t
)
{
#if HIP_VERSION >= 60000000
typedef
hipDataType
hipblasltDatatype_t
;
typedef
hipblasComputeType_t
hipblasLtComputeType_t
;
#define HIPBLASLT_R_16F HIP_R_16F
#define HIPBLASLT_R_32F HIP_R_32F
#define HIPBLASLT_R_16B HIP_R_16BF
#define HIPBLASLT_R_8F_E4M3 HIP_R_8F_E4M3_FNUZ
#define HIPBLASLT_R_8F_E5M2 HIP_R_8F_E5M2_FNUZ
#define HIPBLASLT_COMPUTE_F32 HIPBLAS_COMPUTE_32F
#endif // #if HIP_VERSION >= 60000000
hipblasltDatatype_t
get_hipblaslt_dtype
(
const
transformer_engine
::
DType
t
)
{
using
namespace
transformer_engine
;
switch
(
t
)
{
case
DType
::
kFloat16
:
return
HIP_R_16F
;
return
HIP
BLASLT
_R_16F
;
case
DType
::
kFloat32
:
return
HIP_R_32F
;
return
HIP
BLASLT
_R_32F
;
case
DType
::
kBFloat16
:
return
HIP_R_16BF
;
#if HIP_VERSION >= 60300000
return
HIPBLASLT_R_16B
;
case
DType
::
kFloat8E4M3
:
return
te_fp8_fnuz
()
?
HIP_R_8F_E4M3_FNUZ
:
HIP
_R_8F_E4M3
;
return
HIPBLASLT
_R_8F_E4M3
;
case
DType
::
kFloat8E5M2
:
return
te_fp8_fnuz
()
?
HIP_R_8F_E5M2_FNUZ
:
HIP_R_8F_E5M2
;
#else
case
DType
::
kFloat8E4M3
:
return
HIP_R_8F_E4M3_FNUZ
;
case
DType
::
kFloat8E5M2
:
return
HIP_R_8F_E5M2_FNUZ
;
#endif
return
HIPBLASLT_R_8F_E5M2
;
default:
NVTE_ERROR
(
"Invalid type"
);
}
...
...
@@ -363,7 +367,11 @@ void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipMemset
(
out
,
0
,
n
*
sizeof
(
float
))
);
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipMemsetAsync
(
out
,
0
,
n
*
sizeof
(
float
),
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
hipLaunchKernelGGL
((
bias_gradient_kernel
<
Tin
,
THREADS_PER_BLOCK
>
),
dim3
(
grid
),
dim3
(
block
),
0
,
stream
,
in
,
out
,
m
,
n
);
}
...
...
@@ -567,11 +575,11 @@ public:
const
std
::
string_view
&
getName
(
const
T
&
val
)
{
return
map
.
at
(
val
);
}
T
getValue
(
const
std
::
string
&
name
,
const
char
*
label
=
""
,
std
::
function
<
bool
(
const
T
&
)
>
filter
=
nullptr
)
T
getValue
(
const
std
::
string
&
name
,
const
char
*
label
=
""
)
{
for
(
auto
iter
=
map
.
begin
();
iter
!=
map
.
end
();
++
iter
)
{
if
(
(
name
==
iter
->
second
)
&&
(
!
filter
||
filter
(
iter
->
first
)))
return
iter
->
first
;
if
(
name
==
iter
->
second
)
return
iter
->
first
;
}
NVTE_ERROR
(
"Invalid "
,
label
,
" name: "
,
name
);
}
...
...
@@ -579,18 +587,14 @@ protected:
const
std
::
unordered_map
<
T
,
std
::
string_view
>
&
map
;
};
static
std
::
unordered_map
<
hipDataType
,
std
::
string_view
>
type_name_map
=
{
{
HIP_R_32F
,
"float32"
},
{
HIP_R_16F
,
"float16"
},
{
HIP_R_16BF
,
"bfloat16"
},
{
HIP_R_8F_E4M3_FNUZ
,
"float8e4m3"
},
{
HIP_R_8F_E5M2_FNUZ
,
"float8e5m2"
},
#if HIP_VERSION >= 60300000
{
HIP_R_8F_E4M3
,
"float8e4m3"
},
{
HIP_R_8F_E5M2
,
"float8e5m2"
},
#endif
static
std
::
unordered_map
<
hipblasltDatatype_t
,
std
::
string_view
>
type_name_map
=
{
{
HIPBLASLT_R_32F
,
"float32"
},
{
HIPBLASLT_R_16F
,
"float16"
},
{
HIPBLASLT_R_16B
,
"bfloat16"
},
{
HIPBLASLT_R_8F_E4M3
,
"float8e4m3"
},
{
HIPBLASLT_R_8F_E5M2
,
"float8e5m2"
},
};
static
NameMapper
<
hipData
T
ype
>
typeNameMapper
(
type_name_map
);
static
NameMapper
<
hip
blaslt
Data
t
ype
_t
>
typeNameMapper
(
type_name_map
);
static
std
::
unordered_map
<
hipblasOperation_t
,
std
::
string_view
>
trans_name_map
=
{
{
HIPBLAS_OP_N
,
"N"
},
...
...
@@ -609,24 +613,24 @@ static std::unordered_map<hipblasLtEpilogue_t, std::string_view> epi_name_map =
};
static
NameMapper
<
hipblasLtEpilogue_t
>
epilogueNameMapper
(
epi_name_map
);
static
std
::
unordered_map
<
hipblasComputeType_t
,
std
::
string_view
>
comp_name_map
=
{
{
HIPBLAS_COMPUTE_32
F
,
"f32"
}
static
std
::
unordered_map
<
hipblas
Lt
ComputeType_t
,
std
::
string_view
>
comp_name_map
=
{
{
HIPBLAS
LT
_COMPUTE_
F
32
,
"f32"
}
};
static
NameMapper
<
hipblasComputeType_t
>
computeNameMapper
(
comp_name_map
);
static
NameMapper
<
hipblas
Lt
ComputeType_t
>
computeNameMapper
(
comp_name_map
);
static
class
GemmAlgoCache
{
public:
struct
Key
{
int
deviceCap
;
hipData
T
ype
a_type
,
b_type
,
d_type
,
bias_type
;
hip
blaslt
Data
t
ype
_t
a_type
,
b_type
,
d_type
,
bias_type
;
int
m
,
n
,
k
;
int
lda
,
ldb
,
ldd
;
hipblasOperation_t
transa
,
transb
;
hipblasLtEpilogue_t
epilogue
;
Key
(
int
deviceCap_
,
hipData
T
ype
a_type_
,
hipData
T
ype
b_type_
,
hipData
T
ype
d_type_
,
hipData
T
ype
bias_type_
,
hip
blaslt
Data
t
ype
_t
a_type_
,
hip
blaslt
Data
t
ype
_t
b_type_
,
hip
blaslt
Data
t
ype
_t
d_type_
,
hip
blaslt
Data
t
ype
_t
bias_type_
,
int
m_
,
int
n_
,
int
k_
,
int
lda_
,
int
ldb_
,
int
ldd_
,
hipblasOperation_t
transa_
,
hipblasOperation_t
transb_
,
hipblasLtEpilogue_t
epilogue_
)
:
...
...
@@ -860,32 +864,18 @@ protected:
std
::
cout
<<
"[WARNING] Invalid WS size at "
<<
line
<<
"
\n
"
;
continue
;
}
#if HIP_VERSION >= 60300000
auto
fp8_filter
=
te_fp8_fnuz
()
?
[](
const
hipDataType
&
val
)
{
return
(
val
!=
HIP_R_8F_E4M3
&&
val
!=
HIP_R_8F_E5M2
);
}
:
[](
const
hipDataType
&
val
)
{
return
(
val
!=
HIP_R_8F_E4M3_FNUZ
&&
val
!=
HIP_R_8F_E5M2_FNUZ
);
};
#else
auto
fp8_filter
=
nullptr
;
#endif
cfg
.
a_type
=
typeNameMapper
.
getValue
(
type_a
,
"type_a"
,
fp8_filter
);
cfg
.
b_type
=
typeNameMapper
.
getValue
(
type_b
,
"type_b"
,
fp8_filter
);
cfg
.
d_type
=
typeNameMapper
.
getValue
(
type_d
,
"type_d"
,
fp8_filter
);
cfg
.
bias_type
=
(
bias_type
==
"-"
)
?
(
hipDataType
)
-
1
:
typeNameMapper
.
getValue
(
bias_type
,
"bias_type"
,
fp8_filter
);
cfg
.
a_type
=
typeNameMapper
.
getValue
(
type_a
,
"type_a"
);
cfg
.
b_type
=
typeNameMapper
.
getValue
(
type_b
,
"type_b"
);
cfg
.
d_type
=
typeNameMapper
.
getValue
(
type_d
,
"type_d"
);
cfg
.
bias_type
=
(
bias_type
==
"-"
)
?
(
hipblasltDatatype_t
)
-
1
:
typeNameMapper
.
getValue
(
bias_type
,
"bias_type"
);
cfg
.
transa
=
transposeNameMapper
.
getValue
(
trans_a
,
"trans_a"
);
cfg
.
transb
=
transposeNameMapper
.
getValue
(
trans_b
,
"trans_b"
);
cfg
.
epilogue
=
epilogueNameMapper
.
getValue
(
epi
,
"epi"
);
//Check and filter out compute and scale types
if
(
computeNameMapper
.
getValue
(
comp
,
"comp"
)
!=
HIPBLAS_COMPUTE_32F
||
typeNameMapper
.
getValue
(
scale
,
"scale"
)
!=
HIP_R_32F
)
if
(
computeNameMapper
.
getValue
(
comp
,
"comp"
)
!=
HIPBLASLT_COMPUTE_F32
||
typeNameMapper
.
getValue
(
scale
,
"scale"
)
!=
HIPBLASLT_R_32F
)
{
continue
;
}
...
...
@@ -968,9 +958,9 @@ protected:
csv
<<
cfg
.
deviceCap
<<
cfg
.
m
<<
cfg
.
n
<<
cfg
.
k
<<
transposeNameMapper
.
getName
(
cfg
.
transa
)
<<
transposeNameMapper
.
getName
(
cfg
.
transb
)
<<
typeNameMapper
.
getName
(
cfg
.
a_type
)
<<
typeNameMapper
.
getName
(
cfg
.
b_type
)
<<
typeNameMapper
.
getName
(
cfg
.
d_type
)
<<
((
cfg
.
bias_type
==
(
hipData
T
ype
)
-
1
)
?
"-"
:
typeNameMapper
.
getName
(
cfg
.
bias_type
))
<<
((
cfg
.
bias_type
==
(
hip
blaslt
Data
t
ype
_t
)
-
1
)
?
"-"
:
typeNameMapper
.
getName
(
cfg
.
bias_type
))
<<
cfg
.
lda
<<
cfg
.
ldb
<<
cfg
.
ldd
<<
epilogueNameMapper
.
getName
(
cfg
.
epilogue
)
<<
computeNameMapper
.
getName
(
HIPBLAS_COMPUTE_32
F
)
<<
typeNameMapper
.
getName
(
HIP_R_32F
)
<<
computeNameMapper
.
getName
(
HIPBLAS
LT
_COMPUTE_
F
32
)
<<
typeNameMapper
.
getName
(
HIP
BLASLT
_R_32F
)
<<
algo
.
ws_size_min
<<
algo
.
ws_size_max
<<
algo
.
algoId
<<
algo
.
index
<<
csv_helper
::
end
()
<<
"
\n
"
;
}
...
...
@@ -1036,10 +1026,10 @@ void hipblaslt_gemm(const Tensor *inputA,
const
bool
gelu
=
pre_gelu_out
!=
nullptr
;
const
bool
use_fp8
=
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
is_fp8_dtype
(
inputB
->
data
.
dtype
);
const
hipData
T
ype
A_type
=
get_hipblaslt_dtype
(
inputA
->
data
.
dtype
);
const
hipData
T
ype
B_type
=
get_hipblaslt_dtype
(
inputB
->
data
.
dtype
);
const
hipData
T
ype
D_type
=
get_hipblaslt_dtype
(
outputD
->
data
.
dtype
);
const
hipData
T
ype
bias_type
=
get_hipblaslt_dtype
(
inputBias
->
data
.
dtype
);
const
hip
blaslt
Data
t
ype
_t
A_type
=
get_hipblaslt_dtype
(
inputA
->
data
.
dtype
);
const
hip
blaslt
Data
t
ype
_t
B_type
=
get_hipblaslt_dtype
(
inputB
->
data
.
dtype
);
const
hip
blaslt
Data
t
ype
_t
D_type
=
get_hipblaslt_dtype
(
outputD
->
data
.
dtype
);
const
hip
blaslt
Data
t
ype
_t
bias_type
=
get_hipblaslt_dtype
(
inputBias
->
data
.
dtype
);
NVTE_CHECK
(
!
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
A_scale_inverse
!=
nullptr
,
"FP8 input to GEMM requires inverse of scale!"
);
...
...
@@ -1073,7 +1063,7 @@ void hipblaslt_gemm(const Tensor *inputA,
int64_t
ld_gelumat
=
(
int64_t
)
ldd
;
// default to tf32 except for e5m2 inputs where the config is not supported
hipblasComputeType_t
gemm_compute_type
=
HIPBLAS_COMPUTE_32
F
;
hipblas
Lt
ComputeType_t
gemm_compute_type
=
HIPBLAS
LT
_COMPUTE_
F
32
;
// Create matrix descriptors. Not setting any extra attributes.
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutCreate
(
&
Adesc
,
A_type
,
...
...
@@ -1086,7 +1076,7 @@ void hipblaslt_gemm(const Tensor *inputA,
ldb
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutCreate
(
&
Ddesc
,
D_type
,
m
,
n
,
ldd
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescCreate
(
&
operationDesc
,
gemm_compute_type
,
HIP_R_32F
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescCreate
(
&
operationDesc
,
gemm_compute_type
,
HIP
BLASLT
_R_32F
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_TRANSA
,
&
transa
,
sizeof
(
transa
)));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_TRANSB
,
...
...
@@ -1163,7 +1153,7 @@ void hipblaslt_gemm(const Tensor *inputA,
&
epilogue
,
sizeof
(
epilogue
)));
GemmAlgoCache
::
Key
gemm_cfg
(
algoCache
.
device_cap
(
device_id
),
A_type
,
B_type
,
D_type
,
use_fp8
?
bias_type
:
(
hipData
T
ype
)
-
1
,
use_fp8
?
bias_type
:
(
hip
blaslt
Data
t
ype
_t
)
-
1
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
transa
,
transb
,
epilogue
);
GemmAlgoCache
::
Algo
cached_algo
;
if
(
algoCache
.
find
(
gemm_cfg
,
workspaceSize
,
cached_algo
)
==
0
||
!
cached_algo
.
algo
.
has_value
())
...
...
@@ -1478,7 +1468,11 @@ void rocblas_gemm(const Tensor *inputA,
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipMalloc
(
&
D_temp
,
sizeof
(
float
)
*
m
*
n
)
);
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipMallocAsync
(
&
D_temp
,
sizeof
(
float
)
*
m
*
n
,
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
}
else
{
D_temp
=
D
;
...
...
@@ -1571,7 +1565,11 @@ void rocblas_gemm(const Tensor *inputA,
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipMalloc
(
&
bias_tmp
,
sizeof
(
float
)
*
input_dim
)
);
// The bias gradient is for the first linear layer
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipMallocAsync
(
&
bias_tmp
,
sizeof
(
float
)
*
input_dim
,
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
}
else
{
bias_tmp
=
bias_ptr
;
...
...
@@ -1597,7 +1595,11 @@ void rocblas_gemm(const Tensor *inputA,
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipFree
(
bias_tmp
)
);
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipFreeAsync
(
bias_tmp
,
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
}
...
...
@@ -1645,7 +1647,11 @@ void rocblas_gemm(const Tensor *inputA,
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipMalloc
(
&
bias_tmp
,
sizeof
(
float
)
*
output_dim
)
);
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipMallocAsync
(
&
bias_tmp
,
sizeof
(
float
)
*
output_dim
,
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
}
else
{
bias_tmp
=
bias_ptr
;
...
...
@@ -1672,7 +1678,11 @@ void rocblas_gemm(const Tensor *inputA,
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipFree
(
bias_tmp
)
);
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipFreeAsync
(
bias_tmp
,
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
}
if
(
D_type
==
rocblas_datatype_f16_r
||
D_type
==
rocblas_datatype_bf16_r
)
{
...
...
@@ -1773,7 +1783,11 @@ void rocblas_gemm(const Tensor *inputA,
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipFree
(
D_temp
)
);
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipFreeAsync
(
D_temp
,
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
}
}
...
...
@@ -1785,15 +1799,15 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
int
ldb
,
int
ldd
,
bool
transa
,
bool
transb
,
bool
grad
,
void
*
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
const
Tensor
*
inputCounter
,
hipStream_t
stream
)
const
Tensor
*
inputCounter
,
hipStream_t
stream
,
bool
nvte_use_hipblaslt
=
0
,
bool
nvte_use_rocblas
=
0
)
{
/*If no backend is specified with env variable use HIPBLASLT unless it is disabled
If HIPBLASLT backend is enabled and requested, use it despite ROCBLAS status
Otherwise use ROCBLAS
*/
bool
use_hipblaslt
=
std
::
getenv
(
"NVTE_USE_HIPBLASLT"
)
!=
nullptr
;
bool
use_rocblas
=
std
::
getenv
(
"NVTE_USE_ROCBLAS"
)
!=
nullptr
;
bool
use_hipblaslt
=
(
std
::
getenv
(
"NVTE_USE_HIPBLASLT"
)
!=
nullptr
)
||
nvte_use_hipblaslt
;
bool
use_rocblas
=
(
std
::
getenv
(
"NVTE_USE_ROCBLAS"
)
!=
nullptr
)
||
nvte_use_rocblas
;
#if !defined(USE_HIPBLASLT) && !defined(USE_ROCBLAS)
#error GEMM backend is not specified
...
...
@@ -1813,12 +1827,18 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
if
(
use_hipblaslt
&&
use_rocblas
)
{
use_rocblas
=
false
;
use_hipblaslt
=
true
;
std
::
cout
<<
"[NOTICE] Two GEMM backend are enabled, hipBLASLt will be used
\n
"
;
}
else
if
(
!
use_hipblaslt
&&
!
use_rocblas
)
{
use_rocblas
=
false
;
use_hipblaslt
=
true
;
std
::
cout
<<
"[NOTICE] Two GEMM backend are disabled, hipBLASLt will be used
\n
"
;
}
#endif
#ifdef USE_HIPBLASLT
if
(
use_hipblaslt
||
!
use_rocblas
)
if
(
use_hipblaslt
)
{
hipblaslt_gemm
(
inputA
,
inputB
,
outputD
,
inputBias
,
outputPreGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
...
...
@@ -1833,6 +1853,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#endif
#ifdef USE_ROCBLAS
if
(
use_rocblas
)
{
rocblas_gemm
(
inputA
,
inputB
,
outputD
,
inputBias
,
outputPreGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
...
...
transformer_engine/common/include/transformer_engine/gemm.h
View file @
8b27a2b7
...
...
@@ -42,7 +42,7 @@ extern "C" {
void
nvte_cublas_gemm
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
);
int
math_sm_count
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
=
0
,
bool
nvte_use_rocblas
=
0
);
/*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters.
*
...
...
@@ -77,7 +77,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
const
NVTETensor
counter
,
cudaStream_t
stream
);
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
=
0
,
bool
nvte_use_rocblas
=
0
);
/*! \brief Compute multiple pairs of matrix multiplication, potentially fused with other operations,
* on multiple streams.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment