Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
229be5e8
Commit
229be5e8
authored
May 06, 2025
by
yuguo
Browse files
[DCU] new rocm gemm
parent
388ac735
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
121 additions
and
103 deletions
+121
-103
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
...mer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
+6
-6
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+6
-6
transformer_engine/common/gemm/rocm_gemm.cu
transformer_engine/common/gemm/rocm_gemm.cu
+107
-89
transformer_engine/common/include/transformer_engine/gemm.h
transformer_engine/common/include/transformer_engine/gemm.h
+2
-2
No files found.
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
View file @
229be5e8
...
@@ -451,7 +451,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
...
@@ -451,7 +451,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
0
]);
use_split_accumulator
,
_math_sms
,
_stream_compute
[
0
]
,
1
,
0
,
0
);
for
(
int
i
=
1
;
i
<
_num_splits
;
i
++
)
{
for
(
int
i
=
1
;
i
<
_num_splits
;
i
++
)
{
input_a_chunk
=
get_tensor_chunk
(
A
,
i
*
input_a_chunk_size
,
{
m_chunk
,
k
});
input_a_chunk
=
get_tensor_chunk
(
A
,
i
*
input_a_chunk_size
,
{
m_chunk
,
k
});
...
@@ -462,7 +462,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
...
@@ -462,7 +462,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()],
1
,
0
);
_stream_compute
[
i
%
_stream_compute
.
size
()],
1
,
0
,
i
%
_stream_compute
.
size
()
);
NVTE_CHECK_CUDA
(
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_comm
,
_stream_compute
[(
i
-
1
)
%
_stream_compute
.
size
()]));
cudaEventRecord
(
_start_comm
,
_stream_compute
[(
i
-
1
)
%
_stream_compute
.
size
()]));
...
@@ -510,7 +510,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
...
@@ -510,7 +510,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()],
1
,
0
);
_stream_compute
[
i
%
_stream_compute
.
size
()],
1
,
0
,
i
%
_stream_compute
.
size
()
);
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_comm
,
_stream_compute
[
i
%
_stream_compute
.
size
()]));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_comm
,
_stream_compute
[
i
%
_stream_compute
.
size
()]));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_comm
,
_start_comm
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_comm
,
_start_comm
,
0
));
...
@@ -821,7 +821,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
...
@@ -821,7 +821,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
aux_chunk
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
aux_chunk
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()],
1
,
0
);
_stream_compute
[
i
%
_stream_compute
.
size
()],
1
,
0
,
i
%
_stream_compute
.
size
()
);
if
(
i
<
num_steps
-
1
)
{
if
(
i
<
num_steps
-
1
)
{
// P2P communication
// P2P communication
...
@@ -865,7 +865,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
...
@@ -865,7 +865,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
aux_chunk
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
aux_chunk
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()],
1
,
0
);
_stream_compute
[
i
%
_stream_compute
.
size
()],
1
,
0
,
i
%
_stream_compute
.
size
()
);
if
(
i
<
_tp_size
-
1
)
{
if
(
i
<
_tp_size
-
1
)
{
// P2P communication
// P2P communication
...
@@ -1010,7 +1010,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
...
@@ -1010,7 +1010,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
stream_id
],
1
,
0
);
use_split_accumulator
,
_math_sms
,
_stream_compute
[
stream_id
],
1
,
0
,
stream_id
);
if
(
i
>
0
)
{
if
(
i
>
0
)
{
// P2P communication chunk
// P2P communication chunk
...
...
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
229be5e8
...
@@ -163,7 +163,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
...
@@ -163,7 +163,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
int
ldb
,
int
ldd
,
bool
transa
,
bool
transb
,
bool
grad
,
int
ldb
,
int
ldd
,
bool
transa
,
bool
transb
,
bool
grad
,
void
*
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
void
*
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
const
Tensor
*
inputCounter
,
hipStream_t
stream
);
const
Tensor
*
inputCounter
,
hipStream_t
stream
,
bool
nvte_use_hipblaslt
,
bool
nvte_use_rocblas
,
int
compute_stream_offset
);
#else // Use cublasLt
#else // Use cublasLt
using
cublasHandleManager
=
detail
::
HandleManager
<
cublasLtHandle_t
,
CreateCublasHandle
>
;
using
cublasHandleManager
=
detail
::
HandleManager
<
cublasLtHandle_t
,
CreateCublasHandle
>
;
void
cublas_gemm
(
const
Tensor
*
inputA
,
const
Tensor
*
inputB
,
Tensor
*
outputD
,
void
cublas_gemm
(
const
Tensor
*
inputA
,
const
Tensor
*
inputB
,
Tensor
*
outputD
,
...
@@ -484,7 +484,7 @@ static void init_streams_and_events_batchgemm() {
...
@@ -484,7 +484,7 @@ static void init_streams_and_events_batchgemm() {
void
nvte_cublas_gemm
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
const
NVTETensor
bias
,
void
nvte_cublas_gemm
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
,
bool
nvte_use_rocblas
)
{
int
math_sm_count
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
,
bool
nvte_use_rocblas
,
int
compute_stream_offset
)
{
NVTE_API_CALL
(
nvte_cublas_gemm
);
NVTE_API_CALL
(
nvte_cublas_gemm
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
const
Tensor
*
inputA
=
reinterpret_cast
<
const
Tensor
*>
(
A
);
const
Tensor
*
inputA
=
reinterpret_cast
<
const
Tensor
*>
(
A
);
...
@@ -539,7 +539,7 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
...
@@ -539,7 +539,7 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
grad
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
,
nvte_use_hipblaslt
,
nvte_use_rocblas
);
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
,
nvte_use_hipblaslt
,
nvte_use_rocblas
,
compute_stream_offset
);
#else
#else
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
);
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
);
#endif
#endif
...
@@ -574,7 +574,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
...
@@ -574,7 +574,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
const
NVTETensor
counter
,
int
n_split
,
bool
gemm_producer
,
const
NVTETensor
counter
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
,
bool
nvte_use_rocblas
)
{
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
,
bool
nvte_use_rocblas
,
int
compute_stream_offset
)
{
NVTE_API_CALL
(
nvte_cublas_atomic_gemm
);
NVTE_API_CALL
(
nvte_cublas_atomic_gemm
);
#ifndef __HIP_PLATFORM_AMD__
#ifndef __HIP_PLATFORM_AMD__
...
@@ -637,7 +637,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
...
@@ -637,7 +637,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
grad
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
math_sm_count
,
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
,
nvte_use_hipblaslt
,
nvte_use_rocblas
);
math_sm_count
,
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
,
nvte_use_hipblaslt
,
nvte_use_rocblas
,
compute_stream_offset
);
#else
#else
math_sm_count
,
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
);
math_sm_count
,
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
);
#endif
#endif
...
@@ -706,7 +706,7 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
...
@@ -706,7 +706,7 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
for
(
int
i
=
0
;
i
<
num_gemms
;
i
++
)
{
for
(
int
i
=
0
;
i
<
num_gemms
;
i
++
)
{
nvte_cublas_gemm
(
A
[
i
],
B
[
i
],
D
[
i
],
bias
[
i
],
pre_gelu_out
[
i
],
transa
,
transb
,
grad
,
nvte_cublas_gemm
(
A
[
i
],
B
[
i
],
D
[
i
],
bias
[
i
],
pre_gelu_out
[
i
],
transa
,
transb
,
grad
,
workspace
[
i
%
num_streams
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
workspace
[
i
%
num_streams
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
compute_streams
[
i
%
num_streams
],
1
,
0
);
compute_streams
[
i
%
num_streams
],
1
,
0
,
i
%
num_streams
);
}
}
}
}
...
...
transformer_engine/common/gemm/rocm_gemm.cu
View file @
229be5e8
...
@@ -37,30 +37,26 @@ namespace {
...
@@ -37,30 +37,26 @@ namespace {
#ifdef USE_HIPBLASLT
#ifdef USE_HIPBLASLT
#if HIP_VERSION >= 60000000
static
hipDataType
get_hipblaslt_dtype
(
const
transformer_engine
::
DType
t
)
{
typedef
hipDataType
hipblasltDatatype_t
;
typedef
hipblasComputeType_t
hipblasLtComputeType_t
;
#define HIPBLASLT_R_16F HIP_R_16F
#define HIPBLASLT_R_32F HIP_R_32F
#define HIPBLASLT_R_16B HIP_R_16BF
#define HIPBLASLT_R_8F_E4M3 HIP_R_8F_E4M3_FNUZ
#define HIPBLASLT_R_8F_E5M2 HIP_R_8F_E5M2_FNUZ
#define HIPBLASLT_COMPUTE_F32 HIPBLAS_COMPUTE_32F
#endif // #if HIP_VERSION >= 60000000
hipblasltDatatype_t
get_hipblaslt_dtype
(
const
transformer_engine
::
DType
t
)
{
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
switch
(
t
)
{
switch
(
t
)
{
case
DType
::
kFloat16
:
case
DType
::
kFloat16
:
return
HIP
BLASLT
_R_16F
;
return
HIP_R_16F
;
case
DType
::
kFloat32
:
case
DType
::
kFloat32
:
return
HIP
BLASLT
_R_32F
;
return
HIP_R_32F
;
case
DType
::
kBFloat16
:
case
DType
::
kBFloat16
:
return
HIPBLASLT_R_16B
;
return
HIP_R_16BF
;
#if HIP_VERSION >= 60300000
case
DType
::
kFloat8E4M3
:
case
DType
::
kFloat8E4M3
:
return
HIPBLASLT
_R_8F_E4M3
;
return
te_fp8_fnuz
()
?
HIP_R_8F_E4M3_FNUZ
:
HIP
_R_8F_E4M3
;
case
DType
::
kFloat8E5M2
:
case
DType
::
kFloat8E5M2
:
return
HIPBLASLT_R_8F_E5M2
;
return
te_fp8_fnuz
()
?
HIP_R_8F_E5M2_FNUZ
:
HIP_R_8F_E5M2
;
#else
case
DType
::
kFloat8E4M3
:
return
HIP_R_8F_E4M3_FNUZ
;
case
DType
::
kFloat8E5M2
:
return
HIP_R_8F_E5M2_FNUZ
;
#endif
default:
default:
NVTE_ERROR
(
"Invalid type"
);
NVTE_ERROR
(
"Invalid type"
);
}
}
...
@@ -368,11 +364,7 @@ void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool
...
@@ -368,11 +364,7 @@ void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool
if
(
!
stream_order_alloc
){
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipMemset
(
out
,
0
,
n
*
sizeof
(
float
))
);
NVTE_CHECK_CUDA
(
hipMemset
(
out
,
0
,
n
*
sizeof
(
float
))
);
}
else
{
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipMemsetAsync
(
out
,
0
,
n
*
sizeof
(
float
),
stream
)
);
NVTE_CHECK_CUDA
(
hipMemsetAsync
(
out
,
0
,
n
*
sizeof
(
float
),
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
}
hipLaunchKernelGGL
((
bias_gradient_kernel
<
Tin
,
THREADS_PER_BLOCK
>
),
dim3
(
grid
),
dim3
(
block
),
0
,
stream
,
in
,
out
,
m
,
n
);
hipLaunchKernelGGL
((
bias_gradient_kernel
<
Tin
,
THREADS_PER_BLOCK
>
),
dim3
(
grid
),
dim3
(
block
),
0
,
stream
,
in
,
out
,
m
,
n
);
}
}
...
@@ -576,11 +568,11 @@ public:
...
@@ -576,11 +568,11 @@ public:
const
std
::
string_view
&
getName
(
const
T
&
val
)
{
const
std
::
string_view
&
getName
(
const
T
&
val
)
{
return
map
.
at
(
val
);
return
map
.
at
(
val
);
}
}
T
getValue
(
const
std
::
string
&
name
,
const
char
*
label
=
""
)
T
getValue
(
const
std
::
string
&
name
,
const
char
*
label
=
""
,
std
::
function
<
bool
(
const
T
&
)
>
filter
=
nullptr
)
{
{
for
(
auto
iter
=
map
.
begin
();
iter
!=
map
.
end
();
++
iter
)
for
(
auto
iter
=
map
.
begin
();
iter
!=
map
.
end
();
++
iter
)
{
{
if
(
name
==
iter
->
second
)
return
iter
->
first
;
if
(
(
name
==
iter
->
second
)
&&
(
!
filter
||
filter
(
iter
->
first
)))
return
iter
->
first
;
}
}
NVTE_ERROR
(
"Invalid "
,
label
,
" name: "
,
name
);
NVTE_ERROR
(
"Invalid "
,
label
,
" name: "
,
name
);
}
}
...
@@ -588,14 +580,18 @@ protected:
...
@@ -588,14 +580,18 @@ protected:
const
std
::
unordered_map
<
T
,
std
::
string_view
>
&
map
;
const
std
::
unordered_map
<
T
,
std
::
string_view
>
&
map
;
};
};
static
std
::
unordered_map
<
hipblasltDatatype_t
,
std
::
string_view
>
type_name_map
=
{
static
std
::
unordered_map
<
hipDataType
,
std
::
string_view
>
type_name_map
=
{
{
HIPBLASLT_R_32F
,
"float32"
},
{
HIP_R_32F
,
"float32"
},
{
HIPBLASLT_R_16F
,
"float16"
},
{
HIP_R_16F
,
"float16"
},
{
HIPBLASLT_R_16B
,
"bfloat16"
},
{
HIP_R_16BF
,
"bfloat16"
},
{
HIPBLASLT_R_8F_E4M3
,
"float8e4m3"
},
{
HIP_R_8F_E4M3_FNUZ
,
"float8e4m3"
},
{
HIPBLASLT_R_8F_E5M2
,
"float8e5m2"
},
{
HIP_R_8F_E5M2_FNUZ
,
"float8e5m2"
},
#if HIP_VERSION >= 60300000
{
HIP_R_8F_E4M3
,
"float8e4m3"
},
{
HIP_R_8F_E5M2
,
"float8e5m2"
},
#endif
};
};
static
NameMapper
<
hip
blaslt
Data
t
ype
_t
>
typeNameMapper
(
type_name_map
);
static
NameMapper
<
hipData
T
ype
>
typeNameMapper
(
type_name_map
);
static
std
::
unordered_map
<
hipblasOperation_t
,
std
::
string_view
>
trans_name_map
=
{
static
std
::
unordered_map
<
hipblasOperation_t
,
std
::
string_view
>
trans_name_map
=
{
{
HIPBLAS_OP_N
,
"N"
},
{
HIPBLAS_OP_N
,
"N"
},
...
@@ -614,24 +610,24 @@ static std::unordered_map<hipblasLtEpilogue_t, std::string_view> epi_name_map =
...
@@ -614,24 +610,24 @@ static std::unordered_map<hipblasLtEpilogue_t, std::string_view> epi_name_map =
};
};
static
NameMapper
<
hipblasLtEpilogue_t
>
epilogueNameMapper
(
epi_name_map
);
static
NameMapper
<
hipblasLtEpilogue_t
>
epilogueNameMapper
(
epi_name_map
);
static
std
::
unordered_map
<
hipblas
Lt
ComputeType_t
,
std
::
string_view
>
comp_name_map
=
{
static
std
::
unordered_map
<
hipblasComputeType_t
,
std
::
string_view
>
comp_name_map
=
{
{
HIPBLAS
LT
_COMPUTE_
F
32
,
"f32"
}
{
HIPBLAS_COMPUTE_32
F
,
"f32"
}
};
};
static
NameMapper
<
hipblas
Lt
ComputeType_t
>
computeNameMapper
(
comp_name_map
);
static
NameMapper
<
hipblasComputeType_t
>
computeNameMapper
(
comp_name_map
);
static
class
GemmAlgoCache
{
static
class
GemmAlgoCache
{
public:
public:
struct
Key
{
struct
Key
{
int
deviceCap
;
int
deviceCap
;
hip
blaslt
Data
t
ype
_t
a_type
,
b_type
,
d_type
,
bias_type
;
hipData
T
ype
a_type
,
b_type
,
d_type
,
bias_type
;
int
m
,
n
,
k
;
int
m
,
n
,
k
;
int
lda
,
ldb
,
ldd
;
int
lda
,
ldb
,
ldd
;
hipblasOperation_t
transa
,
transb
;
hipblasOperation_t
transa
,
transb
;
hipblasLtEpilogue_t
epilogue
;
hipblasLtEpilogue_t
epilogue
;
Key
(
int
deviceCap_
,
Key
(
int
deviceCap_
,
hip
blaslt
Data
t
ype
_t
a_type_
,
hip
blaslt
Data
t
ype
_t
b_type_
,
hipData
T
ype
a_type_
,
hipData
T
ype
b_type_
,
hip
blaslt
Data
t
ype
_t
d_type_
,
hip
blaslt
Data
t
ype
_t
bias_type_
,
hipData
T
ype
d_type_
,
hipData
T
ype
bias_type_
,
int
m_
,
int
n_
,
int
k_
,
int
lda_
,
int
ldb_
,
int
ldd_
,
int
m_
,
int
n_
,
int
k_
,
int
lda_
,
int
ldb_
,
int
ldd_
,
hipblasOperation_t
transa_
,
hipblasOperation_t
transb_
,
hipblasOperation_t
transa_
,
hipblasOperation_t
transb_
,
hipblasLtEpilogue_t
epilogue_
)
:
hipblasLtEpilogue_t
epilogue_
)
:
...
@@ -866,17 +862,31 @@ protected:
...
@@ -866,17 +862,31 @@ protected:
continue
;
continue
;
}
}
cfg
.
a_type
=
typeNameMapper
.
getValue
(
type_a
,
"type_a"
);
#if HIP_VERSION >= 60300000
cfg
.
b_type
=
typeNameMapper
.
getValue
(
type_b
,
"type_b"
);
auto
fp8_filter
=
te_fp8_fnuz
()
cfg
.
d_type
=
typeNameMapper
.
getValue
(
type_d
,
"type_d"
);
?
[](
const
hipDataType
&
val
)
cfg
.
bias_type
=
(
bias_type
==
"-"
)
?
(
hipblasltDatatype_t
)
-
1
:
typeNameMapper
.
getValue
(
bias_type
,
"bias_type"
);
{
return
(
val
!=
HIP_R_8F_E4M3
&&
val
!=
HIP_R_8F_E5M2
);
}
:
[](
const
hipDataType
&
val
)
{
return
(
val
!=
HIP_R_8F_E4M3_FNUZ
&&
val
!=
HIP_R_8F_E5M2_FNUZ
);
};
#else
auto
fp8_filter
=
nullptr
;
#endif
cfg
.
a_type
=
typeNameMapper
.
getValue
(
type_a
,
"type_a"
,
fp8_filter
);
cfg
.
b_type
=
typeNameMapper
.
getValue
(
type_b
,
"type_b"
,
fp8_filter
);
cfg
.
d_type
=
typeNameMapper
.
getValue
(
type_d
,
"type_d"
,
fp8_filter
);
cfg
.
bias_type
=
(
bias_type
==
"-"
)
?
(
hipDataType
)
-
1
:
typeNameMapper
.
getValue
(
bias_type
,
"bias_type"
,
fp8_filter
);
cfg
.
transa
=
transposeNameMapper
.
getValue
(
trans_a
,
"trans_a"
);
cfg
.
transa
=
transposeNameMapper
.
getValue
(
trans_a
,
"trans_a"
);
cfg
.
transb
=
transposeNameMapper
.
getValue
(
trans_b
,
"trans_b"
);
cfg
.
transb
=
transposeNameMapper
.
getValue
(
trans_b
,
"trans_b"
);
cfg
.
epilogue
=
epilogueNameMapper
.
getValue
(
epi
,
"epi"
);
cfg
.
epilogue
=
epilogueNameMapper
.
getValue
(
epi
,
"epi"
);
//Check and filter out compute and scale types
//Check and filter out compute and scale types
if
(
computeNameMapper
.
getValue
(
comp
,
"comp"
)
!=
HIPBLASLT_COMPUTE_F32
||
typeNameMapper
.
getValue
(
scale
,
"scale"
)
!=
HIPBLASLT_R_32F
)
if
(
computeNameMapper
.
getValue
(
comp
,
"comp"
)
!=
HIPBLAS_COMPUTE_32F
||
typeNameMapper
.
getValue
(
scale
,
"scale"
)
!=
HIP_R_32F
)
{
{
continue
;
continue
;
}
}
...
@@ -959,9 +969,9 @@ protected:
...
@@ -959,9 +969,9 @@ protected:
csv
<<
cfg
.
deviceCap
<<
cfg
.
m
<<
cfg
.
n
<<
cfg
.
k
csv
<<
cfg
.
deviceCap
<<
cfg
.
m
<<
cfg
.
n
<<
cfg
.
k
<<
transposeNameMapper
.
getName
(
cfg
.
transa
)
<<
transposeNameMapper
.
getName
(
cfg
.
transb
)
<<
transposeNameMapper
.
getName
(
cfg
.
transa
)
<<
transposeNameMapper
.
getName
(
cfg
.
transb
)
<<
typeNameMapper
.
getName
(
cfg
.
a_type
)
<<
typeNameMapper
.
getName
(
cfg
.
b_type
)
<<
typeNameMapper
.
getName
(
cfg
.
d_type
)
<<
typeNameMapper
.
getName
(
cfg
.
a_type
)
<<
typeNameMapper
.
getName
(
cfg
.
b_type
)
<<
typeNameMapper
.
getName
(
cfg
.
d_type
)
<<
((
cfg
.
bias_type
==
(
hip
blaslt
Data
t
ype
_t
)
-
1
)
?
"-"
:
typeNameMapper
.
getName
(
cfg
.
bias_type
))
<<
((
cfg
.
bias_type
==
(
hipData
T
ype
)
-
1
)
?
"-"
:
typeNameMapper
.
getName
(
cfg
.
bias_type
))
<<
cfg
.
lda
<<
cfg
.
ldb
<<
cfg
.
ldd
<<
epilogueNameMapper
.
getName
(
cfg
.
epilogue
)
<<
cfg
.
lda
<<
cfg
.
ldb
<<
cfg
.
ldd
<<
epilogueNameMapper
.
getName
(
cfg
.
epilogue
)
<<
computeNameMapper
.
getName
(
HIPBLAS
LT
_COMPUTE_
F
32
)
<<
typeNameMapper
.
getName
(
HIP
BLASLT
_R_32F
)
<<
computeNameMapper
.
getName
(
HIPBLAS_COMPUTE_32
F
)
<<
typeNameMapper
.
getName
(
HIP_R_32F
)
<<
algo
.
ws_size_min
<<
algo
.
ws_size_max
<<
algo
.
algoId
<<
algo
.
index
<<
csv_helper
::
end
()
<<
"
\n
"
;
<<
algo
.
ws_size_min
<<
algo
.
ws_size_max
<<
algo
.
algoId
<<
algo
.
index
<<
csv_helper
::
end
()
<<
"
\n
"
;
}
}
...
@@ -995,6 +1005,19 @@ static inline int getIntEnv(const char *name, int defval, int minval)
...
@@ -995,6 +1005,19 @@ static inline int getIntEnv(const char *name, int defval, int minval)
}
//namespace
}
//namespace
/* Warning: only call once per device!
* When calling nvte_multi_stream_cublas_gemm with hipblaslt backend
* need to create multiple handles corresponding to compute_streams
* to avoid a handle be used by multi-streams concurrently.
*/
static
void
init_hipblaslt_handles
(
hipblasLtHandle_t
*
hipblaslt_handles
)
{
NVTE_CHECK
(
hipblaslt_handles
!=
nullptr
);
for
(
int
i
=
0
;
i
<
num_streams
;
i
++
)
{
NVTE_CHECK_HIPBLASLT
(
hipblasLtCreate
(
&
hipblaslt_handles
[
i
]));
}
}
void
hipblaslt_gemm
(
const
Tensor
*
inputA
,
void
hipblaslt_gemm
(
const
Tensor
*
inputA
,
const
Tensor
*
inputB
,
const
Tensor
*
inputB
,
Tensor
*
outputD
,
Tensor
*
outputD
,
...
@@ -1014,7 +1037,8 @@ void hipblaslt_gemm(const Tensor *inputA,
...
@@ -1014,7 +1037,8 @@ void hipblaslt_gemm(const Tensor *inputA,
int
n_split
,
int
n_split
,
bool
gemm_producer
,
bool
gemm_producer
,
const
Tensor
*
inputCounter
,
const
Tensor
*
inputCounter
,
hipStream_t
stream
hipStream_t
stream
,
hipblasLtHandle_t
handle
)
{
)
{
void
*
A
=
inputA
->
data
.
dptr
;
void
*
A
=
inputA
->
data
.
dptr
;
void
*
A_scale_inverse
=
inputA
->
scale_inv
.
dptr
;
void
*
A_scale_inverse
=
inputA
->
scale_inv
.
dptr
;
...
@@ -1027,10 +1051,10 @@ void hipblaslt_gemm(const Tensor *inputA,
...
@@ -1027,10 +1051,10 @@ void hipblaslt_gemm(const Tensor *inputA,
const
bool
gelu
=
pre_gelu_out
!=
nullptr
;
const
bool
gelu
=
pre_gelu_out
!=
nullptr
;
const
bool
use_fp8
=
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
const
bool
use_fp8
=
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
is_fp8_dtype
(
inputB
->
data
.
dtype
);
is_fp8_dtype
(
inputB
->
data
.
dtype
);
const
hip
blaslt
Data
t
ype
_t
A_type
=
get_hipblaslt_dtype
(
inputA
->
data
.
dtype
);
const
hipData
T
ype
A_type
=
get_hipblaslt_dtype
(
inputA
->
data
.
dtype
);
const
hip
blaslt
Data
t
ype
_t
B_type
=
get_hipblaslt_dtype
(
inputB
->
data
.
dtype
);
const
hipData
T
ype
B_type
=
get_hipblaslt_dtype
(
inputB
->
data
.
dtype
);
const
hip
blaslt
Data
t
ype
_t
D_type
=
get_hipblaslt_dtype
(
outputD
->
data
.
dtype
);
const
hipData
T
ype
D_type
=
get_hipblaslt_dtype
(
outputD
->
data
.
dtype
);
const
hip
blaslt
Data
t
ype
_t
bias_type
=
get_hipblaslt_dtype
(
inputBias
->
data
.
dtype
);
const
hipData
T
ype
bias_type
=
get_hipblaslt_dtype
(
inputBias
->
data
.
dtype
);
NVTE_CHECK
(
!
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
A_scale_inverse
!=
nullptr
,
NVTE_CHECK
(
!
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
A_scale_inverse
!=
nullptr
,
"FP8 input to GEMM requires inverse of scale!"
);
"FP8 input to GEMM requires inverse of scale!"
);
...
@@ -1050,11 +1074,13 @@ void hipblaslt_gemm(const Tensor *inputA,
...
@@ -1050,11 +1074,13 @@ void hipblaslt_gemm(const Tensor *inputA,
int
device_id
;
int
device_id
;
NVTE_CHECK_CUDA
(
hipGetDevice
(
&
device_id
));
NVTE_CHECK_CUDA
(
hipGetDevice
(
&
device_id
));
hipblasLtHandle_t
handle
=
cached_handles
.
get
(
device_id
);
if
(
handle
==
nullptr
)
{
handle
=
cached_handles
.
get
(
device_id
);
if
(
handle
==
nullptr
)
if
(
handle
==
nullptr
)
{
{
handle
=
cached_handles
.
obtain
(
device_id
);
handle
=
cached_handles
.
obtain
(
device_id
);
}
}
}
hipblasLtMatmulDesc_t
operationDesc
=
nullptr
;
hipblasLtMatmulDesc_t
operationDesc
=
nullptr
;
hipblasLtMatrixLayout_t
Adesc
=
nullptr
,
Bdesc
=
nullptr
,
Cdesc
=
nullptr
,
Ddesc
=
nullptr
;
hipblasLtMatrixLayout_t
Adesc
=
nullptr
,
Bdesc
=
nullptr
,
Cdesc
=
nullptr
,
Ddesc
=
nullptr
;
...
@@ -1064,7 +1090,7 @@ void hipblaslt_gemm(const Tensor *inputA,
...
@@ -1064,7 +1090,7 @@ void hipblaslt_gemm(const Tensor *inputA,
int64_t
ld_gelumat
=
(
int64_t
)
ldd
;
int64_t
ld_gelumat
=
(
int64_t
)
ldd
;
// default to tf32 except for e5m2 inputs where the config is not supported
// default to tf32 except for e5m2 inputs where the config is not supported
hipblas
Lt
ComputeType_t
gemm_compute_type
=
HIPBLAS
LT
_COMPUTE_
F
32
;
hipblasComputeType_t
gemm_compute_type
=
HIPBLAS_COMPUTE_32
F
;
// Create matrix descriptors. Not setting any extra attributes.
// Create matrix descriptors. Not setting any extra attributes.
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutCreate
(
&
Adesc
,
A_type
,
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutCreate
(
&
Adesc
,
A_type
,
...
@@ -1077,7 +1103,7 @@ void hipblaslt_gemm(const Tensor *inputA,
...
@@ -1077,7 +1103,7 @@ void hipblaslt_gemm(const Tensor *inputA,
ldb
));
ldb
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutCreate
(
&
Ddesc
,
D_type
,
m
,
n
,
ldd
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutCreate
(
&
Ddesc
,
D_type
,
m
,
n
,
ldd
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescCreate
(
&
operationDesc
,
gemm_compute_type
,
HIP
BLASLT
_R_32F
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescCreate
(
&
operationDesc
,
gemm_compute_type
,
HIP_R_32F
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_TRANSA
,
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_TRANSA
,
&
transa
,
sizeof
(
transa
)));
&
transa
,
sizeof
(
transa
)));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_TRANSB
,
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_TRANSB
,
...
@@ -1154,7 +1180,7 @@ void hipblaslt_gemm(const Tensor *inputA,
...
@@ -1154,7 +1180,7 @@ void hipblaslt_gemm(const Tensor *inputA,
&
epilogue
,
sizeof
(
epilogue
)));
&
epilogue
,
sizeof
(
epilogue
)));
GemmAlgoCache
::
Key
gemm_cfg
(
algoCache
.
device_cap
(
device_id
),
A_type
,
B_type
,
D_type
,
GemmAlgoCache
::
Key
gemm_cfg
(
algoCache
.
device_cap
(
device_id
),
A_type
,
B_type
,
D_type
,
use_fp8
?
bias_type
:
(
hip
blaslt
Data
t
ype
_t
)
-
1
,
use_fp8
?
bias_type
:
(
hipData
T
ype
)
-
1
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
transa
,
transb
,
epilogue
);
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
transa
,
transb
,
epilogue
);
GemmAlgoCache
::
Algo
cached_algo
;
GemmAlgoCache
::
Algo
cached_algo
;
if
(
algoCache
.
find
(
gemm_cfg
,
workspaceSize
,
cached_algo
)
==
0
||
!
cached_algo
.
algo
.
has_value
())
if
(
algoCache
.
find
(
gemm_cfg
,
workspaceSize
,
cached_algo
)
==
0
||
!
cached_algo
.
algo
.
has_value
())
...
@@ -1231,6 +1257,7 @@ void hipblaslt_gemm(const Tensor *inputA,
...
@@ -1231,6 +1257,7 @@ void hipblaslt_gemm(const Tensor *inputA,
<<
" in range ["
<<
firstAlgo
<<
"-"
<<
(
algoTuneCount
-
1
)
<<
"] with "
<<
" in range ["
<<
firstAlgo
<<
"-"
<<
(
algoTuneCount
-
1
)
<<
"] with "
<<
tuneLoopCount
<<
" loops "
<<
std
::
endl
;
<<
tuneLoopCount
<<
" loops "
<<
std
::
endl
;
NVTE_CHECK_CUDA
(
hipStreamSynchronize
(
stream
));
hipStream_t
profilingStream
;
hipStream_t
profilingStream
;
NVTE_CHECK_CUDA
(
hipStreamCreateWithFlags
(
&
profilingStream
,
hipStreamNonBlocking
));
NVTE_CHECK_CUDA
(
hipStreamCreateWithFlags
(
&
profilingStream
,
hipStreamNonBlocking
));
using
tuning_clock
=
std
::
chrono
::
steady_clock
;
using
tuning_clock
=
std
::
chrono
::
steady_clock
;
...
@@ -1475,11 +1502,7 @@ void rocblas_gemm(const Tensor *inputA,
...
@@ -1475,11 +1502,7 @@ void rocblas_gemm(const Tensor *inputA,
if
(
!
stream_order_alloc
){
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipMalloc
(
&
D_temp
,
sizeof
(
float
)
*
m
*
n
)
);
NVTE_CHECK_CUDA
(
hipMalloc
(
&
D_temp
,
sizeof
(
float
)
*
m
*
n
)
);
}
else
{
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipMallocAsync
(
&
D_temp
,
sizeof
(
float
)
*
m
*
n
,
stream
)
);
NVTE_CHECK_CUDA
(
hipMallocAsync
(
&
D_temp
,
sizeof
(
float
)
*
m
*
n
,
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
}
}
else
{
}
else
{
D_temp
=
D
;
D_temp
=
D
;
...
@@ -1570,11 +1593,7 @@ void rocblas_gemm(const Tensor *inputA,
...
@@ -1570,11 +1593,7 @@ void rocblas_gemm(const Tensor *inputA,
if
(
!
stream_order_alloc
){
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipMalloc
(
&
bias_tmp
,
sizeof
(
float
)
*
input_dim
)
);
// The bias gradient is for the first linear layer
NVTE_CHECK_CUDA
(
hipMalloc
(
&
bias_tmp
,
sizeof
(
float
)
*
input_dim
)
);
// The bias gradient is for the first linear layer
}
else
{
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipMallocAsync
(
&
bias_tmp
,
sizeof
(
float
)
*
input_dim
,
stream
)
);
NVTE_CHECK_CUDA
(
hipMallocAsync
(
&
bias_tmp
,
sizeof
(
float
)
*
input_dim
,
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
}
}
else
{
}
else
{
bias_tmp
=
bias_ptr
;
bias_tmp
=
bias_ptr
;
...
@@ -1600,11 +1619,7 @@ void rocblas_gemm(const Tensor *inputA,
...
@@ -1600,11 +1619,7 @@ void rocblas_gemm(const Tensor *inputA,
if
(
!
stream_order_alloc
){
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipFree
(
bias_tmp
)
);
NVTE_CHECK_CUDA
(
hipFree
(
bias_tmp
)
);
}
else
{
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipFreeAsync
(
bias_tmp
,
stream
)
);
NVTE_CHECK_CUDA
(
hipFreeAsync
(
bias_tmp
,
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
}
}
}
...
@@ -1652,11 +1667,7 @@ void rocblas_gemm(const Tensor *inputA,
...
@@ -1652,11 +1667,7 @@ void rocblas_gemm(const Tensor *inputA,
if
(
!
stream_order_alloc
){
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipMalloc
(
&
bias_tmp
,
sizeof
(
float
)
*
output_dim
)
);
NVTE_CHECK_CUDA
(
hipMalloc
(
&
bias_tmp
,
sizeof
(
float
)
*
output_dim
)
);
}
else
{
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipMallocAsync
(
&
bias_tmp
,
sizeof
(
float
)
*
output_dim
,
stream
)
);
NVTE_CHECK_CUDA
(
hipMallocAsync
(
&
bias_tmp
,
sizeof
(
float
)
*
output_dim
,
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
}
}
else
{
}
else
{
bias_tmp
=
bias_ptr
;
bias_tmp
=
bias_ptr
;
...
@@ -1683,11 +1694,7 @@ void rocblas_gemm(const Tensor *inputA,
...
@@ -1683,11 +1694,7 @@ void rocblas_gemm(const Tensor *inputA,
if
(
!
stream_order_alloc
){
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipFree
(
bias_tmp
)
);
NVTE_CHECK_CUDA
(
hipFree
(
bias_tmp
)
);
}
else
{
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipFreeAsync
(
bias_tmp
,
stream
)
);
NVTE_CHECK_CUDA
(
hipFreeAsync
(
bias_tmp
,
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
}
}
}
if
(
D_type
==
rocblas_datatype_f16_r
||
D_type
==
rocblas_datatype_bf16_r
)
{
if
(
D_type
==
rocblas_datatype_f16_r
||
D_type
==
rocblas_datatype_bf16_r
)
{
...
@@ -1788,11 +1795,7 @@ void rocblas_gemm(const Tensor *inputA,
...
@@ -1788,11 +1795,7 @@ void rocblas_gemm(const Tensor *inputA,
if
(
!
stream_order_alloc
){
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipFree
(
D_temp
)
);
NVTE_CHECK_CUDA
(
hipFree
(
D_temp
)
);
}
else
{
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipFreeAsync
(
D_temp
,
stream
)
);
NVTE_CHECK_CUDA
(
hipFreeAsync
(
D_temp
,
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
}
}
}
}
}
...
@@ -1804,7 +1807,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
...
@@ -1804,7 +1807,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
int
ldb
,
int
ldd
,
bool
transa
,
bool
transb
,
bool
grad
,
int
ldb
,
int
ldd
,
bool
transa
,
bool
transb
,
bool
grad
,
void
*
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
void
*
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
const
Tensor
*
inputCounter
,
hipStream_t
stream
,
bool
nvte_use_hipblaslt
=
0
,
bool
nvte_use_rocblas
=
0
)
const
Tensor
*
inputCounter
,
hipStream_t
stream
,
bool
nvte_use_hipblaslt
=
0
,
bool
nvte_use_rocblas
=
0
,
int
compute_stream_offset
=
-
1
)
{
{
/*If no backend is specified with env variable use HIPBLASLT unless it is disabled
/*If no backend is specified with env variable use HIPBLASLT unless it is disabled
If HIPBLASLT backend is enabled and requested, use it despite ROCBLAS status
If HIPBLASLT backend is enabled and requested, use it despite ROCBLAS status
...
@@ -1845,8 +1848,21 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
...
@@ -1845,8 +1848,21 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#endif
#endif
#ifdef USE_HIPBLASLT
#ifdef USE_HIPBLASLT
if
(
use_hipblaslt
)
if
(
use_hipblaslt
||
!
use_rocblas
)
{
{
// Check compute_stream_offset valid.
NVTE_CHECK
(
compute_stream_offset
>=
-
1
&&
compute_stream_offset
<
num_streams
);
hipblasLtHandle_t
handle
=
nullptr
;
if
(
compute_stream_offset
!=
-
1
)
{
// Init hipblaslt handles (once, globally)
static
std
::
once_flag
init_flag
;
static
hipblasLtHandle_t
hipblaslt_handles
[
num_streams
];
std
::
call_once
(
init_flag
,
init_hipblaslt_handles
,
hipblaslt_handles
);
handle
=
hipblaslt_handles
[
compute_stream_offset
];
}
hipblaslt_gemm
(
inputA
,
inputB
,
outputD
,
inputBias
,
outputPreGelu
,
hipblaslt_gemm
(
inputA
,
inputB
,
outputD
,
inputBias
,
outputPreGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
(
transa
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transa
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
...
@@ -1854,7 +1870,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
...
@@ -1854,7 +1870,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
grad
,
grad
,
workspace
,
workspaceSize
,
accumulate
,
use_split_accumulator
,
workspace
,
workspaceSize
,
accumulate
,
use_split_accumulator
,
math_sm_count
,
m_split
,
n_split
,
gemm_producer
,
math_sm_count
,
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
);
inputCounter
,
stream
,
handle
);
return
;
return
;
}
}
#endif
#endif
...
...
transformer_engine/common/include/transformer_engine/gemm.h
View file @
229be5e8
...
@@ -42,7 +42,7 @@ extern "C" {
...
@@ -42,7 +42,7 @@ extern "C" {
void
nvte_cublas_gemm
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
const
NVTETensor
bias
,
void
nvte_cublas_gemm
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
=
0
,
bool
nvte_use_rocblas
=
0
);
int
math_sm_count
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
=
0
,
bool
nvte_use_rocblas
=
0
,
int
compute_stream_offset
=
-
1
);
/*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters.
/*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters.
*
*
...
@@ -77,7 +77,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
...
@@ -77,7 +77,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
const
NVTETensor
counter
,
int
n_split
,
bool
gemm_producer
,
const
NVTETensor
counter
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
=
0
,
bool
nvte_use_rocblas
=
0
);
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
=
0
,
bool
nvte_use_rocblas
=
0
,
int
compute_stream_offset
=
-
1
);
/*! \brief Compute multiple pairs of matrix multiplication, potentially fused with other operations,
/*! \brief Compute multiple pairs of matrix multiplication, potentially fused with other operations,
* on multiple streams.
* on multiple streams.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment