Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
2b05e121
Commit
2b05e121
authored
Jun 17, 2025
by
yuguo
Browse files
Merge commit '
a69692ac
' of...
Merge commit '
a69692ac
' of
https://github.com/NVIDIA/TransformerEngine
parents
0fd441c2
a69692ac
Changes
245
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
295 additions
and
81 deletions
+295
-81
transformer_engine/pytorch/csrc/extensions/attention.cpp
transformer_engine/pytorch/csrc/extensions/attention.cpp
+11
-11
transformer_engine/pytorch/csrc/extensions/cast.cpp
transformer_engine/pytorch/csrc/extensions/cast.cpp
+4
-1
transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp
...rmer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp
+8
-0
transformer_engine/pytorch/csrc/extensions/fp8_block_scaling_partial_cast.cpp
...ytorch/csrc/extensions/fp8_block_scaling_partial_cast.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/gemm.cpp
transformer_engine/pytorch/csrc/extensions/gemm.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/misc.cpp
transformer_engine/pytorch/csrc/extensions/misc.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp
...rmer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp
...ne/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp
...er_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp
...mer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/multi_tensor/sgd.cpp
...ormer_engine/pytorch/csrc/extensions/multi_tensor/sgd.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/normalization.cpp
transformer_engine/pytorch/csrc/extensions/normalization.cpp
+7
-1
transformer_engine/pytorch/csrc/extensions/padding.cpp
transformer_engine/pytorch/csrc/extensions/padding.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/permutation.cpp
transformer_engine/pytorch/csrc/extensions/permutation.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/pybind.cpp
transformer_engine/pytorch/csrc/extensions/pybind.cpp
+6
-3
transformer_engine/pytorch/csrc/extensions/recipe.cpp
transformer_engine/pytorch/csrc/extensions/recipe.cpp
+25
-20
transformer_engine/pytorch/csrc/extensions/softmax.cpp
transformer_engine/pytorch/csrc/extensions/softmax.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/transpose.cpp
transformer_engine/pytorch/csrc/extensions/transpose.cpp
+1
-1
transformer_engine/pytorch/csrc/quantizer.cpp
transformer_engine/pytorch/csrc/quantizer.cpp
+50
-19
transformer_engine/pytorch/distributed.py
transformer_engine/pytorch/distributed.py
+172
-14
No files found.
transformer_engine/pytorch/csrc/extensions/attention.cpp
View file @
2b05e121
...
@@ -4,8 +4,8 @@
...
@@ -4,8 +4,8 @@
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
#include "../extensions.h"
#include "common.h"
#include "common.h"
#include "extensions.h"
#include "pybind.h"
#include "pybind.h"
namespace
{
namespace
{
...
@@ -24,12 +24,12 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s
...
@@ -24,12 +24,12 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s
NVTE_CHECK
(
fcd_size
%
block_size
==
0
,
"input size not aligned to block size"
);
NVTE_CHECK
(
fcd_size
%
block_size
==
0
,
"input size not aligned to block size"
);
size_t
element_size
=
transformer_engine
::
pytorch
::
typeTo
Size
(
self
.
dtype
());
size_t
element_size
_bits
=
transformer_engine
::
pytorch
::
typeTo
NumBits
(
self
.
dtype
());
int32_t
start_row
=
start_index
.
data_ptr
<
int32_t
>
()[
0
];
int32_t
start_row
=
start_index
.
data_ptr
<
int32_t
>
()[
0
];
void
*
base_ptr
=
static_cast
<
char
*>
(
self
.
get_rowwise_data
().
data_ptr
)
+
void
*
base_ptr
=
static_cast
<
char
*>
(
self
.
get_rowwise_data
().
data_ptr
)
+
static_cast
<
size_t
>
(
start_row
)
*
fcd_size
*
element_size
;
static_cast
<
size_t
>
(
start_row
)
*
fcd_size
*
element_size
_bits
/
8
;
size_t
num_rows_to_zero
=
max_tokens
-
start_row
;
size_t
num_rows_to_zero
=
max_tokens
-
start_row
;
size_t
total_bytes
=
num_rows_to_zero
*
fcd_size
*
element_size
;
size_t
total_bytes
=
num_rows_to_zero
*
fcd_size
*
element_size
_bits
/
8
;
NVTE_SCOPED_GIL_RELEASE
(
NVTE_SCOPED_GIL_RELEASE
(
{
nvte_memset
(
base_ptr
,
0
,
total_bytes
,
at
::
cuda
::
getCurrentCUDAStream
());
});
{
nvte_memset
(
base_ptr
,
0
,
total_bytes
,
at
::
cuda
::
getCurrentCUDAStream
());
});
...
@@ -57,17 +57,17 @@ namespace transformer_engine::pytorch {
...
@@ -57,17 +57,17 @@ namespace transformer_engine::pytorch {
// get the fused attention backend
// get the fused attention backend
NVTE_Fused_Attn_Backend
get_fused_attn_backend
(
NVTE_Fused_Attn_Backend
get_fused_attn_backend
(
const
DType
q_dtype
,
const
DType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
bool
is_training
,
const
DType
q_dtype
,
const
DType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Mask_Type
attn_mask_type
,
float
p_dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
float
p_dropout
,
size_t
num_attn_heads
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
return
NVTE_Fused_Attn_Backend
::
NVTE_No_Backend
;
return
NVTE_Fused_Attn_Backend
::
NVTE_No_Backend
;
#else
#else
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
static_cast
<
NVTEDType
>
(
q_dtype
),
static_cast
<
NVTEDType
>
(
kv_dtype
),
qkv_layout
,
bias_type
,
is_training
,
static_cast
<
NVTEDType
>
(
q_dtype
),
static_cast
<
NVTEDType
>
(
kv_dtype
),
qkv_layout
,
attn_mask_type
,
p_dropout
,
num_attn_heads
,
num_gqa_groups
,
max_seqlen_q
,
max_seqlen_kv
,
bias_type
,
attn_mask_type
,
p_dropout
,
num_attn_heads
,
num_gqa_groups
,
max_seqlen_q
,
head_dim_qk
,
head_dim_v
,
window_size_left
,
window_size_right
);
max_seqlen_kv
,
head_dim_qk
,
head_dim_v
,
window_size_left
,
window_size_right
);
return
fused_attention_backend
;
return
fused_attention_backend
;
#endif
#endif
}
}
...
...
transformer_engine/pytorch/csrc/extensions/cast.cpp
View file @
2b05e121
...
@@ -6,8 +6,8 @@
...
@@ -6,8 +6,8 @@
#include "transformer_engine/cast.h"
#include "transformer_engine/cast.h"
#include "../extensions.h"
#include "common.h"
#include "common.h"
#include "extensions.h"
#include "pybind.h"
#include "pybind.h"
#include "transformer_engine/transformer_engine.h"
#include "transformer_engine/transformer_engine.h"
...
@@ -81,6 +81,9 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
...
@@ -81,6 +81,9 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
auto
my_quantizer_bw
=
static_cast
<
Float8BlockQuantizer
*>
(
my_quantizer
.
get
());
auto
my_quantizer_bw
=
static_cast
<
Float8BlockQuantizer
*>
(
my_quantizer
.
get
());
quant_config
.
set_force_pow_2_scales
(
my_quantizer_bw
->
force_pow_2_scales
);
quant_config
.
set_force_pow_2_scales
(
my_quantizer_bw
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_bw
->
amax_epsilon
);
quant_config
.
set_amax_epsilon
(
my_quantizer_bw
->
amax_epsilon
);
if
(
my_quantizer_bw
->
all_gather_usage
)
{
quant_config
.
set_float8_block_scale_tensor_format
(
Float8BlockScaleTensorFormat
::
COMPACT
);
}
}
}
NVTE_SCOPED_GIL_RELEASE
({
NVTE_SCOPED_GIL_RELEASE
({
nvte_quantize_v2
(
te_input
.
data
(),
te_output
.
data
(),
quant_config
,
nvte_quantize_v2
(
te_input
.
data
(),
te_output
.
data
(),
quant_config
,
...
...
transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp
View file @
2b05e121
...
@@ -216,6 +216,10 @@ at::Tensor CommOverlap::get_buffer(bool local_chunk, std::optional<std::vector<i
...
@@ -216,6 +216,10 @@ at::Tensor CommOverlap::get_buffer(bool local_chunk, std::optional<std::vector<i
return
torch
::
from_blob
(
ubuf_ptr
,
*
shape
,
at
::
dtype
(
dtype
).
device
(
torch
::
kCUDA
));
return
torch
::
from_blob
(
ubuf_ptr
,
*
shape
,
at
::
dtype
(
dtype
).
device
(
torch
::
kCUDA
));
}
}
at
::
Stream
CommOverlap
::
get_communication_stream
()
{
return
at
::
cuda
::
getStreamFromExternal
(
_stream_comm
,
at
::
cuda
::
current_device
());
}
/***************************************************************************************************
/***************************************************************************************************
* CommOverlapP2P
* CommOverlapP2P
**************************************************************************************************/
**************************************************************************************************/
...
@@ -300,3 +304,7 @@ at::Tensor CommOverlapP2P::get_buffer(bool local_chunk, std::optional<std::vecto
...
@@ -300,3 +304,7 @@ at::Tensor CommOverlapP2P::get_buffer(bool local_chunk, std::optional<std::vecto
const
auto
dtype
=
transformer_engine
::
pytorch
::
GetATenDType
(
_ubuf
.
dtype
());
const
auto
dtype
=
transformer_engine
::
pytorch
::
GetATenDType
(
_ubuf
.
dtype
());
return
torch
::
from_blob
(
ubuf_ptr
,
*
shape
,
at
::
dtype
(
dtype
).
device
(
torch
::
kCUDA
));
return
torch
::
from_blob
(
ubuf_ptr
,
*
shape
,
at
::
dtype
(
dtype
).
device
(
torch
::
kCUDA
));
}
}
at
::
Stream
CommOverlapP2P
::
get_communication_stream
()
{
return
at
::
cuda
::
getStreamFromExternal
(
_stream_recv
,
at
::
cuda
::
current_device
());
}
transformer_engine/pytorch/csrc/extensions/fp8_block_scaling_partial_cast.cpp
View file @
2b05e121
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
#include "extensions.h"
#include "
../
extensions.h"
namespace
transformer_engine
::
pytorch
{
namespace
transformer_engine
::
pytorch
{
...
...
transformer_engine/pytorch/csrc/extensions/gemm.cpp
View file @
2b05e121
...
@@ -10,10 +10,10 @@
...
@@ -10,10 +10,10 @@
#include <string>
#include <string>
#include "../common.h"
#include "../common.h"
#include "../extensions.h"
#include "common.h"
#include "common.h"
#include "common/util/cuda_runtime.h"
#include "common/util/cuda_runtime.h"
#include "common/util/system.h"
#include "common/util/system.h"
#include "extensions.h"
#include "pybind.h"
#include "pybind.h"
#include "transformer_engine/transformer_engine.h"
#include "transformer_engine/transformer_engine.h"
#include "util.h"
#include "util.h"
...
...
transformer_engine/pytorch/csrc/extensions/misc.cpp
View file @
2b05e121
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
#include "extensions.h"
#include "
../
extensions.h"
namespace
transformer_engine
::
pytorch
{
namespace
transformer_engine
::
pytorch
{
...
...
transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp
View file @
2b05e121
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
#include "extensions.h"
#include "
../../
extensions.h"
namespace
transformer_engine
::
pytorch
{
namespace
transformer_engine
::
pytorch
{
...
...
transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp
View file @
2b05e121
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
#include "extensions.h"
#include "
../../
extensions.h"
namespace
transformer_engine
::
pytorch
{
namespace
transformer_engine
::
pytorch
{
...
...
transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp
View file @
2b05e121
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
#include "extensions.h"
#include "
../../
extensions.h"
namespace
transformer_engine
::
pytorch
{
namespace
transformer_engine
::
pytorch
{
...
...
transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp
View file @
2b05e121
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
#include "extensions.h"
#include "
../../
extensions.h"
namespace
transformer_engine
::
pytorch
{
namespace
transformer_engine
::
pytorch
{
...
...
transformer_engine/pytorch/csrc/extensions/multi_tensor/sgd.cpp
View file @
2b05e121
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
#include "extensions.h"
#include "
../../
extensions.h"
namespace
transformer_engine
::
pytorch
{
namespace
transformer_engine
::
pytorch
{
...
...
transformer_engine/pytorch/csrc/extensions/normalization.cpp
View file @
2b05e121
...
@@ -4,8 +4,8 @@
...
@@ -4,8 +4,8 @@
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
#include "../extensions.h"
#include "common/util/system.h"
#include "common/util/system.h"
#include "extensions.h"
#include "pybind.h"
#include "pybind.h"
namespace
transformer_engine
::
pytorch
{
namespace
transformer_engine
::
pytorch
{
...
@@ -170,6 +170,9 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
...
@@ -170,6 +170,9 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
auto
my_quantizer_bw
=
static_cast
<
Float8BlockQuantizer
*>
(
my_quantizer
.
get
());
auto
my_quantizer_bw
=
static_cast
<
Float8BlockQuantizer
*>
(
my_quantizer
.
get
());
quant_config
.
set_force_pow_2_scales
(
my_quantizer_bw
->
force_pow_2_scales
);
quant_config
.
set_force_pow_2_scales
(
my_quantizer_bw
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_bw
->
amax_epsilon
);
quant_config
.
set_amax_epsilon
(
my_quantizer_bw
->
amax_epsilon
);
if
(
my_quantizer_bw
->
all_gather_usage
)
{
quant_config
.
set_float8_block_scale_tensor_format
(
Float8BlockScaleTensorFormat
::
COMPACT
);
}
}
}
NVTE_SCOPED_GIL_RELEASE
({
NVTE_SCOPED_GIL_RELEASE
({
nvte_quantize_v2
(
unquantized_out_cu
.
data
(),
out_cu
.
data
(),
quant_config
,
nvte_quantize_v2
(
unquantized_out_cu
.
data
(),
out_cu
.
data
(),
quant_config
,
...
@@ -328,6 +331,9 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
...
@@ -328,6 +331,9 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
auto
my_quantizer_bw
=
static_cast
<
Float8BlockQuantizer
*>
(
my_quantizer
.
get
());
auto
my_quantizer_bw
=
static_cast
<
Float8BlockQuantizer
*>
(
my_quantizer
.
get
());
quant_config
.
set_force_pow_2_scales
(
my_quantizer_bw
->
force_pow_2_scales
);
quant_config
.
set_force_pow_2_scales
(
my_quantizer_bw
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_bw
->
amax_epsilon
);
quant_config
.
set_amax_epsilon
(
my_quantizer_bw
->
amax_epsilon
);
if
(
my_quantizer_bw
->
all_gather_usage
)
{
quant_config
.
set_float8_block_scale_tensor_format
(
Float8BlockScaleTensorFormat
::
COMPACT
);
}
}
}
NVTE_SCOPED_GIL_RELEASE
({
NVTE_SCOPED_GIL_RELEASE
({
nvte_quantize_v2
(
unquantized_out_cu
.
data
(),
out_cu
.
data
(),
quant_config
,
nvte_quantize_v2
(
unquantized_out_cu
.
data
(),
out_cu
.
data
(),
quant_config
,
...
...
transformer_engine/pytorch/csrc/extensions/padding.cpp
View file @
2b05e121
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
#include "extensions.h"
#include "
../
extensions.h"
#include "pybind.h"
#include "pybind.h"
namespace
transformer_engine
::
pytorch
{
namespace
transformer_engine
::
pytorch
{
...
...
transformer_engine/pytorch/csrc/extensions/permutation.cpp
View file @
2b05e121
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
#include "extensions.h"
#include "
../
extensions.h"
namespace
transformer_engine
::
pytorch
{
namespace
transformer_engine
::
pytorch
{
...
...
transformer_engine/pytorch/csrc/extensions/pybind.cpp
View file @
2b05e121
...
@@ -261,7 +261,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -261,7 +261,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Get cublasLt version"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
"Get cublasLt version"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"get_cudnn_version"
,
&
transformer_engine
::
pytorch
::
get_cudnn_version
,
"Get cuDNN version"
,
m
.
def
(
"get_cudnn_version"
,
&
transformer_engine
::
pytorch
::
get_cudnn_version
,
"Get cuDNN version"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
attr
(
"_num_cublas_streams"
)
=
py
::
int_
(
transformer_engine
::
num_streams
);
m
.
def
(
"get_num_cublas_streams"
,
&
nvte_get_num_compute_streams
,
"Get number of compute streams"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
#ifdef USE_ROCM
#ifdef USE_ROCM
m
.
attr
(
"_num_cublas_batchgemm_streams"
)
=
py
::
int_
(
transformer_engine
::
num_batchgemm_streams
);
m
.
attr
(
"_num_cublas_batchgemm_streams"
)
=
py
::
int_
(
transformer_engine
::
num_batchgemm_streams
);
#endif
#endif
...
@@ -390,7 +391,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -390,7 +391,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.
def
(
"copy_into_buffer"
,
&
CommOverlap
::
copy_into_buffer
,
py
::
arg
(
"input"
),
.
def
(
"copy_into_buffer"
,
&
CommOverlap
::
copy_into_buffer
,
py
::
arg
(
"input"
),
py
::
arg
(
"local_chunk"
)
=
false
)
py
::
arg
(
"local_chunk"
)
=
false
)
.
def
(
"get_buffer"
,
&
CommOverlap
::
get_buffer
,
py
::
arg
(
"local_chunk"
)
=
false
,
.
def
(
"get_buffer"
,
&
CommOverlap
::
get_buffer
,
py
::
arg
(
"local_chunk"
)
=
false
,
py
::
arg
(
"shape"
)
=
std
::
nullopt
);
py
::
arg
(
"shape"
)
=
std
::
nullopt
)
.
def
(
"get_communication_stream"
,
&
CommOverlap
::
get_communication_stream
);
py
::
class_
<
CommOverlapP2P
,
std
::
shared_ptr
<
CommOverlapP2P
>
,
py
::
class_
<
CommOverlapP2P
,
std
::
shared_ptr
<
CommOverlapP2P
>
,
transformer_engine
::
CommOverlapP2PBase
,
transformer_engine
::
CommOverlapCore
>
(
transformer_engine
::
CommOverlapP2PBase
,
transformer_engine
::
CommOverlapCore
>
(
...
@@ -407,5 +409,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -407,5 +409,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.
def
(
"copy_into_buffer"
,
&
CommOverlapP2P
::
copy_into_buffer
,
py
::
arg
(
"input"
),
.
def
(
"copy_into_buffer"
,
&
CommOverlapP2P
::
copy_into_buffer
,
py
::
arg
(
"input"
),
py
::
arg
(
"local_chunk"
)
=
false
)
py
::
arg
(
"local_chunk"
)
=
false
)
.
def
(
"get_buffer"
,
&
CommOverlapP2P
::
get_buffer
,
py
::
arg
(
"local_chunk"
)
=
false
,
.
def
(
"get_buffer"
,
&
CommOverlapP2P
::
get_buffer
,
py
::
arg
(
"local_chunk"
)
=
false
,
py
::
arg
(
"shape"
)
=
std
::
nullopt
);
py
::
arg
(
"shape"
)
=
std
::
nullopt
)
.
def
(
"get_communication_stream"
,
&
CommOverlapP2P
::
get_communication_stream
);
}
}
transformer_engine/pytorch/csrc/extensions/recipe.cpp
View file @
2b05e121
...
@@ -9,8 +9,8 @@
...
@@ -9,8 +9,8 @@
#include <string>
#include <string>
#include "
common/comm
on.h"
#include "
../extensi
on
s
.h"
#include "
extensions
.h"
#include "
transformer_engine/transformer_engine
.h"
namespace
transformer_engine
::
pytorch
{
namespace
transformer_engine
::
pytorch
{
...
@@ -34,30 +34,35 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reductio
...
@@ -34,30 +34,35 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reductio
const
std
::
string
&
amax_compute_algo
,
const
std
::
string
&
amax_compute_algo
,
DType
fp8_dtype
,
float
margin
)
{
DType
fp8_dtype
,
float
margin
)
{
size_t
num_tensors
=
amax_histories
.
size
();
size_t
num_tensors
=
amax_histories
.
size
();
std
::
vector
<
Tensor
>
t_amax_histories
(
num_tensors
)
;
std
::
vector
<
NVTE
Tensor
>
t
e
_amax_histories
;
std
::
vector
<
Tensor
>
t_scales
(
num_tensors
)
;
std
::
vector
<
NVTE
Tensor
>
t
e
_scales
;
std
::
vector
<
NVTETensor
>
te_amax_histories
(
num_tensors
);
te_amax_histories
.
reserve
(
num_tensors
);
std
::
vector
<
NVTETensor
>
te_scales
(
num_tensors
);
te_scales
.
reserve
(
num_tensors
);
for
(
size_t
i
=
0
;
i
<
num_tensors
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
num_tensors
;
i
++
)
{
t_amax_histories
[
i
].
data
.
dptr
=
amax_histories
[
i
].
data_ptr
();
te_amax_histories
.
push_back
(
nvte_create_tensor
(
NVTE_DELAYED_TENSOR_SCALING
));
auto
amax_sizes
=
amax_histories
[
i
].
sizes
().
vec
();
NVTETensor
&
amax_history
=
te_amax_histories
.
back
();
std
::
vector
<
size_t
>
amax_shape
{
amax_sizes
.
begin
(),
amax_sizes
.
end
()};
NVTEShape
amax_shape
=
convertTorchShape
(
amax_histories
[
i
].
sizes
());
t_amax_histories
[
i
].
data
.
shape
=
amax_shape
;
NVTEBasicTensor
amax_history_data
=
{
amax_histories
[
i
].
data_ptr
(),
t_amax_histories
[
i
].
data
.
dtype
=
DType
::
kFloat32
;
static_cast
<
NVTEDType
>
(
DType
::
kFloat32
),
amax_shape
};
nvte_set_tensor_param
(
&
amax_history
,
kNVTERowwiseData
,
&
amax_history_data
);
t_scales
[
i
].
data
.
dptr
=
scales
[
i
].
data_ptr
();
auto
scale_sizes
=
scales
[
i
].
sizes
().
vec
();
te_scales
.
push_back
(
nvte_create_tensor
(
NVTE_DELAYED_TENSOR_SCALING
));
std
::
vector
<
size_t
>
scale_shape
{
scale_sizes
.
begin
(),
scale_sizes
.
end
()};
NVTETensor
&
scale
=
te_scales
.
back
();
t_scales
[
i
].
data
.
shape
=
scale_shape
;
NVTEShape
scale_shape
=
convertTorchShape
(
scales
[
i
].
sizes
());
t_scales
[
i
].
data
.
dtype
=
DType
::
kFloat32
;
NVTEBasicTensor
scale_data
=
{
scales
[
i
].
data_ptr
(),
static_cast
<
NVTEDType
>
(
DType
::
kFloat32
),
scale_shape
};
te_amax_histories
[
i
]
=
reinterpret_cast
<
NVTETensor
>
(
&
t_amax_histories
[
i
]);
nvte_set_tensor_param
(
&
scale
,
kNVTERowwiseData
,
&
scale_data
);
te_scales
[
i
]
=
reinterpret_cast
<
NVTETensor
>
(
&
t_scales
[
i
]);
}
}
nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction
(
nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction
(
makeTransformerEngineTensor
(
amax_reduction_buffer
).
data
(),
te_amax_histories
,
te_scales
,
makeTransformerEngineTensor
(
amax_reduction_buffer
).
data
(),
te_amax_histories
,
te_scales
,
amax_compute_algo
.
c_str
(),
static_cast
<
NVTEDType
>
(
fp8_dtype
),
margin
,
amax_compute_algo
.
c_str
(),
static_cast
<
NVTEDType
>
(
fp8_dtype
),
margin
,
at
::
cuda
::
getCurrentCUDAStream
());
at
::
cuda
::
getCurrentCUDAStream
());
for
(
auto
&
t
:
te_amax_histories
)
{
nvte_destroy_tensor
(
t
);
}
for
(
auto
&
t
:
te_scales
)
{
nvte_destroy_tensor
(
t
);
}
}
}
}
// namespace transformer_engine::pytorch
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/extensions/softmax.cpp
View file @
2b05e121
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
#include "extensions.h"
#include "
../
extensions.h"
namespace
transformer_engine
::
pytorch
{
namespace
transformer_engine
::
pytorch
{
...
...
transformer_engine/pytorch/csrc/extensions/transpose.cpp
View file @
2b05e121
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#include <optional>
#include <optional>
#include "extensions.h"
#include "
../
extensions.h"
#include "pybind.h"
#include "pybind.h"
namespace
transformer_engine
::
pytorch
{
namespace
transformer_engine
::
pytorch
{
...
...
transformer_engine/pytorch/csrc/quantizer.cpp
View file @
2b05e121
...
@@ -261,6 +261,7 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti
...
@@ -261,6 +261,7 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti
this
->
amax_epsilon
=
quantizer
.
attr
(
"amax_epsilon"
).
cast
<
float
>
();
this
->
amax_epsilon
=
quantizer
.
attr
(
"amax_epsilon"
).
cast
<
float
>
();
NVTE_CHECK
(
this
->
block_scaling_dim
==
1
||
this
->
block_scaling_dim
==
2
,
NVTE_CHECK
(
this
->
block_scaling_dim
==
1
||
this
->
block_scaling_dim
==
2
,
"Unsupported block scaling dim."
);
"Unsupported block scaling dim."
);
this
->
all_gather_usage
=
quantizer
.
attr
(
"all_gather_usage"
).
cast
<
bool
>
();
}
}
void
Float8BlockQuantizer
::
set_quantization_params
(
TensorWrapper
*
tensor
)
const
{
void
Float8BlockQuantizer
::
set_quantization_params
(
TensorWrapper
*
tensor
)
const
{
...
@@ -299,6 +300,10 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
...
@@ -299,6 +300,10 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
size_t
m_dim
=
numel
/
k_dim
;
size_t
m_dim
=
numel
/
k_dim
;
constexpr
size_t
kBlockLen
=
128
;
constexpr
size_t
kBlockLen
=
128
;
Float8BlockScaleTensorFormat
data_format
=
(
all_gather_usage
?
Float8BlockScaleTensorFormat
::
COMPACT
:
Float8BlockScaleTensorFormat
::
GEMM_READY
);
if
(
rowwise_usage
)
{
if
(
rowwise_usage
)
{
if
(
rowwise_data
.
has_value
())
{
if
(
rowwise_data
.
has_value
())
{
data_rowwise
=
std
::
move
(
*
rowwise_data
);
data_rowwise
=
std
::
move
(
*
rowwise_data
);
...
@@ -308,16 +313,26 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
...
@@ -308,16 +313,26 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
size_t
sinv0
=
0
;
size_t
sinv0
=
0
;
size_t
sinv1
=
0
;
size_t
sinv1
=
0
;
if
(
block_scaling_dim
==
2
)
{
if
(
block_scaling_dim
==
2
)
{
// 2D scaling is always GEMM_READY for now
NVTE_CHECK
(
data_format
==
Float8BlockScaleTensorFormat
::
GEMM_READY
,
"2D scaling is always GEMM_READY for now."
);
sinv0
=
(
m_dim
+
kBlockLen
-
1
)
/
kBlockLen
;
sinv0
=
(
m_dim
+
kBlockLen
-
1
)
/
kBlockLen
;
sinv1
=
roundup
((
k_dim
+
kBlockLen
-
1
)
/
kBlockLen
,
4
);
sinv1
=
roundup
((
k_dim
+
kBlockLen
-
1
)
/
kBlockLen
,
4
);
}
else
if
(
block_scaling_dim
==
1
)
{
}
else
if
(
block_scaling_dim
==
1
)
{
// 1D scaling can be GEMM_READY or COMPACT
bool
rowwise_compact
=
data_format
==
Float8BlockScaleTensorFormat
::
COMPACT
;
// default rowwise scaling factor shape already transpose the scaling factor so it's GEMM_READY
sinv0
=
(
k_dim
+
kBlockLen
-
1
)
/
kBlockLen
;
sinv0
=
(
k_dim
+
kBlockLen
-
1
)
/
kBlockLen
;
sinv1
=
roundup
(
m_dim
,
4
);
sinv1
=
rowwise_compact
?
m_dim
:
roundup
(
m_dim
,
4
);
// if the rowwise format is compact, the scaling factor is not be transposed
if
(
rowwise_compact
)
{
std
::
swap
(
sinv0
,
sinv1
);
}
}
else
{
}
else
{
NVTE_
CHECK
(
false
,
NVTE_
ERROR
(
"Unsupported block_scaling_dim in create_tensor rowwise."
"Unsupported block_scaling_dim in create_tensor rowwise.
"
"Expected 1 or 2. Got "
,
"Expected 1 or 2. Got "
,
block_scaling_dim
);
block_scaling_dim
);
}
}
scale_inv_rowwise
=
scale_inv_rowwise
=
at
::
empty
({
static_cast
<
int64_t
>
(
sinv0
),
static_cast
<
int64_t
>
(
sinv1
)},
scale_opts
);
at
::
empty
({
static_cast
<
int64_t
>
(
sinv0
),
static_cast
<
int64_t
>
(
sinv1
)},
scale_opts
);
...
@@ -332,28 +347,43 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
...
@@ -332,28 +347,43 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
NVTE_CHECK
(
torch_shape
.
size
()
==
shape
.
size
(),
"Shape expected to match torch shape. Shape "
,
NVTE_CHECK
(
torch_shape
.
size
()
==
shape
.
size
(),
"Shape expected to match torch shape. Shape "
,
columnwise_shape
,
" torch shape: "
,
torch_columnwise_shape
);
columnwise_shape
,
" torch shape: "
,
torch_columnwise_shape
);
if
(
torch_shape
.
size
()
>
0
)
{
if
(
torch_shape
.
size
()
>
0
)
{
torch_columnwise_shape
.
reserve
(
torch_shape
.
size
());
if
(
!
all_gather_usage
)
{
columnwise_shape
.
reserve
(
shape
.
size
());
torch_columnwise_shape
.
reserve
(
torch_shape
.
size
());
torch_columnwise_shape
.
push_back
(
torch_shape
[
torch_shape
.
size
()
-
1
]);
columnwise_shape
.
reserve
(
shape
.
size
());
columnwise_shape
.
push_back
(
shape
[
shape
.
size
()
-
1
]);
torch_columnwise_shape
.
push_back
(
torch_shape
[
torch_shape
.
size
()
-
1
]);
for
(
size_t
i
=
0
;
i
<
torch_shape
.
size
()
-
1
;
++
i
)
{
columnwise_shape
.
push_back
(
shape
[
shape
.
size
()
-
1
]);
torch_columnwise_shape
.
push_back
(
torch_shape
[
i
]);
for
(
size_t
i
=
0
;
i
<
torch_shape
.
size
()
-
1
;
++
i
)
{
columnwise_shape
.
push_back
(
shape
[
i
]);
torch_columnwise_shape
.
push_back
(
torch_shape
[
i
]);
columnwise_shape
.
push_back
(
shape
[
i
]);
}
}
else
{
// assert we are doing 1D scaling
NVTE_CHECK
(
block_scaling_dim
==
1
,
"Compact columnwise format is not supported for 128x128 2D block scaling."
);
torch_columnwise_shape
=
torch_shape
;
columnwise_shape
=
shape
;
}
}
}
}
size_t
sinv0
=
0
;
size_t
sinv0
=
0
;
size_t
sinv1
=
0
;
size_t
sinv1
=
0
;
if
(
block_scaling_dim
==
2
)
{
if
(
block_scaling_dim
==
2
)
{
// 2D scaling is always GEMM_READY for now
NVTE_CHECK
(
data_format
==
Float8BlockScaleTensorFormat
::
GEMM_READY
,
"2D scaling is always GEMM_READY for now."
);
sinv0
=
(
k_dim
+
kBlockLen
-
1
)
/
kBlockLen
;
sinv0
=
(
k_dim
+
kBlockLen
-
1
)
/
kBlockLen
;
sinv1
=
roundup
((
m_dim
+
kBlockLen
-
1
)
/
kBlockLen
,
4
);
sinv1
=
roundup
((
m_dim
+
kBlockLen
-
1
)
/
kBlockLen
,
4
);
}
else
if
(
block_scaling_dim
==
1
)
{
}
else
if
(
block_scaling_dim
==
1
)
{
bool
columnwise_compact
=
data_format
==
Float8BlockScaleTensorFormat
::
COMPACT
;
sinv0
=
(
m_dim
+
kBlockLen
-
1
)
/
kBlockLen
;
sinv0
=
(
m_dim
+
kBlockLen
-
1
)
/
kBlockLen
;
sinv1
=
roundup
(
k_dim
,
4
);
sinv1
=
columnwise_compact
?
k_dim
:
roundup
(
k_dim
,
4
);
// GEMM READY case: scaling factor is [sinv0, sinv1], already transposed here for CuBLAS
// for COMPACT case, since we apply 128x1 scaling here without transposing columnwise data, scaling factor is also [sinv0, sinv1]
// so no need to swap sinv0 and sinv1 here
}
else
{
}
else
{
NVTE_
CHECK
(
false
,
NVTE_
ERROR
(
"Unsupported block_scaling_dim in create_tensor columnwise."
"Unsupported block_scaling_dim in create_tensor columnwise.
"
"Expected 1 or 2. Got "
,
"Expected 1 or 2. Got "
,
block_scaling_dim
);
block_scaling_dim
);
}
}
data_colwise
=
at
::
empty
(
torch_columnwise_shape
,
opts
);
data_colwise
=
at
::
empty
(
torch_columnwise_shape
,
opts
);
scale_inv_colwise
=
scale_inv_colwise
=
...
@@ -373,7 +403,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
...
@@ -373,7 +403,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
"rowwise_data"
_a
=
data_rowwise
,
"columnwise_data"
_a
=
data_colwise
,
"rowwise_data"
_a
=
data_rowwise
,
"columnwise_data"
_a
=
data_colwise
,
"rowwise_scale_inv"
_a
=
scale_inv_rowwise
,
"columnwise_scale_inv"
_a
=
scale_inv_colwise
,
"rowwise_scale_inv"
_a
=
scale_inv_rowwise
,
"columnwise_scale_inv"
_a
=
scale_inv_colwise
,
"fp8_dtype"
_a
=
this
->
dtype
,
"quantizer"
_a
=
this
->
quantizer
,
"fp8_dtype"
_a
=
this
->
dtype
,
"quantizer"
_a
=
this
->
quantizer
,
"is_2D_scaled"
_a
=
(
block_scaling_dim
==
2
));
"is_2D_scaled"
_a
=
(
block_scaling_dim
==
2
)
,
"data_format"
_a
=
data_format
);
}
else
{
}
else
{
py
::
handle
Float8BlockwiseQTensorClass
(
py
::
handle
Float8BlockwiseQTensorClass
(
reinterpret_cast
<
PyObject
*>
(
Float8BlockwiseQTensorPythonClass
));
reinterpret_cast
<
PyObject
*>
(
Float8BlockwiseQTensorPythonClass
));
...
@@ -381,7 +411,8 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
...
@@ -381,7 +411,8 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
"shape"
_a
=
torch_shape
,
"dtype"
_a
=
GetATenDType
(
dtype
),
"rowwise_data"
_a
=
data_rowwise
,
"shape"
_a
=
torch_shape
,
"dtype"
_a
=
GetATenDType
(
dtype
),
"rowwise_data"
_a
=
data_rowwise
,
"columnwise_data"
_a
=
data_colwise
,
"rowwise_scale_inv"
_a
=
scale_inv_rowwise
,
"columnwise_data"
_a
=
data_colwise
,
"rowwise_scale_inv"
_a
=
scale_inv_rowwise
,
"columnwise_scale_inv"
_a
=
scale_inv_colwise
,
"fp8_dtype"
_a
=
this
->
dtype
,
"columnwise_scale_inv"
_a
=
scale_inv_colwise
,
"fp8_dtype"
_a
=
this
->
dtype
,
"quantizer"
_a
=
this
->
quantizer
,
"is_2D_scaled"
_a
=
(
block_scaling_dim
==
2
));
"quantizer"
_a
=
this
->
quantizer
,
"is_2D_scaled"
_a
=
(
block_scaling_dim
==
2
),
"data_format"
_a
=
data_format
);
}
}
return
{
std
::
move
(
tensor
),
std
::
move
(
ret
)};
return
{
std
::
move
(
tensor
),
std
::
move
(
ret
)};
...
...
transformer_engine/pytorch/distributed.py
View file @
2b05e121
...
@@ -8,6 +8,7 @@ from __future__ import annotations
...
@@ -8,6 +8,7 @@ from __future__ import annotations
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
from
contextlib
import
contextmanager
,
AbstractContextManager
,
ContextDecorator
from
contextlib
import
contextmanager
,
AbstractContextManager
,
ContextDecorator
from
functools
import
lru_cache
from
functools
import
lru_cache
from
dataclasses
import
dataclass
import
math
import
math
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
warnings
import
warnings
...
@@ -19,6 +20,15 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
...
@@ -19,6 +20,15 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from
torch.distributed.fsdp._common_utils
import
_get_module_fsdp_state
from
torch.distributed.fsdp._common_utils
import
_get_module_fsdp_state
from
torch.distributed.fsdp._traversal_utils
import
_get_fsdp_states_with_modules
from
torch.distributed.fsdp._traversal_utils
import
_get_fsdp_states_with_modules
try
:
import
torch.distributed._symmetric_memory
as
symm_mem
HAS_TORCH_SYMMETRIC
=
True
except
ImportError
:
HAS_TORCH_SYMMETRIC
=
False
import
transformer_engine_torch
as
tex
from
.
import
torch_version
from
.
import
torch_version
from
.utils
import
(
from
.utils
import
(
is_non_tn_fp8_gemm_supported
,
is_non_tn_fp8_gemm_supported
,
...
@@ -34,14 +44,8 @@ from .tensor.quantized_tensor import QuantizedTensor, Quantizer
...
@@ -34,14 +44,8 @@ from .tensor.quantized_tensor import QuantizedTensor, Quantizer
from
.tensor._internal.float8_tensor_base
import
Float8TensorBase
from
.tensor._internal.float8_tensor_base
import
Float8TensorBase
from
.tensor._internal.mxfp8_tensor_base
import
MXFP8TensorBase
from
.tensor._internal.mxfp8_tensor_base
import
MXFP8TensorBase
from
.tensor._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
.tensor._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
..debug.pytorch.debug_quantization
import
DebugQuantizedTensor
from
..debug.pytorch.debug_quantization
import
DebugQuantizedTensor
,
DebugQuantizer
try
:
import
torch.distributed._symmetric_memory
as
symm_mem
HAS_TORCH_SYMMETRIC
=
True
except
ImportError
:
HAS_TORCH_SYMMETRIC
=
False
__all__
=
[
"checkpoint"
,
"CudaRNGStatesTracker"
]
__all__
=
[
"checkpoint"
,
"CudaRNGStatesTracker"
]
...
@@ -943,7 +947,7 @@ def _all_gather_fp8(
...
@@ -943,7 +947,7 @@ def _all_gather_fp8(
out
=
quantizer
.
make_empty
(
out_shape
,
dtype
=
dtype
,
device
=
device
)
out
=
quantizer
.
make_empty
(
out_shape
,
dtype
=
dtype
,
device
=
device
)
elif
isinstance
(
inp
,
Float8Tensor
):
elif
isinstance
(
inp
,
Float8Tensor
):
out
=
inp
.
make_like
(
inp
,
shape
=
out_shape
)
out
=
inp
.
make_like
(
inp
,
shape
=
out_shape
)
out
.
_data
=
torch
.
empty
_like
(
out
.
_data
=
torch
.
empty
(
out_shape
,
out_shape
,
dtype
=
torch
.
uint8
,
dtype
=
torch
.
uint8
,
device
=
inp
.
device
,
device
=
inp
.
device
,
...
@@ -977,6 +981,67 @@ def _all_gather_fp8(
...
@@ -977,6 +981,67 @@ def _all_gather_fp8(
return
out
,
handle
return
out
,
handle
def
_set_quantizer_format
(
quantizer
:
Quantizer
,
compact
:
bool
=
False
)
->
None
:
"""Make quantizer compact"""
_quantizer
=
quantizer
if
isinstance
(
quantizer
,
DebugQuantizer
):
_quantizer
=
quantizer
.
parent_quantizer
if
isinstance
(
_quantizer
,
Float8BlockQuantizer
):
_quantizer
.
all_gather_usage
=
compact
def
_post_process_fp8_blockwise_gather
(
out
:
Float8BlockwiseQTensorBase
,
quantizer
:
Float8BlockQuantizer
,
handle
:
Optional
[
torch
.
distributed
.
Work
]
=
None
,
)
->
Float8BlockwiseQTensorBase
:
"""Post-process FP8 blockwise gather."""
if
handle
is
not
None
:
handle
.
wait
()
handle
=
None
if
out
.
_is_gemm_ready_format
():
return
out
needs_columnwise_data_transpose
=
(
quantizer
is
not
None
and
quantizer
.
columnwise_usage
and
not
is_non_tn_fp8_gemm_supported
()
)
need_rowwise_scale_transpose
=
(
quantizer
is
not
None
and
quantizer
.
rowwise_usage
and
not
is_non_tn_fp8_gemm_supported
()
)
# CuBLAS requires transpose of the scale inv tensor, suppose orig input is 256x1024
# columnwise compact format means doing 128x1 quantization of it
# so quantized tensor is 256x1024, scale inv is 2x1024
# If we were doing GEMM_READY format, then it's equivalent to do 1x128 quantization
# on a transposed 1024x256 tensor, so scale inv is 1024x2, cublas requries 2x1024
# Thereforce, it turns out we don't need to transpose the scale inv, only columnwise data
if
needs_columnwise_data_transpose
:
out
.
_transpose_columnwise_data
()
if
need_rowwise_scale_transpose
:
out
.
_rowwise_scale_inv
=
out
.
_rowwise_scale_inv
.
transpose
(
-
2
,
-
1
).
contiguous
()
out
.
_data_format
=
tex
.
Float8BlockScaleTensorFormat
.
GEMM_READY
return
out
@
dataclass
class
_FP8BlockwiseAllGatherAsyncHandle
:
"""Handle for asynchronous FP8 blockwise all-gather."""
tensor
:
Float8BlockwiseQTensorBase
quantizer
:
Float8BlockQuantizer
async_handle
:
torch
.
distributed
.
Work
_synchronized
:
bool
=
False
def
wait
(
self
)
->
None
:
"""Wait for the async operation to complete and post-process the tensor."""
if
self
.
_synchronized
:
return
self
.
async_handle
.
wait
()
_post_process_fp8_blockwise_gather
(
self
.
tensor
,
self
.
quantizer
)
self
.
_synchronized
=
True
def
_all_gather_fp8_blockwise
(
def
_all_gather_fp8_blockwise
(
inp
:
torch
.
Tensor
,
inp
:
torch
.
Tensor
,
process_group
:
dist_group_type
,
process_group
:
dist_group_type
,
...
@@ -990,8 +1055,9 @@ def _all_gather_fp8_blockwise(
...
@@ -990,8 +1055,9 @@ def _all_gather_fp8_blockwise(
Returns: quantizer(gather(inp))
Returns: quantizer(gather(inp))
NOTE: The implementation is not sophisticated enough to honor async_op=True.
NOTE: The implementation is only going to honor async_op=True for FP8 gather case.
In some cases it falls back to synchronous gather and invokes the quantizer.
In the case where tensor shape is not divisible by 128, the implementation will fall back
to synchronous gather and invoke the quantizer.
"""
"""
# Input tensor attributes
# Input tensor attributes
...
@@ -1027,7 +1093,11 @@ def _all_gather_fp8_blockwise(
...
@@ -1027,7 +1093,11 @@ def _all_gather_fp8_blockwise(
out_shape
[
0
]
*=
world_size
out_shape
[
0
]
*=
world_size
# Doing BF16 gather for now as baseline because it's simpler
# Doing BF16 gather for now as baseline because it's simpler
if
not
isinstance
(
inp
,
Float8BlockwiseQTensorBase
)
and
quantizer
is
not
None
:
if
(
not
isinstance
(
inp
,
Float8BlockwiseQTensorBase
)
and
quantizer
is
not
None
and
not
quantizer
.
is_quantizable
(
inp
)
):
out
=
torch
.
empty
(
out
=
torch
.
empty
(
out_shape
,
out_shape
,
dtype
=
dtype
,
dtype
=
dtype
,
...
@@ -1035,14 +1105,93 @@ def _all_gather_fp8_blockwise(
...
@@ -1035,14 +1105,93 @@ def _all_gather_fp8_blockwise(
memory_format
=
torch
.
contiguous_format
,
memory_format
=
torch
.
contiguous_format
,
)
)
torch
.
distributed
.
all_gather_into_tensor
(
out
,
inp
,
group
=
process_group
,
async_op
=
False
)
torch
.
distributed
.
all_gather_into_tensor
(
out
,
inp
,
group
=
process_group
,
async_op
=
False
)
orig_all_gather_usage
=
quantizer
.
all_gather_usage
quantizer
.
all_gather_usage
=
False
out
=
quantizer
(
out
)
out
=
quantizer
(
out
)
quantizer
.
all_gather_usage
=
orig_all_gather_usage
return
out
,
None
return
out
,
None
# Implementation of fp8 gather needs to account for:
# Implementation of fp8 gather needs to account for:
# * Getting columnwise data as a transpose of how it is stored for GEMMS.
# * Getting columnwise data as a transpose of how it is stored for GEMMS.
# * Gathering non GEMM swizzled scales.
# * Gathering non GEMM swizzled scales.
# * Refer to scaffold code when implementing at:
# https://github.com/kwyss-nvidia/TransformerEngine/commit/6659ee9dc84fb515d1d47699d8bfd20a72b76477
# Cast input tensor to Float8BlockwiseQTensor with required data
raise
NotImplementedError
(
"fp8 blockwise allgather not yet implemented"
)
# Set to compact usage in case the quantizer is not correctly configured
orig_all_gather_usage
=
quantizer
.
all_gather_usage
quantizer
.
all_gather_usage
=
True
if
not
isinstance
(
inp
,
Float8BlockwiseQTensorBase
):
inp
=
quantizer
(
inp
)
elif
(
quantizer
.
rowwise_usage
and
inp
.
_rowwise_data
is
None
)
or
(
quantizer
.
columnwise_usage
and
inp
.
_columnwise_data
is
None
):
warnings
.
warn
(
"Input and quantizer do not have matching usages. "
"Dequantizing and requantizing to Float8BlockwiseQTensor."
)
inp
=
quantizer
(
inp
.
dequantize
())
quantizer
.
all_gather_usage
=
orig_all_gather_usage
# Begin to do network communication, need to make sure compact format
if
inp
.
_data_format
!=
tex
.
Float8BlockScaleTensorFormat
.
COMPACT
:
raise
RuntimeError
(
"All-gather with FP8 block-wise quantized tensor requires compact data format, "
f
"but found data_format=
{
inp
.
_data_format
}
"
)
# Construct Float8BlockwiseQTensor output tensor
out
=
quantizer
.
make_empty
(
out_shape
,
dtype
=
dtype
,
device
=
device
)
# Coalesce NCCL collectives
with
torch
.
distributed
.
_coalescing_manager
(
group
=
process_group
,
device
=
device
,
async_ops
=
async_op
,
)
as
coalescing_manager
:
# Gather Float8BlockwiseQTensor data for row-wise usage
if
quantizer
.
rowwise_usage
:
# Launch all-gathers
torch
.
distributed
.
all_gather_into_tensor
(
out
.
_rowwise_scale_inv
,
inp
.
_rowwise_scale_inv
,
group
=
process_group
,
)
torch
.
distributed
.
all_gather_into_tensor
(
out
.
_rowwise_data
,
inp
.
_rowwise_data
,
group
=
process_group
,
)
# Gather Float8BlockwiseQTensor data for column-wise usage
if
quantizer
.
columnwise_usage
:
# Launch all-gathers
torch
.
distributed
.
all_gather_into_tensor
(
out
.
_columnwise_scale_inv
,
inp
.
_columnwise_scale_inv
,
group
=
process_group
,
)
torch
.
distributed
.
all_gather_into_tensor
(
out
.
_columnwise_data
,
inp
.
_columnwise_data
,
group
=
process_group
,
)
handle
=
coalescing_manager
if
async_op
else
None
# Unlike MXFP8, this fp8 blockwise tensor primarily works with Hopper
# This means that we need to transpose the gathered columnwise data
# Example usage is grad_output tensor, ie. dY in linear backward
# We want to gather two FP8 tensors (rowwise and columnwise) along dim0
# and then transpose the columnwise data to match the rowwise data
# Make sure FP8 transpose is populated if needed
if
async_op
:
handle
=
_FP8BlockwiseAllGatherAsyncHandle
(
out
,
quantizer
,
handle
)
else
:
# if it's a sync op, we need to do the transpose here as post processing step
_post_process_fp8_blockwise_gather
(
out
,
quantizer
,
handle
)
return
out
,
handle
def
_all_gather_mxfp8
(
def
_all_gather_mxfp8
(
...
@@ -1239,12 +1388,18 @@ def gather_along_first_dim(
...
@@ -1239,12 +1388,18 @@ def gather_along_first_dim(
final_quantizer
=
(
final_quantizer
=
(
None
if
not
needs_quantized_gemm
(
inp
,
rowwise
=
True
)
else
quantizer
.
parent_quantizer
None
if
not
needs_quantized_gemm
(
inp
,
rowwise
=
True
)
else
quantizer
.
parent_quantizer
)
)
# Temporary fix for TP communication of Float8BlockwiseQTensorBase
if
isinstance
(
rowwise
,
Float8BlockwiseQTensorBase
):
rowwise
=
inp
.
_original_tensor
rowwise_total
=
gather_along_first_dim
(
rowwise
,
process_group
,
False
,
final_quantizer
)[
0
]
rowwise_total
=
gather_along_first_dim
(
rowwise
,
process_group
,
False
,
final_quantizer
)[
0
]
out_obj
.
rowwise_gemm_tensor
=
rowwise_total
out_obj
.
rowwise_gemm_tensor
=
rowwise_total
if
rowwise
is
not
columnwise
:
if
rowwise
is
not
columnwise
:
final_quantizer_columnwise
=
(
final_quantizer_columnwise
=
(
None
if
not
needs_quantized_gemm
(
inp
,
rowwise
=
False
)
else
quantizer
.
parent_quantizer
None
if
not
needs_quantized_gemm
(
inp
,
rowwise
=
False
)
else
quantizer
.
parent_quantizer
)
)
# Temporary fix for TP communication of Float8BlockwiseQTensorBase
if
isinstance
(
columnwise
,
Float8BlockwiseQTensorBase
):
columnwise
=
inp
.
_original_tensor
columnwise_total
,
_
=
gather_along_first_dim
(
columnwise_total
,
_
=
gather_along_first_dim
(
columnwise
,
process_group
,
False
,
final_quantizer_columnwise
columnwise
,
process_group
,
False
,
final_quantizer_columnwise
)
)
...
@@ -1261,6 +1416,9 @@ def gather_along_first_dim(
...
@@ -1261,6 +1416,9 @@ def gather_along_first_dim(
)
)
if
isinstance
(
inp
,
QuantizedTensor
):
if
isinstance
(
inp
,
QuantizedTensor
):
inp
=
inp
.
dequantize
()
inp
=
inp
.
dequantize
()
# Falling back to high-precision all-gather for Float8BlockQuantizer
# means that it should directly output GEMM_READY format
_set_quantizer_format
(
quantizer
,
compact
=
False
)
out
=
torch
.
empty
(
out
=
torch
.
empty
(
out_shape
,
out_shape
,
dtype
=
inp
.
dtype
,
dtype
=
inp
.
dtype
,
...
...
Prev
1
…
7
8
9
10
11
12
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment