Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
5753c5bb
Commit
5753c5bb
authored
May 26, 2025
by
wenjh
Browse files
Merge branch 'develop_v2.3'
parents
f9d870f4
7d0f5b7f
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
37 additions
and
157 deletions
+37
-157
hipify_custom_map.json
hipify_custom_map.json
+6
-2
tests/cpp/operator/test_cast_float8blockwise.cu
tests/cpp/operator/test_cast_float8blockwise.cu
+16
-0
tests/cpp/test_common.h
tests/cpp/test_common.h
+3
-17
tests/pytorch/distributed/test_fusible_ops.py
tests/pytorch/distributed/test_fusible_ops.py
+1
-1
tests/pytorch/test_recipe.py
tests/pytorch/test_recipe.py
+1
-1
transformer_engine/common/CMakeLists.txt
transformer_engine/common/CMakeLists.txt
+0
-4
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
...mer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
+0
-5
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
...ngine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
+0
-57
transformer_engine/common/common.h
transformer_engine/common/common.h
+0
-20
transformer_engine/common/gemm/rocm_gemm.cu
transformer_engine/common/gemm/rocm_gemm.cu
+3
-13
transformer_engine/common/multi_tensor/adam.cu
transformer_engine/common/multi_tensor/adam.cu
+1
-5
transformer_engine/common/normalization/common.h
transformer_engine/common/normalization/common.h
+0
-6
transformer_engine/common/permutation/permutation.cu
transformer_engine/common/permutation/permutation.cu
+0
-2
transformer_engine/common/recipe/__init__.py
transformer_engine/common/recipe/__init__.py
+1
-1
transformer_engine/common/recipe/delayed_scaling.cu
transformer_engine/common/recipe/delayed_scaling.cu
+0
-4
transformer_engine/common/util/rtc.cpp
transformer_engine/common/util/rtc.cpp
+3
-8
transformer_engine/common/utils.cuh
transformer_engine/common/utils.cuh
+0
-9
transformer_engine/pytorch/optimizers/fused_adam.py
transformer_engine/pytorch/optimizers/fused_adam.py
+1
-1
transformer_engine/pytorch/tensor/utils.py
transformer_engine/pytorch/tensor/utils.py
+1
-1
No files found.
hipify_custom_map.json
View file @
5753c5bb
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
"../util/cuda_runtime.h"
:
"../util/hip_runtime.h"
,
"../util/cuda_runtime.h"
:
"../util/hip_runtime.h"
,
"common/util/cuda_driver.h"
:
"common/util/hip_driver.h"
,
"common/util/cuda_driver.h"
:
"common/util/hip_driver.h"
,
"../util/cuda_driver.h"
:
"../util/hip_driver.h"
,
"../util/cuda_driver.h"
:
"../util/hip_driver.h"
,
"./util/cuda_driver.h"
:
"./util/hip_driver.h"
,
"common/util/cuda_nvml.h"
:
"common/util/hip_nvml.h"
,
"common/util/cuda_nvml.h"
:
"common/util/hip_nvml.h"
,
"common/utils.cuh"
:
"common/utils_hip.cuh"
,
"common/utils.cuh"
:
"common/utils_hip.cuh"
,
"common/transpose/cast_transpose.h"
:
"common/transpose/cast_transpose_hip.h"
,
"common/transpose/cast_transpose.h"
:
"common/transpose/cast_transpose_hip.h"
,
...
@@ -15,14 +16,17 @@
...
@@ -15,14 +16,17 @@
"/logging.h"
:
"/logging_hip.h"
,
"/logging.h"
:
"/logging_hip.h"
,
"/system.h"
:
"/system_hip.h"
,
"/system.h"
:
"/system_hip.h"
,
"<cuda_bf16.h>"
:
"<hip/hip_bf16.h>"
,
"<cuda_bf16.h>"
:
"<hip/hip_bf16.h>"
,
"<cuda_fp8.h>"
:
"
\"
amd_detail
/hip_f
loat
8.h
\"
"
,
"<cuda_fp8.h>"
:
"
<hip
/hip_f
p
8.h
>
"
,
"CUfunc_cache"
:
"hipFuncCache_t"
,
"CUfunc_cache"
:
"hipFuncCache_t"
,
"<nvtx3/nvToolsExt.h>"
:
"<roctracer/roctx.h>"
,
"<nvtx3/nvToolsExt.h>"
:
"<roctracer/roctx.h>"
,
"cudaLaunchKernelExC"
:
"hipLaunchKernelExC"
,
"cudaLaunchKernelExC"
:
"hipLaunchKernelExC"
,
"cudaLaunchConfig_t"
:
"hipLaunchConfig_t"
,
"cudaLaunchConfig_t"
:
"hipLaunchConfig_t"
,
"cudaLaunchAttributeClusterDimension"
:
"hipLaunchAttributeClusterDimension"
,
"cudaLaunchAttributeClusterDimension"
:
"hipLaunchAttributeClusterDimension"
,
"cudaLaunchAttributeCooperative"
:
"hipLaunchAttributeCooperative"
,
"cudaLaunchAttributeCooperative"
:
"hipLaunchAttributeCooperative"
,
"cudaLaunchAttribute"
:
"hipLaunchAttribute"
"cudaLaunchAttribute"
:
"hipLaunchAttribute"
,
"__nv_fp8_e4m3"
:
"__hip_fp8_e4m3"
,
"__nv_fp8_e5m2"
:
"__hip_fp8_e5m2"
,
"nv_bfloat16"
:
"__hip_bfloat16"
}
}
}
}
tests/cpp/operator/test_cast_float8blockwise.cu
View file @
5753c5bb
...
@@ -50,7 +50,11 @@ void scales_from_amax(float amax, const QuantizationOptions& opts, float* qscale
...
@@ -50,7 +50,11 @@ void scales_from_amax(float amax, const QuantizationOptions& opts, float* qscale
float
input_type_max_val
=
Quantized_Limits
<
InputType
>::
max
();
float
input_type_max_val
=
Quantized_Limits
<
InputType
>::
max
();
float
quant_type_max_val
=
Quantized_Limits
<
OutputType
>::
max
();
float
quant_type_max_val
=
Quantized_Limits
<
OutputType
>::
max
();
float
eps
=
opts
.
amax_epsilon
;
float
eps
=
opts
.
amax_epsilon
;
#ifdef __HIP_PLATFORM_AMD__
amax
=
amax
>
eps
?
amax
:
eps
;
#else
amax
=
std
::
max
(
amax
,
eps
);
amax
=
std
::
max
(
amax
,
eps
);
#endif
float
qscale
=
quant_type_max_val
/
amax
;
float
qscale
=
quant_type_max_val
/
amax
;
if
(
std
::
isinf
(
qscale
))
{
if
(
std
::
isinf
(
qscale
))
{
qscale
=
input_type_max_val
;
qscale
=
input_type_max_val
;
...
@@ -101,7 +105,11 @@ void ref_quantize(const ProcessingMethod processing_method, const InputType* inp
...
@@ -101,7 +105,11 @@ void ref_quantize(const ProcessingMethod processing_method, const InputType* inp
continue
;
continue
;
}
}
float
val
=
static_cast
<
float
>
(
input
[
y_pos
*
width
+
x_pos
]);
float
val
=
static_cast
<
float
>
(
input
[
y_pos
*
width
+
x_pos
]);
#ifdef __HIP_PLATFORM_AMD__
amax
=
amax
>
std
::
abs
(
val
)
?
amax
:
std
::
abs
(
val
);
#else
amax
=
std
::
max
(
amax
,
std
::
abs
(
val
));
amax
=
std
::
max
(
amax
,
std
::
abs
(
val
));
#endif
}
}
}
}
...
@@ -172,7 +180,11 @@ void ref_quantize_onedimensional_blocks(const ProcessingMethod processing_method
...
@@ -172,7 +180,11 @@ void ref_quantize_onedimensional_blocks(const ProcessingMethod processing_method
continue
;
continue
;
}
}
float
val
=
static_cast
<
float
>
(
input
[
y
*
width
+
x_pos
]);
float
val
=
static_cast
<
float
>
(
input
[
y
*
width
+
x_pos
]);
#ifdef __HIP_PLATFORM_AMD__
amax
=
amax
>
std
::
abs
(
val
)
?
amax
:
std
::
abs
(
val
);
#else
amax
=
std
::
max
(
amax
,
std
::
abs
(
val
));
amax
=
std
::
max
(
amax
,
std
::
abs
(
val
));
#endif
}
}
// We've calculated amax for a tile. Calculate scale and
// We've calculated amax for a tile. Calculate scale and
...
@@ -204,7 +216,11 @@ void ref_quantize_onedimensional_blocks(const ProcessingMethod processing_method
...
@@ -204,7 +216,11 @@ void ref_quantize_onedimensional_blocks(const ProcessingMethod processing_method
continue
;
continue
;
}
}
float
val
=
static_cast
<
float
>
(
input
[
x
+
y_pos
*
width
]);
float
val
=
static_cast
<
float
>
(
input
[
x
+
y_pos
*
width
]);
#ifdef __HIP_PLATFORM_AMD__
amax
=
amax
>
std
::
abs
(
val
)
?
amax
:
std
::
abs
(
val
);
#else
amax
=
std
::
max
(
amax
,
std
::
abs
(
val
));
amax
=
std
::
max
(
amax
,
std
::
abs
(
val
));
#endif
}
}
// We've calculated amax for a tile. Calculate scale and
// We've calculated amax for a tile. Calculate scale and
...
...
tests/cpp/test_common.h
View file @
5753c5bb
...
@@ -12,14 +12,8 @@
...
@@ -12,14 +12,8 @@
#include <random>
#include <random>
#include <cuda_runtime_api.h>
#include <cuda_runtime_api.h>
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_fp8.h>
#else
#include <hip/hip_bf16.h>
#include "amd_detail/hip_float8.h"
#endif
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transformer_engine.h>
...
@@ -57,16 +51,11 @@ using int32 = int32_t;
...
@@ -57,16 +51,11 @@ using int32 = int32_t;
using
int64
=
int64_t
;
using
int64
=
int64_t
;
using
fp32
=
float
;
using
fp32
=
float
;
using
fp16
=
half
;
using
fp16
=
half
;
#ifndef USE_ROCM
using
bf16
=
nv_bfloat16
;
using
bf16
=
nv_bfloat16
;
using
fp8e4m3
=
__nv_fp8_e4m3
;
using
fp8e4m3
=
__nv_fp8_e4m3
;
using
fp8e5m2
=
__nv_fp8_e5m2
;
using
fp8e5m2
=
__nv_fp8_e5m2
;
#else
using
bf16
=
__hip_bfloat16
;
using
fp8e4m3
=
te_hip_fp8_e4m3
;
using
fp8e5m2
=
te_hip_fp8_e5m2
;
#endif //USE_ROCM
using
fp8e8m0
=
uint8_t
;
using
fp8e8m0
=
uint8_t
;
using
int8
=
int8_t
;
template
<
typename
T
>
template
<
typename
T
>
struct
TypeInfo
{
struct
TypeInfo
{
...
@@ -79,7 +68,8 @@ struct TypeInfo{
...
@@ -79,7 +68,8 @@ struct TypeInfo{
bf16
,
bf16
,
fp8e4m3
,
fp8e4m3
,
fp8e5m2
,
fp8e5m2
,
fp8e8m0
>
;
fp8e8m0
,
int8
>
;
template
<
typename
U
,
DType
current
>
template
<
typename
U
,
DType
current
>
struct
Helper
{
struct
Helper
{
...
@@ -326,11 +316,7 @@ struct Numeric_Traits<fp8e4m3> {
...
@@ -326,11 +316,7 @@ struct Numeric_Traits<fp8e4m3> {
static
constexpr
double
minSubnorm
=
1.0
/
static_cast
<
double
>
(
1
<<
9
);
// std::pow(2.0, -9.0);
static
constexpr
double
minSubnorm
=
1.0
/
static_cast
<
double
>
(
1
<<
9
);
// std::pow(2.0, -9.0);
static
constexpr
double
maxSubnorm
=
0.875
/
static_cast
<
double
>
(
1
<<
6
);
// std::pow(2.0, -6.0);
static
constexpr
double
maxSubnorm
=
0.875
/
static_cast
<
double
>
(
1
<<
6
);
// std::pow(2.0, -6.0);
static
constexpr
double
minNorm
=
1.0
/
static_cast
<
double
>
(
1
<<
6
);
// std::pow(2.0, -6.0);
static
constexpr
double
minNorm
=
1.0
/
static_cast
<
double
>
(
1
<<
6
);
// std::pow(2.0, -6.0);
#ifndef __HIP_PLATFORM_AMD__
static
constexpr
double
maxNorm
=
448.0
;
static
constexpr
double
maxNorm
=
448.0
;
#else
static
constexpr
double
maxNorm
=
240.0
;
#endif
static
constexpr
double
artifInf
=
10.0
*
maxNorm
;
// artificial Infinity
static
constexpr
double
artifInf
=
10.0
*
maxNorm
;
// artificial Infinity
static
constexpr
int
maxBiasedExponentAsFP32
=
8
+
FP32_EXPONENT_BIAS
;
static
constexpr
int
maxBiasedExponentAsFP32
=
8
+
FP32_EXPONENT_BIAS
;
static
constexpr
int
maxUnbiasedExponentAsFP32
=
8
;
static
constexpr
int
maxUnbiasedExponentAsFP32
=
8
;
...
...
tests/pytorch/distributed/test_fusible_ops.py
View file @
5753c5bb
...
@@ -687,7 +687,7 @@ def _test_fp8_scale_update(
...
@@ -687,7 +687,7 @@ def _test_fp8_scale_update(
"""Expected absmax and FP8 scale"""
"""Expected absmax and FP8 scale"""
amax
=
ref
.
abs
().
amax
()
amax
=
ref
.
abs
().
amax
()
max_val
=
{
max_val
=
{
"forward"
:
448.0
if
not
IS_HIP_EXTENSION
else
240.0
,
"forward"
:
448.0
,
"backward"
:
57344.0
,
"backward"
:
57344.0
,
}[
stage
]
}[
stage
]
scale
=
(
max_val
/
amax
)
/
(
2
**
margin
)
scale
=
(
max_val
/
amax
)
/
(
2
**
margin
)
...
...
tests/pytorch/test_recipe.py
View file @
5753c5bb
...
@@ -258,7 +258,7 @@ class TestFP8Recipe:
...
@@ -258,7 +258,7 @@ class TestFP8Recipe:
# Compute scale
# Compute scale
max_val
=
{
max_val
=
{
"forward"
:
448.0
if
not
IS_HIP_EXTENSION
else
240.0
,
"forward"
:
448.0
,
"backward"
:
57344.0
,
"backward"
:
57344.0
,
}[
stage
]
}[
stage
]
ref_scale
=
(
max_val
/
ref_amax
)
/
(
2
**
margin
)
ref_scale
=
(
max_val
/
ref_amax
)
/
(
2
**
margin
)
...
...
transformer_engine/common/CMakeLists.txt
View file @
5753c5bb
...
@@ -341,10 +341,6 @@ else()
...
@@ -341,10 +341,6 @@ else()
string_code_transpose_rtc_cast_transpose_cu
)
string_code_transpose_rtc_cast_transpose_cu
)
make_string_header_from_file
(
transpose/rtc/transpose.hip
make_string_header_from_file
(
transpose/rtc/transpose.hip
string_code_transpose_rtc_transpose_cu
)
string_code_transpose_rtc_transpose_cu
)
make_string_header_from_file
(
amd_detail/hip_float8.h
string_code_amd_detail_hip_float8_h
)
make_string_header_from_file
(
amd_detail/hip_f8_impl.h
string_code_amd_detail_hip_f8_impl_h
)
endif
()
endif
()
...
...
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
View file @
5753c5bb
...
@@ -11,7 +11,6 @@
...
@@ -11,7 +11,6 @@
#include <cassert>
#include <cassert>
#include <numeric>
#include <numeric>
#include "amd_detail/hip_float8.h"
#include "common/common.h"
#include "common/common.h"
#include "common/util/cuda_driver.h"
#include "common/util/cuda_driver.h"
#include "common/util/cuda_runtime.h"
#include "common/util/cuda_runtime.h"
...
@@ -330,11 +329,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
...
@@ -330,11 +329,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
assert
(
rs_output
.
size
(
0
)
==
_ubuf
.
size
(
0
)
/
_tp_size
);
assert
(
rs_output
.
size
(
0
)
==
_ubuf
.
size
(
0
)
/
_tp_size
);
assert
(
rs_output
.
element_size
()
==
2
);
assert
(
rs_output
.
element_size
()
==
2
);
char
*
rs_output_ptr
=
reinterpret_cast
<
char
*>
(
rs_output
.
dptr
());
char
*
rs_output_ptr
=
reinterpret_cast
<
char
*>
(
rs_output
.
dptr
());
#ifdef USE_ROCM
reducescatter2_userbuff_fp8
<
te_hip_fp8_e5m2
>
(
rs_output_ptr
,
_ubuf
.
scale_inv
(),
_ub_reg
,
0
,
#else
reducescatter2_userbuff_fp8
<
__nv_fp8_e5m2
>
(
rs_output_ptr
,
_ubuf
.
scale_inv
(),
_ub_reg
,
0
,
reducescatter2_userbuff_fp8
<
__nv_fp8_e5m2
>
(
rs_output_ptr
,
_ubuf
.
scale_inv
(),
_ub_reg
,
0
,
#endif
comm_elements
,
_ub_comm
,
_stream_comm
,
comm_elements
,
_ub_comm
,
_stream_comm
,
(
cudaEvent_t
)
_comm_launch_event
);
(
cudaEvent_t
)
_comm_launch_event
);
}
else
{
}
else
{
...
...
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
View file @
5753c5bb
...
@@ -2033,53 +2033,6 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const
...
@@ -2033,53 +2033,6 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const
}
}
}
}
#ifdef __HIP_PLATFORM_AMD__
template
void
reducescatter2_userbuff_stridedoutput_fp8
<
te_hip_fp8_e5m2
>(
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
rowelements
,
const
int
colelements
,
const
int
strideelements
,
communicator
*
comm
,
cudaStream_t
stream
,
cudaEvent_t
comm_launch_event
);
template
void
reducescatter2_userbuff_stridedoutput_fp8
<
te_hip_fp8_e4m3
>(
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
rowelements
,
const
int
colelements
,
const
int
strideelements
,
communicator
*
comm
,
cudaStream_t
stream
,
cudaEvent_t
comm_launch_event
);
template
<
typename
fp8type
>
void
reducescatter2_userbuff_fp8
(
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
elements
,
communicator
*
comm
,
cudaStream_t
stream
,
cudaEvent_t
comm_launch_event
)
{
reducescatter2_userbuff_stridedoutput_fp8
<
fp8type
>
(
output
,
scale
,
handler
,
offset
,
elements
,
1
,
0
,
comm
,
stream
,
comm_launch_event
);
}
template
void
reducescatter2_userbuff_fp8
<
te_hip_fp8_e5m2
>(
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
elements
,
communicator
*
comm
,
cudaStream_t
stream
,
cudaEvent_t
comm_launch_event
);
template
void
reducescatter2_userbuff_fp8
<
te_hip_fp8_e4m3
>(
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
elements
,
communicator
*
comm
,
cudaStream_t
stream
,
cudaEvent_t
comm_launch_event
);
template
void
reducescatter2_userbuff_strided_atomic_fp8
<
te_hip_fp8_e4m3
>(
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
rowelements
,
const
int
colelements
,
const
int
strideelements_out
,
const
int
strideelements_in
,
const
int
numchunks
,
void
*
counters
,
communicator
*
comm
,
cudaStream_t
stream
);
template
void
reducescatter2_userbuff_strided_atomic_fp8
<
te_hip_fp8_e5m2
>(
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
rowelements
,
const
int
colelements
,
const
int
strideelements_out
,
const
int
strideelements_in
,
const
int
numchunks
,
void
*
counters
,
communicator
*
comm
,
cudaStream_t
stream
);
template
void
reducescatter2_userbuff_strided_multiatomic_fp8
<
te_hip_fp8_e4m3
>(
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
rowelements
,
const
int
colelements
,
const
int
strideelements_out
,
const
int
strideelements_in
,
const
int
numchunks
,
void
*
counters
,
communicator
*
comm
,
cudaStream_t
stream
);
template
void
reducescatter2_userbuff_strided_multiatomic_fp8
<
te_hip_fp8_e5m2
>(
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
rowelements
,
const
int
colelements
,
const
int
strideelements_out
,
const
int
strideelements_in
,
const
int
numchunks
,
void
*
counters
,
communicator
*
comm
,
cudaStream_t
stream
);
#else
template
void
reducescatter2_userbuff_stridedoutput_fp8
<
__nv_fp8_e5m2
>(
template
void
reducescatter2_userbuff_stridedoutput_fp8
<
__nv_fp8_e5m2
>(
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
rowelements
,
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
rowelements
,
const
int
colelements
,
const
int
strideelements
,
communicator
*
comm
,
cudaStream_t
stream
,
const
int
colelements
,
const
int
strideelements
,
communicator
*
comm
,
cudaStream_t
stream
,
...
@@ -2125,7 +2078,6 @@ template void reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e5m2>(
...
@@ -2125,7 +2078,6 @@ template void reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e5m2>(
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
rowelements
,
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
rowelements
,
const
int
colelements
,
const
int
strideelements_out
,
const
int
strideelements_in
,
const
int
colelements
,
const
int
strideelements_out
,
const
int
strideelements_in
,
const
int
numchunks
,
void
*
counters
,
communicator
*
comm
,
cudaStream_t
stream
);
const
int
numchunks
,
void
*
counters
,
communicator
*
comm
,
cudaStream_t
stream
);
#endif
__global__
void
kuserbuffers_pullsend
(
int
myrank
,
int
peer
,
int
*
send_id
,
int
*
flagptr
)
{
__global__
void
kuserbuffers_pullsend
(
int
myrank
,
int
peer
,
int
*
send_id
,
int
*
flagptr
)
{
atomicAdd_system
(
flagptr
,
1
);
atomicAdd_system
(
flagptr
,
1
);
...
@@ -2844,21 +2796,12 @@ void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_in
...
@@ -2844,21 +2796,12 @@ void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_in
num_aligned_elements_per_input
,
tot_input_size
);
num_aligned_elements_per_input
,
tot_input_size
);
}
}
#ifdef __HIP_PLATFORM_AMD__
template
void
reduce_fp8_in_bf16_out
<
te_hip_fp8_e4m3
>(
void
*
inputs
,
void
*
output
,
float
*
scale
,
int
num_inputs
,
int
input_size
,
cudaStream_t
stream
);
template
void
reduce_fp8_in_bf16_out
<
te_hip_fp8_e5m2
>(
void
*
inputs
,
void
*
output
,
float
*
scale
,
int
num_inputs
,
int
input_size
,
cudaStream_t
stream
);
#else
template
void
reduce_fp8_in_bf16_out
<
__nv_fp8_e4m3
>(
void
*
inputs
,
void
*
output
,
float
*
scale
,
template
void
reduce_fp8_in_bf16_out
<
__nv_fp8_e4m3
>(
void
*
inputs
,
void
*
output
,
float
*
scale
,
int
num_inputs
,
int
input_size
,
int
num_inputs
,
int
input_size
,
cudaStream_t
stream
);
cudaStream_t
stream
);
template
void
reduce_fp8_in_bf16_out
<
__nv_fp8_e5m2
>(
void
*
inputs
,
void
*
output
,
float
*
scale
,
template
void
reduce_fp8_in_bf16_out
<
__nv_fp8_e5m2
>(
void
*
inputs
,
void
*
output
,
float
*
scale
,
int
num_inputs
,
int
input_size
,
int
num_inputs
,
int
input_size
,
cudaStream_t
stream
);
cudaStream_t
stream
);
#endif
template
<
int
nvec
>
template
<
int
nvec
>
__global__
void
__launch_bounds__
(
MAX_THREADS
/
4
)
__global__
void
__launch_bounds__
(
MAX_THREADS
/
4
)
...
...
transformer_engine/common/common.h
View file @
5753c5bb
...
@@ -25,11 +25,7 @@
...
@@ -25,11 +25,7 @@
#include <vector>
#include <vector>
#include "./nvtx.h"
#include "./nvtx.h"
#ifdef __HIP_PLATFORM_AMD__
#include "./util/hip_driver.h"
#else
#include "./util/cuda_driver.h"
#include "./util/cuda_driver.h"
#endif
#include "./util/logging.h"
#include "./util/logging.h"
namespace
transformer_engine
{
namespace
transformer_engine
{
...
@@ -257,15 +253,9 @@ using int64 = int64_t;
...
@@ -257,15 +253,9 @@ using int64 = int64_t;
using
fp32
=
float
;
using
fp32
=
float
;
using
fp16
=
half
;
using
fp16
=
half
;
using
int8
=
int8_t
;
using
int8
=
int8_t
;
#ifndef __HIP_PLATFORM_AMD__
using
bf16
=
nv_bfloat16
;
using
bf16
=
nv_bfloat16
;
using
fp8e4m3
=
__nv_fp8_e4m3
;
using
fp8e4m3
=
__nv_fp8_e4m3
;
using
fp8e5m2
=
__nv_fp8_e5m2
;
using
fp8e5m2
=
__nv_fp8_e5m2
;
#else
using
bf16
=
__hip_bfloat16
;
using
fp8e4m3
=
te_hip_fp8_e4m3
;
using
fp8e5m2
=
te_hip_fp8_e5m2
;
#endif
#if CUDA_VERSION >= 12080
#if CUDA_VERSION >= 12080
using
fp8e8m0
=
__nv_fp8_e8m0
;
using
fp8e8m0
=
__nv_fp8_e8m0
;
#endif
#endif
...
@@ -287,15 +277,9 @@ TRANSFORMER_ENGINE_TYPE_NAME(int32_t)
...
@@ -287,15 +277,9 @@ TRANSFORMER_ENGINE_TYPE_NAME(int32_t)
TRANSFORMER_ENGINE_TYPE_NAME
(
int64_t
)
TRANSFORMER_ENGINE_TYPE_NAME
(
int64_t
)
TRANSFORMER_ENGINE_TYPE_NAME
(
float
)
TRANSFORMER_ENGINE_TYPE_NAME
(
float
)
TRANSFORMER_ENGINE_TYPE_NAME
(
half
)
TRANSFORMER_ENGINE_TYPE_NAME
(
half
)
#ifdef __HIP_PLATFORM_AMD__
TRANSFORMER_ENGINE_TYPE_NAME
(
__hip_bfloat16
)
TRANSFORMER_ENGINE_TYPE_NAME
(
te_hip_fp8_e4m3
)
TRANSFORMER_ENGINE_TYPE_NAME
(
te_hip_fp8_e5m2
)
#else
TRANSFORMER_ENGINE_TYPE_NAME
(
nv_bfloat16
)
TRANSFORMER_ENGINE_TYPE_NAME
(
nv_bfloat16
)
TRANSFORMER_ENGINE_TYPE_NAME
(
__nv_fp8_e4m3
)
TRANSFORMER_ENGINE_TYPE_NAME
(
__nv_fp8_e4m3
)
TRANSFORMER_ENGINE_TYPE_NAME
(
__nv_fp8_e5m2
)
TRANSFORMER_ENGINE_TYPE_NAME
(
__nv_fp8_e5m2
)
#endif
#if CUDA_VERSION >= 12080
#if CUDA_VERSION >= 12080
TRANSFORMER_ENGINE_TYPE_NAME
(
__nv_fp8_e8m0
)
TRANSFORMER_ENGINE_TYPE_NAME
(
__nv_fp8_e8m0
)
#endif
#endif
...
@@ -306,11 +290,7 @@ struct TypeExtrema;
...
@@ -306,11 +290,7 @@ struct TypeExtrema;
template
<
>
template
<
>
struct
TypeExtrema
<
fp8e4m3
>
{
struct
TypeExtrema
<
fp8e4m3
>
{
#ifndef __HIP_PLATFORM_AMD__
static
constexpr
float
max
=
448.0
f
;
static
constexpr
float
max
=
448.0
f
;
#else
static
constexpr
float
max
=
240.0
f
;
#endif
};
};
template
<
>
template
<
>
...
...
transformer_engine/common/gemm/rocm_gemm.cu
View file @
5753c5bb
...
@@ -46,17 +46,10 @@ static hipDataType get_hipblaslt_dtype(const transformer_engine::DType t) {
...
@@ -46,17 +46,10 @@ static hipDataType get_hipblaslt_dtype(const transformer_engine::DType t) {
return
HIP_R_32F
;
return
HIP_R_32F
;
case
DType
::
kBFloat16
:
case
DType
::
kBFloat16
:
return
HIP_R_16BF
;
return
HIP_R_16BF
;
#if HIP_VERSION >= 60300000
case
DType
::
kFloat8E4M3
:
case
DType
::
kFloat8E4M3
:
return
te_fp8_fnuz
()
?
HIP_R_8F_E4M3_FNUZ
:
HIP_R_8F_E4M3
;
return
HIP_R_8F_E4M3
;
case
DType
::
kFloat8E5M2
:
case
DType
::
kFloat8E5M2
:
return
te_fp8_fnuz
()
?
HIP_R_8F_E5M2_FNUZ
:
HIP_R_8F_E5M2
;
return
HIP_R_8F_E5M2
;
#else
case
DType
::
kFloat8E4M3
:
return
HIP_R_8F_E4M3_FNUZ
;
case
DType
::
kFloat8E5M2
:
return
HIP_R_8F_E5M2_FNUZ
;
#endif
default:
default:
NVTE_ERROR
(
"Invalid type"
);
NVTE_ERROR
(
"Invalid type"
);
}
}
...
@@ -863,10 +856,7 @@ protected:
...
@@ -863,10 +856,7 @@ protected:
}
}
#if HIP_VERSION >= 60300000
#if HIP_VERSION >= 60300000
auto
fp8_filter
=
te_fp8_fnuz
()
auto
fp8_filter
=
[](
const
hipDataType
&
val
)
{
?
[](
const
hipDataType
&
val
)
{
return
(
val
!=
HIP_R_8F_E4M3
&&
val
!=
HIP_R_8F_E5M2
);
}
:
[](
const
hipDataType
&
val
)
{
return
(
val
!=
HIP_R_8F_E4M3_FNUZ
&&
val
!=
HIP_R_8F_E5M2_FNUZ
);
return
(
val
!=
HIP_R_8F_E4M3_FNUZ
&&
val
!=
HIP_R_8F_E5M2_FNUZ
);
};
};
#else
#else
...
...
transformer_engine/common/multi_tensor/adam.cu
View file @
5753c5bb
...
@@ -25,13 +25,9 @@ typedef enum {
...
@@ -25,13 +25,9 @@ typedef enum {
}
adamMode_t
;
}
adamMode_t
;
using
MATH_T
=
float
;
using
MATH_T
=
float
;
#ifndef __HIP_PLATFORM_AMD__
using
fp8e4m3
=
__nv_fp8_e4m3
;
using
fp8e4m3
=
__nv_fp8_e4m3
;
using
fp8e5m2
=
__nv_fp8_e5m2
;
using
fp8e5m2
=
__nv_fp8_e5m2
;
#else
using
transformer_engine
::
DType
;
using
fp8e4m3
=
te_hip_fp8_e4m3
;
using
fp8e5m2
=
te_hip_fp8_e5m2
;
#endif
template
<
typename
T
>
template
<
typename
T
>
struct
is_fp8
:
std
::
false_type
{};
struct
is_fp8
:
std
::
false_type
{};
...
...
transformer_engine/common/normalization/common.h
View file @
5753c5bb
...
@@ -329,15 +329,9 @@ using int32 = int32_t;
...
@@ -329,15 +329,9 @@ using int32 = int32_t;
using
fp32
=
float
;
using
fp32
=
float
;
using
fp16
=
half
;
using
fp16
=
half
;
using
int8
=
int8_t
;
using
int8
=
int8_t
;
#ifndef __HIP_PLATFORM_AMD__
using
bf16
=
nv_bfloat16
;
using
bf16
=
nv_bfloat16
;
using
fp8e4m3
=
__nv_fp8_e4m3
;
using
fp8e4m3
=
__nv_fp8_e4m3
;
using
fp8e5m2
=
__nv_fp8_e5m2
;
using
fp8e5m2
=
__nv_fp8_e5m2
;
#else
using
bf16
=
__hip_bfloat16
;
using
fp8e4m3
=
te_hip_fp8_e4m3
;
using
fp8e5m2
=
te_hip_fp8_e5m2
;
#endif
template
<
typename
T
>
template
<
typename
T
>
struct
TypeToDType
;
struct
TypeToDType
;
...
...
transformer_engine/common/permutation/permutation.cu
View file @
5753c5bb
...
@@ -11,8 +11,6 @@
...
@@ -11,8 +11,6 @@
#include "../common.h"
#include "../common.h"
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
using
__nv_fp8_e4m3
=
te_hip_fp8_e4m3
;
using
__nv_fp8_e5m2
=
te_hip_fp8_e5m2
;
#define __ldlu(x) __ldg(x)
#define __ldlu(x) __ldg(x)
#endif
#endif
...
...
transformer_engine/common/recipe/__init__.py
View file @
5753c5bb
...
@@ -35,7 +35,7 @@ class Format(Enum):
...
@@ -35,7 +35,7 @@ class Format(Enum):
FP8 tensors in the backward pass are in e5m2 format
FP8 tensors in the backward pass are in e5m2 format
"""
"""
E4M3
=
_FormatHelper
(
max_fwd
=
448
if
not
IS_HIP_EXTENSION
else
240.0
,
max_bwd
=
448
if
not
IS_HIP_EXTENSION
else
240.0
)
E4M3
=
_FormatHelper
(
max_fwd
=
448
,
max_bwd
=
448
)
E5M2
=
_FormatHelper
(
max_fwd
=
57344
,
max_bwd
=
57344
)
E5M2
=
_FormatHelper
(
max_fwd
=
57344
,
max_bwd
=
57344
)
HYBRID
=
_FormatHelper
(
max_fwd
=
E4M3
.
max_fwd
,
max_bwd
=
E5M2
.
max_bwd
)
HYBRID
=
_FormatHelper
(
max_fwd
=
E4M3
.
max_fwd
,
max_bwd
=
E5M2
.
max_bwd
)
...
...
transformer_engine/common/recipe/delayed_scaling.cu
View file @
5753c5bb
...
@@ -32,11 +32,7 @@ const char* dtype_name(DType dtype) {
...
@@ -32,11 +32,7 @@ const char* dtype_name(DType dtype) {
inline
float
fp8_dtype_max
(
DType
dtype
)
{
inline
float
fp8_dtype_max
(
DType
dtype
)
{
switch
(
dtype
)
{
switch
(
dtype
)
{
case
DType
::
kFloat8E4M3
:
case
DType
::
kFloat8E4M3
:
#ifndef __HIP_PLATFORM_AMD__
return
448
;
return
448
;
#else
return
240
;
#endif
case
DType
::
kFloat8E5M2
:
case
DType
::
kFloat8E5M2
:
return
57344
;
return
57344
;
default:
default:
...
...
transformer_engine/common/util/rtc.cpp
View file @
5753c5bb
...
@@ -29,11 +29,6 @@ namespace {
...
@@ -29,11 +29,6 @@ namespace {
#include "string_code_util_math_h.h"
#include "string_code_util_math_h.h"
#include "string_code_utils_cuh.h"
#include "string_code_utils_cuh.h"
#ifdef USE_ROCM
#include "string_code_amd_detail_hip_float8_h.h"
#include "string_code_amd_detail_hip_f8_impl_h.h"
#endif // USE_ROCM
#ifndef USE_ROCM
#ifndef USE_ROCM
/*! \brief Latest compute capability that NVRTC supports
/*! \brief Latest compute capability that NVRTC supports
*
*
...
@@ -187,9 +182,9 @@ void KernelManager::compile(const std::string& kernel_label, const std::string&
...
@@ -187,9 +182,9 @@ void KernelManager::compile(const std::string& kernel_label, const std::string&
// Compile source
// Compile source
nvrtcProgram
program
;
nvrtcProgram
program
;
#ifdef USE_ROCM
#ifdef USE_ROCM
constexpr
int
num_headers
=
4
;
constexpr
int
num_headers
=
2
;
const
char
*
headers
[
num_headers
]
=
{
string_code_utils_cuh
,
string_code_util_math_h
,
string_code_amd_detail_hip_float8_h
,
string_code_amd_detail_hip_f8_impl_h
};
const
char
*
headers
[
num_headers
]
=
{
string_code_utils_cuh
,
string_code_util_math_h
};
const
char
*
include_names
[
num_headers
]
=
{
"utils_hip.cuh"
,
"util/math.h"
,
"amd_detail/hip_float8.h"
,
"amd_detail/hip_f8_impl.h"
};
const
char
*
include_names
[
num_headers
]
=
{
"utils_hip.cuh"
,
"util/math.h"
};
#else
#else
constexpr
int
num_headers
=
2
;
constexpr
int
num_headers
=
2
;
constexpr
const
char
*
headers
[
num_headers
]
=
{
string_code_utils_cuh
,
string_code_util_math_h
};
constexpr
const
char
*
headers
[
num_headers
]
=
{
string_code_utils_cuh
,
string_code_util_math_h
};
...
...
transformer_engine/common/utils.cuh
View file @
5753c5bb
...
@@ -982,13 +982,8 @@ __device__ __forceinline__ void reciprocal<float>(float *value_inv, const float
...
@@ -982,13 +982,8 @@ __device__ __forceinline__ void reciprocal<float>(float *value_inv, const float
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifndef __HIP_PLATFORM_AMD__
using
fp8e4m3
=
__nv_fp8_e4m3
;
using
fp8e4m3
=
__nv_fp8_e4m3
;
using
fp8e5m2
=
__nv_fp8_e5m2
;
using
fp8e5m2
=
__nv_fp8_e5m2
;
#else
using
fp8e4m3
=
te_hip_fp8_e4m3
;
using
fp8e5m2
=
te_hip_fp8_e5m2
;
#endif
using
e8m0_t
=
uint8_t
;
using
e8m0_t
=
uint8_t
;
using
int8
=
int8_t
;
using
int8
=
int8_t
;
...
@@ -1003,11 +998,7 @@ struct Numeric_Traits;
...
@@ -1003,11 +998,7 @@ struct Numeric_Traits;
template
<
>
template
<
>
struct
Numeric_Traits
<
fp8e4m3
>
{
struct
Numeric_Traits
<
fp8e4m3
>
{
static
constexpr
int
maxUnbiasedExponent
=
8
;
static
constexpr
int
maxUnbiasedExponent
=
8
;
#ifndef __HIP_PLATFORM_AMD__
static
constexpr
double
maxNorm
=
448
;
static
constexpr
double
maxNorm
=
448
;
#else
static
constexpr
double
maxNorm
=
240
;
#endif
};
};
template
<
>
template
<
>
...
...
transformer_engine/pytorch/optimizers/fused_adam.py
View file @
5753c5bb
...
@@ -197,7 +197,7 @@ class FusedAdam(torch.optim.Optimizer):
...
@@ -197,7 +197,7 @@ class FusedAdam(torch.optim.Optimizer):
torch
.
float16
:
torch
.
full
(
torch
.
float16
:
torch
.
full
(
[
1
],
torch
.
finfo
(
torch
.
float16
).
max
/
2.0
,
dtype
=
torch
.
float32
[
1
],
torch
.
finfo
(
torch
.
float16
).
max
/
2.0
,
dtype
=
torch
.
float32
),
),
torch
.
uint8
:
torch
.
full
([
1
],
448.0
if
not
IS_HIP_EXTENSION
else
240.0
,
dtype
=
torch
.
float32
),
torch
.
uint8
:
torch
.
full
([
1
],
448.0
,
dtype
=
torch
.
float32
),
}
}
self
.
_scales
=
{}
self
.
_scales
=
{}
self
.
use_decoupled_grad
=
use_decoupled_grad
self
.
use_decoupled_grad
=
use_decoupled_grad
...
...
transformer_engine/pytorch/tensor/utils.py
View file @
5753c5bb
...
@@ -281,7 +281,7 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo
...
@@ -281,7 +281,7 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo
# Step 3: Update scales and scale_invs.
# Step 3: Update scales and scale_invs.
# ---------------------------------------------------------------------------------------------
# ---------------------------------------------------------------------------------------------
if
fp8_dtype
==
tex
.
DType
.
kFloat8E4M3
:
if
fp8_dtype
==
tex
.
DType
.
kFloat8E4M3
:
max_fp8
=
448.0
if
not
IS_HIP_EXTENSION
else
240.0
max_fp8
=
448.0
elif
fp8_dtype
==
tex
.
DType
.
kFloat8E5M2
:
elif
fp8_dtype
==
tex
.
DType
.
kFloat8E5M2
:
max_fp8
=
57344.0
max_fp8
=
57344.0
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment