Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
f9d870f4
Commit
f9d870f4
authored
May 23, 2025
by
yuguo
Browse files
Merge branch 'develop_v2.3' of
http://10.16.6.30/dcutoolkit/deeplearing/TransformerEngine
parents
7405fe09
80c5079c
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
67 additions
and
14 deletions
+67
-14
tests/pytorch/references/blockwise_quantizer_reference.py
tests/pytorch/references/blockwise_quantizer_reference.py
+9
-2
tests/pytorch/references/quantize_scale_calc.py
tests/pytorch/references/quantize_scale_calc.py
+4
-1
tests/pytorch/test_float8_blockwise_scaling_exact.py
tests/pytorch/test_float8_blockwise_scaling_exact.py
+3
-3
transformer_engine/common/common.h
transformer_engine/common/common.h
+27
-1
transformer_engine/common/include/transformer_engine/transformer_engine.h
...ne/common/include/transformer_engine/transformer_engine.h
+1
-0
transformer_engine/common/normalization/common.h
transformer_engine/common/normalization/common.h
+5
-0
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
...e/common/transpose/quantize_transpose_square_blockwise.cu
+1
-1
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
...e/common/transpose/quantize_transpose_vector_blockwise.cu
+5
-5
transformer_engine/common/util/pybind_helper.h
transformer_engine/common/util/pybind_helper.h
+2
-1
transformer_engine/common/utils.cuh
transformer_engine/common/utils.cuh
+7
-0
transformer_engine/pytorch/constants.py
transformer_engine/pytorch/constants.py
+1
-0
transformer_engine/pytorch/fp8.py
transformer_engine/pytorch/fp8.py
+2
-0
No files found.
tests/pytorch/references/blockwise_quantizer_reference.py
View file @
f9d870f4
...
@@ -116,7 +116,10 @@ class BlockwiseQuantizerReference:
...
@@ -116,7 +116,10 @@ class BlockwiseQuantizerReference:
.
reshape
(
M
//
tile_len
,
K
//
tile_len
,
tile_len
**
2
)
.
reshape
(
M
//
tile_len
,
K
//
tile_len
,
tile_len
**
2
)
.
amax
(
dim
=-
1
)
.
amax
(
dim
=-
1
)
).
float
()
).
float
()
dtype_max
=
torch
.
finfo
(
quant_dtype
).
max
if
quant_dtype
==
torch
.
int8
:
dtype_max
=
torch
.
iinfo
(
quant_dtype
).
max
else
:
dtype_max
=
torch
.
finfo
(
quant_dtype
).
max
scale
,
scale_inv
,
_
=
scale_from_amax_tensor
(
scale
,
scale_inv
,
_
=
scale_from_amax_tensor
(
x_dtype
=
x
.
dtype
,
x_dtype
=
x
.
dtype
,
...
@@ -152,7 +155,10 @@ class BlockwiseQuantizerReference:
...
@@ -152,7 +155,10 @@ class BlockwiseQuantizerReference:
eps
:
float
,
eps
:
float
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
M
,
K
=
x
.
shape
M
,
K
=
x
.
shape
dtype_max
=
torch
.
finfo
(
quant_dtype
).
max
if
quant_dtype
==
torch
.
int8
:
dtype_max
=
torch
.
iinfo
(
quant_dtype
).
max
else
:
dtype_max
=
torch
.
finfo
(
quant_dtype
).
max
x_tiled
=
x
.
reshape
(
M
,
K
//
tile_len
,
tile_len
)
x_tiled
=
x
.
reshape
(
M
,
K
//
tile_len
,
tile_len
)
amax_grid
=
torch
.
abs
(
x_tiled
).
amax
(
dim
=-
1
).
float
()
amax_grid
=
torch
.
abs
(
x_tiled
).
amax
(
dim
=-
1
).
float
()
scale
,
scale_inv
,
_
=
scale_from_amax_tensor
(
scale
,
scale_inv
,
_
=
scale_from_amax_tensor
(
...
@@ -272,6 +278,7 @@ class BlockwiseQuantizerReference:
...
@@ -272,6 +278,7 @@ class BlockwiseQuantizerReference:
assert
quant_dtype
in
(
assert
quant_dtype
in
(
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
,
torch
.
float8_e5m2
,
torch
.
int8
,
),
"Unsupported quant dtype."
),
"Unsupported quant dtype."
assert
quant_tile_shape
in
((
1
,
128
),
(
128
,
128
))
assert
quant_tile_shape
in
((
1
,
128
),
(
128
,
128
))
...
...
tests/pytorch/references/quantize_scale_calc.py
View file @
f9d870f4
...
@@ -24,7 +24,10 @@ def scale_from_amax_tensor(
...
@@ -24,7 +24,10 @@ def scale_from_amax_tensor(
- amax: Amax tensor with updates made for extrema values.
- amax: Amax tensor with updates made for extrema values.
"""
"""
assert
amax
.
dtype
==
torch
.
float
,
"amax must be a float tensor."
assert
amax
.
dtype
==
torch
.
float
,
"amax must be a float tensor."
fp8_max
=
torch
.
finfo
(
quant_dtype
).
max
if
quant_dtype
==
torch
.
int8
:
fp8_max
=
torch
.
iinfo
(
quant_dtype
).
max
else
:
fp8_max
=
torch
.
finfo
(
quant_dtype
).
max
# Clamping amax to avoid division by small numbers
# Clamping amax to avoid division by small numbers
amax
=
torch
.
max
(
amax
,
torch
.
tensor
(
eps
))
amax
=
torch
.
max
(
amax
,
torch
.
tensor
(
eps
))
...
...
tests/pytorch/test_float8_blockwise_scaling_exact.py
View file @
f9d870f4
...
@@ -208,7 +208,7 @@ def check_quantization_block_tiling_versus_reference(
...
@@ -208,7 +208,7 @@ def check_quantization_block_tiling_versus_reference(
],
],
)
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"quant_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"quant_dtype"
,
[
torch
.
int8
,
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"eps"
,
[
0
],
ids
=
[
"eps_0"
])
@
pytest
.
mark
.
parametrize
(
"eps"
,
[
0
],
ids
=
[
"eps_0"
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"return_transpose"
,
[
True
,
False
],
ids
=
[
"quantize_transpose"
,
"quantize_only"
]
"return_transpose"
,
[
True
,
False
],
ids
=
[
"quantize_transpose"
,
"quantize_only"
]
...
@@ -243,7 +243,7 @@ def test_quantization_block_tiling_versus_reference(
...
@@ -243,7 +243,7 @@ def test_quantization_block_tiling_versus_reference(
],
],
)
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"quant_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"quant_dtype"
,
[
torch
.
int8
,
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"eps"
,
[
0
],
ids
=
[
"eps_0"
])
@
pytest
.
mark
.
parametrize
(
"eps"
,
[
0
],
ids
=
[
"eps_0"
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"return_transpose"
,
[
True
,
False
],
ids
=
[
"quantize_transpose"
,
"quantize_only"
]
"return_transpose"
,
[
True
,
False
],
ids
=
[
"quantize_transpose"
,
"quantize_only"
]
...
@@ -274,7 +274,7 @@ def test_quantization_block_tiling_versus_reference_fp32_scales(
...
@@ -274,7 +274,7 @@ def test_quantization_block_tiling_versus_reference_fp32_scales(
],
],
)
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"quant_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"quant_dtype"
,
[
torch
.
int8
,
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"eps"
,
[
0
],
ids
=
[
"eps_0"
])
@
pytest
.
mark
.
parametrize
(
"eps"
,
[
0
],
ids
=
[
"eps_0"
])
@
pytest
.
mark
.
parametrize
(
"pow_2_scales"
,
[
True
,
False
],
ids
=
[
"pow2scales"
,
"fp32scales"
])
@
pytest
.
mark
.
parametrize
(
"pow_2_scales"
,
[
True
,
False
],
ids
=
[
"pow2scales"
,
"fp32scales"
])
@
pytest
.
mark
.
parametrize
(
"tile_size"
,
[(
128
,
128
)])
@
pytest
.
mark
.
parametrize
(
"tile_size"
,
[(
128
,
128
)])
...
...
transformer_engine/common/common.h
View file @
f9d870f4
...
@@ -256,6 +256,7 @@ using int32 = int32_t;
...
@@ -256,6 +256,7 @@ using int32 = int32_t;
using
int64
=
int64_t
;
using
int64
=
int64_t
;
using
fp32
=
float
;
using
fp32
=
float
;
using
fp16
=
half
;
using
fp16
=
half
;
using
int8
=
int8_t
;
#ifndef __HIP_PLATFORM_AMD__
#ifndef __HIP_PLATFORM_AMD__
using
bf16
=
nv_bfloat16
;
using
bf16
=
nv_bfloat16
;
using
fp8e4m3
=
__nv_fp8_e4m3
;
using
fp8e4m3
=
__nv_fp8_e4m3
;
...
@@ -269,6 +270,7 @@ using fp8e5m2 = te_hip_fp8_e5m2;
...
@@ -269,6 +270,7 @@ using fp8e5m2 = te_hip_fp8_e5m2;
using
fp8e8m0
=
__nv_fp8_e8m0
;
using
fp8e8m0
=
__nv_fp8_e8m0
;
#endif
#endif
using
e8m0_t
=
uint8_t
;
using
e8m0_t
=
uint8_t
;
using
int8
=
int8_t
;
namespace
detail
{
namespace
detail
{
...
@@ -311,6 +313,11 @@ struct TypeExtrema<fp8e4m3> {
...
@@ -311,6 +313,11 @@ struct TypeExtrema<fp8e4m3> {
#endif
#endif
};
};
template
<
>
struct
TypeExtrema
<
int8
>
{
static
constexpr
float
max
=
127.0
f
;
};
template
<
>
template
<
>
struct
TypeExtrema
<
fp8e5m2
>
{
struct
TypeExtrema
<
fp8e5m2
>
{
static
constexpr
float
max
=
57344.0
f
;
static
constexpr
float
max
=
57344.0
f
;
...
@@ -337,7 +344,7 @@ struct TypeExtrema {
...
@@ -337,7 +344,7 @@ struct TypeExtrema {
template
<
typename
T
>
template
<
typename
T
>
struct
TypeInfo
{
struct
TypeInfo
{
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
>
;
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
int8
>
;
template
<
typename
U
,
DType
current
>
template
<
typename
U
,
DType
current
>
struct
Helper
{
struct
Helper
{
...
@@ -502,6 +509,25 @@ struct TypeInfo {
...
@@ -502,6 +509,25 @@ struct TypeInfo {
NVTE_ERROR("Invalid type."); \
NVTE_ERROR("Invalid type."); \
}
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat8E5M2: { \
using type = fp8e5m2; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat8E4M3: { \
using type = fp8e4m3; \
{ __VA_ARGS__ } \
} break; \
case DType::kInt8: { \
using type = int8; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \
#define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \
switch (dtype) { \
switch (dtype) { \
using namespace transformer_engine; \
using namespace transformer_engine; \
...
...
transformer_engine/common/include/transformer_engine/transformer_engine.h
View file @
f9d870f4
...
@@ -383,6 +383,7 @@ enum class DType {
...
@@ -383,6 +383,7 @@ enum class DType {
kFloat8E4M3
=
7
,
kFloat8E4M3
=
7
,
kFloat8E5M2
=
8
,
kFloat8E5M2
=
8
,
kFloat8E8M0
=
9
,
kFloat8E8M0
=
9
,
kInt8
=
10
,
kNumTypes
kNumTypes
};
};
...
...
transformer_engine/common/normalization/common.h
View file @
f9d870f4
...
@@ -328,6 +328,7 @@ using byte = uint8_t;
...
@@ -328,6 +328,7 @@ using byte = uint8_t;
using
int32
=
int32_t
;
using
int32
=
int32_t
;
using
fp32
=
float
;
using
fp32
=
float
;
using
fp16
=
half
;
using
fp16
=
half
;
using
int8
=
int8_t
;
#ifndef __HIP_PLATFORM_AMD__
#ifndef __HIP_PLATFORM_AMD__
using
bf16
=
nv_bfloat16
;
using
bf16
=
nv_bfloat16
;
using
fp8e4m3
=
__nv_fp8_e4m3
;
using
fp8e4m3
=
__nv_fp8_e4m3
;
...
@@ -358,6 +359,10 @@ struct TypeToDType<fp8e4m3> {
...
@@ -358,6 +359,10 @@ struct TypeToDType<fp8e4m3> {
static
constexpr
DType
value
=
DType
::
kFloat8E4M3
;
static
constexpr
DType
value
=
DType
::
kFloat8E4M3
;
};
};
template
<
>
template
<
>
struct
TypeToDType
<
int8
>
{
static
constexpr
DType
value
=
DType
::
kInt8
;
};
template
<
>
struct
TypeToDType
<
fp8e5m2
>
{
struct
TypeToDType
<
fp8e5m2
>
{
static
constexpr
DType
value
=
DType
::
kFloat8E5M2
;
static
constexpr
DType
value
=
DType
::
kFloat8E5M2
;
};
};
...
...
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
View file @
f9d870f4
...
@@ -533,7 +533,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
...
@@ -533,7 +533,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
input
.
dtype
,
InputType
,
input
.
dtype
,
InputType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_
FP8ONLY
(
TRANSFORMER_ENGINE_TYPE_SWITCH_
8BIT
(
output
.
dtype
,
OutputType
,
output
.
dtype
,
OutputType
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
...
...
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
View file @
f9d870f4
...
@@ -257,7 +257,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
...
@@ -257,7 +257,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
#pragma unroll
#pragma unroll
for
(
int
delta
=
kNumThreadsStore
/
2
;
delta
>
0
;
delta
/=
2
)
{
for
(
int
delta
=
kNumThreadsStore
/
2
;
delta
>
0
;
delta
/=
2
)
{
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
const
float
other_amax
=
__shfl_down
(
amax
,
delta
);
const
float
other_amax
=
__shfl_down
(
amax
,
delta
,
kThreadsPerWarp
);
#else
#else
const
float
other_amax
=
__shfl_down_sync
(
mask
,
amax
,
delta
);
const
float
other_amax
=
__shfl_down_sync
(
mask
,
amax
,
delta
);
#endif
#endif
...
@@ -266,7 +266,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
...
@@ -266,7 +266,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
amax
=
fmaxf
(
amax
,
other_amax
);
amax
=
fmaxf
(
amax
,
other_amax
);
}
}
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
amax
=
__shfl
(
amax
,
src_lane
);
amax
=
__shfl
(
amax
,
src_lane
,
kThreadsPerWarp
);
#else
#else
amax
=
__shfl_sync
(
mask
,
amax
,
src_lane
);
amax
=
__shfl_sync
(
mask
,
amax
,
src_lane
);
#endif
#endif
...
@@ -354,7 +354,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
...
@@ -354,7 +354,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
#pragma unroll
#pragma unroll
for
(
int
delta
=
kNumThreadsStore
/
2
;
delta
>
0
;
delta
/=
2
)
{
for
(
int
delta
=
kNumThreadsStore
/
2
;
delta
>
0
;
delta
/=
2
)
{
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
const
float
other_amax
=
__shfl_down
(
amax
,
delta
);
const
float
other_amax
=
__shfl_down
(
amax
,
delta
,
kThreadsPerWarp
);
#else
#else
const
float
other_amax
=
__shfl_down_sync
(
mask
,
amax
,
delta
);
const
float
other_amax
=
__shfl_down_sync
(
mask
,
amax
,
delta
);
#endif
#endif
...
@@ -363,7 +363,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
...
@@ -363,7 +363,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
amax
=
fmaxf
(
amax
,
other_amax
);
amax
=
fmaxf
(
amax
,
other_amax
);
}
}
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
amax
=
__shfl
(
amax
,
src_lane
);
amax
=
__shfl
(
amax
,
src_lane
,
kThreadsPerWarp
);
#else
#else
amax
=
__shfl_sync
(
mask
,
amax
,
src_lane
);
amax
=
__shfl_sync
(
mask
,
amax
,
src_lane
);
#endif
#endif
...
@@ -479,7 +479,7 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
...
@@ -479,7 +479,7 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
input
.
dtype
,
InputType
,
input
.
dtype
,
InputType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_
FP8ONLY
(
TRANSFORMER_ENGINE_TYPE_SWITCH_
8BIT
(
output
.
dtype
,
OutputType
,
output
.
dtype
,
OutputType
,
dim3
grid
(
num_blocks_x
,
num_blocks_y
,
1
);
dim3
grid
(
num_blocks_x
,
num_blocks_y
,
1
);
...
...
transformer_engine/common/util/pybind_helper.h
View file @
f9d870f4
...
@@ -26,7 +26,8 @@
...
@@ -26,7 +26,8 @@
.value("kFloat16", transformer_engine::DType::kFloat16) \
.value("kFloat16", transformer_engine::DType::kFloat16) \
.value("kBFloat16", transformer_engine::DType::kBFloat16) \
.value("kBFloat16", transformer_engine::DType::kBFloat16) \
.value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \
.value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \
.value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \
.value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \
.value("kInt8", transformer_engine::DType::kInt8); \
pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local()) \
pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local()) \
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \
.value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \
.value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \
...
...
transformer_engine/common/utils.cuh
View file @
f9d870f4
...
@@ -990,6 +990,7 @@ using fp8e4m3 = te_hip_fp8_e4m3;
...
@@ -990,6 +990,7 @@ using fp8e4m3 = te_hip_fp8_e4m3;
using
fp8e5m2
=
te_hip_fp8_e5m2
;
using
fp8e5m2
=
te_hip_fp8_e5m2
;
#endif
#endif
using
e8m0_t
=
uint8_t
;
using
e8m0_t
=
uint8_t
;
using
int8
=
int8_t
;
constexpr
uint32_t
FP32_MANTISSA_BITS
=
23
;
constexpr
uint32_t
FP32_MANTISSA_BITS
=
23
;
constexpr
uint32_t
FP32_EXPONENT_BIAS
=
127
;
constexpr
uint32_t
FP32_EXPONENT_BIAS
=
127
;
...
@@ -1015,6 +1016,12 @@ struct Numeric_Traits<fp8e5m2> {
...
@@ -1015,6 +1016,12 @@ struct Numeric_Traits<fp8e5m2> {
static
constexpr
double
maxNorm
=
57344
;
static
constexpr
double
maxNorm
=
57344
;
};
};
template
<
>
struct
Numeric_Traits
<
int8
>
{
static
constexpr
int
maxUnbiasedExponent
=
0
;
static
constexpr
double
maxNorm
=
127
;
};
template
<
typename
T
>
template
<
typename
T
>
struct
Quantized_Limits
{
struct
Quantized_Limits
{
static
constexpr
int
max_unbiased_exponent
=
Numeric_Traits
<
T
>::
maxUnbiasedExponent
;
static
constexpr
int
max_unbiased_exponent
=
Numeric_Traits
<
T
>::
maxUnbiasedExponent
;
...
...
transformer_engine/pytorch/constants.py
View file @
f9d870f4
...
@@ -18,6 +18,7 @@ TE_DType = {
...
@@ -18,6 +18,7 @@ TE_DType = {
torch
.
uint8
:
tex
.
DType
.
kByte
,
torch
.
uint8
:
tex
.
DType
.
kByte
,
torch
.
float8_e4m3fn
:
tex
.
DType
.
kFloat8E4M3
,
torch
.
float8_e4m3fn
:
tex
.
DType
.
kFloat8E4M3
,
torch
.
float8_e5m2
:
tex
.
DType
.
kFloat8E5M2
,
torch
.
float8_e5m2
:
tex
.
DType
.
kFloat8E5M2
,
torch
.
int8
:
tex
.
DType
.
kInt8
,
torch
.
int32
:
tex
.
DType
.
kInt32
,
torch
.
int32
:
tex
.
DType
.
kInt32
,
torch
.
float32
:
tex
.
DType
.
kFloat32
,
torch
.
float32
:
tex
.
DType
.
kFloat32
,
torch
.
half
:
tex
.
DType
.
kFloat16
,
torch
.
half
:
tex
.
DType
.
kFloat16
,
...
...
transformer_engine/pytorch/fp8.py
View file @
f9d870f4
...
@@ -60,6 +60,8 @@ def check_mxfp8_support() -> Tuple[bool, str]:
...
@@ -60,6 +60,8 @@ def check_mxfp8_support() -> Tuple[bool, str]:
def
check_fp8_block_scaling_support
()
->
Tuple
[
bool
,
str
]:
def
check_fp8_block_scaling_support
()
->
Tuple
[
bool
,
str
]:
"""Return if fp8 block scaling support is available"""
"""Return if fp8 block scaling support is available"""
if
IS_HIP_EXTENSION
:
return
True
,
""
if
(
if
(
get_device_compute_capability
()
>=
(
9
,
0
)
get_device_compute_capability
()
>=
(
9
,
0
)
and
get_device_compute_capability
()
<
(
10
,
0
)
and
get_device_compute_capability
()
<
(
10
,
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment