Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
8b27a2b7
"git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "3f5b47549567d13db76470073c8f0467c23d4fca"
Commit
8b27a2b7
authored
Apr 23, 2025
by
yuguo
Browse files
[DCU] surpport rocm gemm rocblas
parent
73f3ac47
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
150 additions
and
180 deletions
+150
-180
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+62
-113
transformer_engine/common/gemm/rocm_gemm.cu
transformer_engine/common/gemm/rocm_gemm.cu
+86
-65
transformer_engine/common/include/transformer_engine/gemm.h
transformer_engine/common/include/transformer_engine/gemm.h
+2
-2
No files found.
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
8b27a2b7
...
@@ -484,7 +484,7 @@ static void init_streams_and_events_batchgemm() {
...
@@ -484,7 +484,7 @@ static void init_streams_and_events_batchgemm() {
void
nvte_cublas_gemm
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
const
NVTETensor
bias
,
void
nvte_cublas_gemm
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
)
{
int
math_sm_count
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
,
bool
nvte_use_rocblas
)
{
NVTE_API_CALL
(
nvte_cublas_gemm
);
NVTE_API_CALL
(
nvte_cublas_gemm
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
const
Tensor
*
inputA
=
reinterpret_cast
<
const
Tensor
*>
(
A
);
const
Tensor
*
inputA
=
reinterpret_cast
<
const
Tensor
*>
(
A
);
...
@@ -521,15 +521,11 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
...
@@ -521,15 +521,11 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
#ifdef USE_HIPBLASLT
const
char
*
NVTE_FORCE_ROCM_GEMM
=
std
::
getenv
(
"NVTE_FORCE_ROCM_GEMM"
);
const
char
*
NVTE_BLASLT_BLAS
=
std
::
getenv
(
"NVTE_FORCE_BLASLT"
);
const
bool
use_fp8
=
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
const
bool
use_fp8
=
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
is_fp8_dtype
(
inputB
->
data
.
dtype
);
is_fp8_dtype
(
inputB
->
data
.
dtype
);
if
((
biasTensor
->
data
.
dptr
!=
nullptr
)
||
(
outputGelu
->
data
.
dptr
!=
nullptr
)
||
(
use_fp8
)
||
(
NVTE_
BLASLT_BLAS
!=
nullptr
&&
NVTE_
BLASLT_BLAS
[
0
]
==
'1'
)){
if
((
biasTensor
->
data
.
dptr
!=
nullptr
)
||
(
outputGelu
->
data
.
dptr
!=
nullptr
)
||
(
use_fp8
)
||
(
NVTE_
FORCE_ROCM_GEMM
!=
nullptr
&&
NVTE_
FORCE_ROCM_GEMM
[
0
]
==
'1'
)
||
(
nvte_use_hipblaslt
)
||
(
nvte_use_rocblas
)){
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
#else
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
#endif //USE_HIPBLASLT
#else
#else
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
...
@@ -542,12 +538,14 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
...
@@ -542,12 +538,14 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
#endif //__HIP_PLATFORM_AMD__
#endif //__HIP_PLATFORM_AMD__
grad
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
#ifdef __HIP_PLATFORM_AMD__
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
,
nvte_use_hipblaslt
,
nvte_use_rocblas
);
#else
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
);
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
);
#endif
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
#ifdef USE_HIPBLASLT
}
else
{
}
else
{
hipblas_gemm
(
inputA
,
hipblas_gemm
(
inputA
,
inputB
,
inputB
,
outputD
,
outputD
,
...
@@ -567,7 +565,6 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
...
@@ -567,7 +565,6 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
nullptr
,
nullptr
,
stream
);
stream
);
}
}
#endif //USE_HIPBLASLT
#endif //__HIP_PLATFORM_AMD__
#endif //__HIP_PLATFORM_AMD__
}
}
...
@@ -577,7 +574,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
...
@@ -577,7 +574,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
const
NVTETensor
counter
,
int
n_split
,
bool
gemm_producer
,
const
NVTETensor
counter
,
cudaStream_t
stream
)
{
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
,
bool
nvte_use_rocblas
)
{
NVTE_API_CALL
(
nvte_cublas_atomic_gemm
);
NVTE_API_CALL
(
nvte_cublas_atomic_gemm
);
#ifndef __HIP_PLATFORM_AMD__
#ifndef __HIP_PLATFORM_AMD__
...
@@ -622,15 +619,11 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
...
@@ -622,15 +619,11 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
#ifdef USE_HIPBLASLT
const
char
*
NVTE_FORCE_ROCM_GEMM
=
std
::
getenv
(
"NVTE_FORCE_ROCM_GEMM"
);
const
char
*
NVTE_BLASLT_BLAS
=
std
::
getenv
(
"NVTE_FORCE_BLASLT"
);
const
bool
use_fp8
=
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
const
bool
use_fp8
=
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
is_fp8_dtype
(
inputB
->
data
.
dtype
);
is_fp8_dtype
(
inputB
->
data
.
dtype
);
if
((
biasTensor
->
data
.
dptr
!=
nullptr
)
||
(
outputGelu
->
data
.
dptr
!=
nullptr
)
||
(
use_fp8
)
||
(
NVTE_
BLASLT_BLAS
!=
nullptr
&&
NVTE_
BLASLT_BLAS
[
0
]
==
'1'
)){
if
((
biasTensor
->
data
.
dptr
!=
nullptr
)
||
(
outputGelu
->
data
.
dptr
!=
nullptr
)
||
(
use_fp8
)
||
(
NVTE_
FORCE_ROCM_GEMM
!=
nullptr
&&
NVTE_
FORCE_ROCM_GEMM
[
0
]
==
'1'
)
||
(
nvte_use_hipblaslt
)
||
(
nvte_use_rocblas
)){
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
#else
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
#endif //USE_HIPBLASLT
#else
#else
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
...
@@ -643,14 +636,15 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
...
@@ -643,14 +636,15 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
#endif //__HIP_PLATFORM_AMD__
#endif //__HIP_PLATFORM_AMD__
grad
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
#ifdef __HIP_PLATFORM_AMD__
math_sm_count
,
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
,
nvte_use_hipblaslt
,
nvte_use_rocblas
);
#else
math_sm_count
,
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
);
math_sm_count
,
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
);
#endif
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
#ifdef USE_HIPBLASLT
}
else
{
}
hipblas_gemm
(
inputA
,
else
{
hipblas_gemm
(
inputA
,
inputB
,
inputB
,
outputD
,
outputD
,
biasTensor
,
biasTensor
,
...
@@ -669,55 +663,11 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
...
@@ -669,55 +663,11 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
inputCounter
,
inputCounter
,
stream
);
stream
);
}
}
#endif //USE_HIPBLASLT
#endif //__HIP_PLATFORM_AMD__
#endif //__HIP_PLATFORM_AMD__
}
}
void
nvte_cublaslt_gemm
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_cublaslt_gemm
);
using
namespace
transformer_engine
;
const
Tensor
*
inputA
=
reinterpret_cast
<
const
Tensor
*>
(
A
);
const
Tensor
*
inputB
=
reinterpret_cast
<
const
Tensor
*>
(
B
);
Tensor
*
outputD
=
reinterpret_cast
<
Tensor
*>
(
D
);
const
Tensor
*
biasTensor
=
reinterpret_cast
<
const
Tensor
*>
(
bias
);
Tensor
*
outputGelu
=
reinterpret_cast
<
Tensor
*>
(
pre_gelu_out
);
Tensor
*
wspace
=
reinterpret_cast
<
Tensor
*>
(
workspace
);
const
int
m
=
transa
?
inputA
->
data
.
shape
[
0
]
:
inputA
->
data
.
shape
[
1
];
const
int
k
=
transa
?
inputA
->
data
.
shape
[
1
]
:
inputA
->
data
.
shape
[
0
];
const
int
n
=
transb
?
inputB
->
data
.
shape
[
1
]
:
inputB
->
data
.
shape
[
0
];
int
lda
,
ldb
,
ldd
;
if
(
transa
&&
!
transb
)
{
// TN
lda
=
k
;
ldb
=
k
;
ldd
=
m
;
}
else
if
(
!
transa
&&
!
transb
)
{
// NN
lda
=
m
;
ldb
=
k
;
ldd
=
m
;
}
else
if
(
!
transa
&&
transb
)
{
// NT
lda
=
m
;
ldb
=
n
;
ldd
=
m
;
}
else
{
// TT
NVTE_ERROR
(
"TT layout not allowed."
);
}
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
#ifdef __HIP_PLATFORM_AMD__
transa
,
transb
,
#else
(
transa
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
(
transb
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
#endif //__HIP_PLATFORM_AMD__
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
);
}
void
nvte_multi_stream_cublas_gemm
(
const
NVTETensor
*
A
,
const
NVTETensor
*
B
,
NVTETensor
*
D
,
void
nvte_multi_stream_cublas_gemm
(
const
NVTETensor
*
A
,
const
NVTETensor
*
B
,
NVTETensor
*
D
,
const
NVTETensor
*
bias
,
NVTETensor
*
pre_gelu_out
,
const
NVTETensor
*
bias
,
NVTETensor
*
pre_gelu_out
,
...
@@ -736,20 +686,19 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
...
@@ -736,20 +686,19 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
for
(
int
s
=
0
;
s
<
num_stream_used
;
s
++
)
{
for
(
int
s
=
0
;
s
<
num_stream_used
;
s
++
)
{
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
compute_streams
[
s
],
cublas_event
[
0
]));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
compute_streams
[
s
],
cublas_event
[
0
]));
}
}
const
char
*
NVTE_BLAS_MULSTREAM
=
std
::
getenv
(
"NVTE_FORCE_BLAS_MULSTREAM"
);
const
char
*
NVTE_HIPBLAS_MULSTREAM
=
std
::
getenv
(
"NVTE_FORCE_HIPBLAS_MULSTREAM"
);
const
char
*
NVTE_BLASLT_BLAS
=
std
::
getenv
(
"NVTE_FORCE_BLASLT"
);
const
char
*
NVTE_FORCE_ROCM_GEMM
=
std
::
getenv
(
"NVTE_FORCE_ROCM_GEMM"
);
bool
NVTE_FORCE_BLASLT_MULSTREAM
;
bool
NVTE_FORCE_HIPBLAS_MULSTREAM
;
if
(
NVTE_HIPBLAS_MULSTREAM
!=
nullptr
&&
NVTE_HIPBLAS_MULSTREAM
[
0
]
==
'1'
){
if
(
NVTE_BLAS_MULSTREAM
==
nullptr
){
NVTE_FORCE_HIPBLAS_MULSTREAM
=
true
;
NVTE_FORCE_BLASLT_MULSTREAM
=
true
;
if
((
NVTE_FORCE_ROCM_GEMM
!=
nullptr
&&
NVTE_FORCE_ROCM_GEMM
[
0
]
==
'1'
)
&&
(
NVTE_HIPBLAS_MULSTREAM
!=
nullptr
&&
NVTE_HIPBLAS_MULSTREAM
[
0
]
==
'1'
))
}
else
if
((
NVTE_BLASLT_BLAS
!=
nullptr
&&
NVTE_BLASLT_BLAS
[
0
]
==
'1'
)
&&
(
NVTE_BLAS_MULSTREAM
!=
nullptr
&&
NVTE_BLAS_MULSTREAM
[
0
]
==
'1'
)){
NVTE_ERROR
(
"NVTE_FORCE_HIPBLAS_MULSTREAM and NVTE_FORCE_ROCM_GEMM can't be set at the same time."
);
NVTE_ERROR
(
"NVTE_FORCE_BLAS_MULSTREAM and NVTE_FORCE_BLASLT can't be set at the same time."
);
}
else
{
}
else
{
NVTE_FORCE_BLAS
LT
_MULSTREAM
=
false
;
NVTE_FORCE_
HIP
BLAS_MULSTREAM
=
false
;
}
}
if
(
NVTE_FORCE_BLAS
LT
_MULSTREAM
){
if
(
NVTE_FORCE_
HIP
BLAS_MULSTREAM
){
for
(
int
i
=
0
;
i
<
num_gemms
;
i
++
)
{
for
(
int
i
=
0
;
i
<
num_gemms
;
i
++
)
{
nvte_cublas
lt
_gemm
(
A
[
i
],
B
[
i
],
D
[
i
],
bias
[
i
],
pre_gelu_out
[
i
],
transa
,
transb
,
grad
,
nvte_cublas_gemm
(
A
[
i
],
B
[
i
],
D
[
i
],
bias
[
i
],
pre_gelu_out
[
i
],
transa
,
transb
,
grad
,
workspace
[
i
%
num_streams
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
workspace
[
i
%
num_streams
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
compute_streams
[
i
%
num_streams
]);
compute_streams
[
i
%
num_streams
]);
}
}
...
@@ -757,7 +706,7 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
...
@@ -757,7 +706,7 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
for
(
int
i
=
0
;
i
<
num_gemms
;
i
++
)
{
for
(
int
i
=
0
;
i
<
num_gemms
;
i
++
)
{
nvte_cublas_gemm
(
A
[
i
],
B
[
i
],
D
[
i
],
bias
[
i
],
pre_gelu_out
[
i
],
transa
,
transb
,
grad
,
nvte_cublas_gemm
(
A
[
i
],
B
[
i
],
D
[
i
],
bias
[
i
],
pre_gelu_out
[
i
],
transa
,
transb
,
grad
,
workspace
[
i
%
num_streams
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
workspace
[
i
%
num_streams
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
compute_streams
[
i
%
num_streams
]);
compute_streams
[
i
%
num_streams
]
,
1
,
0
);
}
}
}
}
...
...
transformer_engine/common/gemm/rocm_gemm.cu
View file @
8b27a2b7
...
@@ -36,26 +36,30 @@ namespace {
...
@@ -36,26 +36,30 @@ namespace {
#ifdef USE_HIPBLASLT
#ifdef USE_HIPBLASLT
static
hipDataType
get_hipblaslt_dtype
(
const
transformer_engine
::
DType
t
)
{
#if HIP_VERSION >= 60000000
typedef
hipDataType
hipblasltDatatype_t
;
typedef
hipblasComputeType_t
hipblasLtComputeType_t
;
#define HIPBLASLT_R_16F HIP_R_16F
#define HIPBLASLT_R_32F HIP_R_32F
#define HIPBLASLT_R_16B HIP_R_16BF
#define HIPBLASLT_R_8F_E4M3 HIP_R_8F_E4M3_FNUZ
#define HIPBLASLT_R_8F_E5M2 HIP_R_8F_E5M2_FNUZ
#define HIPBLASLT_COMPUTE_F32 HIPBLAS_COMPUTE_32F
#endif // #if HIP_VERSION >= 60000000
hipblasltDatatype_t
get_hipblaslt_dtype
(
const
transformer_engine
::
DType
t
)
{
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
switch
(
t
)
{
switch
(
t
)
{
case
DType
::
kFloat16
:
case
DType
::
kFloat16
:
return
HIP_R_16F
;
return
HIP
BLASLT
_R_16F
;
case
DType
::
kFloat32
:
case
DType
::
kFloat32
:
return
HIP_R_32F
;
return
HIP
BLASLT
_R_32F
;
case
DType
::
kBFloat16
:
case
DType
::
kBFloat16
:
return
HIP_R_16BF
;
return
HIPBLASLT_R_16B
;
#if HIP_VERSION >= 60300000
case
DType
::
kFloat8E4M3
:
case
DType
::
kFloat8E4M3
:
return
te_fp8_fnuz
()
?
HIP_R_8F_E4M3_FNUZ
:
HIP
_R_8F_E4M3
;
return
HIPBLASLT
_R_8F_E4M3
;
case
DType
::
kFloat8E5M2
:
case
DType
::
kFloat8E5M2
:
return
te_fp8_fnuz
()
?
HIP_R_8F_E5M2_FNUZ
:
HIP_R_8F_E5M2
;
return
HIPBLASLT_R_8F_E5M2
;
#else
case
DType
::
kFloat8E4M3
:
return
HIP_R_8F_E4M3_FNUZ
;
case
DType
::
kFloat8E5M2
:
return
HIP_R_8F_E5M2_FNUZ
;
#endif
default:
default:
NVTE_ERROR
(
"Invalid type"
);
NVTE_ERROR
(
"Invalid type"
);
}
}
...
@@ -363,7 +367,11 @@ void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool
...
@@ -363,7 +367,11 @@ void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool
if
(
!
stream_order_alloc
){
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipMemset
(
out
,
0
,
n
*
sizeof
(
float
))
);
NVTE_CHECK_CUDA
(
hipMemset
(
out
,
0
,
n
*
sizeof
(
float
))
);
}
else
{
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipMemsetAsync
(
out
,
0
,
n
*
sizeof
(
float
),
stream
)
);
NVTE_CHECK_CUDA
(
hipMemsetAsync
(
out
,
0
,
n
*
sizeof
(
float
),
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
}
hipLaunchKernelGGL
((
bias_gradient_kernel
<
Tin
,
THREADS_PER_BLOCK
>
),
dim3
(
grid
),
dim3
(
block
),
0
,
stream
,
in
,
out
,
m
,
n
);
hipLaunchKernelGGL
((
bias_gradient_kernel
<
Tin
,
THREADS_PER_BLOCK
>
),
dim3
(
grid
),
dim3
(
block
),
0
,
stream
,
in
,
out
,
m
,
n
);
}
}
...
@@ -567,11 +575,11 @@ public:
...
@@ -567,11 +575,11 @@ public:
const
std
::
string_view
&
getName
(
const
T
&
val
)
{
const
std
::
string_view
&
getName
(
const
T
&
val
)
{
return
map
.
at
(
val
);
return
map
.
at
(
val
);
}
}
T
getValue
(
const
std
::
string
&
name
,
const
char
*
label
=
""
,
std
::
function
<
bool
(
const
T
&
)
>
filter
=
nullptr
)
T
getValue
(
const
std
::
string
&
name
,
const
char
*
label
=
""
)
{
{
for
(
auto
iter
=
map
.
begin
();
iter
!=
map
.
end
();
++
iter
)
for
(
auto
iter
=
map
.
begin
();
iter
!=
map
.
end
();
++
iter
)
{
{
if
(
(
name
==
iter
->
second
)
&&
(
!
filter
||
filter
(
iter
->
first
)))
return
iter
->
first
;
if
(
name
==
iter
->
second
)
return
iter
->
first
;
}
}
NVTE_ERROR
(
"Invalid "
,
label
,
" name: "
,
name
);
NVTE_ERROR
(
"Invalid "
,
label
,
" name: "
,
name
);
}
}
...
@@ -579,18 +587,14 @@ protected:
...
@@ -579,18 +587,14 @@ protected:
const
std
::
unordered_map
<
T
,
std
::
string_view
>
&
map
;
const
std
::
unordered_map
<
T
,
std
::
string_view
>
&
map
;
};
};
static
std
::
unordered_map
<
hipDataType
,
std
::
string_view
>
type_name_map
=
{
static
std
::
unordered_map
<
hipblasltDatatype_t
,
std
::
string_view
>
type_name_map
=
{
{
HIP_R_32F
,
"float32"
},
{
HIPBLASLT_R_32F
,
"float32"
},
{
HIP_R_16F
,
"float16"
},
{
HIPBLASLT_R_16F
,
"float16"
},
{
HIP_R_16BF
,
"bfloat16"
},
{
HIPBLASLT_R_16B
,
"bfloat16"
},
{
HIP_R_8F_E4M3_FNUZ
,
"float8e4m3"
},
{
HIPBLASLT_R_8F_E4M3
,
"float8e4m3"
},
{
HIP_R_8F_E5M2_FNUZ
,
"float8e5m2"
},
{
HIPBLASLT_R_8F_E5M2
,
"float8e5m2"
},
#if HIP_VERSION >= 60300000
{
HIP_R_8F_E4M3
,
"float8e4m3"
},
{
HIP_R_8F_E5M2
,
"float8e5m2"
},
#endif
};
};
static
NameMapper
<
hipData
T
ype
>
typeNameMapper
(
type_name_map
);
static
NameMapper
<
hip
blaslt
Data
t
ype
_t
>
typeNameMapper
(
type_name_map
);
static
std
::
unordered_map
<
hipblasOperation_t
,
std
::
string_view
>
trans_name_map
=
{
static
std
::
unordered_map
<
hipblasOperation_t
,
std
::
string_view
>
trans_name_map
=
{
{
HIPBLAS_OP_N
,
"N"
},
{
HIPBLAS_OP_N
,
"N"
},
...
@@ -609,24 +613,24 @@ static std::unordered_map<hipblasLtEpilogue_t, std::string_view> epi_name_map =
...
@@ -609,24 +613,24 @@ static std::unordered_map<hipblasLtEpilogue_t, std::string_view> epi_name_map =
};
};
static
NameMapper
<
hipblasLtEpilogue_t
>
epilogueNameMapper
(
epi_name_map
);
static
NameMapper
<
hipblasLtEpilogue_t
>
epilogueNameMapper
(
epi_name_map
);
static
std
::
unordered_map
<
hipblasComputeType_t
,
std
::
string_view
>
comp_name_map
=
{
static
std
::
unordered_map
<
hipblas
Lt
ComputeType_t
,
std
::
string_view
>
comp_name_map
=
{
{
HIPBLAS_COMPUTE_32
F
,
"f32"
}
{
HIPBLAS
LT
_COMPUTE_
F
32
,
"f32"
}
};
};
static
NameMapper
<
hipblasComputeType_t
>
computeNameMapper
(
comp_name_map
);
static
NameMapper
<
hipblas
Lt
ComputeType_t
>
computeNameMapper
(
comp_name_map
);
static
class
GemmAlgoCache
{
static
class
GemmAlgoCache
{
public:
public:
struct
Key
{
struct
Key
{
int
deviceCap
;
int
deviceCap
;
hipData
T
ype
a_type
,
b_type
,
d_type
,
bias_type
;
hip
blaslt
Data
t
ype
_t
a_type
,
b_type
,
d_type
,
bias_type
;
int
m
,
n
,
k
;
int
m
,
n
,
k
;
int
lda
,
ldb
,
ldd
;
int
lda
,
ldb
,
ldd
;
hipblasOperation_t
transa
,
transb
;
hipblasOperation_t
transa
,
transb
;
hipblasLtEpilogue_t
epilogue
;
hipblasLtEpilogue_t
epilogue
;
Key
(
int
deviceCap_
,
Key
(
int
deviceCap_
,
hipData
T
ype
a_type_
,
hipData
T
ype
b_type_
,
hip
blaslt
Data
t
ype
_t
a_type_
,
hip
blaslt
Data
t
ype
_t
b_type_
,
hipData
T
ype
d_type_
,
hipData
T
ype
bias_type_
,
hip
blaslt
Data
t
ype
_t
d_type_
,
hip
blaslt
Data
t
ype
_t
bias_type_
,
int
m_
,
int
n_
,
int
k_
,
int
lda_
,
int
ldb_
,
int
ldd_
,
int
m_
,
int
n_
,
int
k_
,
int
lda_
,
int
ldb_
,
int
ldd_
,
hipblasOperation_t
transa_
,
hipblasOperation_t
transb_
,
hipblasOperation_t
transa_
,
hipblasOperation_t
transb_
,
hipblasLtEpilogue_t
epilogue_
)
:
hipblasLtEpilogue_t
epilogue_
)
:
...
@@ -861,31 +865,17 @@ protected:
...
@@ -861,31 +865,17 @@ protected:
continue
;
continue
;
}
}
#if HIP_VERSION >= 60300000
cfg
.
a_type
=
typeNameMapper
.
getValue
(
type_a
,
"type_a"
);
auto
fp8_filter
=
te_fp8_fnuz
()
cfg
.
b_type
=
typeNameMapper
.
getValue
(
type_b
,
"type_b"
);
?
[](
const
hipDataType
&
val
)
cfg
.
d_type
=
typeNameMapper
.
getValue
(
type_d
,
"type_d"
);
{
return
(
val
!=
HIP_R_8F_E4M3
&&
val
!=
HIP_R_8F_E5M2
);
}
cfg
.
bias_type
=
(
bias_type
==
"-"
)
?
(
hipblasltDatatype_t
)
-
1
:
typeNameMapper
.
getValue
(
bias_type
,
"bias_type"
);
:
[](
const
hipDataType
&
val
)
{
return
(
val
!=
HIP_R_8F_E4M3_FNUZ
&&
val
!=
HIP_R_8F_E5M2_FNUZ
);
};
#else
auto
fp8_filter
=
nullptr
;
#endif
cfg
.
a_type
=
typeNameMapper
.
getValue
(
type_a
,
"type_a"
,
fp8_filter
);
cfg
.
b_type
=
typeNameMapper
.
getValue
(
type_b
,
"type_b"
,
fp8_filter
);
cfg
.
d_type
=
typeNameMapper
.
getValue
(
type_d
,
"type_d"
,
fp8_filter
);
cfg
.
bias_type
=
(
bias_type
==
"-"
)
?
(
hipDataType
)
-
1
:
typeNameMapper
.
getValue
(
bias_type
,
"bias_type"
,
fp8_filter
);
cfg
.
transa
=
transposeNameMapper
.
getValue
(
trans_a
,
"trans_a"
);
cfg
.
transa
=
transposeNameMapper
.
getValue
(
trans_a
,
"trans_a"
);
cfg
.
transb
=
transposeNameMapper
.
getValue
(
trans_b
,
"trans_b"
);
cfg
.
transb
=
transposeNameMapper
.
getValue
(
trans_b
,
"trans_b"
);
cfg
.
epilogue
=
epilogueNameMapper
.
getValue
(
epi
,
"epi"
);
cfg
.
epilogue
=
epilogueNameMapper
.
getValue
(
epi
,
"epi"
);
//Check and filter out compute and scale types
//Check and filter out compute and scale types
if
(
computeNameMapper
.
getValue
(
comp
,
"comp"
)
!=
HIPBLAS_COMPUTE_32F
||
if
(
computeNameMapper
.
getValue
(
comp
,
"comp"
)
!=
HIPBLASLT_COMPUTE_F32
||
typeNameMapper
.
getValue
(
scale
,
"scale"
)
!=
HIPBLASLT_R_32F
)
typeNameMapper
.
getValue
(
scale
,
"scale"
)
!=
HIP_R_32F
)
{
{
continue
;
continue
;
}
}
...
@@ -968,9 +958,9 @@ protected:
...
@@ -968,9 +958,9 @@ protected:
csv
<<
cfg
.
deviceCap
<<
cfg
.
m
<<
cfg
.
n
<<
cfg
.
k
csv
<<
cfg
.
deviceCap
<<
cfg
.
m
<<
cfg
.
n
<<
cfg
.
k
<<
transposeNameMapper
.
getName
(
cfg
.
transa
)
<<
transposeNameMapper
.
getName
(
cfg
.
transb
)
<<
transposeNameMapper
.
getName
(
cfg
.
transa
)
<<
transposeNameMapper
.
getName
(
cfg
.
transb
)
<<
typeNameMapper
.
getName
(
cfg
.
a_type
)
<<
typeNameMapper
.
getName
(
cfg
.
b_type
)
<<
typeNameMapper
.
getName
(
cfg
.
d_type
)
<<
typeNameMapper
.
getName
(
cfg
.
a_type
)
<<
typeNameMapper
.
getName
(
cfg
.
b_type
)
<<
typeNameMapper
.
getName
(
cfg
.
d_type
)
<<
((
cfg
.
bias_type
==
(
hipData
T
ype
)
-
1
)
?
"-"
:
typeNameMapper
.
getName
(
cfg
.
bias_type
))
<<
((
cfg
.
bias_type
==
(
hip
blaslt
Data
t
ype
_t
)
-
1
)
?
"-"
:
typeNameMapper
.
getName
(
cfg
.
bias_type
))
<<
cfg
.
lda
<<
cfg
.
ldb
<<
cfg
.
ldd
<<
epilogueNameMapper
.
getName
(
cfg
.
epilogue
)
<<
cfg
.
lda
<<
cfg
.
ldb
<<
cfg
.
ldd
<<
epilogueNameMapper
.
getName
(
cfg
.
epilogue
)
<<
computeNameMapper
.
getName
(
HIPBLAS_COMPUTE_32
F
)
<<
typeNameMapper
.
getName
(
HIP_R_32F
)
<<
computeNameMapper
.
getName
(
HIPBLAS
LT
_COMPUTE_
F
32
)
<<
typeNameMapper
.
getName
(
HIP
BLASLT
_R_32F
)
<<
algo
.
ws_size_min
<<
algo
.
ws_size_max
<<
algo
.
algoId
<<
algo
.
index
<<
csv_helper
::
end
()
<<
"
\n
"
;
<<
algo
.
ws_size_min
<<
algo
.
ws_size_max
<<
algo
.
algoId
<<
algo
.
index
<<
csv_helper
::
end
()
<<
"
\n
"
;
}
}
...
@@ -1036,10 +1026,10 @@ void hipblaslt_gemm(const Tensor *inputA,
...
@@ -1036,10 +1026,10 @@ void hipblaslt_gemm(const Tensor *inputA,
const
bool
gelu
=
pre_gelu_out
!=
nullptr
;
const
bool
gelu
=
pre_gelu_out
!=
nullptr
;
const
bool
use_fp8
=
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
const
bool
use_fp8
=
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
is_fp8_dtype
(
inputB
->
data
.
dtype
);
is_fp8_dtype
(
inputB
->
data
.
dtype
);
const
hipData
T
ype
A_type
=
get_hipblaslt_dtype
(
inputA
->
data
.
dtype
);
const
hip
blaslt
Data
t
ype
_t
A_type
=
get_hipblaslt_dtype
(
inputA
->
data
.
dtype
);
const
hipData
T
ype
B_type
=
get_hipblaslt_dtype
(
inputB
->
data
.
dtype
);
const
hip
blaslt
Data
t
ype
_t
B_type
=
get_hipblaslt_dtype
(
inputB
->
data
.
dtype
);
const
hipData
T
ype
D_type
=
get_hipblaslt_dtype
(
outputD
->
data
.
dtype
);
const
hip
blaslt
Data
t
ype
_t
D_type
=
get_hipblaslt_dtype
(
outputD
->
data
.
dtype
);
const
hipData
T
ype
bias_type
=
get_hipblaslt_dtype
(
inputBias
->
data
.
dtype
);
const
hip
blaslt
Data
t
ype
_t
bias_type
=
get_hipblaslt_dtype
(
inputBias
->
data
.
dtype
);
NVTE_CHECK
(
!
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
A_scale_inverse
!=
nullptr
,
NVTE_CHECK
(
!
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
A_scale_inverse
!=
nullptr
,
"FP8 input to GEMM requires inverse of scale!"
);
"FP8 input to GEMM requires inverse of scale!"
);
...
@@ -1073,7 +1063,7 @@ void hipblaslt_gemm(const Tensor *inputA,
...
@@ -1073,7 +1063,7 @@ void hipblaslt_gemm(const Tensor *inputA,
int64_t
ld_gelumat
=
(
int64_t
)
ldd
;
int64_t
ld_gelumat
=
(
int64_t
)
ldd
;
// default to tf32 except for e5m2 inputs where the config is not supported
// default to tf32 except for e5m2 inputs where the config is not supported
hipblasComputeType_t
gemm_compute_type
=
HIPBLAS_COMPUTE_32
F
;
hipblas
Lt
ComputeType_t
gemm_compute_type
=
HIPBLAS
LT
_COMPUTE_
F
32
;
// Create matrix descriptors. Not setting any extra attributes.
// Create matrix descriptors. Not setting any extra attributes.
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutCreate
(
&
Adesc
,
A_type
,
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutCreate
(
&
Adesc
,
A_type
,
...
@@ -1086,7 +1076,7 @@ void hipblaslt_gemm(const Tensor *inputA,
...
@@ -1086,7 +1076,7 @@ void hipblaslt_gemm(const Tensor *inputA,
ldb
));
ldb
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutCreate
(
&
Ddesc
,
D_type
,
m
,
n
,
ldd
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutCreate
(
&
Ddesc
,
D_type
,
m
,
n
,
ldd
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescCreate
(
&
operationDesc
,
gemm_compute_type
,
HIP_R_32F
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescCreate
(
&
operationDesc
,
gemm_compute_type
,
HIP
BLASLT
_R_32F
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_TRANSA
,
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_TRANSA
,
&
transa
,
sizeof
(
transa
)));
&
transa
,
sizeof
(
transa
)));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_TRANSB
,
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_TRANSB
,
...
@@ -1163,7 +1153,7 @@ void hipblaslt_gemm(const Tensor *inputA,
...
@@ -1163,7 +1153,7 @@ void hipblaslt_gemm(const Tensor *inputA,
&
epilogue
,
sizeof
(
epilogue
)));
&
epilogue
,
sizeof
(
epilogue
)));
GemmAlgoCache
::
Key
gemm_cfg
(
algoCache
.
device_cap
(
device_id
),
A_type
,
B_type
,
D_type
,
GemmAlgoCache
::
Key
gemm_cfg
(
algoCache
.
device_cap
(
device_id
),
A_type
,
B_type
,
D_type
,
use_fp8
?
bias_type
:
(
hipData
T
ype
)
-
1
,
use_fp8
?
bias_type
:
(
hip
blaslt
Data
t
ype
_t
)
-
1
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
transa
,
transb
,
epilogue
);
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
transa
,
transb
,
epilogue
);
GemmAlgoCache
::
Algo
cached_algo
;
GemmAlgoCache
::
Algo
cached_algo
;
if
(
algoCache
.
find
(
gemm_cfg
,
workspaceSize
,
cached_algo
)
==
0
||
!
cached_algo
.
algo
.
has_value
())
if
(
algoCache
.
find
(
gemm_cfg
,
workspaceSize
,
cached_algo
)
==
0
||
!
cached_algo
.
algo
.
has_value
())
...
@@ -1478,7 +1468,11 @@ void rocblas_gemm(const Tensor *inputA,
...
@@ -1478,7 +1468,11 @@ void rocblas_gemm(const Tensor *inputA,
if
(
!
stream_order_alloc
){
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipMalloc
(
&
D_temp
,
sizeof
(
float
)
*
m
*
n
)
);
NVTE_CHECK_CUDA
(
hipMalloc
(
&
D_temp
,
sizeof
(
float
)
*
m
*
n
)
);
}
else
{
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipMallocAsync
(
&
D_temp
,
sizeof
(
float
)
*
m
*
n
,
stream
)
);
NVTE_CHECK_CUDA
(
hipMallocAsync
(
&
D_temp
,
sizeof
(
float
)
*
m
*
n
,
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
}
}
else
{
}
else
{
D_temp
=
D
;
D_temp
=
D
;
...
@@ -1571,7 +1565,11 @@ void rocblas_gemm(const Tensor *inputA,
...
@@ -1571,7 +1565,11 @@ void rocblas_gemm(const Tensor *inputA,
if
(
!
stream_order_alloc
){
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipMalloc
(
&
bias_tmp
,
sizeof
(
float
)
*
input_dim
)
);
// The bias gradient is for the first linear layer
NVTE_CHECK_CUDA
(
hipMalloc
(
&
bias_tmp
,
sizeof
(
float
)
*
input_dim
)
);
// The bias gradient is for the first linear layer
}
else
{
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipMallocAsync
(
&
bias_tmp
,
sizeof
(
float
)
*
input_dim
,
stream
)
);
NVTE_CHECK_CUDA
(
hipMallocAsync
(
&
bias_tmp
,
sizeof
(
float
)
*
input_dim
,
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
}
}
else
{
}
else
{
bias_tmp
=
bias_ptr
;
bias_tmp
=
bias_ptr
;
...
@@ -1597,7 +1595,11 @@ void rocblas_gemm(const Tensor *inputA,
...
@@ -1597,7 +1595,11 @@ void rocblas_gemm(const Tensor *inputA,
if
(
!
stream_order_alloc
){
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipFree
(
bias_tmp
)
);
NVTE_CHECK_CUDA
(
hipFree
(
bias_tmp
)
);
}
else
{
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipFreeAsync
(
bias_tmp
,
stream
)
);
NVTE_CHECK_CUDA
(
hipFreeAsync
(
bias_tmp
,
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
}
}
}
...
@@ -1645,7 +1647,11 @@ void rocblas_gemm(const Tensor *inputA,
...
@@ -1645,7 +1647,11 @@ void rocblas_gemm(const Tensor *inputA,
if
(
!
stream_order_alloc
){
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipMalloc
(
&
bias_tmp
,
sizeof
(
float
)
*
output_dim
)
);
NVTE_CHECK_CUDA
(
hipMalloc
(
&
bias_tmp
,
sizeof
(
float
)
*
output_dim
)
);
}
else
{
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipMallocAsync
(
&
bias_tmp
,
sizeof
(
float
)
*
output_dim
,
stream
)
);
NVTE_CHECK_CUDA
(
hipMallocAsync
(
&
bias_tmp
,
sizeof
(
float
)
*
output_dim
,
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
}
}
else
{
}
else
{
bias_tmp
=
bias_ptr
;
bias_tmp
=
bias_ptr
;
...
@@ -1672,7 +1678,11 @@ void rocblas_gemm(const Tensor *inputA,
...
@@ -1672,7 +1678,11 @@ void rocblas_gemm(const Tensor *inputA,
if
(
!
stream_order_alloc
){
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipFree
(
bias_tmp
)
);
NVTE_CHECK_CUDA
(
hipFree
(
bias_tmp
)
);
}
else
{
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipFreeAsync
(
bias_tmp
,
stream
)
);
NVTE_CHECK_CUDA
(
hipFreeAsync
(
bias_tmp
,
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
}
}
}
if
(
D_type
==
rocblas_datatype_f16_r
||
D_type
==
rocblas_datatype_bf16_r
)
{
if
(
D_type
==
rocblas_datatype_f16_r
||
D_type
==
rocblas_datatype_bf16_r
)
{
...
@@ -1773,7 +1783,11 @@ void rocblas_gemm(const Tensor *inputA,
...
@@ -1773,7 +1783,11 @@ void rocblas_gemm(const Tensor *inputA,
if
(
!
stream_order_alloc
){
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipFree
(
D_temp
)
);
NVTE_CHECK_CUDA
(
hipFree
(
D_temp
)
);
}
else
{
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipFreeAsync
(
D_temp
,
stream
)
);
NVTE_CHECK_CUDA
(
hipFreeAsync
(
D_temp
,
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
}
}
}
}
}
...
@@ -1785,15 +1799,15 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
...
@@ -1785,15 +1799,15 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
int
ldb
,
int
ldd
,
bool
transa
,
bool
transb
,
bool
grad
,
int
ldb
,
int
ldd
,
bool
transa
,
bool
transb
,
bool
grad
,
void
*
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
void
*
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
const
Tensor
*
inputCounter
,
hipStream_t
stream
)
const
Tensor
*
inputCounter
,
hipStream_t
stream
,
bool
nvte_use_hipblaslt
=
0
,
bool
nvte_use_rocblas
=
0
)
{
{
/*If no backend is specified with env variable use HIPBLASLT unless it is disabled
/*If no backend is specified with env variable use HIPBLASLT unless it is disabled
If HIPBLASLT backend is enabled and requested, use it despite ROCBLAS status
If HIPBLASLT backend is enabled and requested, use it despite ROCBLAS status
Otherwise use ROCBLAS
Otherwise use ROCBLAS
*/
*/
bool
use_hipblaslt
=
std
::
getenv
(
"NVTE_USE_HIPBLASLT"
)
!=
nullptr
;
bool
use_hipblaslt
=
(
std
::
getenv
(
"NVTE_USE_HIPBLASLT"
)
!=
nullptr
)
||
nvte_use_hipblaslt
;
bool
use_rocblas
=
std
::
getenv
(
"NVTE_USE_ROCBLAS"
)
!=
nullptr
;
bool
use_rocblas
=
(
std
::
getenv
(
"NVTE_USE_ROCBLAS"
)
!=
nullptr
)
||
nvte_use_rocblas
;
#if !defined(USE_HIPBLASLT) && !defined(USE_ROCBLAS)
#if !defined(USE_HIPBLASLT) && !defined(USE_ROCBLAS)
#error GEMM backend is not specified
#error GEMM backend is not specified
...
@@ -1813,12 +1827,18 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
...
@@ -1813,12 +1827,18 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
if
(
use_hipblaslt
&&
use_rocblas
)
if
(
use_hipblaslt
&&
use_rocblas
)
{
{
use_rocblas
=
false
;
use_rocblas
=
false
;
use_hipblaslt
=
true
;
std
::
cout
<<
"[NOTICE] Two GEMM backend are enabled, hipBLASLt will be used
\n
"
;
std
::
cout
<<
"[NOTICE] Two GEMM backend are enabled, hipBLASLt will be used
\n
"
;
}
else
if
(
!
use_hipblaslt
&&
!
use_rocblas
)
{
use_rocblas
=
false
;
use_hipblaslt
=
true
;
std
::
cout
<<
"[NOTICE] Two GEMM backend are disabled, hipBLASLt will be used
\n
"
;
}
}
#endif
#endif
#ifdef USE_HIPBLASLT
#ifdef USE_HIPBLASLT
if
(
use_hipblaslt
||
!
use_rocblas
)
if
(
use_hipblaslt
)
{
{
hipblaslt_gemm
(
inputA
,
inputB
,
outputD
,
inputBias
,
outputPreGelu
,
hipblaslt_gemm
(
inputA
,
inputB
,
outputD
,
inputBias
,
outputPreGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
...
@@ -1833,6 +1853,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
...
@@ -1833,6 +1853,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#endif
#endif
#ifdef USE_ROCBLAS
#ifdef USE_ROCBLAS
if
(
use_rocblas
)
{
{
rocblas_gemm
(
inputA
,
inputB
,
outputD
,
inputBias
,
outputPreGelu
,
rocblas_gemm
(
inputA
,
inputB
,
outputD
,
inputBias
,
outputPreGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
...
...
transformer_engine/common/include/transformer_engine/gemm.h
View file @
8b27a2b7
...
@@ -42,7 +42,7 @@ extern "C" {
...
@@ -42,7 +42,7 @@ extern "C" {
void
nvte_cublas_gemm
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
const
NVTETensor
bias
,
void
nvte_cublas_gemm
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
);
int
math_sm_count
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
=
0
,
bool
nvte_use_rocblas
=
0
);
/*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters.
/*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters.
*
*
...
@@ -77,7 +77,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
...
@@ -77,7 +77,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
const
NVTETensor
counter
,
int
n_split
,
bool
gemm_producer
,
const
NVTETensor
counter
,
cudaStream_t
stream
);
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
=
0
,
bool
nvte_use_rocblas
=
0
);
/*! \brief Compute multiple pairs of matrix multiplication, potentially fused with other operations,
/*! \brief Compute multiple pairs of matrix multiplication, potentially fused with other operations,
* on multiple streams.
* on multiple streams.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment