Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
0a8072fa
Commit
0a8072fa
authored
Jun 09, 2025
by
yuguo
Browse files
[DCU] surpport cast master weight to int8
parent
2cbe1b70
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
83 additions
and
14 deletions
+83
-14
setup.py
setup.py
+2
-2
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
+1
-1
transformer_engine/common/CMakeLists.txt
transformer_engine/common/CMakeLists.txt
+7
-1
transformer_engine/common/common.h
transformer_engine/common/common.h
+32
-0
transformer_engine/common/recipe/fp8_block_scaling.cu
transformer_engine/common/recipe/fp8_block_scaling.cu
+4
-4
transformer_engine/common/transpose/transpose.cu
transformer_engine/common/transpose/transpose.cu
+1
-1
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+30
-2
transformer_engine/pytorch/csrc/extensions/fp8_block_scaling_partial_cast.cpp
...ytorch/csrc/extensions/fp8_block_scaling_partial_cast.cpp
+3
-2
transformer_engine/pytorch/tensor/utils.py
transformer_engine/pytorch/tensor/utils.py
+3
-1
No files found.
setup.py
View file @
0a8072fa
...
...
@@ -65,9 +65,9 @@ def setup_common_extension() -> CMakeExtension:
cmake_flags
=
[]
if
bool
(
int
(
os
.
getenv
(
"NVTE_BUILD_SUPPRESS_UNUSED_WARNING"
,
"1"
))):
cmake_flags
.
append
(
"-DNVTE_BUILD_SUPPRESS_UNUSED_WARNING=ON"
)
if
bool
(
int
(
os
.
getenv
(
"NVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING"
,
"
0
"
))):
if
bool
(
int
(
os
.
getenv
(
"NVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING"
,
"
1
"
))):
cmake_flags
.
append
(
"-DNVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING=ON"
)
if
bool
(
int
(
os
.
getenv
(
"NVTE_BUILD_SUPPRESS_SIGN_COMPARE"
,
"
0
"
))):
if
bool
(
int
(
os
.
getenv
(
"NVTE_BUILD_SUPPRESS_SIGN_COMPARE"
,
"
1
"
))):
cmake_flags
.
append
(
"-DNVTE_BUILD_SUPPRESS_SIGN_COMPARE_WARNING=ON"
)
else
:
...
...
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
View file @
0a8072fa
...
...
@@ -9,7 +9,7 @@ from pathlib import Path
import
pytest
import
torch
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
# NVTE_INT8_SIM_FP8=1 torchrun --nproc_per_node=4 run_cast_master_weights_to_fp8.py --quantization fp8_block
if
torch
.
cuda
.
device_count
()
<
2
:
pytest
.
skip
(
"cast_master_weights_to_fp8 test needs at least 2 GPUs."
)
...
...
transformer_engine/common/CMakeLists.txt
View file @
0a8072fa
...
...
@@ -125,6 +125,9 @@ if(USE_CUDA)
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu
activation/gelu.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu
activation/relu.cu
...
...
@@ -165,6 +168,9 @@ else()
cudnn_utils.cpp
transformer_engine.cpp
common.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
...
...
@@ -222,6 +228,7 @@ else()
set
(
header_include_dir
${
CMAKE_CURRENT_SOURCE_DIR
}
/comm_gemm_overlap/userbuffers
${
CMAKE_CURRENT_SOURCE_DIR
}
/activation
${
CMAKE_CURRENT_SOURCE_DIR
}
/fused_attn
${
CMAKE_CURRENT_SOURCE_DIR
}
/include
${
CMAKE_CURRENT_SOURCE_DIR
}
/transpose
${
CMAKE_CURRENT_SOURCE_DIR
}
/util
...
...
@@ -234,7 +241,6 @@ else()
hipify
(
CUDA_SOURCE_DIR
${
CMAKE_CURRENT_SOURCE_DIR
}
HEADER_INCLUDE_DIR
${
header_include_dir
}
IGNORES
"*/amd_detail/*"
IGNORES
"*/fused_attn/*"
CUSTOM_MAP_FILE
"
${
TE
}
/hipify_custom_map.json"
)
get_hipified_list
(
"
${
transformer_engine_SOURCES
}
"
te_hip_sources
)
...
...
transformer_engine/common/common.h
View file @
0a8072fa
...
...
@@ -280,6 +280,7 @@ TRANSFORMER_ENGINE_TYPE_NAME(half)
TRANSFORMER_ENGINE_TYPE_NAME
(
nv_bfloat16
)
TRANSFORMER_ENGINE_TYPE_NAME
(
__nv_fp8_e4m3
)
TRANSFORMER_ENGINE_TYPE_NAME
(
__nv_fp8_e5m2
)
TRANSFORMER_ENGINE_TYPE_NAME
(
int8_t
)
#if CUDA_VERSION >= 12080
TRANSFORMER_ENGINE_TYPE_NAME
(
__nv_fp8_e8m0
)
#endif
...
...
@@ -455,6 +456,37 @@ struct TypeInfo {
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT_WITH_INT8(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat32: { \
using type = float; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat16: { \
using type = fp16; \
{ __VA_ARGS__ } \
} break; \
case DType::kBFloat16: { \
using type = bf16; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat8E5M2: { \
using type = fp8e5m2; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat8E4M3: { \
using type = fp8e4m3; \
{ __VA_ARGS__ } \
} break; \
case DType::kInt8: { \
using type = int8; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
...
...
transformer_engine/common/recipe/fp8_block_scaling.cu
View file @
0a8072fa
...
...
@@ -53,7 +53,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
for
(
int
delta
=
kThreadsPerWarp
/
2
;
delta
>
0
;
delta
/=
2
)
{
#ifdef __HIP_PLATFORM_AMD__
float
other_amax
=
__shfl_down
(
amax
,
delta
);
float
other_amax
=
__shfl_down
(
amax
,
delta
,
kThreadsPerWarp
);
#else
float
other_amax
=
__shfl_down_sync
(
0xFFFFFFFF
,
amax
,
delta
);
#endif
...
...
@@ -124,14 +124,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
for
(
int
delta
=
kThreadsPerWarp
/
2
;
delta
>
0
;
delta
/=
2
)
{
#ifdef __HIP_PLATFORM_AMD__
bool
other_skip_store
=
__shfl_down
(
skip_store
,
delta
);
bool
other_skip_store
=
__shfl_down
(
skip_store
,
delta
,
kThreadsPerWarp
);
#else
bool
other_skip_store
=
__shfl_down_sync
(
0xFFFFFFFF
,
skip_store
,
delta
);
#endif
skip_store
=
skip_store
&&
other_skip_store
;
}
#ifdef __HIP_PLATFORM_AMD__
skip_store
=
__shfl
(
skip_store
,
0
);
skip_store
=
__shfl
(
skip_store
,
0
,
kThreadsPerWarp
);
#else
skip_store
=
__shfl_sync
(
0xFFFFFFFF
,
skip_store
,
0
);
#endif
...
...
@@ -217,7 +217,7 @@ void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor s
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
inp
.
dtype
(),
inp_dtype
,
TRANSFORMER_ENGINE_TYPE_SWITCH_
FP8ONLY
(
TRANSFORMER_ENGINE_TYPE_SWITCH_
8BIT
(
out_dtype
,
fp8_type
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
w
%
kTileDim
==
0
,
kWidthAligned
,
...
...
transformer_engine/common/transpose/transpose.cu
View file @
0a8072fa
...
...
@@ -211,7 +211,7 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr
NVTE_CHECK
(
noop
.
data
.
dptr
!=
nullptr
);
}
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
_WITH_INT8
(
input
.
data
.
dtype
,
Type
,
constexpr
const
char
*
type_name
=
TypeInfo
<
Type
>::
name
;
constexpr
size_t
type_size
=
sizeof
(
Type
);
...
...
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
0a8072fa
...
...
@@ -60,7 +60,6 @@ def general_gemm(
assert
not
gelu
,
"GELU not supported with int8 simulation"
assert
gelu_in
is
None
,
"GELU input not supported with int8 simulation"
assert
bias
is
None
,
"Bias not supported with int8 simulation"
assert
not
accumulate
,
"Accumulation not supported with int8 simulation"
assert
ub
is
None
,
"User buffer not supported with int8 simulation"
assert
ub_type
is
None
,
"User buffer type not supported with int8 simulation"
assert
extra_output
is
None
,
"Extra output not supported with int8 simulation"
...
...
@@ -80,6 +79,11 @@ def general_gemm(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
128
,
128
],
output_dtype
=
out_dtype
)
if
accumulate
:
assert
out
is
not
None
y
=
y
+
out
else
:
assert
out
is
None
,
"Output tensor should be None when accumulate is False."
return
y
,
None
,
None
,
None
elif
layout
==
"NN"
:
...
...
@@ -96,6 +100,11 @@ def general_gemm(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
128
,
128
],
output_dtype
=
out_dtype
)
if
accumulate
:
assert
out
is
not
None
y
=
y
+
out
else
:
assert
out
is
None
,
"Output tensor should be None when accumulate is False."
return
y
,
None
,
None
,
None
elif
layout
==
"NT"
:
...
...
@@ -112,6 +121,11 @@ def general_gemm(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
[
128
,
128
],
output_dtype
=
out_dtype
)
if
accumulate
:
assert
out
is
not
None
y
=
y
+
out
else
:
assert
out
is
None
,
"Output tensor should be None when accumulate is False."
return
y
,
None
,
None
,
None
else
:
...
...
@@ -203,7 +217,6 @@ def general_grouped_gemm(
assert
len
(
set
(
m_splits
))
==
1
,
"Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assert
not
gelu
,
"GELU not supported with int8 simulation groupgemm."
assert
bias
is
None
,
"Bias not supported with int8 simulation groupgemm."
assert
not
accumulate
,
"Accumulation not supported with int8 simulation groupgemm."
if
layout
==
"TN"
:
qx_data
=
[
...
...
@@ -219,6 +232,11 @@ def general_grouped_gemm(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
128
,
128
],
output_dtype
=
out_dtype
)
if
accumulate
:
assert
out
is
not
None
y
=
y
+
out
else
:
assert
out
is
None
,
"Output tensor should be None when accumulate is False."
return
y
,
None
,
None
elif
layout
==
"NN"
:
...
...
@@ -235,6 +253,11 @@ def general_grouped_gemm(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
128
,
128
],
output_dtype
=
out_dtype
)
if
accumulate
:
assert
out
is
not
None
y
=
y
+
out
else
:
assert
out
is
None
,
"Output tensor should be None when accumulate is False."
return
y
,
None
,
None
elif
layout
==
"NT"
:
...
...
@@ -251,6 +274,11 @@ def general_grouped_gemm(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
[
128
,
128
],
output_dtype
=
out_dtype
)
if
accumulate
:
assert
out
is
not
None
y
=
y
+
out
else
:
assert
out
is
None
,
"Output tensor should be None when accumulate is False."
return
y
,
None
,
None
else
:
...
...
transformer_engine/pytorch/csrc/extensions/fp8_block_scaling_partial_cast.cpp
View file @
0a8072fa
...
...
@@ -36,8 +36,9 @@ void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const
"input must be a float or bfloat16 tensor"
);
TORCH_CHECK
(
out
.
scalar_type
()
==
at
::
ScalarType
::
Byte
,
"output must be a uint8 tensor"
);
TORCH_CHECK
(
out_dtype
==
transformer_engine
::
DType
::
kFloat8E4M3
||
out_dtype
==
transformer_engine
::
DType
::
kFloat8E5M2
,
"out_dtype must be kFloat8E4M3 or kFloat8E5M2"
);
out_dtype
==
transformer_engine
::
DType
::
kFloat8E5M2
||
out_dtype
==
transformer_engine
::
DType
::
kInt8
,
"out_dtype must be kFloat8E4M3 or kFloat8E5M2 or kInt8"
);
const
TensorWrapper
inp_cu
=
makeTransformerEngineTensor
(
inp
);
TensorWrapper
out_cu
=
makeTransformerEngineTensor
(
out
);
...
...
transformer_engine/pytorch/tensor/utils.py
View file @
0a8072fa
...
...
@@ -414,6 +414,8 @@ def _cast_master_weights_to_fp8_blockwise_scaling(
max_fp8
=
448.0
elif
fp8_dtype
==
tex
.
DType
.
kFloat8E5M2
:
max_fp8
=
57344.0
elif
fp8_dtype
==
tex
.
DType
.
kInt8
:
max_fp8
=
127.0
else
:
raise
ValueError
(
f
"Unsupported FP8 dtype:
{
fp8_dtype
}
"
)
multi_tensor_applier
(
...
...
@@ -435,7 +437,7 @@ def _cast_master_weights_to_fp8_blockwise_scaling(
# We cannot create columnwise data here because users (like megatron) may want to overlap
# the all-gather of model weights and forward process, so the model weight is not updated
# at this moment.
model_weight
.
update_usage
(
rowwise_usage
=
True
,
columnwise_usage
=
False
)
model_weight
.
update_usage
(
rowwise_usage
=
True
,
columnwise_usage
=
False
)
# May cause core dump in iter 2
# If master weight is None, it means that the master weight of the current model weight
# is in other DP ranks.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment