Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
fbee8990
Commit
fbee8990
authored
Apr 01, 2025
by
yuguo
Browse files
[DCU] fix fp8
parent
57deee08
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
67 additions
and
26 deletions
+67
-26
tests/cpp/CMakeLists.txt
tests/cpp/CMakeLists.txt
+1
-0
tests/cpp/operator/test_cast_current_scaling.cu
tests/cpp/operator/test_cast_current_scaling.cu
+3
-3
tests/cpp/operator/test_cast_mxfp8.cu
tests/cpp/operator/test_cast_mxfp8.cu
+1
-1
tests/cpp/operator/test_cast_transpose_current_scaling.cu
tests/cpp/operator/test_cast_transpose_current_scaling.cu
+3
-3
tests/cpp/operator/test_cublaslt_gemm.cu
tests/cpp/operator/test_cublaslt_gemm.cu
+11
-11
tests/cpp/operator/test_dequantize_mxfp8.cu
tests/cpp/operator/test_dequantize_mxfp8.cu
+1
-1
tests/cpp/operator/test_normalization.cu
tests/cpp/operator/test_normalization.cu
+7
-0
tests/cpp/operator/test_normalization_mxfp8.cu
tests/cpp/operator/test_normalization_mxfp8.cu
+4
-0
tests/cpp/test_common.cu
tests/cpp/test_common.cu
+14
-0
tests/cpp/test_common.h
tests/cpp/test_common.h
+5
-0
tests/pytorch/test_multi_tensor.py
tests/pytorch/test_multi_tensor.py
+2
-2
transformer_engine/common/CMakeLists.txt
transformer_engine/common/CMakeLists.txt
+1
-0
transformer_engine/common/recipe/__init__.py
transformer_engine/common/recipe/__init__.py
+2
-2
transformer_engine/common/recipe/delayed_scaling.cu
transformer_engine/common/recipe/delayed_scaling.cu
+4
-0
transformer_engine/common/utils.cuh
transformer_engine/common/utils.cuh
+4
-0
transformer_engine/pytorch/optimizers/fused_adam.py
transformer_engine/pytorch/optimizers/fused_adam.py
+2
-2
transformer_engine/pytorch/tensor/utils.py
transformer_engine/pytorch/tensor/utils.py
+2
-1
No files found.
tests/cpp/CMakeLists.txt
View file @
fbee8990
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#
#
# See LICENSE for license information.
# See LICENSE for license information.
# CXX=hipcc make build && cd build && cmake ../
cmake_minimum_required
(
VERSION 3.18
)
cmake_minimum_required
(
VERSION 3.18
)
option
(
USE_CUDA
"Use CUDA"
ON
)
option
(
USE_CUDA
"Use CUDA"
ON
)
...
...
tests/cpp/operator/test_cast_current_scaling.cu
View file @
fbee8990
...
@@ -58,7 +58,7 @@ void compute_amax_scale_ref(const InputType *data,
...
@@ -58,7 +58,7 @@ void compute_amax_scale_ref(const InputType *data,
float
scale
=
1.
f
;
float
scale
=
1.
f
;
float
scale_inv
=
1.
f
;
float
scale_inv
=
1.
f
;
if
(
isinf
(
clamp_amax
)
||
clamp_amax
==
0.
f
)
{
if
(
std
::
isinf
(
clamp_amax
)
||
clamp_amax
==
0.
f
)
{
*
scale_ptr
=
scale
;
*
scale_ptr
=
scale
;
*
scale_inv_ptr
=
scale_inv
;
*
scale_inv_ptr
=
scale_inv
;
return
;
return
;
...
@@ -69,11 +69,11 @@ void compute_amax_scale_ref(const InputType *data,
...
@@ -69,11 +69,11 @@ void compute_amax_scale_ref(const InputType *data,
// The amax is too small that the scale becoming infinite in FP32. In other word,
// The amax is too small that the scale becoming infinite in FP32. In other word,
// the scale is not representable in FP32.
// the scale is not representable in FP32.
if
(
isinf
(
scale
))
{
if
(
std
::
isinf
(
scale
))
{
scale
=
std
::
numeric_limits
<
float
>::
max
();
scale
=
std
::
numeric_limits
<
float
>::
max
();
}
}
if
(
isnan
(
scale
))
{
if
(
std
::
isnan
(
scale
))
{
scale
=
1.
f
;
scale
=
1.
f
;
}
}
...
...
tests/cpp/operator/test_cast_mxfp8.cu
View file @
fbee8990
...
@@ -69,7 +69,7 @@ void scale_block(const ProcessingMethod processing_method,
...
@@ -69,7 +69,7 @@ void scale_block(const ProcessingMethod processing_method,
elt
*=
static_cast
<
float
>
(
grad
[
idx
]);
elt
*=
static_cast
<
float
>
(
grad
[
idx
]);
}
}
dbias
[
j
]
+=
elt
;
dbias
[
j
]
+=
elt
;
if
(
isinf
(
elt
)
||
isnan
(
elt
))
{
if
(
std
::
isinf
(
elt
)
||
std
::
isnan
(
elt
))
{
continue
;
continue
;
}
}
amax
=
std
::
max
(
amax
,
std
::
abs
(
elt
));
amax
=
std
::
max
(
amax
,
std
::
abs
(
elt
));
...
...
tests/cpp/operator/test_cast_transpose_current_scaling.cu
View file @
fbee8990
...
@@ -62,7 +62,7 @@ void compute_amax_scale_ref(const InputType *data,
...
@@ -62,7 +62,7 @@ void compute_amax_scale_ref(const InputType *data,
float
scale
=
1.
f
;
float
scale
=
1.
f
;
float
scale_inv
=
1.
f
;
float
scale_inv
=
1.
f
;
if
(
isinf
(
clamp_amax
)
||
clamp_amax
==
0.
f
)
{
if
(
std
::
isinf
(
clamp_amax
)
||
clamp_amax
==
0.
f
)
{
*
scale_ptr
=
scale
;
*
scale_ptr
=
scale
;
*
scale_inv_ptr
=
scale_inv
;
*
scale_inv_ptr
=
scale_inv
;
return
;
return
;
...
@@ -73,11 +73,11 @@ void compute_amax_scale_ref(const InputType *data,
...
@@ -73,11 +73,11 @@ void compute_amax_scale_ref(const InputType *data,
// The amax is too small that the scale becoming infinite in FP32. In other word,
// The amax is too small that the scale becoming infinite in FP32. In other word,
// the scale is not representable in FP32.
// the scale is not representable in FP32.
if
(
isinf
(
scale
))
{
if
(
std
::
isinf
(
scale
))
{
scale
=
std
::
numeric_limits
<
float
>::
max
();
scale
=
std
::
numeric_limits
<
float
>::
max
();
}
}
if
(
isnan
(
scale
))
{
if
(
std
::
isnan
(
scale
))
{
scale
=
1.
f
;
scale
=
1.
f
;
}
}
...
...
tests/cpp/operator/test_cublaslt_gemm.cu
View file @
fbee8990
...
@@ -111,16 +111,16 @@ void performTest(bool use_bias, bool use_gelu, const size_t m, const size_t k, c
...
@@ -111,16 +111,16 @@ void performTest(bool use_bias, bool use_gelu, const size_t m, const size_t k, c
DType
dtype
=
TypeInfo
<
D_Type
>::
dtype
;
DType
dtype
=
TypeInfo
<
D_Type
>::
dtype
;
// pytorch tensor storage is row-major while cublas/rocblas is column-major
// pytorch tensor storage is row-major while cublas/rocblas is column-major
Tensor
A
({
k
,
m
},
atype
);
Tensor
A
(
"A"
,
{
k
,
m
},
atype
);
Tensor
B
({
n
,
k
},
btype
);
Tensor
B
(
"B"
,
{
n
,
k
},
btype
);
Tensor
D
({
n
,
m
},
dtype
);
Tensor
D
(
"D"
,
{
n
,
m
},
dtype
);
Tensor
bias
;
Tensor
bias
;
if
(
use_bias
){
if
(
use_bias
){
bias
=
Tensor
({
m
},
bias_type
);
bias
=
Tensor
(
"bias"
,
{
m
},
bias_type
);
}
}
Tensor
pre_gelu_out
;
Tensor
pre_gelu_out
;
if
(
use_gelu
){
if
(
use_gelu
){
pre_gelu_out
=
Tensor
({
n
,
m
},
gelu_type
);
pre_gelu_out
=
Tensor
(
"pre_gelu_out"
,
{
n
,
m
},
gelu_type
);
}
}
//initialize the data and scale inv of A, B
//initialize the data and scale inv of A, B
...
@@ -149,7 +149,7 @@ void performTest(bool use_bias, bool use_gelu, const size_t m, const size_t k, c
...
@@ -149,7 +149,7 @@ void performTest(bool use_bias, bool use_gelu, const size_t m, const size_t k, c
}
}
#endif
#endif
Tensor
Workspace
({
33554432
},
DType
::
kByte
);
Tensor
Workspace
(
"Workspace"
,
{
33554432
},
DType
::
kByte
);
//perform the gemm in GPU
//perform the gemm in GPU
nvte_cublas_gemm
(
A
.
data
(),
nvte_cublas_gemm
(
A
.
data
(),
...
@@ -180,11 +180,11 @@ void performTest(bool use_bias, bool use_gelu, const size_t m, const size_t k, c
...
@@ -180,11 +180,11 @@ void performTest(bool use_bias, bool use_gelu, const size_t m, const size_t k, c
}
}
float
ref_amax_d
;
float
ref_amax_d
;
compute_ref
<
A_Type
,
B_Type
,
Bias_Type
,
Gelu_Type
,
D_Type
>
(
compute_ref
<
A_Type
,
B_Type
,
Bias_Type
,
Gelu_Type
,
D_Type
>
(
A
.
cpu_dptr
<
A_Type
>
(),
A
.
rowwise_
cpu_dptr
<
A_Type
>
(),
B
.
cpu_dptr
<
B_Type
>
(),
B
.
rowwise_
cpu_dptr
<
B_Type
>
(),
A
.
scale_inv
(),
A
.
rowwise_
scale_inv
(),
B
.
scale_inv
(),
B
.
rowwise_
scale_inv
(),
use_bias
?
bias
.
cpu_dptr
<
Bias_Type
>
()
:
nullptr
,
use_bias
?
bias
.
rowwise_
cpu_dptr
<
Bias_Type
>
()
:
nullptr
,
D
.
scale
(),
D
.
scale
(),
m
,
k
,
n
,
m
,
k
,
n
,
ref_D
.
get
(),
ref_D
.
get
(),
...
...
tests/cpp/operator/test_dequantize_mxfp8.cu
View file @
fbee8990
...
@@ -143,7 +143,7 @@ void generate_data(InputType * const data,
...
@@ -143,7 +143,7 @@ void generate_data(InputType * const data,
if
(
is_negative
)
{
if
(
is_negative
)
{
val
=
-
val
;
val
=
-
val
;
}
}
data
[
idx
]
=
static_cast
<
InputType
>
(
val
);
data
[
idx
]
=
static_cast
<
InputType
>
(
static_cast
<
float
>
(
val
)
)
;
}
}
}
}
}
}
...
...
tests/cpp/operator/test_normalization.cu
View file @
fbee8990
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
#include <random>
#include <random>
#include <cuda_bf16.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
...
@@ -78,11 +79,17 @@ inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const
...
@@ -78,11 +79,17 @@ inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const
}
else
{
}
else
{
if
(
use_cudnn
){
if
(
use_cudnn
){
compute_t
g
=
static_cast
<
compute_t
>
(
0.
f
);
compute_t
g
=
static_cast
<
compute_t
>
(
0.
f
);
#ifndef __HIP_PLATFORM_AMD__
InputType
gi
=
gamma
;
InputType
gi
=
gamma
;
if
(
zero_centered_gamma
)
{
if
(
zero_centered_gamma
)
{
gi
=
gi
+
static_cast
<
InputType
>
(
1.
f
);
gi
=
gi
+
static_cast
<
InputType
>
(
1.
f
);
}
}
g
=
static_cast
<
compute_t
>
(
gi
);
g
=
static_cast
<
compute_t
>
(
gi
);
#else
if
(
zero_centered_gamma
)
{
g
+=
static_cast
<
compute_t
>
(
1.
f
);
}
#endif
return
g
;
return
g
;
}
else
{
}
else
{
compute_t
g
=
static_cast
<
compute_t
>
(
gamma
);
compute_t
g
=
static_cast
<
compute_t
>
(
gamma
);
...
...
tests/cpp/operator/test_normalization_mxfp8.cu
View file @
fbee8990
...
@@ -133,7 +133,11 @@ void compute_ref_stats(NormType norm_type,
...
@@ -133,7 +133,11 @@ void compute_ref_stats(NormType norm_type,
compute_t
current
=
static_cast
<
compute_t
>
(
data
[
i
*
H
+
j
]);
compute_t
current
=
static_cast
<
compute_t
>
(
data
[
i
*
H
+
j
]);
sum_sq
+=
(
current
-
m
)
*
(
current
-
m
);
sum_sq
+=
(
current
-
m
)
*
(
current
-
m
);
}
}
#ifdef __HIP_PLATFORM_AMD__
rsigma
[
i
]
=
1.0
/
sqrtf
((
sum_sq
/
H
)
+
epsilon
);
#else
rsigma
[
i
]
=
rsqrtf
((
sum_sq
/
H
)
+
epsilon
);
rsigma
[
i
]
=
rsqrtf
((
sum_sq
/
H
)
+
epsilon
);
#endif
}
}
}
}
...
...
tests/cpp/test_common.cu
View file @
fbee8990
...
@@ -584,8 +584,13 @@ void compareResults_parallel(const std::string &name, const Tensor &test, const
...
@@ -584,8 +584,13 @@ void compareResults_parallel(const std::string &name, const Tensor &test, const
const
size_t
i
=
getFirstMismatchIdx
<
T
>
(
test
.
dtype
(),
test_data
,
ref_data
,
N
,
atol
,
rtol
);
const
size_t
i
=
getFirstMismatchIdx
<
T
>
(
test
.
dtype
(),
test_data
,
ref_data
,
N
,
atol
,
rtol
);
if
(
i
!=
N
)
{
if
(
i
!=
N
)
{
#ifndef __HIP_PLATFORM_AMD__
const
double
t
=
static_cast
<
double
>
(
test_data
[
i
]);
const
double
t
=
static_cast
<
double
>
(
test_data
[
i
]);
const
double
r
=
static_cast
<
double
>
(
ref_data
[
i
]);
const
double
r
=
static_cast
<
double
>
(
ref_data
[
i
]);
#else
const
double
t
=
static_cast
<
double
>
(
static_cast
<
float
>
(
test_data
[
i
]));
const
double
r
=
static_cast
<
double
>
(
static_cast
<
float
>
(
ref_data
[
i
]));
#endif
std
::
string
direction
=
rowwise
?
"rowwise"
:
"columnwise"
;
std
::
string
direction
=
rowwise
?
"rowwise"
:
"columnwise"
;
ASSERT_FALSE
(
true
)
<<
"Error in tensor "
<<
name
<<
" in "
ASSERT_FALSE
(
true
)
<<
"Error in tensor "
<<
name
<<
" in "
<<
direction
<<
" direction."
<<
std
::
endl
<<
direction
<<
" direction."
<<
std
::
endl
...
@@ -607,8 +612,13 @@ void compareResults(const std::string &name, const Tensor &test, const void *ref
...
@@ -607,8 +612,13 @@ void compareResults(const std::string &name, const Tensor &test, const void *ref
void
compareResults
(
const
std
::
string
&
name
,
const
float
test
,
const
float
ref
,
void
compareResults
(
const
std
::
string
&
name
,
const
float
test
,
const
float
ref
,
double
atol
,
double
rtol
)
{
double
atol
,
double
rtol
)
{
#ifndef __HIP_PLATFORM_AMD__
double
t
=
static_cast
<
double
>
(
test
);
double
t
=
static_cast
<
double
>
(
test
);
double
r
=
static_cast
<
double
>
(
ref
);
double
r
=
static_cast
<
double
>
(
ref
);
#else
double
t
=
static_cast
<
double
>
(
static_cast
<
float
>
(
test
));
double
r
=
static_cast
<
double
>
(
static_cast
<
float
>
(
ref
));
#endif
bool
mismatch
=
fabs
(
t
-
r
)
>
atol
&&
(
r
==
0
||
fabs
((
t
-
r
)
/
r
)
>
rtol
);
bool
mismatch
=
fabs
(
t
-
r
)
>
atol
&&
(
r
==
0
||
fabs
((
t
-
r
)
/
r
)
>
rtol
);
ASSERT_FALSE
(
mismatch
)
<<
"Error in "
<<
name
<<
std
::
endl
ASSERT_FALSE
(
mismatch
)
<<
"Error in "
<<
name
<<
std
::
endl
<<
"Mismatch: "
<<
t
<<
" vs "
<<
r
;
<<
"Mismatch: "
<<
t
<<
" vs "
<<
r
;
...
@@ -692,7 +702,11 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) {
...
@@ -692,7 +702,11 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) {
std
::
uniform_real_distribution
<>
dis
(
-
2.0
,
1.0
);
std
::
uniform_real_distribution
<>
dis
(
-
2.0
,
1.0
);
for
(
int
i
=
idx_min
;
i
<
idx_max
;
++
i
)
{
for
(
int
i
=
idx_min
;
i
<
idx_max
;
++
i
)
{
#ifndef __HIP_PLATFORM_AMD__
data
[
i
]
=
static_cast
<
T
>
(
dis
(
gen_local
));
data
[
i
]
=
static_cast
<
T
>
(
dis
(
gen_local
));
#else
data
[
i
]
=
static_cast
<
T
>
(
static_cast
<
float
>
(
dis
(
gen_local
)));
#endif
}
}
}
}
gen
->
discard
(
size
);
gen
->
discard
(
size
);
...
...
tests/cpp/test_common.h
View file @
fbee8990
...
@@ -61,6 +61,7 @@ using bf16 = nv_bfloat16;
...
@@ -61,6 +61,7 @@ using bf16 = nv_bfloat16;
using
fp8e4m3
=
__nv_fp8_e4m3
;
using
fp8e4m3
=
__nv_fp8_e4m3
;
using
fp8e5m2
=
__nv_fp8_e5m2
;
using
fp8e5m2
=
__nv_fp8_e5m2
;
#else
#else
using
bf16
=
__hip_bfloat16
;
using
fp8e4m3
=
te_hip_fp8_e4m3
;
using
fp8e4m3
=
te_hip_fp8_e4m3
;
using
fp8e5m2
=
te_hip_fp8_e5m2
;
using
fp8e5m2
=
te_hip_fp8_e5m2
;
#endif //USE_ROCM
#endif //USE_ROCM
...
@@ -325,7 +326,11 @@ struct Numeric_Traits<fp8e4m3> {
...
@@ -325,7 +326,11 @@ struct Numeric_Traits<fp8e4m3> {
static
constexpr
double
minSubnorm
=
1.0
/
static_cast
<
double
>
(
1
<<
9
);
// std::pow(2.0, -9.0);
static
constexpr
double
minSubnorm
=
1.0
/
static_cast
<
double
>
(
1
<<
9
);
// std::pow(2.0, -9.0);
static
constexpr
double
maxSubnorm
=
0.875
/
static_cast
<
double
>
(
1
<<
6
);
// std::pow(2.0, -6.0);
static
constexpr
double
maxSubnorm
=
0.875
/
static_cast
<
double
>
(
1
<<
6
);
// std::pow(2.0, -6.0);
static
constexpr
double
minNorm
=
1.0
/
static_cast
<
double
>
(
1
<<
6
);
// std::pow(2.0, -6.0);
static
constexpr
double
minNorm
=
1.0
/
static_cast
<
double
>
(
1
<<
6
);
// std::pow(2.0, -6.0);
#ifndef __HIP_PLATFORM_AMD__
static
constexpr
double
maxNorm
=
448.0
;
static
constexpr
double
maxNorm
=
448.0
;
#else
static
constexpr
double
maxNorm
=
240.0
;
#endif
static
constexpr
double
artifInf
=
10.0
*
maxNorm
;
// artificial Infinity
static
constexpr
double
artifInf
=
10.0
*
maxNorm
;
// artificial Infinity
static
constexpr
int
maxBiasedExponentAsFP32
=
8
+
FP32_EXPONENT_BIAS
;
static
constexpr
int
maxBiasedExponentAsFP32
=
8
+
FP32_EXPONENT_BIAS
;
static
constexpr
int
maxUnbiasedExponentAsFP32
=
8
;
static
constexpr
int
maxUnbiasedExponentAsFP32
=
8
;
...
...
tests/pytorch/test_multi_tensor.py
View file @
fbee8990
...
@@ -10,7 +10,7 @@ import transformer_engine_torch as tex
...
@@ -10,7 +10,7 @@ import transformer_engine_torch as tex
from
transformer_engine.pytorch.optimizers
import
MultiTensorApply
from
transformer_engine.pytorch.optimizers
import
MultiTensorApply
from
references.ref_per_tensor_cs
import
ref_compute_scale_and_scale_inv_from_amax
from
references.ref_per_tensor_cs
import
ref_compute_scale_and_scale_inv_from_amax
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
input_size_pairs
=
[
input_size_pairs
=
[
(
7777
*
77
,
555
*
555
),
(
7777
*
77
,
555
*
555
),
...
@@ -224,7 +224,7 @@ def test_multi_tensor_unscale_l2norm(input_size_pair, applier, repeat, in_type,
...
@@ -224,7 +224,7 @@ def test_multi_tensor_unscale_l2norm(input_size_pair, applier, repeat, in_type,
@
pytest
.
mark
.
parametrize
(
"input_size_pair"
,
input_size_pairs
+
[(
1
,
1
)])
@
pytest
.
mark
.
parametrize
(
"input_size_pair"
,
input_size_pairs
+
[(
1
,
1
)])
@
pytest
.
mark
.
parametrize
(
"applier"
,
appliers
)
@
pytest
.
mark
.
parametrize
(
"applier"
,
appliers
)
@
pytest
.
mark
.
parametrize
(
"repeat"
,
[
1
,
55
])
@
pytest
.
mark
.
parametrize
(
"repeat"
,
[
1
,
55
])
@
pytest
.
mark
.
parametrize
(
"max_fp8"
,
[
448.0
,
57344.0
])
@
pytest
.
mark
.
parametrize
(
"max_fp8"
,
[
448.0
if
not
IS_HIP_EXTENSION
else
240.0
,
57344.0
])
@
pytest
.
mark
.
parametrize
(
"pow_2_scales"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"pow_2_scales"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"epsilon"
,
[
0.0
,
100.0
])
@
pytest
.
mark
.
parametrize
(
"epsilon"
,
[
0.0
,
100.0
])
def
test_multi_tensor_compute_scale_and_scale_inv
(
def
test_multi_tensor_compute_scale_and_scale_inv
(
...
...
transformer_engine/common/CMakeLists.txt
View file @
fbee8990
...
@@ -165,6 +165,7 @@ else()
...
@@ -165,6 +165,7 @@ else()
activation/relu.cu
activation/relu.cu
activation/swiglu.cu
activation/swiglu.cu
gemm/cublaslt_gemm.cu
gemm/cublaslt_gemm.cu
gemm/hipblas_gemm.cu
normalization/common.cpp
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/layernorm/ln_api.cpp
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
...
...
transformer_engine/common/recipe/__init__.py
View file @
fbee8990
...
@@ -8,7 +8,7 @@ import warnings
...
@@ -8,7 +8,7 @@ import warnings
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Literal
,
Optional
,
Union
,
Callable
,
NamedTuple
from
typing
import
Literal
,
Optional
,
Union
,
Callable
,
NamedTuple
from
pydantic.dataclasses
import
dataclass
from
pydantic.dataclasses
import
dataclass
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
class
_FormatHelper
(
NamedTuple
):
class
_FormatHelper
(
NamedTuple
):
"""
"""
...
@@ -34,7 +34,7 @@ class Format(Enum):
...
@@ -34,7 +34,7 @@ class Format(Enum):
FP8 tensors in the backward pass are in e5m2 format
FP8 tensors in the backward pass are in e5m2 format
"""
"""
E4M3
=
_FormatHelper
(
max_fwd
=
448
,
max_bwd
=
448
)
E4M3
=
_FormatHelper
(
max_fwd
=
448
if
not
IS_HIP_EXTENSION
else
240.0
,
max_bwd
=
448
if
not
IS_HIP_EXTENSION
else
240.0
)
E5M2
=
_FormatHelper
(
max_fwd
=
57344
,
max_bwd
=
57344
)
E5M2
=
_FormatHelper
(
max_fwd
=
57344
,
max_bwd
=
57344
)
HYBRID
=
_FormatHelper
(
max_fwd
=
E4M3
.
max_fwd
,
max_bwd
=
E5M2
.
max_bwd
)
HYBRID
=
_FormatHelper
(
max_fwd
=
E4M3
.
max_fwd
,
max_bwd
=
E5M2
.
max_bwd
)
...
...
transformer_engine/common/recipe/delayed_scaling.cu
View file @
fbee8990
...
@@ -36,7 +36,11 @@ const char* dtype_name(DType dtype) {
...
@@ -36,7 +36,11 @@ const char* dtype_name(DType dtype) {
inline
float
fp8_dtype_max
(
DType
dtype
)
{
inline
float
fp8_dtype_max
(
DType
dtype
)
{
switch
(
dtype
)
{
switch
(
dtype
)
{
case
DType
::
kFloat8E4M3
:
case
DType
::
kFloat8E4M3
:
#ifndef __HIP_PLATFORM_AMD__
return
448
;
return
448
;
#else
return
240
;
#endif
case
DType
::
kFloat8E5M2
:
case
DType
::
kFloat8E5M2
:
return
57344
;
return
57344
;
default:
default:
...
...
transformer_engine/common/utils.cuh
View file @
fbee8990
...
@@ -1002,7 +1002,11 @@ struct Numeric_Traits;
...
@@ -1002,7 +1002,11 @@ struct Numeric_Traits;
template
<
>
template
<
>
struct
Numeric_Traits
<
fp8e4m3
>
{
struct
Numeric_Traits
<
fp8e4m3
>
{
static
constexpr
int
maxUnbiasedExponent
=
8
;
static
constexpr
int
maxUnbiasedExponent
=
8
;
#ifndef __HIP_PLATFORM_AMD__
static
constexpr
double
maxNorm
=
448
;
static
constexpr
double
maxNorm
=
448
;
#else
static
constexpr
double
maxNorm
=
240
;
#endif
};
};
template
<
>
template
<
>
...
...
transformer_engine/pytorch/optimizers/fused_adam.py
View file @
fbee8990
...
@@ -14,7 +14,7 @@ import torch
...
@@ -14,7 +14,7 @@ import torch
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Tensor
,
Float8Quantizer
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Tensor
,
Float8Quantizer
from
.multi_tensor_apply
import
multi_tensor_applier
from
.multi_tensor_apply
import
multi_tensor_applier
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
def
get_fp8_meta
(
fp8_tensor
):
def
get_fp8_meta
(
fp8_tensor
):
"""FP8 metadata getter."""
"""FP8 metadata getter."""
...
@@ -197,7 +197,7 @@ class FusedAdam(torch.optim.Optimizer):
...
@@ -197,7 +197,7 @@ class FusedAdam(torch.optim.Optimizer):
torch
.
float16
:
torch
.
full
(
torch
.
float16
:
torch
.
full
(
[
1
],
torch
.
finfo
(
torch
.
float16
).
max
/
2.0
,
dtype
=
torch
.
float32
[
1
],
torch
.
finfo
(
torch
.
float16
).
max
/
2.0
,
dtype
=
torch
.
float32
),
),
torch
.
uint8
:
torch
.
full
([
1
],
448.0
,
dtype
=
torch
.
float32
),
torch
.
uint8
:
torch
.
full
([
1
],
448.0
if
not
IS_HIP_EXTENSION
else
240.0
,
dtype
=
torch
.
float32
),
}
}
self
.
_scales
=
{}
self
.
_scales
=
{}
self
.
use_decoupled_grad
=
use_decoupled_grad
self
.
use_decoupled_grad
=
use_decoupled_grad
...
...
transformer_engine/pytorch/tensor/utils.py
View file @
fbee8990
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
"""Helper functions for using fp8 tensors as weights"""
"""Helper functions for using fp8 tensors as weights"""
import
torch
import
torch
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
multi_tensor_scale
,
multi_tensor_compute_scale_and_scale_inv
from
transformer_engine_torch
import
multi_tensor_scale
,
multi_tensor_compute_scale_and_scale_inv
...
@@ -243,7 +244,7 @@ def _cast_master_weights_to_fp8_current_scaling(params, group):
...
@@ -243,7 +244,7 @@ def _cast_master_weights_to_fp8_current_scaling(params, group):
# Step 3: Update scales and scale_invs.
# Step 3: Update scales and scale_invs.
# ---------------------------------------------------------------------------------------------
# ---------------------------------------------------------------------------------------------
if
fp8_dtype
==
tex
.
DType
.
kFloat8E4M3
:
if
fp8_dtype
==
tex
.
DType
.
kFloat8E4M3
:
max_fp8
=
448.0
max_fp8
=
448.0
if
not
IS_HIP_EXTENSION
else
240.0
elif
fp8_dtype
==
tex
.
DType
.
kFloat8E5M2
:
elif
fp8_dtype
==
tex
.
DType
.
kFloat8E5M2
:
max_fp8
=
57344.0
max_fp8
=
57344.0
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment