Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
ea272d4a
Commit
ea272d4a
authored
Mar 20, 2025
by
yuguo
Browse files
[DCU] support for ROCm FP8 FNUZ and OCP formats
parent
a248abb6
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
178 additions
and
89 deletions
+178
-89
tests/cpp/test_common.h
tests/cpp/test_common.h
+3
-4
transformer_engine/common/amd_detail/hip_float8.h
transformer_engine/common/amd_detail/hip_float8.h
+108
-4
transformer_engine/common/common.h
transformer_engine/common/common.h
+4
-4
transformer_engine/common/gemm/rocm_gemm.cu
transformer_engine/common/gemm/rocm_gemm.cu
+61
-75
transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu
...pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu
+2
-2
No files found.
tests/cpp/test_common.h
View file @
ea272d4a
...
...
@@ -61,10 +61,9 @@ using bf16 = nv_bfloat16;
using
fp8e4m3
=
__nv_fp8_e4m3
;
using
fp8e5m2
=
__nv_fp8_e5m2
;
#else
using
bf16
=
__hip_bfloat16
;
using
fp8e4m3
=
hip_f8
<
hip_f8_type
::
fp8
>
;
using
fp8e5m2
=
hip_f8
<
hip_f8_type
::
bf8
>
;
#endif
using
fp8e4m3
=
te_hip_fp8_e4m3
;
using
fp8e5m2
=
te_hip_fp8_e5m2
;
#endif //USE_ROCM
using
fp8e8m0
=
uint8_t
;
template
<
typename
T
>
...
...
transformer_engine/common/amd_detail/hip_float8.h
View file @
ea272d4a
...
...
@@ -4,10 +4,108 @@
* License for AMD contributions = MIT. See LICENSE for more information
************************************************************************/
#pragma once
// FP8 header version 0.3, 2021/05/11
#ifdef __HIPCC__
#include <hip/hip_runtime.h>
#if HIP_VERSION >= 60200000
#include <hip/hip_fp8.h>
#if HIP_VERSION >= 60300000
#if !defined(__HIP_DEVICE_COMPILE__)
#include <optional>
#include "../util/string.h"
/* Platforms that have both MI300 family and other families GPUs are unknown and not supported.
* Thus, FP8 format is selected once by the current (any) GPU architecture.
*/
static
bool
_te_check_fp8_fnuz
()
{
hipDeviceProp_t
prop
;
hipError_t
res
=
hipGetDeviceProperties
(
&
prop
,
0
);
if
(
res
!=
hipSuccess
)
{
//TODO: better error out system
throw
std
::
runtime_error
(
transformer_engine
::
concat_strings
(
"hipGetDeviceProperties failed with error: "
,
hipGetErrorString
(
res
)));
}
return
prop
.
major
==
9
&&
prop
.
minor
==
4
;
}
static
inline
bool
te_fp8_fnuz
()
{
static
std
::
optional
<
bool
>
use_fnuz
;
if
(
!
use_fnuz
.
has_value
())
{
use_fnuz
=
_te_check_fp8_fnuz
();
}
return
use_fnuz
.
value
();
}
/* Device methods in _te_hip_fp8 are dummy and are needed for compilation
* because HIPCC compiles __device__ and __global__ functions for host.
* The results are discarded so those methods are declared but not defined
*/
template
<
typename
FNUZ
,
typename
OCP
>
union
_te_hip_fp8
{
FNUZ
fnuz
;
OCP
ocp
;
__host__
__device__
_te_hip_fp8
<
FNUZ
,
OCP
>
()
=
default
;
__host__
operator
float
()
const
{
return
te_fp8_fnuz
()
?
fnuz
.
operator
float
()
:
ocp
.
operator
float
();
}
__device__
operator
float
()
const
;
__host__
_te_hip_fp8
<
FNUZ
,
OCP
>
(
const
float
&
v
)
{
if
(
te_fp8_fnuz
())
fnuz
=
v
;
else
ocp
=
v
;
}
__device__
_te_hip_fp8
<
FNUZ
,
OCP
>
(
const
float
&
v
);
};
typedef
_te_hip_fp8
<
__hip_fp8_e4m3_fnuz
,
__hip_fp8_e4m3
>
_te_hip_fp8_e4m3
;
typedef
_te_hip_fp8
<
__hip_fp8_e5m2_fnuz
,
__hip_fp8_e5m2
>
_te_hip_fp8_e5m2
;
#elif HIP_FP8_TYPE_FNUZ
typedef
__hip_fp8_e4m3_fnuz
_te_hip_fp8_e4m3
;
typedef
__hip_fp8_e5m2_fnuz
_te_hip_fp8_e5m2
;
static
inline
bool
te_fp8_fnuz
()
{
return
true
;
}
#elif HIP_FP8_TYPE_OCP
typedef
__hip_fp8_e4m3
_te_hip_fp8_e4m3
;
typedef
__hip_fp8_e5m2
_te_hip_fp8_e5m2
;
static
inline
bool
te_fp8_fnuz
()
{
return
false
;
}
#else
#error "Unsupported HIP_FP8_TYPE"
#endif //__HIP_DEVICE_COMPILE__
#else //HIP_VERSION >= 60300000
typedef
__hip_fp8_e4m3_fnuz
_te_hip_fp8_e4m3
;
typedef
__hip_fp8_e5m2_fnuz
_te_hip_fp8_e5m2
;
#endif //HIP_VERSION >= 60300000
struct
te_hip_fp8_e4m3
{
_te_hip_fp8_e4m3
data
;
__host__
__device__
te_hip_fp8_e4m3
()
=
default
;
__host__
__device__
operator
float
()
const
{
return
data
.
operator
float
();
}
__host__
__device__
te_hip_fp8_e4m3
(
const
float
&
v
)
{
data
=
v
;}
};
static_assert
(
sizeof
(
te_hip_fp8_e4m3
)
==
1
,
"Size mismatch"
);
union
te_hip_fp8_e5m2
{
_te_hip_fp8_e5m2
data
;
__host__
__device__
te_hip_fp8_e5m2
()
=
default
;
__host__
__device__
operator
float
()
const
{
return
data
.
operator
float
();
}
__host__
__device__
te_hip_fp8_e5m2
(
const
float
&
v
)
{
data
=
v
;
}
};
static_assert
(
sizeof
(
te_hip_fp8_e5m2
)
==
1
,
"Size mismatch"
);
#else //HIP_VERSION >= 60200000
// FP8 header version 0.3, 2021/05/11
#define HIP_HOST_DEVICE __host__ __device__
#define HIP_DEVICE __device__
#define HIP_HOST __host__
...
...
@@ -69,7 +167,6 @@ static inline __host__ bool get_hip_f8_bias_mode() {
}
#endif // __HIPCC_RTC__
#ifdef __HIPCC__
static
__device__
bool
hip_f8_bias_mode_bit_device
=
true
;
static
inline
__device__
bool
get_hip_f8_bias_mode
()
{
...
...
@@ -91,7 +188,6 @@ static void set_hip_f8_bias_mode_optimal() {
hip_f8_bias_mode_bit_host
=
true
;
}
#endif // __HIPCC_RTC__
#endif // __HIPCC__
template
<
hip_f8_type
T
>
...
...
@@ -376,7 +472,6 @@ struct hip_f8 {
}
};
#ifdef __HIPCC__
template
<
hip_f8_type
T
>
struct
hip_f8x4
{
...
...
@@ -455,4 +550,13 @@ __device__ hip_float32x4 mfma_f32_16x16x32(hip_f8x8<T_A> a, hip_f8x8<T_B> b, hip
template
<
hip_f8_type
T_A
,
hip_f8_type
T_B
>
__device__
hip_float32x16
mfma_f32_32x32x16
(
hip_f8x8
<
T_A
>
a
,
hip_f8x8
<
T_B
>
b
,
hip_float32x16
c
);
typedef
hip_f8
<
hip_f8_type
::
fp8
>
te_hip_fp8_e4m3
;
typedef
hip_f8
<
hip_f8_type
::
bf8
>
te_hip_fp8_e5m2
;
#endif //HIP_VERSION >= 60200000
#else //__HIPCC__
typedef
struct
{
char
storage
;}
te_hip_fp8_e4m3
;
typedef
struct
{
char
storage
;}
te_hip_fp8_e5m2
;
#endif //__HIPCC__
\ No newline at end of file
transformer_engine/common/common.h
View file @
ea272d4a
...
...
@@ -224,8 +224,8 @@ using fp8e4m3 = __nv_fp8_e4m3;
using
fp8e5m2
=
__nv_fp8_e5m2
;
#else
using
bf16
=
hip_bfloat16
;
using
fp8e4m3
=
hip_f
8
<
hip_f8_type
::
fp8
>
;
using
fp8e5m2
=
hip_f
8
<
hip_f8_type
::
bf8
>
;
using
fp8e4m3
=
te_
hip_f
p8_e4m3
;
using
fp8e5m2
=
te_
hip_f
p8_e5m2
;
#endif
#if CUDA_VERSION >= 12080
using
fp8e8m0
=
__nv_fp8_e8m0
;
...
...
@@ -248,8 +248,8 @@ TRANSFORMER_ENGINE_TYPE_NAME(float)
TRANSFORMER_ENGINE_TYPE_NAME
(
half
)
#ifdef __HIP_PLATFORM_AMD__
TRANSFORMER_ENGINE_TYPE_NAME
(
hip_bfloat16
)
TRANSFORMER_ENGINE_TYPE_NAME
(
hip_f
8
<
hip_f8_type
::
fp8
>
)
TRANSFORMER_ENGINE_TYPE_NAME
(
hip_f
8
<
hip_f8_type
::
bf8
>
)
TRANSFORMER_ENGINE_TYPE_NAME
(
te_
hip_f
p8_e4m3
)
TRANSFORMER_ENGINE_TYPE_NAME
(
te_
hip_f
p8_e5m2
)
#else
TRANSFORMER_ENGINE_TYPE_NAME
(
nv_bfloat16
)
TRANSFORMER_ENGINE_TYPE_NAME
(
__nv_fp8_e4m3
)
...
...
transformer_engine/common/gemm/rocm_gemm.cu
View file @
ea272d4a
...
...
@@ -36,30 +36,26 @@ namespace {
#ifdef USE_HIPBLASLT
#if HIP_VERSION >= 60000000
typedef
hipDataType
hipblasltDatatype_t
;
typedef
hipblasComputeType_t
hipblasLtComputeType_t
;
#define HIPBLASLT_R_16F HIP_R_16F
#define HIPBLASLT_R_32F HIP_R_32F
#define HIPBLASLT_R_16B HIP_R_16BF
#define HIPBLASLT_R_8F_E4M3 HIP_R_8F_E4M3_FNUZ
#define HIPBLASLT_R_8F_E5M2 HIP_R_8F_E5M2_FNUZ
#define HIPBLASLT_COMPUTE_F32 HIPBLAS_COMPUTE_32F
#endif // #if HIP_VERSION >= 60000000
hipblasltDatatype_t
get_hipblaslt_dtype
(
const
transformer_engine
::
DType
t
)
{
static
hipDataType
get_hipblaslt_dtype
(
const
transformer_engine
::
DType
t
)
{
using
namespace
transformer_engine
;
switch
(
t
)
{
case
DType
::
kFloat16
:
return
HIP
BLASLT
_R_16F
;
return
HIP_R_16F
;
case
DType
::
kFloat32
:
return
HIP
BLASLT
_R_32F
;
return
HIP_R_32F
;
case
DType
::
kBFloat16
:
return
HIPBLASLT_R_16B
;
return
HIP_R_16BF
;
#if HIP_VERSION >= 60300000
case
DType
::
kFloat8E4M3
:
return
HIPBLASLT
_R_8F_E4M3
;
return
te_fp8_fnuz
()
?
HIP_R_8F_E4M3_FNUZ
:
HIP
_R_8F_E4M3
;
case
DType
::
kFloat8E5M2
:
return
HIPBLASLT_R_8F_E5M2
;
return
te_fp8_fnuz
()
?
HIP_R_8F_E5M2_FNUZ
:
HIP_R_8F_E5M2
;
#else
case
DType
::
kFloat8E4M3
:
return
HIP_R_8F_E4M3_FNUZ
;
case
DType
::
kFloat8E5M2
:
return
HIP_R_8F_E5M2_FNUZ
;
#endif
default:
NVTE_ERROR
(
"Invalid type"
);
}
...
...
@@ -367,11 +363,7 @@ void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipMemset
(
out
,
0
,
n
*
sizeof
(
float
))
);
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipMemsetAsync
(
out
,
0
,
n
*
sizeof
(
float
),
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
hipLaunchKernelGGL
((
bias_gradient_kernel
<
Tin
,
THREADS_PER_BLOCK
>
),
dim3
(
grid
),
dim3
(
block
),
0
,
stream
,
in
,
out
,
m
,
n
);
}
...
...
@@ -575,11 +567,11 @@ public:
const
std
::
string_view
&
getName
(
const
T
&
val
)
{
return
map
.
at
(
val
);
}
T
getValue
(
const
std
::
string
&
name
,
const
char
*
label
=
""
)
T
getValue
(
const
std
::
string
&
name
,
const
char
*
label
=
""
,
std
::
function
<
bool
(
const
T
&
)
>
filter
=
nullptr
)
{
for
(
auto
iter
=
map
.
begin
();
iter
!=
map
.
end
();
++
iter
)
{
if
(
name
==
iter
->
second
)
return
iter
->
first
;
if
(
(
name
==
iter
->
second
)
&&
(
!
filter
||
filter
(
iter
->
first
)))
return
iter
->
first
;
}
NVTE_ERROR
(
"Invalid "
,
label
,
" name: "
,
name
);
}
...
...
@@ -587,14 +579,18 @@ protected:
const
std
::
unordered_map
<
T
,
std
::
string_view
>
&
map
;
};
static
std
::
unordered_map
<
hipblasltDatatype_t
,
std
::
string_view
>
type_name_map
=
{
{
HIPBLASLT_R_32F
,
"float32"
},
{
HIPBLASLT_R_16F
,
"float16"
},
{
HIPBLASLT_R_16B
,
"bfloat16"
},
{
HIPBLASLT_R_8F_E4M3
,
"float8e4m3"
},
{
HIPBLASLT_R_8F_E5M2
,
"float8e5m2"
},
static
std
::
unordered_map
<
hipDataType
,
std
::
string_view
>
type_name_map
=
{
{
HIP_R_32F
,
"float32"
},
{
HIP_R_16F
,
"float16"
},
{
HIP_R_16BF
,
"bfloat16"
},
{
HIP_R_8F_E4M3_FNUZ
,
"float8e4m3"
},
{
HIP_R_8F_E5M2_FNUZ
,
"float8e5m2"
},
#if HIP_VERSION >= 60300000
{
HIP_R_8F_E4M3
,
"float8e4m3"
},
{
HIP_R_8F_E5M2
,
"float8e5m2"
},
#endif
};
static
NameMapper
<
hip
blaslt
Data
t
ype
_t
>
typeNameMapper
(
type_name_map
);
static
NameMapper
<
hipData
T
ype
>
typeNameMapper
(
type_name_map
);
static
std
::
unordered_map
<
hipblasOperation_t
,
std
::
string_view
>
trans_name_map
=
{
{
HIPBLAS_OP_N
,
"N"
},
...
...
@@ -613,24 +609,24 @@ static std::unordered_map<hipblasLtEpilogue_t, std::string_view> epi_name_map =
};
static
NameMapper
<
hipblasLtEpilogue_t
>
epilogueNameMapper
(
epi_name_map
);
static
std
::
unordered_map
<
hipblas
Lt
ComputeType_t
,
std
::
string_view
>
comp_name_map
=
{
{
HIPBLAS
LT
_COMPUTE_
F
32
,
"f32"
}
static
std
::
unordered_map
<
hipblasComputeType_t
,
std
::
string_view
>
comp_name_map
=
{
{
HIPBLAS_COMPUTE_32
F
,
"f32"
}
};
static
NameMapper
<
hipblas
Lt
ComputeType_t
>
computeNameMapper
(
comp_name_map
);
static
NameMapper
<
hipblasComputeType_t
>
computeNameMapper
(
comp_name_map
);
static
class
GemmAlgoCache
{
public:
struct
Key
{
int
deviceCap
;
hip
blaslt
Data
t
ype
_t
a_type
,
b_type
,
d_type
,
bias_type
;
hipData
T
ype
a_type
,
b_type
,
d_type
,
bias_type
;
int
m
,
n
,
k
;
int
lda
,
ldb
,
ldd
;
hipblasOperation_t
transa
,
transb
;
hipblasLtEpilogue_t
epilogue
;
Key
(
int
deviceCap_
,
hip
blaslt
Data
t
ype
_t
a_type_
,
hip
blaslt
Data
t
ype
_t
b_type_
,
hip
blaslt
Data
t
ype
_t
d_type_
,
hip
blaslt
Data
t
ype
_t
bias_type_
,
hipData
T
ype
a_type_
,
hipData
T
ype
b_type_
,
hipData
T
ype
d_type_
,
hipData
T
ype
bias_type_
,
int
m_
,
int
n_
,
int
k_
,
int
lda_
,
int
ldb_
,
int
ldd_
,
hipblasOperation_t
transa_
,
hipblasOperation_t
transb_
,
hipblasLtEpilogue_t
epilogue_
)
:
...
...
@@ -865,17 +861,31 @@ protected:
continue
;
}
cfg
.
a_type
=
typeNameMapper
.
getValue
(
type_a
,
"type_a"
);
cfg
.
b_type
=
typeNameMapper
.
getValue
(
type_b
,
"type_b"
);
cfg
.
d_type
=
typeNameMapper
.
getValue
(
type_d
,
"type_d"
);
cfg
.
bias_type
=
(
bias_type
==
"-"
)
?
(
hipblasltDatatype_t
)
-
1
:
typeNameMapper
.
getValue
(
bias_type
,
"bias_type"
);
#if HIP_VERSION >= 60300000
auto
fp8_filter
=
te_fp8_fnuz
()
?
[](
const
hipDataType
&
val
)
{
return
(
val
!=
HIP_R_8F_E4M3
&&
val
!=
HIP_R_8F_E5M2
);
}
:
[](
const
hipDataType
&
val
)
{
return
(
val
!=
HIP_R_8F_E4M3_FNUZ
&&
val
!=
HIP_R_8F_E5M2_FNUZ
);
};
#else
auto
fp8_filter
=
nullptr
;
#endif
cfg
.
a_type
=
typeNameMapper
.
getValue
(
type_a
,
"type_a"
,
fp8_filter
);
cfg
.
b_type
=
typeNameMapper
.
getValue
(
type_b
,
"type_b"
,
fp8_filter
);
cfg
.
d_type
=
typeNameMapper
.
getValue
(
type_d
,
"type_d"
,
fp8_filter
);
cfg
.
bias_type
=
(
bias_type
==
"-"
)
?
(
hipDataType
)
-
1
:
typeNameMapper
.
getValue
(
bias_type
,
"bias_type"
,
fp8_filter
);
cfg
.
transa
=
transposeNameMapper
.
getValue
(
trans_a
,
"trans_a"
);
cfg
.
transb
=
transposeNameMapper
.
getValue
(
trans_b
,
"trans_b"
);
cfg
.
epilogue
=
epilogueNameMapper
.
getValue
(
epi
,
"epi"
);
//Check and filter out compute and scale types
if
(
computeNameMapper
.
getValue
(
comp
,
"comp"
)
!=
HIPBLASLT_COMPUTE_F32
||
typeNameMapper
.
getValue
(
scale
,
"scale"
)
!=
HIPBLASLT_R_32F
)
if
(
computeNameMapper
.
getValue
(
comp
,
"comp"
)
!=
HIPBLAS_COMPUTE_32F
||
typeNameMapper
.
getValue
(
scale
,
"scale"
)
!=
HIP_R_32F
)
{
continue
;
}
...
...
@@ -958,9 +968,9 @@ protected:
csv
<<
cfg
.
deviceCap
<<
cfg
.
m
<<
cfg
.
n
<<
cfg
.
k
<<
transposeNameMapper
.
getName
(
cfg
.
transa
)
<<
transposeNameMapper
.
getName
(
cfg
.
transb
)
<<
typeNameMapper
.
getName
(
cfg
.
a_type
)
<<
typeNameMapper
.
getName
(
cfg
.
b_type
)
<<
typeNameMapper
.
getName
(
cfg
.
d_type
)
<<
((
cfg
.
bias_type
==
(
hip
blaslt
Data
t
ype
_t
)
-
1
)
?
"-"
:
typeNameMapper
.
getName
(
cfg
.
bias_type
))
<<
((
cfg
.
bias_type
==
(
hipData
T
ype
)
-
1
)
?
"-"
:
typeNameMapper
.
getName
(
cfg
.
bias_type
))
<<
cfg
.
lda
<<
cfg
.
ldb
<<
cfg
.
ldd
<<
epilogueNameMapper
.
getName
(
cfg
.
epilogue
)
<<
computeNameMapper
.
getName
(
HIPBLAS
LT
_COMPUTE_
F
32
)
<<
typeNameMapper
.
getName
(
HIP
BLASLT
_R_32F
)
<<
computeNameMapper
.
getName
(
HIPBLAS_COMPUTE_32
F
)
<<
typeNameMapper
.
getName
(
HIP_R_32F
)
<<
algo
.
ws_size_min
<<
algo
.
ws_size_max
<<
algo
.
algoId
<<
algo
.
index
<<
csv_helper
::
end
()
<<
"
\n
"
;
}
...
...
@@ -1026,10 +1036,10 @@ void hipblaslt_gemm(const Tensor *inputA,
const
bool
gelu
=
pre_gelu_out
!=
nullptr
;
const
bool
use_fp8
=
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
is_fp8_dtype
(
inputB
->
data
.
dtype
);
const
hip
blaslt
Data
t
ype
_t
A_type
=
get_hipblaslt_dtype
(
inputA
->
data
.
dtype
);
const
hip
blaslt
Data
t
ype
_t
B_type
=
get_hipblaslt_dtype
(
inputB
->
data
.
dtype
);
const
hip
blaslt
Data
t
ype
_t
D_type
=
get_hipblaslt_dtype
(
outputD
->
data
.
dtype
);
const
hip
blaslt
Data
t
ype
_t
bias_type
=
get_hipblaslt_dtype
(
inputBias
->
data
.
dtype
);
const
hipData
T
ype
A_type
=
get_hipblaslt_dtype
(
inputA
->
data
.
dtype
);
const
hipData
T
ype
B_type
=
get_hipblaslt_dtype
(
inputB
->
data
.
dtype
);
const
hipData
T
ype
D_type
=
get_hipblaslt_dtype
(
outputD
->
data
.
dtype
);
const
hipData
T
ype
bias_type
=
get_hipblaslt_dtype
(
inputBias
->
data
.
dtype
);
NVTE_CHECK
(
!
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
A_scale_inverse
!=
nullptr
,
"FP8 input to GEMM requires inverse of scale!"
);
...
...
@@ -1063,7 +1073,7 @@ void hipblaslt_gemm(const Tensor *inputA,
int64_t
ld_gelumat
=
(
int64_t
)
ldd
;
// default to tf32 except for e5m2 inputs where the config is not supported
hipblas
Lt
ComputeType_t
gemm_compute_type
=
HIPBLAS
LT
_COMPUTE_
F
32
;
hipblasComputeType_t
gemm_compute_type
=
HIPBLAS_COMPUTE_32
F
;
// Create matrix descriptors. Not setting any extra attributes.
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutCreate
(
&
Adesc
,
A_type
,
...
...
@@ -1076,7 +1086,7 @@ void hipblaslt_gemm(const Tensor *inputA,
ldb
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutCreate
(
&
Ddesc
,
D_type
,
m
,
n
,
ldd
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescCreate
(
&
operationDesc
,
gemm_compute_type
,
HIP
BLASLT
_R_32F
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescCreate
(
&
operationDesc
,
gemm_compute_type
,
HIP_R_32F
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_TRANSA
,
&
transa
,
sizeof
(
transa
)));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_TRANSB
,
...
...
@@ -1153,7 +1163,7 @@ void hipblaslt_gemm(const Tensor *inputA,
&
epilogue
,
sizeof
(
epilogue
)));
GemmAlgoCache
::
Key
gemm_cfg
(
algoCache
.
device_cap
(
device_id
),
A_type
,
B_type
,
D_type
,
use_fp8
?
bias_type
:
(
hip
blaslt
Data
t
ype
_t
)
-
1
,
use_fp8
?
bias_type
:
(
hipData
T
ype
)
-
1
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
transa
,
transb
,
epilogue
);
GemmAlgoCache
::
Algo
cached_algo
;
if
(
algoCache
.
find
(
gemm_cfg
,
workspaceSize
,
cached_algo
)
==
0
||
!
cached_algo
.
algo
.
has_value
())
...
...
@@ -1468,11 +1478,7 @@ void rocblas_gemm(const Tensor *inputA,
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipMalloc
(
&
D_temp
,
sizeof
(
float
)
*
m
*
n
)
);
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipMallocAsync
(
&
D_temp
,
sizeof
(
float
)
*
m
*
n
,
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
}
else
{
D_temp
=
D
;
...
...
@@ -1565,11 +1571,7 @@ void rocblas_gemm(const Tensor *inputA,
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipMalloc
(
&
bias_tmp
,
sizeof
(
float
)
*
input_dim
)
);
// The bias gradient is for the first linear layer
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipMallocAsync
(
&
bias_tmp
,
sizeof
(
float
)
*
input_dim
,
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
}
else
{
bias_tmp
=
bias_ptr
;
...
...
@@ -1595,11 +1597,7 @@ void rocblas_gemm(const Tensor *inputA,
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipFree
(
bias_tmp
)
);
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipFreeAsync
(
bias_tmp
,
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
}
...
...
@@ -1647,11 +1645,7 @@ void rocblas_gemm(const Tensor *inputA,
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipMalloc
(
&
bias_tmp
,
sizeof
(
float
)
*
output_dim
)
);
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipMallocAsync
(
&
bias_tmp
,
sizeof
(
float
)
*
output_dim
,
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
}
else
{
bias_tmp
=
bias_ptr
;
...
...
@@ -1678,11 +1672,7 @@ void rocblas_gemm(const Tensor *inputA,
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipFree
(
bias_tmp
)
);
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipFreeAsync
(
bias_tmp
,
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
}
if
(
D_type
==
rocblas_datatype_f16_r
||
D_type
==
rocblas_datatype_bf16_r
)
{
...
...
@@ -1783,11 +1773,7 @@ void rocblas_gemm(const Tensor *inputA,
if
(
!
stream_order_alloc
){
NVTE_CHECK_CUDA
(
hipFree
(
D_temp
)
);
}
else
{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA
(
hipFreeAsync
(
D_temp
,
stream
)
);
#else
NVTE_ERROR
(
"Stream order allocation is supported on ROCm 5.3 and above."
);
#endif
}
}
}
...
...
transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu
View file @
ea272d4a
...
...
@@ -36,8 +36,8 @@ using MATH_T = float;
using
fp8e4m3
=
__nv_fp8_e4m3
;
using
fp8e5m2
=
__nv_fp8_e5m2
;
#else
using
fp8e4m3
=
hip_f
8
<
hip_f8_type
::
fp8
>
;
using
fp8e5m2
=
hip_f
8
<
hip_f8_type
::
bf8
>
;
using
fp8e4m3
=
te_
hip_f
p8_e4m3
;
using
fp8e5m2
=
te_
hip_f
p8_e5m2
;
#endif
using
transformer_engine
::
DType
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment