Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
229be5e8
Commit
229be5e8
authored
May 06, 2025
by
yuguo
Browse files
[DCU] new rocm gemm
parent
388ac735
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
121 additions
and
103 deletions
+121
-103
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
...mer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
+6
-6
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+6
-6
transformer_engine/common/gemm/rocm_gemm.cu
transformer_engine/common/gemm/rocm_gemm.cu
+107
-89
transformer_engine/common/include/transformer_engine/gemm.h
transformer_engine/common/include/transformer_engine/gemm.h
+2
-2
No files found.
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
View file @
229be5e8
...
...
@@ -451,7 +451,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
0
]);
use_split_accumulator
,
_math_sms
,
_stream_compute
[
0
]
,
1
,
0
,
0
);
for
(
int
i
=
1
;
i
<
_num_splits
;
i
++
)
{
input_a_chunk
=
get_tensor_chunk
(
A
,
i
*
input_a_chunk_size
,
{
m_chunk
,
k
});
...
...
@@ -462,7 +462,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()],
1
,
0
);
_stream_compute
[
i
%
_stream_compute
.
size
()],
1
,
0
,
i
%
_stream_compute
.
size
()
);
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_comm
,
_stream_compute
[(
i
-
1
)
%
_stream_compute
.
size
()]));
...
...
@@ -510,7 +510,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()],
1
,
0
);
_stream_compute
[
i
%
_stream_compute
.
size
()],
1
,
0
,
i
%
_stream_compute
.
size
()
);
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_comm
,
_stream_compute
[
i
%
_stream_compute
.
size
()]));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_comm
,
_start_comm
,
0
));
...
...
@@ -821,7 +821,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
aux_chunk
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()],
1
,
0
);
_stream_compute
[
i
%
_stream_compute
.
size
()],
1
,
0
,
i
%
_stream_compute
.
size
()
);
if
(
i
<
num_steps
-
1
)
{
// P2P communication
...
...
@@ -865,7 +865,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
aux_chunk
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()],
1
,
0
);
_stream_compute
[
i
%
_stream_compute
.
size
()],
1
,
0
,
i
%
_stream_compute
.
size
()
);
if
(
i
<
_tp_size
-
1
)
{
// P2P communication
...
...
@@ -1010,7 +1010,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
stream_id
],
1
,
0
);
use_split_accumulator
,
_math_sms
,
_stream_compute
[
stream_id
],
1
,
0
,
stream_id
);
if
(
i
>
0
)
{
// P2P communication chunk
...
...
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
229be5e8
...
...
@@ -163,7 +163,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
int
ldb
,
int
ldd
,
bool
transa
,
bool
transb
,
bool
grad
,
void
*
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
const
Tensor
*
inputCounter
,
hipStream_t
stream
);
const
Tensor
*
inputCounter
,
hipStream_t
stream
,
bool
nvte_use_hipblaslt
,
bool
nvte_use_rocblas
,
int
compute_stream_offset
);
#else // Use cublasLt
using
cublasHandleManager
=
detail
::
HandleManager
<
cublasLtHandle_t
,
CreateCublasHandle
>
;
void
cublas_gemm
(
const
Tensor
*
inputA
,
const
Tensor
*
inputB
,
Tensor
*
outputD
,
...
...
@@ -484,7 +484,7 @@ static void init_streams_and_events_batchgemm() {
void
nvte_cublas_gemm
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
,
bool
nvte_use_rocblas
)
{
int
math_sm_count
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
,
bool
nvte_use_rocblas
,
int
compute_stream_offset
)
{
NVTE_API_CALL
(
nvte_cublas_gemm
);
using
namespace
transformer_engine
;
const
Tensor
*
inputA
=
reinterpret_cast
<
const
Tensor
*>
(
A
);
...
...
@@ -539,7 +539,7 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
#ifdef __HIP_PLATFORM_AMD__
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
,
nvte_use_hipblaslt
,
nvte_use_rocblas
);
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
,
nvte_use_hipblaslt
,
nvte_use_rocblas
,
compute_stream_offset
);
#else
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
);
#endif
...
...
@@ -574,7 +574,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
const
NVTETensor
counter
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
,
bool
nvte_use_rocblas
)
{
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
,
bool
nvte_use_rocblas
,
int
compute_stream_offset
)
{
NVTE_API_CALL
(
nvte_cublas_atomic_gemm
);
#ifndef __HIP_PLATFORM_AMD__
...
...
@@ -637,7 +637,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
#ifdef __HIP_PLATFORM_AMD__
math_sm_count
,
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
,
nvte_use_hipblaslt
,
nvte_use_rocblas
);
math_sm_count
,
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
,
nvte_use_hipblaslt
,
nvte_use_rocblas
,
compute_stream_offset
);
#else
math_sm_count
,
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
);
#endif
...
...
@@ -706,7 +706,7 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
for
(
int
i
=
0
;
i
<
num_gemms
;
i
++
)
{
nvte_cublas_gemm
(
A
[
i
],
B
[
i
],
D
[
i
],
bias
[
i
],
pre_gelu_out
[
i
],
transa
,
transb
,
grad
,
workspace
[
i
%
num_streams
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
compute_streams
[
i
%
num_streams
],
1
,
0
);
compute_streams
[
i
%
num_streams
],
1
,
0
,
i
%
num_streams
);
}
}
...
...
transformer_engine/common/gemm/rocm_gemm.cu
View file @
229be5e8
...
...
@@ -37,30 +37,26 @@ namespace {
#ifdef USE_HIPBLASLT
#if HIP_VERSION >= 60000000
typedef
hipDataType
hipblasltDatatype_t
;
typedef
hipblasComputeType_t
hipblasLtComputeType_t
;
#define HIPBLASLT_R_16F HIP_R_16F
#define HIPBLASLT_R_32F HIP_R_32F
#define HIPBLASLT_R_16B HIP_R_16BF
#define HIPBLASLT_R_8F_E4M3 HIP_R_8F_E4M3_FNUZ
#define HIPBLASLT_R_8F_E5M2 HIP_R_8F_E5M2_FNUZ
#define HIPBLASLT_COMPUTE_F32 HIPBLAS_COMPUTE_32F
#endif // #if HIP_VERSION >= 60000000
hipblasltDatatype_t
get_hipblaslt_dtype
(
const
transformer_engine
::
DType
t
)
{
static
hipDataType
get_hipblaslt_dtype
(
const
transformer_engine
::
DType
t
)
{
using
namespace
transformer_engine
;
switch
(
t
)
{
case
DType
::
kFloat16
:
return
HIP
BLASLT
_R_16F
;
return
HIP_R_16F
;
case
DType
::
kFloat32
:
return
HIP
BLASLT
_R_32F
;
return
HIP_R_32F
;
case
DType
::
kBFloat16
:
return
HIPBLASLT_R_16B
;
return
HIP_R_16BF
;
#if HIP_VERSION >= 60300000
case
DType
::
kFloat8E4M3
:
return
HIPBLASLT
_R_8F_E4M3
;
return
te_fp8_fnuz
()
?
HIP_R_8F_E4M3_FNUZ
:
HIP
_R_8F_E4M3
;
case
DType
::
kFloat8E5M2
:
return
HIPBLASLT_R_8F_E5M2
;
return
te_fp8_fnuz
()
?
HIP_R_8F_E5M2_FNUZ
:
HIP_R_8F_E5M2
;
#else
case
DType
::
kFloat8E4M3
:
return
HIP_R_8F_E4M3_FNUZ
;
case
DType
::
kFloat8E5M2
:
return
HIP_R_8F_E5M2_FNUZ
;
#endif
default:
NVTE_ERROR
(
"Invalid type"
);
}
...
...
@@ -368,11 +364,7 @@ void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipMemset
(
out
,
0
,
n
*
sizeof
(
float
))
);
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipMemsetAsync
(
out
,
0
,
n
*
sizeof
(
float
),
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
hipLaunchKernelGGL
((
bias_gradient_kernel
<
Tin
,
THREADS_PER_BLOCK
>
),
dim3
(
grid
),
dim3
(
block
),
0
,
stream
,
in
,
out
,
m
,
n
);
}
...
...
@@ -576,11 +568,11 @@ public:
const
std
::
string_view
&
getName
(
const
T
&
val
)
{
return
map
.
at
(
val
);
}
T
getValue
(
const
std
::
string
&
name
,
const
char
*
label
=
""
)
T
getValue
(
const
std
::
string
&
name
,
const
char
*
label
=
""
,
std
::
function
<
bool
(
const
T
&
)
>
filter
=
nullptr
)
{
for
(
auto
iter
=
map
.
begin
();
iter
!=
map
.
end
();
++
iter
)
{
if
(
name
==
iter
->
second
)
return
iter
->
first
;
if
(
(
name
==
iter
->
second
)
&&
(
!
filter
||
filter
(
iter
->
first
)))
return
iter
->
first
;
}
NVTE_ERROR
(
"Invalid "
,
label
,
" name: "
,
name
);
}
...
...
@@ -588,14 +580,18 @@ protected:
const
std
::
unordered_map
<
T
,
std
::
string_view
>
&
map
;
};
static
std
::
unordered_map
<
hipblasltDatatype_t
,
std
::
string_view
>
type_name_map
=
{
{
HIPBLASLT_R_32F
,
"float32"
},
{
HIPBLASLT_R_16F
,
"float16"
},
{
HIPBLASLT_R_16B
,
"bfloat16"
},
{
HIPBLASLT_R_8F_E4M3
,
"float8e4m3"
},
{
HIPBLASLT_R_8F_E5M2
,
"float8e5m2"
},
static
std
::
unordered_map
<
hipDataType
,
std
::
string_view
>
type_name_map
=
{
{
HIP_R_32F
,
"float32"
},
{
HIP_R_16F
,
"float16"
},
{
HIP_R_16BF
,
"bfloat16"
},
{
HIP_R_8F_E4M3_FNUZ
,
"float8e4m3"
},
{
HIP_R_8F_E5M2_FNUZ
,
"float8e5m2"
},
#if HIP_VERSION >= 60300000
{
HIP_R_8F_E4M3
,
"float8e4m3"
},
{
HIP_R_8F_E5M2
,
"float8e5m2"
},
#endif
};
static
NameMapper
<
hip
blaslt
Data
t
ype
_t
>
typeNameMapper
(
type_name_map
);
static
NameMapper
<
hipData
T
ype
>
typeNameMapper
(
type_name_map
);
static
std
::
unordered_map
<
hipblasOperation_t
,
std
::
string_view
>
trans_name_map
=
{
{
HIPBLAS_OP_N
,
"N"
},
...
...
@@ -614,24 +610,24 @@ static std::unordered_map<hipblasLtEpilogue_t, std::string_view> epi_name_map =
};
static
NameMapper
<
hipblasLtEpilogue_t
>
epilogueNameMapper
(
epi_name_map
);
static
std
::
unordered_map
<
hipblas
Lt
ComputeType_t
,
std
::
string_view
>
comp_name_map
=
{
{
HIPBLAS
LT
_COMPUTE_
F
32
,
"f32"
}
static
std
::
unordered_map
<
hipblasComputeType_t
,
std
::
string_view
>
comp_name_map
=
{
{
HIPBLAS_COMPUTE_32
F
,
"f32"
}
};
static
NameMapper
<
hipblas
Lt
ComputeType_t
>
computeNameMapper
(
comp_name_map
);
static
NameMapper
<
hipblasComputeType_t
>
computeNameMapper
(
comp_name_map
);
static
class
GemmAlgoCache
{
public:
struct
Key
{
int
deviceCap
;
hip
blaslt
Data
t
ype
_t
a_type
,
b_type
,
d_type
,
bias_type
;
hipData
T
ype
a_type
,
b_type
,
d_type
,
bias_type
;
int
m
,
n
,
k
;
int
lda
,
ldb
,
ldd
;
hipblasOperation_t
transa
,
transb
;
hipblasLtEpilogue_t
epilogue
;
Key
(
int
deviceCap_
,
hip
blaslt
Data
t
ype
_t
a_type_
,
hip
blaslt
Data
t
ype
_t
b_type_
,
hip
blaslt
Data
t
ype
_t
d_type_
,
hip
blaslt
Data
t
ype
_t
bias_type_
,
hipData
T
ype
a_type_
,
hipData
T
ype
b_type_
,
hipData
T
ype
d_type_
,
hipData
T
ype
bias_type_
,
int
m_
,
int
n_
,
int
k_
,
int
lda_
,
int
ldb_
,
int
ldd_
,
hipblasOperation_t
transa_
,
hipblasOperation_t
transb_
,
hipblasLtEpilogue_t
epilogue_
)
:
...
...
@@ -865,18 +861,32 @@ protected:
std
::
cout
<<
"[WARNING] Invalid WS size at "
<<
line
<<
"
\n
"
;
continue
;
}
cfg
.
a_type
=
typeNameMapper
.
getValue
(
type_a
,
"type_a"
);
cfg
.
b_type
=
typeNameMapper
.
getValue
(
type_b
,
"type_b"
);
cfg
.
d_type
=
typeNameMapper
.
getValue
(
type_d
,
"type_d"
);
cfg
.
bias_type
=
(
bias_type
==
"-"
)
?
(
hipblasltDatatype_t
)
-
1
:
typeNameMapper
.
getValue
(
bias_type
,
"bias_type"
);
#if HIP_VERSION >= 60300000
auto
fp8_filter
=
te_fp8_fnuz
()
?
[](
const
hipDataType
&
val
)
{
return
(
val
!=
HIP_R_8F_E4M3
&&
val
!=
HIP_R_8F_E5M2
);
}
:
[](
const
hipDataType
&
val
)
{
return
(
val
!=
HIP_R_8F_E4M3_FNUZ
&&
val
!=
HIP_R_8F_E5M2_FNUZ
);
};
#else
auto
fp8_filter
=
nullptr
;
#endif
cfg
.
a_type
=
typeNameMapper
.
getValue
(
type_a
,
"type_a"
,
fp8_filter
);
cfg
.
b_type
=
typeNameMapper
.
getValue
(
type_b
,
"type_b"
,
fp8_filter
);
cfg
.
d_type
=
typeNameMapper
.
getValue
(
type_d
,
"type_d"
,
fp8_filter
);
cfg
.
bias_type
=
(
bias_type
==
"-"
)
?
(
hipDataType
)
-
1
:
typeNameMapper
.
getValue
(
bias_type
,
"bias_type"
,
fp8_filter
);
cfg
.
transa
=
transposeNameMapper
.
getValue
(
trans_a
,
"trans_a"
);
cfg
.
transb
=
transposeNameMapper
.
getValue
(
trans_b
,
"trans_b"
);
cfg
.
epilogue
=
epilogueNameMapper
.
getValue
(
epi
,
"epi"
);
//Check and filter out compute and scale types
if
(
computeNameMapper
.
getValue
(
comp
,
"comp"
)
!=
HIPBLASLT_COMPUTE_F32
||
typeNameMapper
.
getValue
(
scale
,
"scale"
)
!=
HIPBLASLT_R_32F
)
if
(
computeNameMapper
.
getValue
(
comp
,
"comp"
)
!=
HIPBLAS_COMPUTE_32F
||
typeNameMapper
.
getValue
(
scale
,
"scale"
)
!=
HIP_R_32F
)
{
continue
;
}
...
...
@@ -959,9 +969,9 @@ protected:
csv
<<
cfg
.
deviceCap
<<
cfg
.
m
<<
cfg
.
n
<<
cfg
.
k
<<
transposeNameMapper
.
getName
(
cfg
.
transa
)
<<
transposeNameMapper
.
getName
(
cfg
.
transb
)
<<
typeNameMapper
.
getName
(
cfg
.
a_type
)
<<
typeNameMapper
.
getName
(
cfg
.
b_type
)
<<
typeNameMapper
.
getName
(
cfg
.
d_type
)
<<
((
cfg
.
bias_type
==
(
hip
blaslt
Data
t
ype
_t
)
-
1
)
?
"-"
:
typeNameMapper
.
getName
(
cfg
.
bias_type
))
<<
((
cfg
.
bias_type
==
(
hipData
T
ype
)
-
1
)
?
"-"
:
typeNameMapper
.
getName
(
cfg
.
bias_type
))
<<
cfg
.
lda
<<
cfg
.
ldb
<<
cfg
.
ldd
<<
epilogueNameMapper
.
getName
(
cfg
.
epilogue
)
<<
computeNameMapper
.
getName
(
HIPBLAS
LT
_COMPUTE_
F
32
)
<<
typeNameMapper
.
getName
(
HIP
BLASLT
_R_32F
)
<<
computeNameMapper
.
getName
(
HIPBLAS_COMPUTE_32
F
)
<<
typeNameMapper
.
getName
(
HIP_R_32F
)
<<
algo
.
ws_size_min
<<
algo
.
ws_size_max
<<
algo
.
algoId
<<
algo
.
index
<<
csv_helper
::
end
()
<<
"
\n
"
;
}
...
...
@@ -995,6 +1005,19 @@ static inline int getIntEnv(const char *name, int defval, int minval)
}
//namespace
/* Warning: only call once per device!
* When calling nvte_multi_stream_cublas_gemm with hipblaslt backend
* need to create multiple handles corresponding to compute_streams
* to avoid a handle be used by multi-streams concurrently.
*/
static
void
init_hipblaslt_handles
(
hipblasLtHandle_t
*
hipblaslt_handles
)
{
NVTE_CHECK
(
hipblaslt_handles
!=
nullptr
);
for
(
int
i
=
0
;
i
<
num_streams
;
i
++
)
{
NVTE_CHECK_HIPBLASLT
(
hipblasLtCreate
(
&
hipblaslt_handles
[
i
]));
}
}
void
hipblaslt_gemm
(
const
Tensor
*
inputA
,
const
Tensor
*
inputB
,
Tensor
*
outputD
,
...
...
@@ -1014,7 +1037,8 @@ void hipblaslt_gemm(const Tensor *inputA,
int
n_split
,
bool
gemm_producer
,
const
Tensor
*
inputCounter
,
hipStream_t
stream
hipStream_t
stream
,
hipblasLtHandle_t
handle
)
{
void
*
A
=
inputA
->
data
.
dptr
;
void
*
A_scale_inverse
=
inputA
->
scale_inv
.
dptr
;
...
...
@@ -1027,10 +1051,10 @@ void hipblaslt_gemm(const Tensor *inputA,
const
bool
gelu
=
pre_gelu_out
!=
nullptr
;
const
bool
use_fp8
=
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
is_fp8_dtype
(
inputB
->
data
.
dtype
);
const
hip
blaslt
Data
t
ype
_t
A_type
=
get_hipblaslt_dtype
(
inputA
->
data
.
dtype
);
const
hip
blaslt
Data
t
ype
_t
B_type
=
get_hipblaslt_dtype
(
inputB
->
data
.
dtype
);
const
hip
blaslt
Data
t
ype
_t
D_type
=
get_hipblaslt_dtype
(
outputD
->
data
.
dtype
);
const
hip
blaslt
Data
t
ype
_t
bias_type
=
get_hipblaslt_dtype
(
inputBias
->
data
.
dtype
);
const
hipData
T
ype
A_type
=
get_hipblaslt_dtype
(
inputA
->
data
.
dtype
);
const
hipData
T
ype
B_type
=
get_hipblaslt_dtype
(
inputB
->
data
.
dtype
);
const
hipData
T
ype
D_type
=
get_hipblaslt_dtype
(
outputD
->
data
.
dtype
);
const
hipData
T
ype
bias_type
=
get_hipblaslt_dtype
(
inputBias
->
data
.
dtype
);
NVTE_CHECK
(
!
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
A_scale_inverse
!=
nullptr
,
"FP8 input to GEMM requires inverse of scale!"
);
...
...
@@ -1050,10 +1074,12 @@ void hipblaslt_gemm(const Tensor *inputA,
int
device_id
;
NVTE_CHECK_CUDA
(
hipGetDevice
(
&
device_id
));
hipblasLtHandle_t
handle
=
cached_handles
.
get
(
device_id
);
if
(
handle
==
nullptr
)
{
handle
=
cached_handles
.
obtain
(
device_id
);
if
(
handle
==
nullptr
)
{
handle
=
cached_handles
.
get
(
device_id
);
if
(
handle
==
nullptr
)
{
handle
=
cached_handles
.
obtain
(
device_id
);
}
}
hipblasLtMatmulDesc_t
operationDesc
=
nullptr
;
...
...
@@ -1064,7 +1090,7 @@ void hipblaslt_gemm(const Tensor *inputA,
int64_t
ld_gelumat
=
(
int64_t
)
ldd
;
// default to tf32 except for e5m2 inputs where the config is not supported
hipblas
Lt
ComputeType_t
gemm_compute_type
=
HIPBLAS
LT
_COMPUTE_
F
32
;
hipblasComputeType_t
gemm_compute_type
=
HIPBLAS_COMPUTE_32
F
;
// Create matrix descriptors. Not setting any extra attributes.
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutCreate
(
&
Adesc
,
A_type
,
...
...
@@ -1077,7 +1103,7 @@ void hipblaslt_gemm(const Tensor *inputA,
ldb
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutCreate
(
&
Ddesc
,
D_type
,
m
,
n
,
ldd
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescCreate
(
&
operationDesc
,
gemm_compute_type
,
HIP
BLASLT
_R_32F
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescCreate
(
&
operationDesc
,
gemm_compute_type
,
HIP_R_32F
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_TRANSA
,
&
transa
,
sizeof
(
transa
)));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_TRANSB
,
...
...
@@ -1154,7 +1180,7 @@ void hipblaslt_gemm(const Tensor *inputA,
&
epilogue
,
sizeof
(
epilogue
)));
GemmAlgoCache
::
Key
gemm_cfg
(
algoCache
.
device_cap
(
device_id
),
A_type
,
B_type
,
D_type
,
use_fp8
?
bias_type
:
(
hip
blaslt
Data
t
ype
_t
)
-
1
,
use_fp8
?
bias_type
:
(
hipData
T
ype
)
-
1
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
transa
,
transb
,
epilogue
);
GemmAlgoCache
::
Algo
cached_algo
;
if
(
algoCache
.
find
(
gemm_cfg
,
workspaceSize
,
cached_algo
)
==
0
||
!
cached_algo
.
algo
.
has_value
())
...
...
@@ -1231,6 +1257,7 @@ void hipblaslt_gemm(const Tensor *inputA,
<<
" in range ["
<<
firstAlgo
<<
"-"
<<
(
algoTuneCount
-
1
)
<<
"] with "
<<
tuneLoopCount
<<
" loops "
<<
std
::
endl
;
NVTE_CHECK_CUDA
(
hipStreamSynchronize
(
stream
));
hipStream_t
profilingStream
;
NVTE_CHECK_CUDA
(
hipStreamCreateWithFlags
(
&
profilingStream
,
hipStreamNonBlocking
));
using
tuning_clock
=
std
::
chrono
::
steady_clock
;
...
...
@@ -1475,11 +1502,7 @@ void rocblas_gemm(const Tensor *inputA,
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipMalloc
(
&
D_temp
,
sizeof
(
float
)
*
m
*
n
)
);
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipMallocAsync
(
&
D_temp
,
sizeof
(
float
)
*
m
*
n
,
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
}
else
{
D_temp
=
D
;
...
...
@@ -1570,11 +1593,7 @@ void rocblas_gemm(const Tensor *inputA,
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipMalloc
(
&
bias_tmp
,
sizeof
(
float
)
*
input_dim
)
);
// The bias gradient is for the first linear layer
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipMallocAsync
(
&
bias_tmp
,
sizeof
(
float
)
*
input_dim
,
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
}
else
{
bias_tmp
=
bias_ptr
;
...
...
@@ -1600,11 +1619,7 @@ void rocblas_gemm(const Tensor *inputA,
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipFree
(
bias_tmp
)
);
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipFreeAsync
(
bias_tmp
,
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
}
...
...
@@ -1652,11 +1667,7 @@ void rocblas_gemm(const Tensor *inputA,
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipMalloc
(
&
bias_tmp
,
sizeof
(
float
)
*
output_dim
)
);
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipMallocAsync
(
&
bias_tmp
,
sizeof
(
float
)
*
output_dim
,
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
}
else
{
bias_tmp
=
bias_ptr
;
...
...
@@ -1683,11 +1694,7 @@ void rocblas_gemm(const Tensor *inputA,
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipFree
(
bias_tmp
)
);
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipFreeAsync
(
bias_tmp
,
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
}
if
(
D_type
==
rocblas_datatype_f16_r
||
D_type
==
rocblas_datatype_bf16_r
)
{
...
...
@@ -1788,11 +1795,7 @@ void rocblas_gemm(const Tensor *inputA,
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipFree
(
D_temp
)
);
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipFreeAsync
(
D_temp
,
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
}
}
...
...
@@ -1804,7 +1807,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
int
ldb
,
int
ldd
,
bool
transa
,
bool
transb
,
bool
grad
,
void
*
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
const
Tensor
*
inputCounter
,
hipStream_t
stream
,
bool
nvte_use_hipblaslt
=
0
,
bool
nvte_use_rocblas
=
0
)
const
Tensor
*
inputCounter
,
hipStream_t
stream
,
bool
nvte_use_hipblaslt
=
0
,
bool
nvte_use_rocblas
=
0
,
int
compute_stream_offset
=
-
1
)
{
/*If no backend is specified with env variable use HIPBLASLT unless it is disabled
If HIPBLASLT backend is enabled and requested, use it despite ROCBLAS status
...
...
@@ -1845,16 +1848,31 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#endif
#ifdef USE_HIPBLASLT
if
(
use_hipblaslt
)
if
(
use_hipblaslt
||
!
use_rocblas
)
{
// Check compute_stream_offset valid.
NVTE_CHECK
(
compute_stream_offset
>=
-
1
&&
compute_stream_offset
<
num_streams
);
hipblasLtHandle_t
handle
=
nullptr
;
if
(
compute_stream_offset
!=
-
1
)
{
// Init hipblaslt handles (once, globally)
static
std
::
once_flag
init_flag
;
static
hipblasLtHandle_t
hipblaslt_handles
[
num_streams
];
std
::
call_once
(
init_flag
,
init_hipblaslt_handles
,
hipblaslt_handles
);
handle
=
hipblaslt_handles
[
compute_stream_offset
];
}
hipblaslt_gemm
(
inputA
,
inputB
,
outputD
,
inputBias
,
outputPreGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
(
transa
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transb
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
grad
,
workspace
,
workspaceSize
,
accumulate
,
use_split_accumulator
,
math_sm_count
,
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
);
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
(
transa
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transb
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
grad
,
workspace
,
workspaceSize
,
accumulate
,
use_split_accumulator
,
math_sm_count
,
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
,
handle
);
return
;
}
#endif
...
...
transformer_engine/common/include/transformer_engine/gemm.h
View file @
229be5e8
...
...
@@ -42,7 +42,7 @@ extern "C" {
void
nvte_cublas_gemm
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
=
0
,
bool
nvte_use_rocblas
=
0
);
int
math_sm_count
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
=
0
,
bool
nvte_use_rocblas
=
0
,
int
compute_stream_offset
=
-
1
);
/*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters.
*
...
...
@@ -77,7 +77,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
const
NVTETensor
counter
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
=
0
,
bool
nvte_use_rocblas
=
0
);
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
=
0
,
bool
nvte_use_rocblas
=
0
,
int
compute_stream_offset
=
-
1
);
/*! \brief Compute multiple pairs of matrix multiplication, potentially fused with other operations,
* on multiple streams.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment