Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
2b05e121
Commit
2b05e121
authored
Jun 17, 2025
by
yuguo
Browse files
Merge commit '
a69692ac
' of...
Merge commit '
a69692ac
' of
https://github.com/NVIDIA/TransformerEngine
parents
0fd441c2
a69692ac
Changes
245
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
295 additions
and
81 deletions
+295
-81
transformer_engine/pytorch/csrc/extensions/attention.cpp
transformer_engine/pytorch/csrc/extensions/attention.cpp
+11
-11
transformer_engine/pytorch/csrc/extensions/cast.cpp
transformer_engine/pytorch/csrc/extensions/cast.cpp
+4
-1
transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp
...rmer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp
+8
-0
transformer_engine/pytorch/csrc/extensions/fp8_block_scaling_partial_cast.cpp
...ytorch/csrc/extensions/fp8_block_scaling_partial_cast.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/gemm.cpp
transformer_engine/pytorch/csrc/extensions/gemm.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/misc.cpp
transformer_engine/pytorch/csrc/extensions/misc.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp
...rmer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp
...ne/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp
...er_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp
...mer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/multi_tensor/sgd.cpp
...ormer_engine/pytorch/csrc/extensions/multi_tensor/sgd.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/normalization.cpp
transformer_engine/pytorch/csrc/extensions/normalization.cpp
+7
-1
transformer_engine/pytorch/csrc/extensions/padding.cpp
transformer_engine/pytorch/csrc/extensions/padding.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/permutation.cpp
transformer_engine/pytorch/csrc/extensions/permutation.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/pybind.cpp
transformer_engine/pytorch/csrc/extensions/pybind.cpp
+6
-3
transformer_engine/pytorch/csrc/extensions/recipe.cpp
transformer_engine/pytorch/csrc/extensions/recipe.cpp
+25
-20
transformer_engine/pytorch/csrc/extensions/softmax.cpp
transformer_engine/pytorch/csrc/extensions/softmax.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/transpose.cpp
transformer_engine/pytorch/csrc/extensions/transpose.cpp
+1
-1
transformer_engine/pytorch/csrc/quantizer.cpp
transformer_engine/pytorch/csrc/quantizer.cpp
+50
-19
transformer_engine/pytorch/distributed.py
transformer_engine/pytorch/distributed.py
+172
-14
No files found.
transformer_engine/pytorch/csrc/extensions/attention.cpp
View file @
2b05e121
...
...
@@ -4,8 +4,8 @@
* See LICENSE for license information.
************************************************************************/
#include "../extensions.h"
#include "common.h"
#include "extensions.h"
#include "pybind.h"
namespace
{
...
...
@@ -24,12 +24,12 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s
NVTE_CHECK
(
fcd_size
%
block_size
==
0
,
"input size not aligned to block size"
);
size_t
element_size
=
transformer_engine
::
pytorch
::
typeTo
Size
(
self
.
dtype
());
size_t
element_size
_bits
=
transformer_engine
::
pytorch
::
typeTo
NumBits
(
self
.
dtype
());
int32_t
start_row
=
start_index
.
data_ptr
<
int32_t
>
()[
0
];
void
*
base_ptr
=
static_cast
<
char
*>
(
self
.
get_rowwise_data
().
data_ptr
)
+
static_cast
<
size_t
>
(
start_row
)
*
fcd_size
*
element_size
;
static_cast
<
size_t
>
(
start_row
)
*
fcd_size
*
element_size
_bits
/
8
;
size_t
num_rows_to_zero
=
max_tokens
-
start_row
;
size_t
total_bytes
=
num_rows_to_zero
*
fcd_size
*
element_size
;
size_t
total_bytes
=
num_rows_to_zero
*
fcd_size
*
element_size
_bits
/
8
;
NVTE_SCOPED_GIL_RELEASE
(
{
nvte_memset
(
base_ptr
,
0
,
total_bytes
,
at
::
cuda
::
getCurrentCUDAStream
());
});
...
...
@@ -57,17 +57,17 @@ namespace transformer_engine::pytorch {
// get the fused attention backend
NVTE_Fused_Attn_Backend
get_fused_attn_backend
(
const
DType
q_dtype
,
const
DType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
float
p_dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
bool
is_training
,
const
DType
q_dtype
,
const
DType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
float
p_dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
#ifdef __HIP_PLATFORM_AMD__
return
NVTE_Fused_Attn_Backend
::
NVTE_No_Backend
;
#else
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
static_cast
<
NVTEDType
>
(
q_dtype
),
static_cast
<
NVTEDType
>
(
kv_dtype
),
qkv_layout
,
bias_type
,
attn_mask_type
,
p_dropout
,
num_attn_heads
,
num_gqa_groups
,
max_seqlen_q
,
max_seqlen_kv
,
head_dim_qk
,
head_dim_v
,
window_size_left
,
window_size_right
);
is_training
,
static_cast
<
NVTEDType
>
(
q_dtype
),
static_cast
<
NVTEDType
>
(
kv_dtype
),
qkv_layout
,
bias_type
,
attn_mask_type
,
p_dropout
,
num_attn_heads
,
num_gqa_groups
,
max_seqlen_q
,
max_seqlen_kv
,
head_dim_qk
,
head_dim_v
,
window_size_left
,
window_size_right
);
return
fused_attention_backend
;
#endif
}
...
...
transformer_engine/pytorch/csrc/extensions/cast.cpp
View file @
2b05e121
...
...
@@ -6,8 +6,8 @@
#include "transformer_engine/cast.h"
#include "../extensions.h"
#include "common.h"
#include "extensions.h"
#include "pybind.h"
#include "transformer_engine/transformer_engine.h"
...
...
@@ -81,6 +81,9 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
auto
my_quantizer_bw
=
static_cast
<
Float8BlockQuantizer
*>
(
my_quantizer
.
get
());
quant_config
.
set_force_pow_2_scales
(
my_quantizer_bw
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_bw
->
amax_epsilon
);
if
(
my_quantizer_bw
->
all_gather_usage
)
{
quant_config
.
set_float8_block_scale_tensor_format
(
Float8BlockScaleTensorFormat
::
COMPACT
);
}
}
NVTE_SCOPED_GIL_RELEASE
({
nvte_quantize_v2
(
te_input
.
data
(),
te_output
.
data
(),
quant_config
,
...
...
transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp
View file @
2b05e121
...
...
@@ -216,6 +216,10 @@ at::Tensor CommOverlap::get_buffer(bool local_chunk, std::optional<std::vector<i
return
torch
::
from_blob
(
ubuf_ptr
,
*
shape
,
at
::
dtype
(
dtype
).
device
(
torch
::
kCUDA
));
}
at
::
Stream
CommOverlap
::
get_communication_stream
()
{
return
at
::
cuda
::
getStreamFromExternal
(
_stream_comm
,
at
::
cuda
::
current_device
());
}
/***************************************************************************************************
* CommOverlapP2P
**************************************************************************************************/
...
...
@@ -300,3 +304,7 @@ at::Tensor CommOverlapP2P::get_buffer(bool local_chunk, std::optional<std::vecto
const
auto
dtype
=
transformer_engine
::
pytorch
::
GetATenDType
(
_ubuf
.
dtype
());
return
torch
::
from_blob
(
ubuf_ptr
,
*
shape
,
at
::
dtype
(
dtype
).
device
(
torch
::
kCUDA
));
}
at
::
Stream
CommOverlapP2P
::
get_communication_stream
()
{
return
at
::
cuda
::
getStreamFromExternal
(
_stream_recv
,
at
::
cuda
::
current_device
());
}
transformer_engine/pytorch/csrc/extensions/fp8_block_scaling_partial_cast.cpp
View file @
2b05e121
...
...
@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "
../
extensions.h"
namespace
transformer_engine
::
pytorch
{
...
...
transformer_engine/pytorch/csrc/extensions/gemm.cpp
View file @
2b05e121
...
...
@@ -10,10 +10,10 @@
#include <string>
#include "../common.h"
#include "../extensions.h"
#include "common.h"
#include "common/util/cuda_runtime.h"
#include "common/util/system.h"
#include "extensions.h"
#include "pybind.h"
#include "transformer_engine/transformer_engine.h"
#include "util.h"
...
...
transformer_engine/pytorch/csrc/extensions/misc.cpp
View file @
2b05e121
...
...
@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "
../
extensions.h"
namespace
transformer_engine
::
pytorch
{
...
...
transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp
View file @
2b05e121
...
...
@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "
../../
extensions.h"
namespace
transformer_engine
::
pytorch
{
...
...
transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp
View file @
2b05e121
...
...
@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "
../../
extensions.h"
namespace
transformer_engine
::
pytorch
{
...
...
transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp
View file @
2b05e121
...
...
@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "
../../
extensions.h"
namespace
transformer_engine
::
pytorch
{
...
...
transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp
View file @
2b05e121
...
...
@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "
../../
extensions.h"
namespace
transformer_engine
::
pytorch
{
...
...
transformer_engine/pytorch/csrc/extensions/multi_tensor/sgd.cpp
View file @
2b05e121
...
...
@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "
../../
extensions.h"
namespace
transformer_engine
::
pytorch
{
...
...
transformer_engine/pytorch/csrc/extensions/normalization.cpp
View file @
2b05e121
...
...
@@ -4,8 +4,8 @@
* See LICENSE for license information.
************************************************************************/
#include "../extensions.h"
#include "common/util/system.h"
#include "extensions.h"
#include "pybind.h"
namespace
transformer_engine
::
pytorch
{
...
...
@@ -170,6 +170,9 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
auto
my_quantizer_bw
=
static_cast
<
Float8BlockQuantizer
*>
(
my_quantizer
.
get
());
quant_config
.
set_force_pow_2_scales
(
my_quantizer_bw
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_bw
->
amax_epsilon
);
if
(
my_quantizer_bw
->
all_gather_usage
)
{
quant_config
.
set_float8_block_scale_tensor_format
(
Float8BlockScaleTensorFormat
::
COMPACT
);
}
}
NVTE_SCOPED_GIL_RELEASE
({
nvte_quantize_v2
(
unquantized_out_cu
.
data
(),
out_cu
.
data
(),
quant_config
,
...
...
@@ -328,6 +331,9 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
auto
my_quantizer_bw
=
static_cast
<
Float8BlockQuantizer
*>
(
my_quantizer
.
get
());
quant_config
.
set_force_pow_2_scales
(
my_quantizer_bw
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_bw
->
amax_epsilon
);
if
(
my_quantizer_bw
->
all_gather_usage
)
{
quant_config
.
set_float8_block_scale_tensor_format
(
Float8BlockScaleTensorFormat
::
COMPACT
);
}
}
NVTE_SCOPED_GIL_RELEASE
({
nvte_quantize_v2
(
unquantized_out_cu
.
data
(),
out_cu
.
data
(),
quant_config
,
...
...
transformer_engine/pytorch/csrc/extensions/padding.cpp
View file @
2b05e121
...
...
@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "
../
extensions.h"
#include "pybind.h"
namespace
transformer_engine
::
pytorch
{
...
...
transformer_engine/pytorch/csrc/extensions/permutation.cpp
View file @
2b05e121
...
...
@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "
../
extensions.h"
namespace
transformer_engine
::
pytorch
{
...
...
transformer_engine/pytorch/csrc/extensions/pybind.cpp
View file @
2b05e121
...
...
@@ -261,7 +261,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Get cublasLt version"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"get_cudnn_version"
,
&
transformer_engine
::
pytorch
::
get_cudnn_version
,
"Get cuDNN version"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
attr
(
"_num_cublas_streams"
)
=
py
::
int_
(
transformer_engine
::
num_streams
);
m
.
def
(
"get_num_cublas_streams"
,
&
nvte_get_num_compute_streams
,
"Get number of compute streams"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
#ifdef USE_ROCM
m
.
attr
(
"_num_cublas_batchgemm_streams"
)
=
py
::
int_
(
transformer_engine
::
num_batchgemm_streams
);
#endif
...
...
@@ -390,7 +391,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.
def
(
"copy_into_buffer"
,
&
CommOverlap
::
copy_into_buffer
,
py
::
arg
(
"input"
),
py
::
arg
(
"local_chunk"
)
=
false
)
.
def
(
"get_buffer"
,
&
CommOverlap
::
get_buffer
,
py
::
arg
(
"local_chunk"
)
=
false
,
py
::
arg
(
"shape"
)
=
std
::
nullopt
);
py
::
arg
(
"shape"
)
=
std
::
nullopt
)
.
def
(
"get_communication_stream"
,
&
CommOverlap
::
get_communication_stream
);
py
::
class_
<
CommOverlapP2P
,
std
::
shared_ptr
<
CommOverlapP2P
>
,
transformer_engine
::
CommOverlapP2PBase
,
transformer_engine
::
CommOverlapCore
>
(
...
...
@@ -407,5 +409,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.
def
(
"copy_into_buffer"
,
&
CommOverlapP2P
::
copy_into_buffer
,
py
::
arg
(
"input"
),
py
::
arg
(
"local_chunk"
)
=
false
)
.
def
(
"get_buffer"
,
&
CommOverlapP2P
::
get_buffer
,
py
::
arg
(
"local_chunk"
)
=
false
,
py
::
arg
(
"shape"
)
=
std
::
nullopt
);
py
::
arg
(
"shape"
)
=
std
::
nullopt
)
.
def
(
"get_communication_stream"
,
&
CommOverlapP2P
::
get_communication_stream
);
}
transformer_engine/pytorch/csrc/extensions/recipe.cpp
View file @
2b05e121
...
...
@@ -9,8 +9,8 @@
#include <string>
#include "
common/comm
on.h"
#include "
extensions
.h"
#include "
../extensi
on
s
.h"
#include "
transformer_engine/transformer_engine
.h"
namespace
transformer_engine
::
pytorch
{
...
...
@@ -34,30 +34,35 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reductio
const
std
::
string
&
amax_compute_algo
,
DType
fp8_dtype
,
float
margin
)
{
size_t
num_tensors
=
amax_histories
.
size
();
std
::
vector
<
Tensor
>
t_amax_histories
(
num_tensors
)
;
std
::
vector
<
Tensor
>
t_scales
(
num_tensors
)
;
std
::
vector
<
NVTETensor
>
te_amax_histories
(
num_tensors
);
std
::
vector
<
NVTETensor
>
te_scales
(
num_tensors
);
std
::
vector
<
NVTE
Tensor
>
t
e
_amax_histories
;
std
::
vector
<
NVTE
Tensor
>
t
e
_scales
;
te_amax_histories
.
reserve
(
num_tensors
);
te_scales
.
reserve
(
num_tensors
);
for
(
size_t
i
=
0
;
i
<
num_tensors
;
i
++
)
{
t_amax_histories
[
i
].
data
.
dptr
=
amax_histories
[
i
].
data_ptr
();
auto
amax_sizes
=
amax_histories
[
i
].
sizes
().
vec
();
std
::
vector
<
size_t
>
amax_shape
{
amax_sizes
.
begin
(),
amax_sizes
.
end
()};
t_amax_histories
[
i
].
data
.
shape
=
amax_shape
;
t_amax_histories
[
i
].
data
.
dtype
=
DType
::
kFloat32
;
t_scales
[
i
].
data
.
dptr
=
scales
[
i
].
data_ptr
();
auto
scale_sizes
=
scales
[
i
].
sizes
().
vec
();
std
::
vector
<
size_t
>
scale_shape
{
scale_sizes
.
begin
(),
scale_sizes
.
end
()};
t_scales
[
i
].
data
.
shape
=
scale_shape
;
t_scales
[
i
].
data
.
dtype
=
DType
::
kFloat32
;
te_amax_histories
[
i
]
=
reinterpret_cast
<
NVTETensor
>
(
&
t_amax_histories
[
i
]);
te_scales
[
i
]
=
reinterpret_cast
<
NVTETensor
>
(
&
t_scales
[
i
]);
te_amax_histories
.
push_back
(
nvte_create_tensor
(
NVTE_DELAYED_TENSOR_SCALING
));
NVTETensor
&
amax_history
=
te_amax_histories
.
back
();
NVTEShape
amax_shape
=
convertTorchShape
(
amax_histories
[
i
].
sizes
());
NVTEBasicTensor
amax_history_data
=
{
amax_histories
[
i
].
data_ptr
(),
static_cast
<
NVTEDType
>
(
DType
::
kFloat32
),
amax_shape
};
nvte_set_tensor_param
(
&
amax_history
,
kNVTERowwiseData
,
&
amax_history_data
);
te_scales
.
push_back
(
nvte_create_tensor
(
NVTE_DELAYED_TENSOR_SCALING
));
NVTETensor
&
scale
=
te_scales
.
back
();
NVTEShape
scale_shape
=
convertTorchShape
(
scales
[
i
].
sizes
());
NVTEBasicTensor
scale_data
=
{
scales
[
i
].
data_ptr
(),
static_cast
<
NVTEDType
>
(
DType
::
kFloat32
),
scale_shape
};
nvte_set_tensor_param
(
&
scale
,
kNVTERowwiseData
,
&
scale_data
);
}
nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction
(
makeTransformerEngineTensor
(
amax_reduction_buffer
).
data
(),
te_amax_histories
,
te_scales
,
amax_compute_algo
.
c_str
(),
static_cast
<
NVTEDType
>
(
fp8_dtype
),
margin
,
at
::
cuda
::
getCurrentCUDAStream
());
for
(
auto
&
t
:
te_amax_histories
)
{
nvte_destroy_tensor
(
t
);
}
for
(
auto
&
t
:
te_scales
)
{
nvte_destroy_tensor
(
t
);
}
}
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/extensions/softmax.cpp
View file @
2b05e121
...
...
@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "
../
extensions.h"
namespace
transformer_engine
::
pytorch
{
...
...
transformer_engine/pytorch/csrc/extensions/transpose.cpp
View file @
2b05e121
...
...
@@ -6,7 +6,7 @@
#include <optional>
#include "extensions.h"
#include "
../
extensions.h"
#include "pybind.h"
namespace
transformer_engine
::
pytorch
{
...
...
transformer_engine/pytorch/csrc/quantizer.cpp
View file @
2b05e121
...
...
@@ -261,6 +261,7 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti
this
->
amax_epsilon
=
quantizer
.
attr
(
"amax_epsilon"
).
cast
<
float
>
();
NVTE_CHECK
(
this
->
block_scaling_dim
==
1
||
this
->
block_scaling_dim
==
2
,
"Unsupported block scaling dim."
);
this
->
all_gather_usage
=
quantizer
.
attr
(
"all_gather_usage"
).
cast
<
bool
>
();
}
void
Float8BlockQuantizer
::
set_quantization_params
(
TensorWrapper
*
tensor
)
const
{
...
...
@@ -299,6 +300,10 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
size_t
m_dim
=
numel
/
k_dim
;
constexpr
size_t
kBlockLen
=
128
;
Float8BlockScaleTensorFormat
data_format
=
(
all_gather_usage
?
Float8BlockScaleTensorFormat
::
COMPACT
:
Float8BlockScaleTensorFormat
::
GEMM_READY
);
if
(
rowwise_usage
)
{
if
(
rowwise_data
.
has_value
())
{
data_rowwise
=
std
::
move
(
*
rowwise_data
);
...
...
@@ -308,14 +313,24 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
size_t
sinv0
=
0
;
size_t
sinv1
=
0
;
if
(
block_scaling_dim
==
2
)
{
// 2D scaling is always GEMM_READY for now
NVTE_CHECK
(
data_format
==
Float8BlockScaleTensorFormat
::
GEMM_READY
,
"2D scaling is always GEMM_READY for now."
);
sinv0
=
(
m_dim
+
kBlockLen
-
1
)
/
kBlockLen
;
sinv1
=
roundup
((
k_dim
+
kBlockLen
-
1
)
/
kBlockLen
,
4
);
}
else
if
(
block_scaling_dim
==
1
)
{
// 1D scaling can be GEMM_READY or COMPACT
bool
rowwise_compact
=
data_format
==
Float8BlockScaleTensorFormat
::
COMPACT
;
// default rowwise scaling factor shape already transpose the scaling factor so it's GEMM_READY
sinv0
=
(
k_dim
+
kBlockLen
-
1
)
/
kBlockLen
;
sinv1
=
roundup
(
m_dim
,
4
);
sinv1
=
rowwise_compact
?
m_dim
:
roundup
(
m_dim
,
4
);
// if the rowwise format is compact, the scaling factor is not be transposed
if
(
rowwise_compact
)
{
std
::
swap
(
sinv0
,
sinv1
);
}
}
else
{
NVTE_
CHECK
(
false
,
"Unsupported block_scaling_dim in create_tensor rowwise."
NVTE_
ERROR
(
"Unsupported block_scaling_dim in create_tensor rowwise.
"
"Expected 1 or 2. Got "
,
block_scaling_dim
);
}
...
...
@@ -332,6 +347,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
NVTE_CHECK
(
torch_shape
.
size
()
==
shape
.
size
(),
"Shape expected to match torch shape. Shape "
,
columnwise_shape
,
" torch shape: "
,
torch_columnwise_shape
);
if
(
torch_shape
.
size
()
>
0
)
{
if
(
!
all_gather_usage
)
{
torch_columnwise_shape
.
reserve
(
torch_shape
.
size
());
columnwise_shape
.
reserve
(
shape
.
size
());
torch_columnwise_shape
.
push_back
(
torch_shape
[
torch_shape
.
size
()
-
1
]);
...
...
@@ -340,18 +356,32 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
torch_columnwise_shape
.
push_back
(
torch_shape
[
i
]);
columnwise_shape
.
push_back
(
shape
[
i
]);
}
}
else
{
// assert we are doing 1D scaling
NVTE_CHECK
(
block_scaling_dim
==
1
,
"Compact columnwise format is not supported for 128x128 2D block scaling."
);
torch_columnwise_shape
=
torch_shape
;
columnwise_shape
=
shape
;
}
}
size_t
sinv0
=
0
;
size_t
sinv1
=
0
;
if
(
block_scaling_dim
==
2
)
{
// 2D scaling is always GEMM_READY for now
NVTE_CHECK
(
data_format
==
Float8BlockScaleTensorFormat
::
GEMM_READY
,
"2D scaling is always GEMM_READY for now."
);
sinv0
=
(
k_dim
+
kBlockLen
-
1
)
/
kBlockLen
;
sinv1
=
roundup
((
m_dim
+
kBlockLen
-
1
)
/
kBlockLen
,
4
);
}
else
if
(
block_scaling_dim
==
1
)
{
bool
columnwise_compact
=
data_format
==
Float8BlockScaleTensorFormat
::
COMPACT
;
sinv0
=
(
m_dim
+
kBlockLen
-
1
)
/
kBlockLen
;
sinv1
=
roundup
(
k_dim
,
4
);
sinv1
=
columnwise_compact
?
k_dim
:
roundup
(
k_dim
,
4
);
// GEMM READY case: scaling factor is [sinv0, sinv1], already transposed here for CuBLAS
// for COMPACT case, since we apply 128x1 scaling here without transposing columnwise data, scaling factor is also [sinv0, sinv1]
// so no need to swap sinv0 and sinv1 here
}
else
{
NVTE_
CHECK
(
false
,
"Unsupported block_scaling_dim in create_tensor columnwise."
NVTE_
ERROR
(
"Unsupported block_scaling_dim in create_tensor columnwise.
"
"Expected 1 or 2. Got "
,
block_scaling_dim
);
}
...
...
@@ -373,7 +403,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
"rowwise_data"
_a
=
data_rowwise
,
"columnwise_data"
_a
=
data_colwise
,
"rowwise_scale_inv"
_a
=
scale_inv_rowwise
,
"columnwise_scale_inv"
_a
=
scale_inv_colwise
,
"fp8_dtype"
_a
=
this
->
dtype
,
"quantizer"
_a
=
this
->
quantizer
,
"is_2D_scaled"
_a
=
(
block_scaling_dim
==
2
));
"is_2D_scaled"
_a
=
(
block_scaling_dim
==
2
)
,
"data_format"
_a
=
data_format
);
}
else
{
py
::
handle
Float8BlockwiseQTensorClass
(
reinterpret_cast
<
PyObject
*>
(
Float8BlockwiseQTensorPythonClass
));
...
...
@@ -381,7 +411,8 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
"shape"
_a
=
torch_shape
,
"dtype"
_a
=
GetATenDType
(
dtype
),
"rowwise_data"
_a
=
data_rowwise
,
"columnwise_data"
_a
=
data_colwise
,
"rowwise_scale_inv"
_a
=
scale_inv_rowwise
,
"columnwise_scale_inv"
_a
=
scale_inv_colwise
,
"fp8_dtype"
_a
=
this
->
dtype
,
"quantizer"
_a
=
this
->
quantizer
,
"is_2D_scaled"
_a
=
(
block_scaling_dim
==
2
));
"quantizer"
_a
=
this
->
quantizer
,
"is_2D_scaled"
_a
=
(
block_scaling_dim
==
2
),
"data_format"
_a
=
data_format
);
}
return
{
std
::
move
(
tensor
),
std
::
move
(
ret
)};
...
...
transformer_engine/pytorch/distributed.py
View file @
2b05e121
...
...
@@ -8,6 +8,7 @@ from __future__ import annotations
from
collections.abc
import
Iterable
from
contextlib
import
contextmanager
,
AbstractContextManager
,
ContextDecorator
from
functools
import
lru_cache
from
dataclasses
import
dataclass
import
math
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
warnings
...
...
@@ -19,6 +20,15 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from
torch.distributed.fsdp._common_utils
import
_get_module_fsdp_state
from
torch.distributed.fsdp._traversal_utils
import
_get_fsdp_states_with_modules
try
:
import
torch.distributed._symmetric_memory
as
symm_mem
HAS_TORCH_SYMMETRIC
=
True
except
ImportError
:
HAS_TORCH_SYMMETRIC
=
False
import
transformer_engine_torch
as
tex
from
.
import
torch_version
from
.utils
import
(
is_non_tn_fp8_gemm_supported
,
...
...
@@ -34,14 +44,8 @@ from .tensor.quantized_tensor import QuantizedTensor, Quantizer
from
.tensor._internal.float8_tensor_base
import
Float8TensorBase
from
.tensor._internal.mxfp8_tensor_base
import
MXFP8TensorBase
from
.tensor._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
..debug.pytorch.debug_quantization
import
DebugQuantizedTensor
from
..debug.pytorch.debug_quantization
import
DebugQuantizedTensor
,
DebugQuantizer
try
:
import
torch.distributed._symmetric_memory
as
symm_mem
HAS_TORCH_SYMMETRIC
=
True
except
ImportError
:
HAS_TORCH_SYMMETRIC
=
False
__all__
=
[
"checkpoint"
,
"CudaRNGStatesTracker"
]
...
...
@@ -943,7 +947,7 @@ def _all_gather_fp8(
out
=
quantizer
.
make_empty
(
out_shape
,
dtype
=
dtype
,
device
=
device
)
elif
isinstance
(
inp
,
Float8Tensor
):
out
=
inp
.
make_like
(
inp
,
shape
=
out_shape
)
out
.
_data
=
torch
.
empty
_like
(
out
.
_data
=
torch
.
empty
(
out_shape
,
dtype
=
torch
.
uint8
,
device
=
inp
.
device
,
...
...
@@ -977,6 +981,67 @@ def _all_gather_fp8(
return
out
,
handle
def
_set_quantizer_format
(
quantizer
:
Quantizer
,
compact
:
bool
=
False
)
->
None
:
"""Make quantizer compact"""
_quantizer
=
quantizer
if
isinstance
(
quantizer
,
DebugQuantizer
):
_quantizer
=
quantizer
.
parent_quantizer
if
isinstance
(
_quantizer
,
Float8BlockQuantizer
):
_quantizer
.
all_gather_usage
=
compact
def
_post_process_fp8_blockwise_gather
(
out
:
Float8BlockwiseQTensorBase
,
quantizer
:
Float8BlockQuantizer
,
handle
:
Optional
[
torch
.
distributed
.
Work
]
=
None
,
)
->
Float8BlockwiseQTensorBase
:
"""Post-process FP8 blockwise gather."""
if
handle
is
not
None
:
handle
.
wait
()
handle
=
None
if
out
.
_is_gemm_ready_format
():
return
out
needs_columnwise_data_transpose
=
(
quantizer
is
not
None
and
quantizer
.
columnwise_usage
and
not
is_non_tn_fp8_gemm_supported
()
)
need_rowwise_scale_transpose
=
(
quantizer
is
not
None
and
quantizer
.
rowwise_usage
and
not
is_non_tn_fp8_gemm_supported
()
)
# CuBLAS requires transpose of the scale inv tensor, suppose orig input is 256x1024
# columnwise compact format means doing 128x1 quantization of it
# so quantized tensor is 256x1024, scale inv is 2x1024
# If we were doing GEMM_READY format, then it's equivalent to do 1x128 quantization
# on a transposed 1024x256 tensor, so scale inv is 1024x2, cublas requries 2x1024
# Thereforce, it turns out we don't need to transpose the scale inv, only columnwise data
if
needs_columnwise_data_transpose
:
out
.
_transpose_columnwise_data
()
if
need_rowwise_scale_transpose
:
out
.
_rowwise_scale_inv
=
out
.
_rowwise_scale_inv
.
transpose
(
-
2
,
-
1
).
contiguous
()
out
.
_data_format
=
tex
.
Float8BlockScaleTensorFormat
.
GEMM_READY
return
out
@
dataclass
class
_FP8BlockwiseAllGatherAsyncHandle
:
"""Handle for asynchronous FP8 blockwise all-gather."""
tensor
:
Float8BlockwiseQTensorBase
quantizer
:
Float8BlockQuantizer
async_handle
:
torch
.
distributed
.
Work
_synchronized
:
bool
=
False
def
wait
(
self
)
->
None
:
"""Wait for the async operation to complete and post-process the tensor."""
if
self
.
_synchronized
:
return
self
.
async_handle
.
wait
()
_post_process_fp8_blockwise_gather
(
self
.
tensor
,
self
.
quantizer
)
self
.
_synchronized
=
True
def
_all_gather_fp8_blockwise
(
inp
:
torch
.
Tensor
,
process_group
:
dist_group_type
,
...
...
@@ -990,8 +1055,9 @@ def _all_gather_fp8_blockwise(
Returns: quantizer(gather(inp))
NOTE: The implementation is not sophisticated enough to honor async_op=True.
In some cases it falls back to synchronous gather and invokes the quantizer.
NOTE: The implementation is only going to honor async_op=True for FP8 gather case.
In the case where tensor shape is not divisible by 128, the implementation will fall back
to synchronous gather and invoke the quantizer.
"""
# Input tensor attributes
...
...
@@ -1027,7 +1093,11 @@ def _all_gather_fp8_blockwise(
out_shape
[
0
]
*=
world_size
# Doing BF16 gather for now as baseline because it's simpler
if
not
isinstance
(
inp
,
Float8BlockwiseQTensorBase
)
and
quantizer
is
not
None
:
if
(
not
isinstance
(
inp
,
Float8BlockwiseQTensorBase
)
and
quantizer
is
not
None
and
not
quantizer
.
is_quantizable
(
inp
)
):
out
=
torch
.
empty
(
out_shape
,
dtype
=
dtype
,
...
...
@@ -1035,14 +1105,93 @@ def _all_gather_fp8_blockwise(
memory_format
=
torch
.
contiguous_format
,
)
torch
.
distributed
.
all_gather_into_tensor
(
out
,
inp
,
group
=
process_group
,
async_op
=
False
)
orig_all_gather_usage
=
quantizer
.
all_gather_usage
quantizer
.
all_gather_usage
=
False
out
=
quantizer
(
out
)
quantizer
.
all_gather_usage
=
orig_all_gather_usage
return
out
,
None
# Implementation of fp8 gather needs to account for:
# * Getting columnwise data as a transpose of how it is stored for GEMMS.
# * Gathering non GEMM swizzled scales.
# * Refer to scaffold code when implementing at:
# https://github.com/kwyss-nvidia/TransformerEngine/commit/6659ee9dc84fb515d1d47699d8bfd20a72b76477
raise
NotImplementedError
(
"fp8 blockwise allgather not yet implemented"
)
# Cast input tensor to Float8BlockwiseQTensor with required data
# Set to compact usage in case the quantizer is not correctly configured
orig_all_gather_usage
=
quantizer
.
all_gather_usage
quantizer
.
all_gather_usage
=
True
if
not
isinstance
(
inp
,
Float8BlockwiseQTensorBase
):
inp
=
quantizer
(
inp
)
elif
(
quantizer
.
rowwise_usage
and
inp
.
_rowwise_data
is
None
)
or
(
quantizer
.
columnwise_usage
and
inp
.
_columnwise_data
is
None
):
warnings
.
warn
(
"Input and quantizer do not have matching usages. "
"Dequantizing and requantizing to Float8BlockwiseQTensor."
)
inp
=
quantizer
(
inp
.
dequantize
())
quantizer
.
all_gather_usage
=
orig_all_gather_usage
# Begin to do network communication, need to make sure compact format
if
inp
.
_data_format
!=
tex
.
Float8BlockScaleTensorFormat
.
COMPACT
:
raise
RuntimeError
(
"All-gather with FP8 block-wise quantized tensor requires compact data format, "
f
"but found data_format=
{
inp
.
_data_format
}
"
)
# Construct Float8BlockwiseQTensor output tensor
out
=
quantizer
.
make_empty
(
out_shape
,
dtype
=
dtype
,
device
=
device
)
# Coalesce NCCL collectives
with
torch
.
distributed
.
_coalescing_manager
(
group
=
process_group
,
device
=
device
,
async_ops
=
async_op
,
)
as
coalescing_manager
:
# Gather Float8BlockwiseQTensor data for row-wise usage
if
quantizer
.
rowwise_usage
:
# Launch all-gathers
torch
.
distributed
.
all_gather_into_tensor
(
out
.
_rowwise_scale_inv
,
inp
.
_rowwise_scale_inv
,
group
=
process_group
,
)
torch
.
distributed
.
all_gather_into_tensor
(
out
.
_rowwise_data
,
inp
.
_rowwise_data
,
group
=
process_group
,
)
# Gather Float8BlockwiseQTensor data for column-wise usage
if
quantizer
.
columnwise_usage
:
# Launch all-gathers
torch
.
distributed
.
all_gather_into_tensor
(
out
.
_columnwise_scale_inv
,
inp
.
_columnwise_scale_inv
,
group
=
process_group
,
)
torch
.
distributed
.
all_gather_into_tensor
(
out
.
_columnwise_data
,
inp
.
_columnwise_data
,
group
=
process_group
,
)
handle
=
coalescing_manager
if
async_op
else
None
# Unlike MXFP8, this fp8 blockwise tensor primarily works with Hopper
# This means that we need to transpose the gathered columnwise data
# Example usage is grad_output tensor, ie. dY in linear backward
# We want to gather two FP8 tensors (rowwise and columnwise) along dim0
# and then transpose the columnwise data to match the rowwise data
# Make sure FP8 transpose is populated if needed
if
async_op
:
handle
=
_FP8BlockwiseAllGatherAsyncHandle
(
out
,
quantizer
,
handle
)
else
:
# if it's a sync op, we need to do the transpose here as post processing step
_post_process_fp8_blockwise_gather
(
out
,
quantizer
,
handle
)
return
out
,
handle
def
_all_gather_mxfp8
(
...
...
@@ -1239,12 +1388,18 @@ def gather_along_first_dim(
final_quantizer
=
(
None
if
not
needs_quantized_gemm
(
inp
,
rowwise
=
True
)
else
quantizer
.
parent_quantizer
)
# Temporary fix for TP communication of Float8BlockwiseQTensorBase
if
isinstance
(
rowwise
,
Float8BlockwiseQTensorBase
):
rowwise
=
inp
.
_original_tensor
rowwise_total
=
gather_along_first_dim
(
rowwise
,
process_group
,
False
,
final_quantizer
)[
0
]
out_obj
.
rowwise_gemm_tensor
=
rowwise_total
if
rowwise
is
not
columnwise
:
final_quantizer_columnwise
=
(
None
if
not
needs_quantized_gemm
(
inp
,
rowwise
=
False
)
else
quantizer
.
parent_quantizer
)
# Temporary fix for TP communication of Float8BlockwiseQTensorBase
if
isinstance
(
columnwise
,
Float8BlockwiseQTensorBase
):
columnwise
=
inp
.
_original_tensor
columnwise_total
,
_
=
gather_along_first_dim
(
columnwise
,
process_group
,
False
,
final_quantizer_columnwise
)
...
...
@@ -1261,6 +1416,9 @@ def gather_along_first_dim(
)
if
isinstance
(
inp
,
QuantizedTensor
):
inp
=
inp
.
dequantize
()
# Falling back to high-precision all-gather for Float8BlockQuantizer
# means that it should directly output GEMM_READY format
_set_quantizer_format
(
quantizer
,
compact
=
False
)
out
=
torch
.
empty
(
out_shape
,
dtype
=
inp
.
dtype
,
...
...
Prev
1
…
7
8
9
10
11
12
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment