Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
44740c6c
Commit
44740c6c
authored
Jul 22, 2025
by
yuguo
Browse files
Merge commit '
7a9a0825
' of...
Merge commit '
7a9a0825
' of
https://github.com/NVIDIA/TransformerEngine
parents
8113d9e0
7a9a0825
Changes
162
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1613 additions
and
332 deletions
+1613
-332
transformer_engine/pytorch/attention/inference.py
transformer_engine/pytorch/attention/inference.py
+7
-1
transformer_engine/pytorch/cpu_offload.py
transformer_engine/pytorch/cpu_offload.py
+4
-1
transformer_engine/pytorch/csrc/common.h
transformer_engine/pytorch/csrc/common.h
+5
-0
transformer_engine/pytorch/csrc/extensions.h
transformer_engine/pytorch/csrc/extensions.h
+43
-5
transformer_engine/pytorch/csrc/extensions/cast.cpp
transformer_engine/pytorch/csrc/extensions/cast.cpp
+597
-59
transformer_engine/pytorch/csrc/extensions/padding.cpp
transformer_engine/pytorch/csrc/extensions/padding.cpp
+73
-0
transformer_engine/pytorch/csrc/extensions/pybind.cpp
transformer_engine/pytorch/csrc/extensions/pybind.cpp
+36
-5
transformer_engine/pytorch/csrc/extensions/router.cpp
transformer_engine/pytorch/csrc/extensions/router.cpp
+190
-0
transformer_engine/pytorch/csrc/extensions/transpose.cpp
transformer_engine/pytorch/csrc/extensions/transpose.cpp
+7
-70
transformer_engine/pytorch/csrc/quantizer.cpp
transformer_engine/pytorch/csrc/quantizer.cpp
+121
-62
transformer_engine/pytorch/export.py
transformer_engine/pytorch/export.py
+71
-0
transformer_engine/pytorch/fp8.py
transformer_engine/pytorch/fp8.py
+7
-1
transformer_engine/pytorch/graph.py
transformer_engine/pytorch/graph.py
+202
-35
transformer_engine/pytorch/jit.py
transformer_engine/pytorch/jit.py
+12
-2
transformer_engine/pytorch/module/_common.py
transformer_engine/pytorch/module/_common.py
+3
-0
transformer_engine/pytorch/module/base.py
transformer_engine/pytorch/module/base.py
+1
-0
transformer_engine/pytorch/module/fp8_padding.py
transformer_engine/pytorch/module/fp8_padding.py
+18
-17
transformer_engine/pytorch/module/fp8_unpadding.py
transformer_engine/pytorch/module/fp8_unpadding.py
+16
-13
transformer_engine/pytorch/module/grouped_linear.py
transformer_engine/pytorch/module/grouped_linear.py
+120
-52
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+80
-9
No files found.
transformer_engine/pytorch/attention/inference.py
View file @
44740c6c
...
@@ -420,6 +420,8 @@ class NonPagedKVCacheManager(KVCacheManager):
...
@@ -420,6 +420,8 @@ class NonPagedKVCacheManager(KVCacheManager):
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
torch
.
cuda
.
current_device
(),
device
=
torch
.
cuda
.
current_device
(),
)
)
# whether reindexing is needed, i.e. when batch seq_ids have changed
self
.
need_reindex
=
True
def
allocate_memory
(
self
,
layer_number
):
def
allocate_memory
(
self
,
layer_number
):
"""Allocate memory for the cache"""
"""Allocate memory for the cache"""
...
@@ -451,6 +453,7 @@ class NonPagedKVCacheManager(KVCacheManager):
...
@@ -451,6 +453,7 @@ class NonPagedKVCacheManager(KVCacheManager):
# step() re-indexes k_cache and v_cache using batch_indices = [0, 2, 3, 1] so that
# step() re-indexes k_cache and v_cache using batch_indices = [0, 2, 3, 1] so that
# they are contiguous and match the indexing in q
# they are contiguous and match the indexing in q
prev_batch_size
=
len
(
self
.
sequences
)
prev_batch_size
=
len
(
self
.
sequences
)
prev_seq_ids
=
set
(
self
.
sequences
.
keys
())
unfinished_seqs
=
self
.
sequences
.
keys
()
&
step_dict
.
keys
()
unfinished_seqs
=
self
.
sequences
.
keys
()
&
step_dict
.
keys
()
finished_seqs
=
self
.
sequences
.
keys
()
-
unfinished_seqs
finished_seqs
=
self
.
sequences
.
keys
()
-
unfinished_seqs
unfinished_indices
=
[
i
for
i
,
j
in
enumerate
(
self
.
sequences
)
if
j
in
unfinished_seqs
]
unfinished_indices
=
[
i
for
i
,
j
in
enumerate
(
self
.
sequences
)
if
j
in
unfinished_seqs
]
...
@@ -478,6 +481,9 @@ class NonPagedKVCacheManager(KVCacheManager):
...
@@ -478,6 +481,9 @@ class NonPagedKVCacheManager(KVCacheManager):
for
i
in
new_seqs
:
for
i
in
new_seqs
:
self
.
sequences
[
i
]
=
step_dict
[
i
]
self
.
sequences
[
i
]
=
step_dict
[
i
]
# Whether reindexing is needed
self
.
need_reindex
=
set
(
self
.
sequences
.
keys
())
!=
prev_seq_ids
return
self
.
sequences
return
self
.
sequences
def
step
(
def
step
(
...
@@ -538,7 +544,7 @@ class NonPagedKVCacheManager(KVCacheManager):
...
@@ -538,7 +544,7 @@ class NonPagedKVCacheManager(KVCacheManager):
ctx_len
,
ctx_len
,
self
.
max_seqlen
,
self
.
max_seqlen
,
1
,
1
,
True
,
self
.
need_reindex
,
)
)
k_cache
=
k_cache
[:
batch_size
]
k_cache
=
k_cache
[:
batch_size
]
...
...
transformer_engine/pytorch/cpu_offload.py
View file @
44740c6c
...
@@ -9,8 +9,8 @@ from typing import Any, Dict, Optional
...
@@ -9,8 +9,8 @@ from typing import Any, Dict, Optional
import
torch
import
torch
from
transformer_engine.debug.pytorch.debug_state
import
TEDebugState
from
.tensor.quantized_tensor
import
QuantizedTensorBase
from
.tensor.quantized_tensor
import
QuantizedTensorBase
from
.tensor.float8_tensor
import
Float8Tensor
from
.tensor.float8_tensor
import
Float8Tensor
__all__
=
[
"get_cpu_offload_context"
]
__all__
=
[
"get_cpu_offload_context"
]
...
@@ -20,6 +20,9 @@ CPUOffloadEnabled = False
...
@@ -20,6 +20,9 @@ CPUOffloadEnabled = False
def
mark_activation_offload
(
*
tensors
):
def
mark_activation_offload
(
*
tensors
):
"""Set the type of the offloading needed for a tensor."""
"""Set the type of the offloading needed for a tensor."""
if
TEDebugState
.
debug_enabled
:
raise
RuntimeError
(
"CPU offload is not supported in debug mode."
)
for
tensor
in
tensors
:
for
tensor
in
tensors
:
if
tensor
is
None
:
if
tensor
is
None
:
continue
continue
...
...
transformer_engine/pytorch/csrc/common.h
View file @
44740c6c
...
@@ -33,6 +33,7 @@
...
@@ -33,6 +33,7 @@
#include <transformer_engine/comm_gemm_overlap.h>
#include <transformer_engine/comm_gemm_overlap.h>
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/fused_rope.h>
#include <transformer_engine/fused_rope.h>
#include <transformer_engine/fused_router.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/multi_stream.h>
#include <transformer_engine/multi_stream.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/multi_tensor.h>
...
@@ -215,6 +216,8 @@ class Float8BlockQuantizer : public Quantizer {
...
@@ -215,6 +216,8 @@ class Float8BlockQuantizer : public Quantizer {
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
,
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
,
std
::
optional
<
at
::
Tensor
>
rowwise_data
=
std
::
nullopt
)
const
override
;
std
::
optional
<
at
::
Tensor
>
rowwise_data
=
std
::
nullopt
)
const
override
;
std
::
vector
<
size_t
>
get_scale_shape
(
const
std
::
vector
<
size_t
>&
shape
,
bool
columnwise
)
const
;
};
};
class
MXFP8Quantizer
:
public
Quantizer
{
class
MXFP8Quantizer
:
public
Quantizer
{
...
@@ -230,6 +233,8 @@ class MXFP8Quantizer : public Quantizer {
...
@@ -230,6 +233,8 @@ class MXFP8Quantizer : public Quantizer {
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
,
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
,
std
::
optional
<
at
::
Tensor
>
rowwise_data
=
std
::
nullopt
)
const
override
;
std
::
optional
<
at
::
Tensor
>
rowwise_data
=
std
::
nullopt
)
const
override
;
std
::
vector
<
size_t
>
get_scale_shape
(
const
std
::
vector
<
size_t
>&
shape
,
bool
columnwise
)
const
;
};
};
std
::
unique_ptr
<
Quantizer
>
convert_quantizer
(
py
::
handle
quantizer
);
std
::
unique_ptr
<
Quantizer
>
convert_quantizer
(
py
::
handle
quantizer
);
...
...
transformer_engine/pytorch/csrc/extensions.h
View file @
44740c6c
...
@@ -13,6 +13,38 @@
...
@@ -13,6 +13,38 @@
namespace
transformer_engine
::
pytorch
{
namespace
transformer_engine
::
pytorch
{
/***************************************************************************************************
* Router fusion
**************************************************************************************************/
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
fused_topk_with_score_function_fwd
(
at
::
Tensor
logits
,
int
topk
,
bool
use_pre_softmax
,
c10
::
optional
<
int
>
num_groups
,
c10
::
optional
<
int
>
group_topk
,
c10
::
optional
<
float
>
scaling_factor
,
std
::
string
score_function
,
c10
::
optional
<
at
::
Tensor
>
expert_bias
);
at
::
Tensor
fused_topk_with_score_function_bwd
(
int
num_tokens
,
int
num_experts
,
at
::
Tensor
routing_map
,
at
::
Tensor
intermediate_output
,
at
::
Tensor
grad_probs
,
int
topk
,
bool
use_pre_softmax
,
c10
::
optional
<
float
>
scaling_factor
,
std
::
string
score_function
);
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
fused_score_for_moe_aux_loss_fwd
(
at
::
Tensor
logits
,
int
topk
,
std
::
string
score_function
);
at
::
Tensor
fused_score_for_moe_aux_loss_bwd
(
int
num_tokens
,
int
num_experts
,
at
::
Tensor
intermediate_output
,
at
::
Tensor
grad_probs
,
int
topk
,
std
::
string
score_function
);
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
fused_moe_aux_loss_fwd
(
at
::
Tensor
probs
,
at
::
Tensor
tokens_per_expert
,
int
total_num_tokens
,
int
num_experts
,
int
num_rows
,
int
num_cols
,
int
topk
,
float
coeff
);
at
::
Tensor
fused_moe_aux_loss_bwd
(
at
::
Tensor
Const_buf
,
at
::
Tensor
tokens_per_expert
,
int
num_rows
,
int
num_cols
,
at
::
Tensor
grad_aux_loss
);
/***************************************************************************************************
/***************************************************************************************************
* Permutation
* Permutation
**************************************************************************************************/
**************************************************************************************************/
...
@@ -136,10 +168,6 @@ std::vector<at::Tensor> te_batchgemm_ts(
...
@@ -136,10 +168,6 @@ std::vector<at::Tensor> te_batchgemm_ts(
* Transpose
* Transpose
**************************************************************************************************/
**************************************************************************************************/
std
::
vector
<
py
::
object
>
fused_multi_quantize
(
std
::
vector
<
at
::
Tensor
>
input_list
,
std
::
optional
<
std
::
vector
<
py
::
object
>>
output_list
,
std
::
vector
<
py
::
handle
>
quantizer_list
,
DType
otype
);
at
::
Tensor
fp8_transpose
(
at
::
Tensor
input
,
DType
otype
,
at
::
Tensor
fp8_transpose
(
at
::
Tensor
input
,
DType
otype
,
std
::
optional
<
at
::
Tensor
>
output
=
std
::
nullopt
);
std
::
optional
<
at
::
Tensor
>
output
=
std
::
nullopt
);
...
@@ -210,10 +238,17 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
...
@@ -210,10 +238,17 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
**************************************************************************************************/
**************************************************************************************************/
py
::
object
quantize
(
const
at
::
Tensor
&
tensor
,
py
::
handle
quantizer
,
const
py
::
object
&
output
,
py
::
object
quantize
(
const
at
::
Tensor
&
tensor
,
py
::
handle
quantizer
,
const
py
::
object
&
output
,
std
::
optional
<
at
::
Tensor
>
noop
);
std
::
optional
<
at
::
Tensor
>
noop
_flag
);
py
::
object
dequantize
(
const
py
::
handle
&
input
,
DType
otype
);
py
::
object
dequantize
(
const
py
::
handle
&
input
,
DType
otype
);
std
::
vector
<
py
::
object
>
multi_tensor_quantize
(
const
std
::
vector
<
at
::
Tensor
>
&
tensor_list
,
std
::
vector
<
py
::
handle
>
quantizer_list
);
std
::
vector
<
py
::
object
>
split_quantize
(
const
at
::
Tensor
&
tensor
,
const
std
::
vector
<
int
>
&
split_sections
,
std
::
vector
<
py
::
handle
>
quantizer_list
);
/***************************************************************************************************
/***************************************************************************************************
* Bias gradient fusions
* Bias gradient fusions
**************************************************************************************************/
**************************************************************************************************/
...
@@ -395,6 +430,9 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
...
@@ -395,6 +430,9 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
std
::
vector
<
size_t
>
input_row_list
,
std
::
vector
<
size_t
>
input_row_list
,
std
::
vector
<
size_t
>
padded_input_row_list
);
std
::
vector
<
size_t
>
padded_input_row_list
);
void
fused_multi_row_unpadding
(
at
::
Tensor
input
,
at
::
Tensor
output
,
std
::
vector
<
size_t
>
input_row_list
,
std
::
vector
<
size_t
>
unpadded_input_row_list
);
/***************************************************************************************************
/***************************************************************************************************
* NVSHMEM APIs
* NVSHMEM APIs
**************************************************************************************************/
**************************************************************************************************/
...
...
transformer_engine/pytorch/csrc/extensions/cast.cpp
View file @
44740c6c
...
@@ -6,60 +6,51 @@
...
@@ -6,60 +6,51 @@
#include "transformer_engine/cast.h"
#include "transformer_engine/cast.h"
#include <cstdint>
#include <memory>
#include <optional>
#include <tuple>
#include <utility>
#include <vector>
#include "../extensions.h"
#include "../extensions.h"
#include "common.h"
#include "common.h"
#include "pybind.h"
#include "pybind.h"
#include "transformer_engine/transformer_engine.h"
#include "transformer_engine/transformer_engine.h"
namespace
transformer_engine
::
pytorch
{
namespace
transformer_engine
{
namespace
pytorch
{
py
::
object
quantize
(
const
at
::
Tensor
&
tensor
,
py
::
handle
quantizer
,
const
py
::
object
&
output
,
std
::
optional
<
at
::
Tensor
>
noop
)
{
init_extension
();
auto
my_quantizer
=
convert_quantizer
(
quantizer
);
auto
input_tensor
=
tensor
.
contiguous
();
const
TensorWrapper
&
te_input
=
makeTransformerEngineTensor
(
input_tensor
);
namespace
{
const
auto
&
te_input_shape
=
te_input
.
shape
();
std
::
vector
<
size_t
>
input_shape
(
te_input_shape
.
data
,
te_input_shape
.
data
+
te_input_shape
.
ndim
);
auto
fake_tensor_type
=
tensor
.
scalar_type
();
if
(
!
detail
::
IsFloatingPointType
(
fake_tensor_type
))
{
fake_tensor_type
=
at
::
kFloat
;
}
TensorWrapper
te_output
;
std
::
vector
<
size_t
>
get_tensor_shape
(
const
TensorWrapper
&
tensor
)
{
py
::
object
out
;
const
auto
&
shape
=
tensor
.
shape
();
if
(
output
.
is_none
())
{
return
std
::
vector
<
size_t
>
(
shape
.
data
,
shape
.
data
+
shape
.
ndim
);
DType
fake_te_type
=
GetTransformerEngineDType
(
fake_tensor_type
);
}
std
::
tie
(
te_output
,
out
)
=
my_quantizer
->
create_tensor
(
input_shape
,
fake_te_type
);
}
else
{
out
=
output
;
te_output
=
makeTransformerEngineTensor
(
output
,
quantizer
);
}
TensorWrapper
te_noop
;
void
quantize_impl
(
const
TensorWrapper
&
input
,
py
::
handle
&
quantizer_py
,
if
(
noop
.
has_value
())
{
std
::
unique_ptr
<
Quantizer
>
&
quantizer_cpp
,
TensorWrapper
&
output
,
te_noop
=
makeTransformerEngineTensor
(
*
noop
);
TensorWrapper
&
noop_flag
)
{
}
else
{
// Check tensor dims
te_noop
=
TensorWrapper
();
NVTE_CHECK
(
get_tensor_shape
(
input
)
==
get_tensor_shape
(
output
),
"Input tensor (shape="
,
get_tensor_shape
(
input
),
") and output tensor (shape="
,
get_tensor_shape
(
output
),
") do not match"
);
if
(
input
.
numel
()
==
0
)
{
return
;
}
}
if
(
te_output
.
numel
()
==
0
)
return
out
;
// Recipe-specific configuration
QuantizationConfigWrapper
quant_config
;
QuantizationConfigWrapper
quant_config
;
quant_config
.
set_noop_tensor
(
te_noop
.
data
());
quant_config
.
set_noop_tensor
(
noop_flag
.
data
());
if
(
detail
::
IsFloat8CurrentScalingQuantizers
(
quantizer_py
.
ptr
()))
{
if
(
detail
::
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
auto
my_quantizer_cs
=
static_cast
<
Float8CurrentScalingQuantizer
*>
(
quantizer_cpp
.
get
());
// my_quantizer here has to be a Float8CurrentScalingQuantizer
NVTE_SCOPED_GIL_RELEASE
(
auto
my_quantizer_cs
=
static_cast
<
Float8CurrentScalingQuantizer
*>
(
my_quantizer
.
get
());
{
nvte_compute_amax
(
input
.
data
(),
output
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
NVTE_SCOPED_GIL_RELEASE
({
nvte_compute_amax
(
te_input
.
data
(),
te_output
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
// check if we need to do amax reudction (depending on model parallel configs)
// check if we need to do amax reudction (depending on model parallel configs)
if
(
my_quantizer_cs
->
with_amax_reduction
)
{
if
(
my_quantizer_cs
->
with_amax_reduction
)
{
c10
::
intrusive_ptr
<
dist_group_type
>
process_group_ptr
=
my_quantizer_cs
->
amax_reduction_group
;
c10
::
intrusive_ptr
<
dist_group_type
>
process_group_ptr
=
my_quantizer_cs
->
amax_reduction_group
;
// construct torch tesnor from NVTEBasicTensor without reallocating memory
// construct torch tesnor from NVTEBasicTensor without reallocating memory
at
::
Tensor
&
amax_tensor_torch
=
my_quantizer_cs
->
amax
;
at
::
Tensor
&
amax_tensor_torch
=
my_quantizer_cs
->
amax
;
std
::
vector
<
at
::
Tensor
>
tensors
=
{
amax_tensor_torch
};
std
::
vector
<
at
::
Tensor
>
tensors
=
{
amax_tensor_torch
};
// allreduce amax tensor
// allreduce amax tensor
c10d
::
AllreduceOptions
allreduce_opts
;
c10d
::
AllreduceOptions
allreduce_opts
;
...
@@ -72,37 +63,70 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
...
@@ -72,37 +63,70 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
quant_config
.
set_force_pow_2_scales
(
my_quantizer_cs
->
force_pow_2_scales
);
quant_config
.
set_force_pow_2_scales
(
my_quantizer_cs
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_cs
->
amax_epsilon
);
quant_config
.
set_amax_epsilon
(
my_quantizer_cs
->
amax_epsilon
);
NVTE_SCOPED_GIL_RELEASE
({
NVTE_SCOPED_GIL_RELEASE
({
nvte_compute_scale_from_amax
(
te_output
.
data
(),
quant_config
,
nvte_compute_scale_from_amax
(
output
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
at
::
cuda
::
getCurrentCUDAStream
());
});
});
// set amax ptr to null in
te_
output TensorWrapper to avoid atomic amax updates in kernel
// set amax ptr to null in output TensorWrapper to avoid atomic amax updates in kernel
te_
output
.
set_amax
(
nullptr
,
DType
::
kFloat32
,
te_
output
.
defaultShape
);
output
.
set_amax
(
nullptr
,
DType
::
kFloat32
,
output
.
defaultShape
);
}
else
if
(
detail
::
IsFloat8BlockwiseQuantizers
(
quantizer
.
ptr
()))
{
}
else
if
(
detail
::
IsFloat8BlockwiseQuantizers
(
quantizer
_py
.
ptr
()))
{
auto
my_quantizer_bw
=
static_cast
<
Float8BlockQuantizer
*>
(
my_
quantizer
.
get
());
auto
my_quantizer_bw
=
static_cast
<
Float8BlockQuantizer
*>
(
quantizer
_cpp
.
get
());
quant_config
.
set_force_pow_2_scales
(
my_quantizer_bw
->
force_pow_2_scales
);
quant_config
.
set_force_pow_2_scales
(
my_quantizer_bw
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_bw
->
amax_epsilon
);
quant_config
.
set_amax_epsilon
(
my_quantizer_bw
->
amax_epsilon
);
if
(
my_quantizer_bw
->
all_gather_usage
)
{
if
(
my_quantizer_bw
->
all_gather_usage
)
{
quant_config
.
set_float8_block_scale_tensor_format
(
Float8BlockScaleTensorFormat
::
COMPACT
);
quant_config
.
set_float8_block_scale_tensor_format
(
Float8BlockScaleTensorFormat
::
COMPACT
);
}
}
}
}
// Perform quantization
NVTE_SCOPED_GIL_RELEASE
({
NVTE_SCOPED_GIL_RELEASE
({
nvte_quantize_v2
(
te_input
.
data
(),
te_output
.
data
(),
quant_config
,
nvte_quantize_v2
(
input
.
data
(),
output
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
at
::
cuda
::
getCurrentCUDAStream
());
});
});
}
return
out
;
}
// namespace
py
::
object
quantize
(
const
at
::
Tensor
&
tensor
,
py
::
handle
quantizer
,
const
py
::
object
&
output
,
std
::
optional
<
at
::
Tensor
>
noop_flag
)
{
// Convert quantizer to C++ object
auto
quantizer_cpp
=
convert_quantizer
(
quantizer
);
// Convert input tensor to C++ object
auto
input_contiguous
=
tensor
.
contiguous
();
const
auto
input_cpp
=
makeTransformerEngineTensor
(
input_contiguous
);
// Initialize output tensor
TensorWrapper
output_cpp
;
py
::
object
output_py
;
if
(
output
.
is_none
())
{
const
auto
shape
=
get_tensor_shape
(
input_cpp
);
const
auto
fake_dtype
=
input_cpp
.
dtype
();
std
::
tie
(
output_cpp
,
output_py
)
=
quantizer_cpp
->
create_tensor
(
shape
,
fake_dtype
);
}
else
{
output_py
=
output
;
output_cpp
=
makeTransformerEngineTensor
(
output_py
,
quantizer
);
}
// Initialize no-op flag
TensorWrapper
noop_flag_cpp
;
if
(
noop_flag
.
has_value
())
{
noop_flag_cpp
=
makeTransformerEngineTensor
(
*
noop_flag
);
}
// Perform quantization
quantize_impl
(
input_cpp
,
quantizer
,
quantizer_cpp
,
output_cpp
,
noop_flag_cpp
);
return
output_py
;
}
}
py
::
object
dequantize
(
const
py
::
handle
&
input
,
transformer_engine
::
DType
otype
)
{
py
::
object
dequantize
(
const
py
::
handle
&
input
,
transformer_engine
::
DType
otype
)
{
init_extension
();
init_extension
();
const
auto
none
=
py
::
none
();
const
auto
none
=
py
::
none
();
const
auto
&
input_tensor
=
makeTransformerEngineTensor
(
input
,
none
);
const
auto
&
input_tensor
=
makeTransformerEngineTensor
(
input
,
none
);
NoneQuantizer
q
(
none
);
NoneQuantizer
q
(
none
);
const
auto
&
shape
=
convertShape
(
input_tensor
.
shape
());
const
auto
&
shape
=
convertShape
(
input_tensor
.
shape
());
auto
[
out_tensor
,
out
]
=
q
.
create_tensor
(
shape
,
otype
);
auto
[
out_tensor
,
out
]
=
q
.
create_tensor
(
shape
,
otype
);
...
@@ -113,9 +137,522 @@ py::object dequantize(const py::handle& input, transformer_engine::DType otype)
...
@@ -113,9 +137,522 @@ py::object dequantize(const py::handle& input, transformer_engine::DType otype)
return
out
;
return
out
;
}
}
namespace
{
void
multi_tensor_quantize_impl
(
const
std
::
vector
<
TensorWrapper
>
&
input_list
,
std
::
vector
<
py
::
handle
>
&
quantizer_py_list
,
std
::
vector
<
std
::
unique_ptr
<
Quantizer
>>
&
quantizer_cpp_list
,
std
::
vector
<
TensorWrapper
>
&
output_list
)
{
// Check number of tensors
const
size_t
num_tensors
=
input_list
.
size
();
NVTE_CHECK
(
quantizer_py_list
.
size
()
==
num_tensors
,
"Expected "
,
num_tensors
,
" Python quantizers, but got "
,
quantizer_py_list
.
size
());
NVTE_CHECK
(
quantizer_cpp_list
.
size
()
==
num_tensors
,
"Expected "
,
num_tensors
,
" C++ quantizers, but got "
,
quantizer_cpp_list
.
size
());
NVTE_CHECK
(
output_list
.
size
()
==
num_tensors
,
"Expected "
,
num_tensors
,
" output tensors, but got "
,
output_list
.
size
());
// Choose implementation
// Note: Currently only have fused kernel for FP8 delayed scaling
bool
with_fused_kernel
=
true
;
for
(
size_t
i
=
0
;
i
<
num_tensors
;
i
++
)
{
if
(
!
detail
::
IsFloat8Quantizers
(
quantizer_py_list
[
i
].
ptr
()))
{
with_fused_kernel
=
false
;
break
;
}
if
(
nvte_tensor_data
(
output_list
[
i
].
data
())
==
nullptr
||
nvte_tensor_columnwise_data
(
output_list
[
i
].
data
())
==
nullptr
)
{
with_fused_kernel
=
false
;
break
;
}
}
// Launch TE kernel
if
(
with_fused_kernel
)
{
// Fused kernel for multi-tensor quantize
std
::
vector
<
NVTETensor
>
nvte_tensor_input_list
;
std
::
vector
<
NVTETensor
>
nvte_tensor_output_list
;
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
nvte_tensor_input_list
.
push_back
(
input_list
[
i
].
data
());
nvte_tensor_output_list
.
push_back
(
output_list
[
i
].
data
());
}
NVTE_SCOPED_GIL_RELEASE
({
nvte_multi_cast_transpose
(
nvte_tensor_input_list
.
size
(),
nvte_tensor_input_list
.
data
(),
nvte_tensor_output_list
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
}
else
{
// Quantize kernels individually
TensorWrapper
dummy_noop_flag
;
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
quantize_impl
(
input_list
[
i
],
quantizer_py_list
[
i
],
quantizer_cpp_list
[
i
],
output_list
[
i
],
dummy_noop_flag
);
}
}
}
}
// namespace
std
::
vector
<
py
::
object
>
multi_tensor_quantize
(
const
std
::
vector
<
at
::
Tensor
>
&
tensor_list
,
std
::
vector
<
py
::
handle
>
quantizer_list
)
{
// Check number of tensors
const
size_t
num_tensors
=
tensor_list
.
size
();
NVTE_CHECK
(
quantizer_list
.
size
()
==
num_tensors
,
"Expected "
,
num_tensors
,
" quantizers, but got "
,
quantizer_list
.
size
());
// Convert quantizers to C++ objects
std
::
vector
<
std
::
unique_ptr
<
Quantizer
>>
quantizer_cpp_list
;
for
(
size_t
i
=
0
;
i
<
num_tensors
;
i
++
)
{
quantizer_cpp_list
.
push_back
(
convert_quantizer
(
quantizer_list
[
i
]));
}
// Initialize input and output tensors
std
::
vector
<
TensorWrapper
>
input_cpp_list
;
std
::
vector
<
TensorWrapper
>
output_cpp_list
;
std
::
vector
<
py
::
object
>
output_py_list
;
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
// Convert input tensor to C++ object
const
auto
&
input_py
=
tensor_list
[
i
];
NVTE_CHECK
(
input_py
.
is_contiguous
(),
"Input tensor "
,
i
,
" is not contiguous"
);
input_cpp_list
.
emplace_back
(
makeTransformerEngineTensor
(
input_py
));
const
auto
&
input_cpp
=
input_cpp_list
.
back
();
const
auto
input_shape
=
input_cpp
.
shape
();
const
auto
input_dtype
=
GetTransformerEngineDType
(
input_py
.
scalar_type
());
// Construct output tensor
std
::
vector
<
size_t
>
output_shape
(
input_shape
.
data
,
input_shape
.
data
+
input_shape
.
ndim
);
auto
[
output_cpp
,
output_py
]
=
quantizer_cpp_list
[
i
]
->
create_tensor
(
output_shape
,
input_dtype
);
output_cpp_list
.
emplace_back
(
std
::
move
(
output_cpp
));
output_py_list
.
emplace_back
(
std
::
move
(
output_py
));
}
// Perform multi-tensor quantization
multi_tensor_quantize_impl
(
input_cpp_list
,
quantizer_list
,
quantizer_cpp_list
,
output_cpp_list
);
return
output_py_list
;
}
namespace
{
std
::
tuple
<
std
::
vector
<
py
::
object
>
,
std
::
vector
<
TensorWrapper
>>
bulk_allocate_fp8_blockwise_tensors
(
std
::
vector
<
std
::
vector
<
size_t
>>
&
shape_list
,
std
::
vector
<
py
::
handle
>
&
quantizer_py_list
,
std
::
vector
<
Float8BlockQuantizer
*>
&
quantizer_cpp_list
)
{
init_extension
();
std
::
tuple
<
std
::
vector
<
py
::
object
>
,
std
::
vector
<
TensorWrapper
>>
retval
;
auto
&
tensor_py_list
=
std
::
get
<
0
>
(
retval
);
auto
&
tensor_cpp_list
=
std
::
get
<
1
>
(
retval
);
// Number of tensors
const
size_t
num_tensors
=
shape_list
.
size
();
if
(
num_tensors
==
0
)
{
return
retval
;
}
// Quantization parameters
const
auto
rowwise_usage
=
quantizer_cpp_list
[
0
]
->
rowwise_usage
;
const
auto
columnwise_usage
=
quantizer_cpp_list
[
0
]
->
columnwise_usage
;
const
auto
scaling_mode
=
quantizer_cpp_list
[
0
]
->
get_scaling_mode
();
const
auto
is_2D_scaled
=
scaling_mode
==
NVTE_BLOCK_SCALING_2D
;
const
auto
fp8_dtype
=
quantizer_cpp_list
[
0
]
->
dtype
;
constexpr
size_t
fp8_elem_size
=
1
;
constexpr
size_t
scale_elem_size
=
4
;
// Helper function to construct tensor view
// Note: Deleter holds a shared_ptr for the buffer, so the buffer
// will survive until all views are deleted.
auto
make_torch_view
=
[](
std
::
shared_ptr
<
at
::
Tensor
>
&
buffer
,
const
std
::
vector
<
size_t
>
&
shape
,
size_t
offset
,
at
::
ScalarType
dtype
)
->
at
::
Tensor
{
std
::
vector
<
int64_t
>
shape_int64
(
shape
.
begin
(),
shape
.
end
());
// in the case where full buffer is empty because local rank receives no tokens for all the experts
// then the data_ptr is nullptr, we need to return an empty tensor instead of calling from_blob
// but in the case where some experts receive tokens, some not, we want to leverage from_blob
// as much as possible to avoid CPU overhead
if
(
buffer
->
data_ptr
<
uint8_t
>
()
==
nullptr
)
{
return
at
::
empty
(
shape_int64
,
at
::
device
(
at
::
kCUDA
).
dtype
(
dtype
));
}
return
at
::
from_blob
(
buffer
->
data_ptr
<
uint8_t
>
()
+
offset
,
shape_int64
,
[
buffer
](
void
*
)
{},
// deleter holds shared_ptr
at
::
device
(
at
::
kCUDA
).
dtype
(
dtype
));
};
// Allocate row-wise data
std
::
vector
<
at
::
Tensor
>
rowwise_data_list
,
rowwise_scale_list
;
std
::
vector
<
std
::
vector
<
size_t
>>
rowwise_data_shapes
,
rowwise_scale_shapes
;
if
(
rowwise_usage
)
{
// Tensor sizes
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
rowwise_data_shapes
.
emplace_back
(
shape_list
[
i
]);
rowwise_scale_shapes
.
emplace_back
(
quantizer_cpp_list
[
i
]
->
get_scale_shape
(
shape_list
[
i
],
false
));
}
// Offsets in full buffer
size_t
buffer_size
=
0
;
std
::
vector
<
size_t
>
data_offsets
,
scale_offsets
;
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
buffer_size
=
roundup
(
buffer_size
,
256
);
// align to 256B
data_offsets
.
push_back
(
buffer_size
);
buffer_size
+=
product
(
rowwise_data_shapes
[
i
])
*
fp8_elem_size
;
}
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
buffer_size
=
roundup
(
buffer_size
,
16
);
// align to 16B
scale_offsets
.
push_back
(
buffer_size
);
buffer_size
+=
product
(
rowwise_scale_shapes
[
i
])
*
scale_elem_size
;
}
// Allocate full buffer
auto
buffer
=
std
::
make_shared
<
at
::
Tensor
>
(
at
::
empty
({(
int64_t
)
buffer_size
},
at
::
device
(
at
::
kCUDA
).
dtype
(
torch
::
kUInt8
)));
// Construct tensor views
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
rowwise_data_list
.
emplace_back
(
make_torch_view
(
buffer
,
rowwise_data_shapes
[
i
],
data_offsets
[
i
],
torch
::
kUInt8
));
rowwise_scale_list
.
emplace_back
(
make_torch_view
(
buffer
,
rowwise_scale_shapes
[
i
],
scale_offsets
[
i
],
torch
::
kFloat32
));
}
}
// Allocate column-wise data
std
::
vector
<
at
::
Tensor
>
columnwise_data_list
,
columnwise_scale_list
;
std
::
vector
<
std
::
vector
<
size_t
>>
columnwise_data_shapes
,
columnwise_scale_shapes
;
if
(
columnwise_usage
)
{
// Tensor sizes
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
columnwise_data_shapes
.
emplace_back
();
auto
&
shape
=
columnwise_data_shapes
.
back
();
shape
.
push_back
(
shape_list
[
i
].
back
());
for
(
size_t
j
=
0
;
j
<
shape_list
[
i
].
size
()
-
1
;
++
j
)
{
shape
.
push_back
(
shape_list
[
i
][
j
]);
}
columnwise_scale_shapes
.
emplace_back
(
quantizer_cpp_list
[
i
]
->
get_scale_shape
(
shape_list
[
i
],
true
));
}
// Offsets in full buffer
size_t
buffer_size
=
0
;
std
::
vector
<
size_t
>
data_offsets
,
scale_offsets
;
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
buffer_size
=
roundup
(
buffer_size
,
256
);
// align to 256B
data_offsets
.
push_back
(
buffer_size
);
buffer_size
+=
product
(
columnwise_data_shapes
[
i
])
*
fp8_elem_size
;
}
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
buffer_size
=
roundup
(
buffer_size
,
16
);
// align to 16B
scale_offsets
.
push_back
(
buffer_size
);
buffer_size
+=
product
(
columnwise_scale_shapes
[
i
])
*
scale_elem_size
;
}
// Allocate full buffer
auto
buffer
=
std
::
make_shared
<
at
::
Tensor
>
(
at
::
empty
({(
int64_t
)
buffer_size
},
at
::
device
(
at
::
kCUDA
).
dtype
(
torch
::
kUInt8
)));
// Construct tensor views
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
columnwise_data_list
.
emplace_back
(
make_torch_view
(
buffer
,
columnwise_data_shapes
[
i
],
data_offsets
[
i
],
torch
::
kUInt8
));
columnwise_scale_list
.
emplace_back
(
make_torch_view
(
buffer
,
columnwise_scale_shapes
[
i
],
scale_offsets
[
i
],
torch
::
kFloat32
));
}
}
// Construct FP8 block-wise tensors
py
::
handle
Float8BlockwiseQTensorClass
(
reinterpret_cast
<
PyObject
*>
(
Float8BlockwiseQTensorBasePythonClass
));
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
// Create tensor objects with proper reference counting
py
::
object
rowwise_data
=
rowwise_usage
?
py
::
cast
(
rowwise_data_list
[
i
])
:
py
::
none
();
py
::
object
rowwise_scale
=
rowwise_usage
?
py
::
cast
(
rowwise_scale_list
[
i
])
:
py
::
none
();
py
::
object
columnwise_data
=
(
columnwise_usage
?
py
::
cast
(
columnwise_data_list
[
i
])
:
py
::
none
());
py
::
object
columnwise_scale
=
(
columnwise_usage
?
py
::
cast
(
columnwise_scale_list
[
i
])
:
py
::
none
());
// Construct Python tensor
tensor_py_list
.
emplace_back
(
Float8BlockwiseQTensorClass
(
rowwise_data
,
rowwise_scale
,
columnwise_data
,
columnwise_scale
,
fp8_dtype
,
quantizer_py_list
[
i
],
is_2D_scaled
,
Float8BlockScaleTensorFormat
::
GEMM_READY
));
// Construct C++ tensor
tensor_cpp_list
.
emplace_back
(
makeTransformerEngineTensor
(
rowwise_usage
?
rowwise_data_list
[
i
].
data_ptr
()
:
nullptr
,
columnwise_usage
?
columnwise_data_list
[
i
].
data_ptr
()
:
nullptr
,
rowwise_usage
?
rowwise_data_shapes
[
i
]
:
std
::
vector
<
size_t
>
{},
columnwise_usage
?
columnwise_data_shapes
[
i
]
:
std
::
vector
<
size_t
>
{},
fp8_dtype
,
nullptr
,
nullptr
,
rowwise_usage
?
rowwise_scale_list
[
i
].
data_ptr
()
:
nullptr
,
columnwise_usage
?
columnwise_scale_list
[
i
].
data_ptr
()
:
nullptr
,
rowwise_usage
?
rowwise_scale_shapes
[
i
]
:
std
::
vector
<
size_t
>
{},
columnwise_usage
?
columnwise_scale_shapes
[
i
]
:
std
::
vector
<
size_t
>
{},
scaling_mode
));
}
return
retval
;
}
std
::
tuple
<
std
::
vector
<
py
::
object
>
,
std
::
vector
<
TensorWrapper
>>
bulk_allocate_mxfp8_tensors
(
std
::
vector
<
std
::
vector
<
size_t
>>
&
shape_list
,
std
::
vector
<
py
::
handle
>
&
quantizer_py_list
,
std
::
vector
<
MXFP8Quantizer
*>
&
quantizer_cpp_list
)
{
init_extension
();
std
::
tuple
<
std
::
vector
<
py
::
object
>
,
std
::
vector
<
TensorWrapper
>>
retval
;
auto
&
tensor_py_list
=
std
::
get
<
0
>
(
retval
);
auto
&
tensor_cpp_list
=
std
::
get
<
1
>
(
retval
);
// Number of tensors
const
size_t
num_tensors
=
shape_list
.
size
();
if
(
num_tensors
==
0
)
{
return
retval
;
}
// Quantization parameters
const
auto
rowwise_usage
=
quantizer_cpp_list
[
0
]
->
rowwise_usage
;
const
auto
columnwise_usage
=
quantizer_cpp_list
[
0
]
->
columnwise_usage
;
const
auto
scaling_mode
=
quantizer_cpp_list
[
0
]
->
get_scaling_mode
();
const
auto
fp8_dtype
=
quantizer_cpp_list
[
0
]
->
dtype
;
constexpr
size_t
fp8_elem_size
=
1
;
constexpr
size_t
scale_elem_size
=
1
;
// Helper function to construct tensor view
// Note: Deleter holds a shared_ptr for the buffer, so the buffer
// will survive until all views are deleted.
auto
make_torch_view
=
[](
std
::
shared_ptr
<
at
::
Tensor
>
&
buffer
,
const
std
::
vector
<
size_t
>
&
shape
,
size_t
offset
,
at
::
ScalarType
dtype
)
->
at
::
Tensor
{
std
::
vector
<
int64_t
>
shape_int64
(
shape
.
begin
(),
shape
.
end
());
// in the case where full buffer is empty because local rank receives no tokens for all the experts
// then the data_ptr is nullptr, we need to return an empty tensor instead of calling from_blob
// but in the case where some experts receive tokens, some not, we want to leverage from_blob
// as much as possible to avoid CPU overhead
if
(
buffer
->
data_ptr
<
uint8_t
>
()
==
nullptr
)
{
return
at
::
empty
(
shape_int64
,
at
::
device
(
at
::
kCUDA
).
dtype
(
dtype
));
}
return
at
::
from_blob
(
buffer
->
data_ptr
<
uint8_t
>
()
+
offset
,
shape_int64
,
[
buffer
](
void
*
)
{},
// deleter holds shared_ptr
at
::
device
(
at
::
kCUDA
).
dtype
(
dtype
));
};
// Allocate row-wise data
std
::
vector
<
at
::
Tensor
>
rowwise_data_list
,
rowwise_scale_list
;
std
::
vector
<
std
::
vector
<
size_t
>>
rowwise_data_shapes
,
rowwise_scale_shapes
;
if
(
rowwise_usage
)
{
// Tensor sizes
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
rowwise_data_shapes
.
emplace_back
(
shape_list
[
i
]);
rowwise_scale_shapes
.
emplace_back
(
quantizer_cpp_list
[
i
]
->
get_scale_shape
(
shape_list
[
i
],
false
));
}
// Offsets in full buffer
size_t
buffer_size
=
0
;
std
::
vector
<
size_t
>
data_offsets
,
scale_offsets
;
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
buffer_size
=
roundup
(
buffer_size
,
256
);
// align to 256B
data_offsets
.
push_back
(
buffer_size
);
buffer_size
+=
product
(
rowwise_data_shapes
[
i
])
*
fp8_elem_size
;
}
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
buffer_size
=
roundup
(
buffer_size
,
16
);
// align to 16B
scale_offsets
.
push_back
(
buffer_size
);
buffer_size
+=
product
(
rowwise_scale_shapes
[
i
])
*
scale_elem_size
;
}
// Allocate full buffer
// TODO(zhongbo): use torch.empty if zero padding is added to the swizzle kernel
auto
buffer
=
std
::
make_shared
<
at
::
Tensor
>
(
at
::
zeros
({(
int64_t
)
buffer_size
},
at
::
device
(
at
::
kCUDA
).
dtype
(
torch
::
kUInt8
)));
// auto buffer = std::make_shared<at::Tensor>(
// at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
// Construct tensor views
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
rowwise_data_list
.
emplace_back
(
make_torch_view
(
buffer
,
rowwise_data_shapes
[
i
],
data_offsets
[
i
],
torch
::
kUInt8
));
rowwise_scale_list
.
emplace_back
(
make_torch_view
(
buffer
,
rowwise_scale_shapes
[
i
],
scale_offsets
[
i
],
torch
::
kUInt8
));
}
}
// Allocate column-wise data
std
::
vector
<
at
::
Tensor
>
columnwise_data_list
,
columnwise_scale_list
;
std
::
vector
<
std
::
vector
<
size_t
>>
columnwise_data_shapes
,
columnwise_scale_shapes
;
if
(
columnwise_usage
)
{
// Tensor sizes
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
// For MXFP8, the columnwise data doesn't need transpose
// because of TN, NT, NN layout support in SM100
columnwise_data_shapes
.
emplace_back
(
shape_list
[
i
]);
columnwise_scale_shapes
.
emplace_back
(
quantizer_cpp_list
[
i
]
->
get_scale_shape
(
shape_list
[
i
],
true
));
}
// Offsets in full buffer
size_t
buffer_size
=
0
;
std
::
vector
<
size_t
>
data_offsets
,
scale_offsets
;
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
buffer_size
=
roundup
(
buffer_size
,
256
);
// align to 256B
data_offsets
.
push_back
(
buffer_size
);
buffer_size
+=
product
(
columnwise_data_shapes
[
i
])
*
fp8_elem_size
;
}
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
buffer_size
=
roundup
(
buffer_size
,
16
);
// align to 16B
scale_offsets
.
push_back
(
buffer_size
);
buffer_size
+=
product
(
columnwise_scale_shapes
[
i
])
*
scale_elem_size
;
}
// Allocate full buffer
// TODO(zhongbo): use torch.empty if zero padding is added to the swizzle kernel
auto
buffer
=
std
::
make_shared
<
at
::
Tensor
>
(
at
::
zeros
({(
int64_t
)
buffer_size
},
at
::
device
(
at
::
kCUDA
).
dtype
(
torch
::
kUInt8
)));
// auto buffer = std::make_shared<at::Tensor>(
// at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
// Construct tensor views
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
columnwise_data_list
.
emplace_back
(
make_torch_view
(
buffer
,
columnwise_data_shapes
[
i
],
data_offsets
[
i
],
torch
::
kUInt8
));
columnwise_scale_list
.
emplace_back
(
make_torch_view
(
buffer
,
columnwise_scale_shapes
[
i
],
scale_offsets
[
i
],
torch
::
kUInt8
));
}
}
// Construct mxfp8 tensors
py
::
handle
MXFP8TensorClass
(
reinterpret_cast
<
PyObject
*>
(
MXFP8TensorBasePythonClass
));
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
// Create tensor objects with proper reference counting
py
::
object
rowwise_data
=
rowwise_usage
?
py
::
cast
(
rowwise_data_list
[
i
])
:
py
::
none
();
py
::
object
rowwise_scale
=
rowwise_usage
?
py
::
cast
(
rowwise_scale_list
[
i
])
:
py
::
none
();
py
::
object
columnwise_data
=
(
columnwise_usage
?
py
::
cast
(
columnwise_data_list
[
i
])
:
py
::
none
());
py
::
object
columnwise_scale
=
(
columnwise_usage
?
py
::
cast
(
columnwise_scale_list
[
i
])
:
py
::
none
());
// Construct Python tensor
tensor_py_list
.
emplace_back
(
MXFP8TensorClass
(
rowwise_data
,
rowwise_scale
,
columnwise_data
,
columnwise_scale
,
fp8_dtype
,
quantizer_py_list
[
i
]));
// Construct C++ tensor
tensor_cpp_list
.
emplace_back
(
makeTransformerEngineTensor
(
rowwise_usage
?
rowwise_data_list
[
i
].
data_ptr
()
:
nullptr
,
columnwise_usage
?
columnwise_data_list
[
i
].
data_ptr
()
:
nullptr
,
rowwise_usage
?
rowwise_data_shapes
[
i
]
:
std
::
vector
<
size_t
>
{},
columnwise_usage
?
columnwise_data_shapes
[
i
]
:
std
::
vector
<
size_t
>
{},
fp8_dtype
,
nullptr
,
nullptr
,
rowwise_usage
?
rowwise_scale_list
[
i
].
data_ptr
()
:
nullptr
,
columnwise_usage
?
columnwise_scale_list
[
i
].
data_ptr
()
:
nullptr
,
rowwise_usage
?
rowwise_scale_shapes
[
i
]
:
std
::
vector
<
size_t
>
{},
columnwise_usage
?
columnwise_scale_shapes
[
i
]
:
std
::
vector
<
size_t
>
{},
scaling_mode
));
}
return
retval
;
}
}
// namespace
std
::
vector
<
py
::
object
>
split_quantize
(
const
at
::
Tensor
&
tensor
,
const
std
::
vector
<
int
>
&
split_sections
,
std
::
vector
<
py
::
handle
>
quantizer_list
)
{
init_extension
();
// Check number of tensors
const
size_t
num_splits
=
split_sections
.
size
();
NVTE_CHECK
(
quantizer_list
.
size
()
==
num_splits
,
"Expected "
,
num_splits
,
" quantizers, but got "
,
quantizer_list
.
size
());
if
(
num_splits
==
0
)
{
return
{};
}
// Input tensor properties
auto
input_py
=
tensor
.
contiguous
();
uint8_t
*
input_dptr
=
reinterpret_cast
<
uint8_t
*>
(
input_py
.
data_ptr
());
auto
input_dtype
=
GetTransformerEngineDType
(
input_py
.
scalar_type
());
std
::
vector
<
size_t
>
input_shape
;
size_t
input_size
=
1
;
for
(
const
auto
&
d
:
input_py
.
sizes
())
{
input_shape
.
push_back
(
d
);
input_size
*=
d
;
}
NVTE_CHECK
(
input_shape
.
size
()
>
0
,
"Input tensor has 0 dims"
);
// Split input tensor along dim 0
std
::
vector
<
TensorWrapper
>
input_list
;
std
::
vector
<
std
::
vector
<
size_t
>>
split_shapes
;
size_t
dim0_offset
=
0
;
const
size_t
dim0_stride
=
input_shape
[
0
]
==
0
?
0
:
input_py
.
element_size
()
*
input_size
/
input_shape
[
0
];
for
(
size_t
i
=
0
;
i
<
num_splits
;
++
i
)
{
NVTE_CHECK
(
split_sections
[
i
]
>=
0
,
"Attempted to split tensor with shape="
,
input_shape
,
" along dim 0 with split_sections="
,
split_sections
);
NVTE_CHECK
(
dim0_offset
+
split_sections
[
i
]
<=
input_shape
[
0
],
"Attempted to split tensor with shape="
,
input_shape
,
" along dim 0 with split_sections="
,
split_sections
);
split_shapes
.
push_back
(
input_shape
);
auto
&
split_shape
=
split_shapes
.
back
();
split_shape
[
0
]
=
split_sections
[
i
];
void
*
split_dptr
=
static_cast
<
void
*>
(
input_dptr
+
dim0_offset
*
dim0_stride
);
input_list
.
emplace_back
(
makeTransformerEngineTensor
(
split_dptr
,
split_shape
,
input_dtype
));
dim0_offset
+=
split_sections
[
i
];
}
// Convert quantizers to C++ objects
std
::
vector
<
std
::
unique_ptr
<
Quantizer
>>
quantizer_cpp_list
;
for
(
size_t
i
=
0
;
i
<
num_splits
;
i
++
)
{
quantizer_cpp_list
.
push_back
(
convert_quantizer
(
quantizer_list
[
i
]));
}
// For FP8 block-scaling, we construct output tensors with bulk allocations
// For MXFP8, we also use bulk allocations
bool
use_fused_bulk_alloc
=
true
;
for
(
size_t
i
=
0
;
i
<
quantizer_list
.
size
();
i
++
)
{
if
(
!
detail
::
IsFloat8BlockwiseQuantizers
(
quantizer_list
[
i
].
ptr
())
&&
!
detail
::
IsMXFP8Quantizers
(
quantizer_list
[
i
].
ptr
()))
{
use_fused_bulk_alloc
=
false
;
break
;
}
}
// Allocate output tensors
std
::
vector
<
TensorWrapper
>
output_cpp_list
;
std
::
vector
<
py
::
object
>
output_py_list
;
if
(
!
use_fused_bulk_alloc
)
{
// Allocate output tensors individually
for
(
size_t
i
=
0
;
i
<
num_splits
;
++
i
)
{
auto
[
output_cpp
,
output_py
]
=
quantizer_cpp_list
[
i
]
->
create_tensor
(
split_shapes
[
i
],
input_dtype
);
output_cpp_list
.
emplace_back
(
std
::
move
(
output_cpp
));
output_py_list
.
emplace_back
(
std
::
move
(
output_py
));
}
}
else
{
// TODO(zhongbo): make a better api to make this part less hacky
bool
is_fp8_blockwise
=
detail
::
IsFloat8BlockwiseQuantizers
(
quantizer_list
[
0
].
ptr
());
bool
is_mxfp8
=
detail
::
IsMXFP8Quantizers
(
quantizer_list
[
0
].
ptr
());
if
(
is_fp8_blockwise
)
{
// FP8 block-scaling: construct output tensors with bulk allocations
std
::
vector
<
Float8BlockQuantizer
*>
blockwise_quantizers
;
for
(
auto
&
quantizer
:
quantizer_cpp_list
)
{
blockwise_quantizers
.
push_back
(
static_cast
<
Float8BlockQuantizer
*>
(
quantizer
.
get
()));
}
std
::
tie
(
output_py_list
,
output_cpp_list
)
=
bulk_allocate_fp8_blockwise_tensors
(
split_shapes
,
quantizer_list
,
blockwise_quantizers
);
}
else
if
(
is_mxfp8
)
{
// MXFP8: construct output tensors with bulk allocations
std
::
vector
<
MXFP8Quantizer
*>
mxfp8_quantizers
;
for
(
auto
&
quantizer
:
quantizer_cpp_list
)
{
mxfp8_quantizers
.
push_back
(
static_cast
<
MXFP8Quantizer
*>
(
quantizer
.
get
()));
}
std
::
tie
(
output_py_list
,
output_cpp_list
)
=
bulk_allocate_mxfp8_tensors
(
split_shapes
,
quantizer_list
,
mxfp8_quantizers
);
}
else
{
NVTE_CHECK
(
false
,
"Expected either FP8 block-scaling or MXFP8 quantizer"
);
}
}
// Perform multi-tensor quantization
multi_tensor_quantize_impl
(
input_list
,
quantizer_list
,
quantizer_cpp_list
,
output_cpp_list
);
return
output_py_list
;
}
template
<
void
(
*
func
)(
const
NVTETensor
,
const
NVTETensor
,
NVTETensor
,
NVTETensor
,
NVTETensor
,
template
<
void
(
*
func
)(
const
NVTETensor
,
const
NVTETensor
,
NVTETensor
,
NVTETensor
,
NVTETensor
,
cudaStream_t
)>
cudaStream_t
)>
std
::
vector
<
py
::
object
>
dbias_dact
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
std
::
vector
<
py
::
object
>
dbias_dact
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
py
::
handle
quantizer
)
{
py
::
handle
quantizer
)
{
init_extension
();
init_extension
();
auto
my_quantizer
=
convert_quantizer
(
quantizer
);
auto
my_quantizer
=
convert_quantizer
(
quantizer
);
...
@@ -125,7 +662,7 @@ std::vector<py::object> dbias_dact(const at::Tensor& grad_output, const at::Tens
...
@@ -125,7 +662,7 @@ std::vector<py::object> dbias_dact(const at::Tensor& grad_output, const at::Tens
auto
grad_bias
=
allocateTorchTensor
(
grad_output
.
size
(
-
1
),
grad_tensor
.
dtype
());
auto
grad_bias
=
allocateTorchTensor
(
grad_output
.
size
(
-
1
),
grad_tensor
.
dtype
());
auto
act_input_tensor
=
makeTransformerEngineTensor
(
act_input
);
auto
act_input_tensor
=
makeTransformerEngineTensor
(
act_input
);
const
auto
&
shape
=
convertShape
(
grad_tensor
.
shape
());
const
auto
&
shape
=
convertShape
(
grad_tensor
.
shape
());
auto
[
dact_tensor
,
dact
]
=
my_quantizer
->
create_tensor
(
shape
,
act_input_tensor
.
dtype
());
auto
[
dact_tensor
,
dact
]
=
my_quantizer
->
create_tensor
(
shape
,
act_input_tensor
.
dtype
());
auto
dbias_tensor
=
makeTransformerEngineTensor
(
grad_bias
);
auto
dbias_tensor
=
makeTransformerEngineTensor
(
grad_bias
);
...
@@ -149,29 +686,30 @@ std::vector<py::object> dbias_dact(const at::Tensor& grad_output, const at::Tens
...
@@ -149,29 +686,30 @@ std::vector<py::object> dbias_dact(const at::Tensor& grad_output, const at::Tens
return
{
py
::
cast
(
grad_bias
),
dact
};
return
{
py
::
cast
(
grad_bias
),
dact
};
}
}
std
::
vector
<
py
::
object
>
dbias_dgelu
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
std
::
vector
<
py
::
object
>
dbias_dgelu
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
py
::
handle
quantizer
)
{
py
::
handle
quantizer
)
{
return
dbias_dact
<
nvte_quantize_dbias_dgelu
>
(
grad_output
,
act_input
,
quantizer
);
return
dbias_dact
<
nvte_quantize_dbias_dgelu
>
(
grad_output
,
act_input
,
quantizer
);
}
}
std
::
vector
<
py
::
object
>
dbias_dsilu
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
std
::
vector
<
py
::
object
>
dbias_dsilu
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
py
::
handle
quantizer
)
{
py
::
handle
quantizer
)
{
return
dbias_dact
<
nvte_quantize_dbias_dsilu
>
(
grad_output
,
act_input
,
quantizer
);
return
dbias_dact
<
nvte_quantize_dbias_dsilu
>
(
grad_output
,
act_input
,
quantizer
);
}
}
std
::
vector
<
py
::
object
>
dbias_drelu
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
std
::
vector
<
py
::
object
>
dbias_drelu
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
py
::
handle
quantizer
)
{
py
::
handle
quantizer
)
{
return
dbias_dact
<
nvte_quantize_dbias_drelu
>
(
grad_output
,
act_input
,
quantizer
);
return
dbias_dact
<
nvte_quantize_dbias_drelu
>
(
grad_output
,
act_input
,
quantizer
);
}
}
std
::
vector
<
py
::
object
>
dbias_dqgelu
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
std
::
vector
<
py
::
object
>
dbias_dqgelu
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
py
::
handle
quantizer
)
{
py
::
handle
quantizer
)
{
return
dbias_dact
<
nvte_quantize_dbias_dqgelu
>
(
grad_output
,
act_input
,
quantizer
);
return
dbias_dact
<
nvte_quantize_dbias_dqgelu
>
(
grad_output
,
act_input
,
quantizer
);
}
}
std
::
vector
<
py
::
object
>
dbias_dsrelu
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
std
::
vector
<
py
::
object
>
dbias_dsrelu
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
py
::
handle
quantizer
)
{
py
::
handle
quantizer
)
{
return
dbias_dact
<
nvte_quantize_dbias_dsrelu
>
(
grad_output
,
act_input
,
quantizer
);
return
dbias_dact
<
nvte_quantize_dbias_dsrelu
>
(
grad_output
,
act_input
,
quantizer
);
}
}
}
// namespace transformer_engine::pytorch
}
// namespace pytorch
}
// namespace transformer_engine
transformer_engine/pytorch/csrc/extensions/padding.cpp
View file @
44740c6c
...
@@ -81,4 +81,77 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
...
@@ -81,4 +81,77 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
});
});
}
}
void
fused_multi_row_unpadding
(
at
::
Tensor
input
,
at
::
Tensor
output
,
std
::
vector
<
size_t
>
input_row_list
,
std
::
vector
<
size_t
>
unpadded_input_row_list
)
{
using
namespace
transformer_engine
;
using
namespace
transformer_engine
::
pytorch
;
NVTE_CHECK
(
input_row_list
.
size
()
==
unpadded_input_row_list
.
size
(),
"Number of input row list and padded row list must match."
);
NVTE_CHECK
(
input
.
dim
()
==
2
,
"Dimension of input must equal 2."
);
NVTE_CHECK
(
output
.
dim
()
==
2
,
"Dimension of output must equal 2."
);
const
auto
num_tensors
=
input_row_list
.
size
();
// Extract properties from PyTorch tensors
std
::
vector
<
void
*>
input_dptr_list
,
output_dptr_list
;
std
::
vector
<
std
::
vector
<
size_t
>>
input_shape_list
,
output_shape_list
;
std
::
vector
<
transformer_engine
::
DType
>
input_type_list
;
void
*
d_input_ptr
=
reinterpret_cast
<
void
*>
(
input
.
data_ptr
());
void
*
d_output_ptr
=
reinterpret_cast
<
void
*>
(
output
.
data_ptr
());
for
(
size_t
tensor_id
=
0
;
tensor_id
<
num_tensors
;
++
tensor_id
)
{
input_dptr_list
.
push_back
(
d_input_ptr
);
output_dptr_list
.
push_back
(
d_output_ptr
);
// Move the input pointer to the next split.
char
*
input_char_ptr
=
reinterpret_cast
<
char
*>
(
d_input_ptr
);
const
size_t
input_dptr_offset
=
input_row_list
[
tensor_id
]
*
input
.
size
(
1
)
*
input
.
element_size
();
input_char_ptr
+=
input_dptr_offset
;
d_input_ptr
=
reinterpret_cast
<
void
*>
(
input_char_ptr
);
input_shape_list
.
push_back
({
input_row_list
[
tensor_id
],
static_cast
<
size_t
>
(
input
.
size
(
1
))});
input_type_list
.
push_back
(
GetTransformerEngineDType
(
input
.
scalar_type
()));
// Move the output pointer to the next split.
char
*
output_char_ptr
=
reinterpret_cast
<
char
*>
(
d_output_ptr
);
const
size_t
output_dptr_offset
=
unpadded_input_row_list
[
tensor_id
]
*
output
.
size
(
1
)
*
output
.
element_size
();
output_char_ptr
+=
output_dptr_offset
;
d_output_ptr
=
reinterpret_cast
<
void
*>
(
output_char_ptr
);
output_shape_list
.
push_back
(
{
unpadded_input_row_list
[
tensor_id
],
static_cast
<
size_t
>
(
output
.
size
(
1
))});
}
// Construct TE tensors
std
::
vector
<
NVTETensor
>
nvte_input_list
,
nvte_output_list
;
std
::
vector
<
transformer_engine
::
TensorWrapper
>
tensor_wrappers
;
auto
make_tensor
=
[
&
tensor_wrappers
](
void
*
dptr
,
const
std
::
vector
<
size_t
>&
shape
,
transformer_engine
::
DType
dtype
)
->
NVTETensor
{
tensor_wrappers
.
emplace_back
(
makeTransformerEngineTensor
(
dptr
,
shape
,
dtype
));
return
tensor_wrappers
.
back
().
data
();
};
std
::
vector
<
int
>
unpadded_num_rows_list
;
for
(
size_t
i
=
0
;
i
<
input_dptr_list
.
size
();
++
i
)
{
if
(
input_dptr_list
[
i
]
==
nullptr
||
input_row_list
[
i
]
==
0
)
continue
;
nvte_input_list
.
emplace_back
(
make_tensor
(
input_dptr_list
[
i
],
input_shape_list
[
i
],
input_type_list
[
i
]));
nvte_output_list
.
emplace_back
(
make_tensor
(
output_dptr_list
[
i
],
output_shape_list
[
i
],
input_type_list
[
i
]));
unpadded_num_rows_list
.
emplace_back
(
unpadded_input_row_list
[
i
]);
}
// Check tensor lists
NVTE_CHECK
(
nvte_output_list
.
size
()
==
nvte_input_list
.
size
(),
"Number of input and output tensors must match"
);
NVTE_CHECK
(
unpadded_num_rows_list
.
size
()
==
nvte_input_list
.
size
()
&&
"Number of input and padded row list must match"
);
// Launch TE kernel
nvte_multi_unpadding
(
nvte_input_list
.
size
(),
nvte_input_list
.
data
(),
nvte_output_list
.
data
(),
unpadded_num_rows_list
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
}
}
// namespace transformer_engine::pytorch
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/extensions/pybind.cpp
View file @
44740c6c
...
@@ -12,7 +12,9 @@
...
@@ -12,7 +12,9 @@
#include <pybind11/pybind11.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/stl.h>
#include <stdexcept>
#include <memory>
#include <optional>
#include <vector>
#include "../common.h"
#include "../common.h"
#include "../extensions.h"
#include "../extensions.h"
...
@@ -206,10 +208,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -206,10 +208,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"weight"
),
py
::
arg
(
"eps"
),
py
::
arg
(
"ln_out"
),
py
::
arg
(
"quantizer"
),
py
::
arg
(
"weight"
),
py
::
arg
(
"eps"
),
py
::
arg
(
"ln_out"
),
py
::
arg
(
"quantizer"
),
py
::
arg
(
"otype"
),
py
::
arg
(
"sm_margin"
),
py
::
arg
(
"zero_centered_gamma"
));
py
::
arg
(
"otype"
),
py
::
arg
(
"sm_margin"
),
py
::
arg
(
"zero_centered_gamma"
));
m
.
def
(
"rmsnorm_bwd"
,
&
transformer_engine
::
pytorch
::
rmsnorm_bwd
,
"Backward of RMSNorm"
);
m
.
def
(
"rmsnorm_bwd"
,
&
transformer_engine
::
pytorch
::
rmsnorm_bwd
,
"Backward of RMSNorm"
);
m
.
def
(
"fused_multi_quantize"
,
&
transformer_engine
::
pytorch
::
fused_multi_quantize
,
m
.
def
(
"multi_tensor_quantize"
,
&
transformer_engine
::
pytorch
::
multi_tensor_quantize
,
"Fused Multi-tensor Cast + Transpose"
,
py
::
arg
(
"input_list"
),
py
::
arg
(
"output_list"
),
"Multi-tensor quantize"
,
py
::
arg
(
"tensor_list"
),
py
::
arg
(
"quantizer_list"
));
py
::
arg
(
"quantizer_list"
),
py
::
arg
(
"otype"
));
m
.
def
(
"split_quantize"
,
&
transformer_engine
::
pytorch
::
split_quantize
,
"Split and multi-tensor quantize"
,
py
::
arg
(
"tensor"
),
py
::
arg
(
"split_sections"
),
py
::
arg
(
"quantizer_list"
));
m
.
def
(
"te_general_grouped_gemm"
,
&
transformer_engine
::
pytorch
::
te_general_grouped_gemm
,
m
.
def
(
"te_general_grouped_gemm"
,
&
transformer_engine
::
pytorch
::
te_general_grouped_gemm
,
"Grouped GEMM"
);
"Grouped GEMM"
);
#ifdef USE_ROCM
#ifdef USE_ROCM
...
@@ -242,6 +245,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -242,6 +245,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"out_dtype"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
arg
(
"out_dtype"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fused_multi_row_padding"
,
&
transformer_engine
::
pytorch
::
fused_multi_row_padding
,
m
.
def
(
"fused_multi_row_padding"
,
&
transformer_engine
::
pytorch
::
fused_multi_row_padding
,
"Fused Multi-tensor padding"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
"Fused Multi-tensor padding"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fused_multi_row_unpadding"
,
&
transformer_engine
::
pytorch
::
fused_multi_row_unpadding
,
"Fused Multi-tensor unpadding"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
// attention kernels
// attention kernels
m
.
def
(
"fa_prepare_fwd"
,
&
transformer_engine
::
pytorch
::
fa_prepare_fwd
,
m
.
def
(
"fa_prepare_fwd"
,
&
transformer_engine
::
pytorch
::
fa_prepare_fwd
,
...
@@ -266,6 +271,32 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -266,6 +271,32 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"fused_rope_backward"
,
&
transformer_engine
::
pytorch
::
fused_rope_backward
,
m
.
def
(
"fused_rope_backward"
,
&
transformer_engine
::
pytorch
::
fused_rope_backward
,
"Fused Apply RoPE BWD"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
"Fused Apply RoPE BWD"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
// fused router
m
.
def
(
"fused_topk_with_score_function_fwd"
,
&
transformer_engine
::
pytorch
::
fused_topk_with_score_function_fwd
,
py
::
arg
(
"logits"
),
py
::
arg
(
"topk"
),
py
::
arg
(
"use_pre_softmax"
),
py
::
arg
(
"num_groups"
),
py
::
arg
(
"group_topk"
),
py
::
arg
(
"scaling_factor"
),
py
::
arg
(
"score_function"
),
py
::
arg
(
"expert_bias"
),
"Fused topk softmax fwd"
);
m
.
def
(
"fused_topk_with_score_function_bwd"
,
&
transformer_engine
::
pytorch
::
fused_topk_with_score_function_bwd
,
py
::
arg
(
"num_tokens"
),
py
::
arg
(
"num_experts"
),
py
::
arg
(
"routing_map"
),
py
::
arg
(
"intermediate_output"
),
py
::
arg
(
"grad_probs"
),
py
::
arg
(
"topk"
),
py
::
arg
(
"use_pre_softmax"
),
py
::
arg
(
"scaling_factor"
),
py
::
arg
(
"score_function"
),
"Fused topk softmax bwd"
);
m
.
def
(
"fused_score_for_moe_aux_loss_fwd"
,
&
transformer_engine
::
pytorch
::
fused_score_for_moe_aux_loss_fwd
,
py
::
arg
(
"logits"
),
py
::
arg
(
"topk"
),
py
::
arg
(
"score_function"
),
"Fused topk softmax fwd"
);
m
.
def
(
"fused_score_for_moe_aux_loss_bwd"
,
&
transformer_engine
::
pytorch
::
fused_score_for_moe_aux_loss_bwd
,
py
::
arg
(
"num_tokens"
),
py
::
arg
(
"num_experts"
),
py
::
arg
(
"intermediate_output"
),
py
::
arg
(
"grad_scores"
),
py
::
arg
(
"topk"
),
py
::
arg
(
"score_function"
),
"Fused topk softmax bwd"
);
m
.
def
(
"fused_moe_aux_loss_fwd"
,
&
transformer_engine
::
pytorch
::
fused_moe_aux_loss_fwd
,
py
::
arg
(
"probs"
),
py
::
arg
(
"tokens_per_expert"
),
py
::
arg
(
"total_num_tokens"
),
py
::
arg
(
"num_experts"
),
py
::
arg
(
"num_rows"
),
py
::
arg
(
"num_cols"
),
py
::
arg
(
"topk"
),
py
::
arg
(
"coeff"
),
"Fused aux loss fwd"
);
m
.
def
(
"fused_moe_aux_loss_bwd"
,
&
transformer_engine
::
pytorch
::
fused_moe_aux_loss_bwd
,
py
::
arg
(
"Const_buf"
),
py
::
arg
(
"tokens_per_expert"
),
py
::
arg
(
"num_rows"
),
py
::
arg
(
"num_cols"
),
py
::
arg
(
"grad_aux_loss"
),
"Fused aux loss bwd"
);
// Misc
// Misc
m
.
def
(
"get_cublasLt_version"
,
&
transformer_engine
::
pytorch
::
get_cublasLt_version
,
m
.
def
(
"get_cublasLt_version"
,
&
transformer_engine
::
pytorch
::
get_cublasLt_version
,
"Get cublasLt version"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
"Get cublasLt version"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
...
...
transformer_engine/pytorch/csrc/extensions/router.cpp
0 → 100644
View file @
44740c6c
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../extensions.h"
#include "common.h"
namespace
transformer_engine
::
pytorch
{
static
std
::
map
<
std
::
string
,
int
>
score_function_map
=
{{
"sigmoid"
,
0
},
{
"softmax"
,
1
}};
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
fused_topk_with_score_function_fwd
(
at
::
Tensor
logits
,
int
topk
,
bool
use_pre_softmax
,
c10
::
optional
<
int
>
num_groups
,
c10
::
optional
<
int
>
group_topk
,
c10
::
optional
<
float
>
scaling_factor
,
std
::
string
score_function
,
c10
::
optional
<
at
::
Tensor
>
expert_bias
)
{
int
num_tokens
=
logits
.
size
(
0
);
int
num_experts
=
logits
.
size
(
1
);
// Check if the input is valid
TORCH_CHECK
(
num_tokens
>
0
&&
num_experts
>
0
,
"num_tokens and num_experts must be greater than 0"
);
// Expert bias only happens at the sigmoid case
if
(
expert_bias
.
has_value
())
{
TORCH_CHECK
(
score_function
==
"sigmoid"
,
"score_function must be sigmoid when expert_bias is not None"
);
}
// Check if the score function is valid
TORCH_CHECK
(
score_function
==
"softmax"
||
score_function
==
"sigmoid"
,
"score_function must be softmax or sigmoid for router fusion"
);
if
(
score_function
==
"sigmoid"
)
{
use_pre_softmax
=
false
;
// Pre-softmax only happens at the softmax case
}
// Reformat the input to make it compatible with the kernel
int
group_topk_value
=
group_topk
.
has_value
()
?
group_topk
.
value
()
:
-
1
;
int
num_groups_value
=
num_groups
.
has_value
()
?
num_groups
.
value
()
:
-
1
;
float
scaling_factor_value
=
scaling_factor
.
has_value
()
?
scaling_factor
.
value
()
:
1.0
f
;
// Construct the output tensor
at
::
Tensor
probs
=
at
::
empty
({
num_tokens
,
num_experts
},
at
::
dtype
(
logits
.
scalar_type
()).
device
(
at
::
kCUDA
));
at
::
Tensor
routing_map
=
at
::
empty
({
num_tokens
,
num_experts
},
at
::
dtype
(
at
::
kBool
).
device
(
at
::
kCUDA
));
// Intermediate output is used to store the output of the softmax/sigmoid function
at
::
Tensor
intermediate_output
=
at
::
empty
({
num_tokens
,
num_experts
},
at
::
dtype
(
logits
.
scalar_type
()).
device
(
at
::
kCUDA
));
auto
logits_cu
=
makeTransformerEngineTensor
(
logits
);
auto
probs_cu
=
makeTransformerEngineTensor
(
probs
);
auto
routing_map_cu
=
makeTransformerEngineTensor
(
routing_map
);
auto
intermediate_output_cu
=
makeTransformerEngineTensor
(
intermediate_output
);
auto
expert_bias_cu
=
TensorWrapper
();
// empty expert_bias_cu tensor
if
(
expert_bias
.
has_value
())
{
expert_bias_cu
=
makeTransformerEngineTensor
(
expert_bias
.
value
());
}
nvte_fused_topk_with_score_function_forward
(
logits_cu
.
data
(),
num_tokens
,
num_experts
,
topk
,
use_pre_softmax
,
num_groups_value
,
group_topk_value
,
scaling_factor_value
,
score_function_map
[
score_function
],
expert_bias_cu
.
data
(),
probs_cu
.
data
(),
routing_map_cu
.
data
(),
intermediate_output_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
return
std
::
make_tuple
(
probs
,
routing_map
,
intermediate_output
);
}
at
::
Tensor
fused_topk_with_score_function_bwd
(
int
num_tokens
,
int
num_experts
,
at
::
Tensor
routing_map
,
at
::
Tensor
intermediate_output
,
at
::
Tensor
grad_probs
,
int
topk
,
bool
use_pre_softmax
,
c10
::
optional
<
float
>
scaling_factor
,
std
::
string
score_function
)
{
// Get the value of the parameters
auto
scaling_factor_value
=
scaling_factor
.
has_value
()
?
scaling_factor
.
value
()
:
1.0
f
;
auto
score_function_value
=
score_function_map
[
score_function
];
// Init the output tensor
at
::
Tensor
grad_logits
=
at
::
empty
(
{
num_tokens
,
num_experts
},
at
::
dtype
(
intermediate_output
.
scalar_type
()).
device
(
at
::
kCUDA
));
auto
routing_map_cu
=
makeTransformerEngineTensor
(
routing_map
);
auto
intermediate_output_cu
=
makeTransformerEngineTensor
(
intermediate_output
);
auto
grad_probs_cu
=
makeTransformerEngineTensor
(
grad_probs
);
auto
grad_logits_cu
=
makeTransformerEngineTensor
(
grad_logits
);
nvte_fused_topk_with_score_function_backward
(
routing_map_cu
.
data
(),
intermediate_output_cu
.
data
(),
grad_probs_cu
.
data
(),
num_tokens
,
num_experts
,
topk
,
use_pre_softmax
,
scaling_factor_value
,
score_function_value
,
grad_logits_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
return
grad_logits
;
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
fused_score_for_moe_aux_loss_fwd
(
at
::
Tensor
logits
,
int
topk
,
std
::
string
score_function
)
{
int
num_tokens
=
logits
.
size
(
0
);
int
num_experts
=
logits
.
size
(
1
);
// Check if the input is valid
TORCH_CHECK
(
num_tokens
>
0
&&
num_experts
>
0
,
"num_tokens and num_experts must be greater than 0"
);
TORCH_CHECK
(
topk
>
0
,
"topk must be greater than 0"
);
// Check if the score function is valid
TORCH_CHECK
(
score_function
==
"softmax"
||
score_function
==
"sigmoid"
,
"score_function must be softmax or sigmoid for router fusion"
);
int
score_function_value
=
score_function_map
[
score_function
];
// Construct the output tensor
at
::
Tensor
scores
=
at
::
empty
({
num_tokens
,
num_experts
},
at
::
dtype
(
logits
.
scalar_type
()).
device
(
at
::
kCUDA
));
at
::
Tensor
routing_map
=
at
::
empty
({
num_tokens
,
num_experts
},
at
::
dtype
(
at
::
kBool
).
device
(
at
::
kCUDA
));
at
::
Tensor
intermediate_output
=
at
::
empty
({
num_tokens
,
num_experts
},
at
::
dtype
(
logits
.
scalar_type
()).
device
(
at
::
kCUDA
));
auto
logits_cu
=
makeTransformerEngineTensor
(
logits
);
auto
scores_cu
=
makeTransformerEngineTensor
(
scores
);
auto
routing_map_cu
=
makeTransformerEngineTensor
(
routing_map
);
auto
intermediate_output_cu
=
makeTransformerEngineTensor
(
intermediate_output
);
nvte_fused_score_for_moe_aux_loss_forward
(
logits_cu
.
data
(),
num_tokens
,
num_experts
,
topk
,
score_function_value
,
scores_cu
.
data
(),
routing_map_cu
.
data
(),
intermediate_output_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
return
std
::
make_tuple
(
scores
,
routing_map
,
intermediate_output
);
}
at
::
Tensor
fused_score_for_moe_aux_loss_bwd
(
int
num_tokens
,
int
num_experts
,
at
::
Tensor
intermediate_output
,
at
::
Tensor
grad_scores
,
int
topk
,
std
::
string
score_function
)
{
// Get the value of the parameters
int
score_function_value
=
score_function_map
[
score_function
];
// Init the output tensor
at
::
Tensor
grad_logits
=
at
::
empty
(
{
num_tokens
,
num_experts
},
at
::
dtype
(
intermediate_output
.
scalar_type
()).
device
(
at
::
kCUDA
));
auto
intermediate_output_cu
=
makeTransformerEngineTensor
(
intermediate_output
);
auto
grad_scores_cu
=
makeTransformerEngineTensor
(
grad_scores
);
auto
grad_logits_cu
=
makeTransformerEngineTensor
(
grad_logits
);
nvte_fused_score_for_moe_aux_loss_backward
(
intermediate_output_cu
.
data
(),
grad_scores_cu
.
data
(),
num_tokens
,
num_experts
,
topk
,
score_function_value
,
grad_logits_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
return
grad_logits
;
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
fused_moe_aux_loss_fwd
(
at
::
Tensor
probs
,
at
::
Tensor
tokens_per_expert
,
int
total_num_tokens
,
int
num_experts
,
int
num_rows
,
int
num_cols
,
int
topk
,
float
coeff
)
{
TORCH_CHECK
(
topk
>
0
,
"topk must be greater than 0"
);
TORCH_CHECK
(
total_num_tokens
>
0
,
"total_num_tokens must be greater than 0"
);
TORCH_CHECK
(
num_experts
>
0
,
"num_experts must be greater than 0"
);
// Create the output tensor
at
::
Tensor
aux_loss
=
at
::
empty
({},
at
::
dtype
(
probs
.
scalar_type
()).
device
(
at
::
kCUDA
));
at
::
Tensor
Const_buf
=
at
::
empty
({},
at
::
dtype
(
at
::
kFloat
).
device
(
at
::
kCUDA
));
auto
probs_cu
=
makeTransformerEngineTensor
(
probs
);
auto
tokens_per_expert_cu
=
makeTransformerEngineTensor
(
tokens_per_expert
);
auto
aux_loss_cu
=
makeTransformerEngineTensor
(
aux_loss
);
auto
Const_buf_cu
=
makeTransformerEngineTensor
(
Const_buf
);
nvte_fused_moe_aux_loss_forward
(
probs_cu
.
data
(),
tokens_per_expert_cu
.
data
(),
total_num_tokens
,
num_experts
,
num_rows
,
num_cols
,
topk
,
coeff
,
aux_loss_cu
.
data
(),
Const_buf_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
return
std
::
make_tuple
(
aux_loss
,
Const_buf
);
}
at
::
Tensor
fused_moe_aux_loss_bwd
(
at
::
Tensor
Const_buf
,
at
::
Tensor
tokens_per_expert
,
int
num_rows
,
int
num_cols
,
at
::
Tensor
grad_aux_loss
)
{
// Create the output tensor
at
::
Tensor
grad_probs
=
at
::
empty
({
num_rows
,
num_cols
},
at
::
dtype
(
grad_aux_loss
.
scalar_type
()).
device
(
at
::
kCUDA
));
auto
Const_buf_cu
=
makeTransformerEngineTensor
(
Const_buf
);
auto
tokens_per_expert_cu
=
makeTransformerEngineTensor
(
tokens_per_expert
);
auto
grad_aux_loss_cu
=
makeTransformerEngineTensor
(
grad_aux_loss
);
auto
grad_probs_cu
=
makeTransformerEngineTensor
(
grad_probs
);
// Meta data for the kernel
nvte_fused_moe_aux_loss_backward
(
Const_buf_cu
.
data
(),
tokens_per_expert_cu
.
data
(),
num_rows
,
num_cols
,
grad_aux_loss_cu
.
data
(),
grad_probs_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
return
grad_probs
;
}
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/extensions/transpose.cpp
View file @
44740c6c
...
@@ -4,80 +4,16 @@
...
@@ -4,80 +4,16 @@
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
#include <pybind.h>
#include <optional>
#include <optional>
#include <vector>
#include "../extensions.h"
#include "../extensions.h"
#include "pybind.h"
#include "pybind.h"
namespace
transformer_engine
::
pytorch
{
namespace
transformer_engine
{
namespace
pytorch
{
std
::
vector
<
py
::
object
>
fused_multi_quantize
(
std
::
vector
<
at
::
Tensor
>
input_list
,
std
::
optional
<
std
::
vector
<
py
::
object
>>
output_list
,
std
::
vector
<
py
::
handle
>
quantizer_list
,
DType
otype
)
{
init_extension
();
std
::
vector
<
NVTETensor
>
nvte_tensor_input_list
;
std
::
vector
<
NVTETensor
>
nvte_tensor_output_list
;
std
::
vector
<
py
::
object
>
py_output_objects_list
;
std
::
vector
<
TensorWrapper
>
tensor_wrappers
;
if
(
output_list
.
has_value
())
{
py_output_objects_list
=
output_list
.
value
();
}
// Choose implementation
// Note: Currently only have fused kernel for FP8 cast-transpose
bool
with_fused_kernel
=
true
;
// create TE tensors from input
for
(
size_t
i
=
0
;
i
<
input_list
.
size
();
i
++
)
{
auto
input_tensor
=
makeTransformerEngineTensor
(
input_list
[
i
]);
const
NVTEShape
input_shape
=
input_tensor
.
shape
();
TensorWrapper
output_tensor
;
if
(
!
detail
::
IsFloat8Quantizers
(
quantizer_list
[
i
].
ptr
()))
{
with_fused_kernel
=
false
;
}
if
(
output_list
==
std
::
nullopt
)
{
std
::
unique_ptr
<
Quantizer
>
quantizer
=
convert_quantizer
(
quantizer_list
[
i
]);
std
::
vector
<
size_t
>
output_shape
(
input_shape
.
data
,
input_shape
.
data
+
input_shape
.
ndim
);
py
::
object
o
;
std
::
tie
(
output_tensor
,
o
)
=
quantizer
->
create_tensor
(
output_shape
,
otype
);
py_output_objects_list
.
push_back
(
o
);
}
else
{
output_tensor
=
makeTransformerEngineTensor
((
*
output_list
)[
i
],
quantizer_list
[
i
]);
}
if
(
input_tensor
.
numel
()
==
0
)
continue
;
nvte_tensor_output_list
.
emplace_back
(
output_tensor
.
data
());
nvte_tensor_input_list
.
emplace_back
(
input_tensor
.
data
());
tensor_wrappers
.
emplace_back
(
std
::
move
(
input_tensor
));
tensor_wrappers
.
emplace_back
(
std
::
move
(
output_tensor
));
}
// Check tensor lists
NVTE_CHECK
(
nvte_tensor_output_list
.
size
()
==
nvte_tensor_input_list
.
size
(),
"Number of input and output tensors must match"
);
for
(
size_t
i
=
0
;
i
<
nvte_tensor_output_list
.
size
();
i
++
)
{
if
(
nvte_tensor_columnwise_data
(
nvte_tensor_output_list
[
i
])
==
nullptr
)
{
with_fused_kernel
=
false
;
break
;
}
}
// Launch TE kernel
if
(
with_fused_kernel
)
{
NVTE_SCOPED_GIL_RELEASE
({
nvte_multi_cast_transpose
(
nvte_tensor_input_list
.
size
(),
nvte_tensor_input_list
.
data
(),
nvte_tensor_output_list
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
}
else
{
for
(
size_t
i
=
0
;
i
<
py_output_objects_list
.
size
();
i
++
)
{
quantize
(
input_list
[
i
],
quantizer_list
[
i
],
py_output_objects_list
[
i
],
std
::
nullopt
);
}
}
return
py_output_objects_list
;
}
at
::
Tensor
fp8_transpose
(
at
::
Tensor
input
,
DType
otype
,
std
::
optional
<
at
::
Tensor
>
output
)
{
at
::
Tensor
fp8_transpose
(
at
::
Tensor
input
,
DType
otype
,
std
::
optional
<
at
::
Tensor
>
output
)
{
init_extension
();
init_extension
();
...
@@ -108,4 +44,5 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional<at::Tensor
...
@@ -108,4 +44,5 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional<at::Tensor
return
out
;
return
out
;
}
}
}
// namespace transformer_engine::pytorch
}
// namespace pytorch
}
// namespace transformer_engine
transformer_engine/pytorch/csrc/quantizer.cpp
View file @
44740c6c
...
@@ -283,10 +283,8 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
...
@@ -283,10 +283,8 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
,
std
::
optional
<
at
::
Tensor
>
rowwise_data
)
const
{
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
,
std
::
optional
<
at
::
Tensor
>
rowwise_data
)
const
{
using
namespace
pybind11
::
literals
;
using
namespace
pybind11
::
literals
;
std
::
vector
<
int64_t
>
torch_shape
;
std
::
vector
<
int64_t
>
torch_shape
;
size_t
numel
=
1
;
for
(
auto
s
:
shape
)
{
for
(
auto
s
:
shape
)
{
torch_shape
.
emplace_back
(
static_cast
<
int64_t
>
(
s
));
torch_shape
.
emplace_back
(
static_cast
<
int64_t
>
(
s
));
numel
*=
s
;
}
}
TensorWrapper
tensor
(
this
->
get_scaling_mode
());
TensorWrapper
tensor
(
this
->
get_scaling_mode
());
...
@@ -296,10 +294,6 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
...
@@ -296,10 +294,6 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
opts
=
opts
.
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCUDA
);
opts
=
opts
.
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCUDA
);
scale_opts
=
scale_opts
.
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
);
scale_opts
=
scale_opts
.
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
);
size_t
k_dim
=
torch_shape
.
size
()
==
0
?
1u
:
torch_shape
.
back
();
size_t
m_dim
=
numel
/
k_dim
;
size_t
kBlockLen
=
static_cast
<
size_t
>
(
blockwise_fp8_block_len
());
Float8BlockScaleTensorFormat
data_format
=
Float8BlockScaleTensorFormat
data_format
=
(
all_gather_usage
?
Float8BlockScaleTensorFormat
::
COMPACT
(
all_gather_usage
?
Float8BlockScaleTensorFormat
::
COMPACT
:
Float8BlockScaleTensorFormat
::
GEMM_READY
);
:
Float8BlockScaleTensorFormat
::
GEMM_READY
);
...
@@ -310,30 +304,9 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
...
@@ -310,30 +304,9 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
}
else
{
}
else
{
data_rowwise
=
at
::
empty
(
torch_shape
,
opts
);
data_rowwise
=
at
::
empty
(
torch_shape
,
opts
);
}
}
size_t
sinv0
=
0
;
auto
scale_shape
=
get_scale_shape
(
shape
,
false
);
size_t
sinv1
=
0
;
size_t
sinv0
=
scale_shape
[
0
];
if
(
block_scaling_dim
==
2
)
{
size_t
sinv1
=
scale_shape
[
1
];
// 2D scaling is always GEMM_READY for now
NVTE_CHECK
(
data_format
==
Float8BlockScaleTensorFormat
::
GEMM_READY
,
"2D scaling is always GEMM_READY for now."
);
sinv0
=
(
m_dim
+
kBlockLen
-
1
)
/
kBlockLen
;
sinv1
=
roundup
((
k_dim
+
kBlockLen
-
1
)
/
kBlockLen
,
4
);
}
else
if
(
block_scaling_dim
==
1
)
{
// 1D scaling can be GEMM_READY or COMPACT
bool
rowwise_compact
=
data_format
==
Float8BlockScaleTensorFormat
::
COMPACT
;
// default rowwise scaling factor shape already transpose the scaling factor so it's GEMM_READY
sinv0
=
(
k_dim
+
kBlockLen
-
1
)
/
kBlockLen
;
sinv1
=
rowwise_compact
?
m_dim
:
roundup
(
m_dim
,
4
);
// if the rowwise format is compact, the scaling factor is not be transposed
if
(
rowwise_compact
)
{
std
::
swap
(
sinv0
,
sinv1
);
}
}
else
{
NVTE_ERROR
(
"Unsupported block_scaling_dim in create_tensor rowwise. "
"Expected 1 or 2. Got "
,
block_scaling_dim
);
}
scale_inv_rowwise
=
scale_inv_rowwise
=
at
::
empty
({
static_cast
<
int64_t
>
(
sinv0
),
static_cast
<
int64_t
>
(
sinv1
)},
scale_opts
);
at
::
empty
({
static_cast
<
int64_t
>
(
sinv0
),
static_cast
<
int64_t
>
(
sinv1
)},
scale_opts
);
tensor
.
set_rowwise_data
(
data_rowwise
.
data_ptr
(),
this
->
dtype
,
shape
);
tensor
.
set_rowwise_data
(
data_rowwise
.
data_ptr
(),
this
->
dtype
,
shape
);
...
@@ -364,27 +337,9 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
...
@@ -364,27 +337,9 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
columnwise_shape
=
shape
;
columnwise_shape
=
shape
;
}
}
}
}
size_t
sinv0
=
0
;
auto
scale_shape
=
get_scale_shape
(
shape
,
true
);
size_t
sinv1
=
0
;
size_t
sinv0
=
scale_shape
[
0
];
if
(
block_scaling_dim
==
2
)
{
size_t
sinv1
=
scale_shape
[
1
];
// 2D scaling is always GEMM_READY for now
NVTE_CHECK
(
data_format
==
Float8BlockScaleTensorFormat
::
GEMM_READY
,
"2D scaling is always GEMM_READY for now."
);
sinv0
=
(
k_dim
+
kBlockLen
-
1
)
/
kBlockLen
;
sinv1
=
roundup
((
m_dim
+
kBlockLen
-
1
)
/
kBlockLen
,
4
);
}
else
if
(
block_scaling_dim
==
1
)
{
bool
columnwise_compact
=
data_format
==
Float8BlockScaleTensorFormat
::
COMPACT
;
sinv0
=
(
m_dim
+
kBlockLen
-
1
)
/
kBlockLen
;
sinv1
=
columnwise_compact
?
k_dim
:
roundup
(
k_dim
,
4
);
// GEMM READY case: scaling factor is [sinv0, sinv1], already transposed here for CuBLAS
// for COMPACT case, since we apply 128x1 scaling here without transposing columnwise data, scaling factor is also [sinv0, sinv1]
// so no need to swap sinv0 and sinv1 here
}
else
{
NVTE_ERROR
(
"Unsupported block_scaling_dim in create_tensor columnwise. "
"Expected 1 or 2. Got "
,
block_scaling_dim
);
}
data_colwise
=
at
::
empty
(
torch_columnwise_shape
,
opts
);
data_colwise
=
at
::
empty
(
torch_columnwise_shape
,
opts
);
scale_inv_colwise
=
scale_inv_colwise
=
at
::
empty
({
static_cast
<
int64_t
>
(
sinv0
),
static_cast
<
int64_t
>
(
sinv1
)},
scale_opts
);
at
::
empty
({
static_cast
<
int64_t
>
(
sinv0
),
static_cast
<
int64_t
>
(
sinv1
)},
scale_opts
);
...
@@ -418,6 +373,81 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
...
@@ -418,6 +373,81 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
return
{
std
::
move
(
tensor
),
std
::
move
(
ret
)};
return
{
std
::
move
(
tensor
),
std
::
move
(
ret
)};
}
}
std
::
vector
<
size_t
>
Float8BlockQuantizer
::
get_scale_shape
(
const
std
::
vector
<
size_t
>&
shape
,
bool
columnwise
)
const
{
size_t
numel
=
1
;
for
(
auto
s
:
shape
)
{
numel
*=
s
;
}
size_t
k_dim
=
shape
.
size
()
==
0
?
1u
:
shape
.
back
();
size_t
m_dim
=
numel
/
k_dim
;
constexpr
size_t
kBlockLen
=
128
;
Float8BlockScaleTensorFormat
data_format
=
(
all_gather_usage
?
Float8BlockScaleTensorFormat
::
COMPACT
:
Float8BlockScaleTensorFormat
::
GEMM_READY
);
std
::
vector
<
size_t
>
scale_shape
;
bool
rowwise_usage
=
!
columnwise
;
if
(
rowwise_usage
)
{
// rowwise scaling factor shape
size_t
sinv0
=
0
;
size_t
sinv1
=
0
;
if
(
block_scaling_dim
==
2
)
{
// 2D scaling is always GEMM_READY for now
NVTE_CHECK
(
data_format
==
Float8BlockScaleTensorFormat
::
GEMM_READY
,
"2D scaling is always GEMM_READY for now."
);
sinv0
=
(
m_dim
+
kBlockLen
-
1
)
/
kBlockLen
;
sinv1
=
roundup
((
k_dim
+
kBlockLen
-
1
)
/
kBlockLen
,
4
);
}
else
if
(
block_scaling_dim
==
1
)
{
// 1D scaling can be GEMM_READY or COMPACT
bool
rowwise_compact
=
data_format
==
Float8BlockScaleTensorFormat
::
COMPACT
;
// default rowwise scaling factor shape already transpose the scaling factor so it's GEMM_READY
sinv0
=
(
k_dim
+
kBlockLen
-
1
)
/
kBlockLen
;
sinv1
=
rowwise_compact
?
m_dim
:
roundup
(
m_dim
,
4
);
// if the rowwise format is compact, the scaling factor is not be transposed
if
(
rowwise_compact
)
{
std
::
swap
(
sinv0
,
sinv1
);
}
}
else
{
NVTE_CHECK
(
false
,
"Unsupported block_scaling_dim in create_tensor rowwise."
"Expected 1 or 2. Got "
,
block_scaling_dim
);
}
scale_shape
=
{
sinv0
,
sinv1
};
}
else
{
// columnwise scaling factor shape
size_t
sinv0
=
0
;
size_t
sinv1
=
0
;
if
(
block_scaling_dim
==
2
)
{
// 2D scaling is always GEMM_READY for now
NVTE_CHECK
(
data_format
==
Float8BlockScaleTensorFormat
::
GEMM_READY
,
"2D scaling is always GEMM_READY for now."
);
sinv0
=
(
k_dim
+
kBlockLen
-
1
)
/
kBlockLen
;
sinv1
=
roundup
((
m_dim
+
kBlockLen
-
1
)
/
kBlockLen
,
4
);
}
else
if
(
block_scaling_dim
==
1
)
{
// 1D scaling can be GEMM_READY or COMPACT
bool
columnwise_compact
=
data_format
==
Float8BlockScaleTensorFormat
::
COMPACT
;
sinv0
=
(
m_dim
+
kBlockLen
-
1
)
/
kBlockLen
;
sinv1
=
columnwise_compact
?
k_dim
:
roundup
(
k_dim
,
4
);
// GEMM READY case: scaling factor is [sinv0, sinv1], already transposed here for CuBLAS
// for COMPACT case, since we apply 128x1 scaling here without transposing columnwise data, scaling factor is also [sinv0, sinv1]
// so no need to swap sinv0 and sinv1 here
}
else
{
NVTE_CHECK
(
false
,
"Unsupported block_scaling_dim in create_tensor columnwise."
"Expected 1 or 2. Got "
,
block_scaling_dim
);
}
scale_shape
=
{
sinv0
,
sinv1
};
}
return
scale_shape
;
}
MXFP8Quantizer
::
MXFP8Quantizer
(
const
py
::
handle
&
quantizer
)
:
Quantizer
(
quantizer
)
{
MXFP8Quantizer
::
MXFP8Quantizer
(
const
py
::
handle
&
quantizer
)
:
Quantizer
(
quantizer
)
{
this
->
dtype
=
quantizer
.
attr
(
"dtype"
).
cast
<
DType
>
();
this
->
dtype
=
quantizer
.
attr
(
"dtype"
).
cast
<
DType
>
();
}
}
...
@@ -450,11 +480,6 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
...
@@ -450,11 +480,6 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
at
::
Tensor
rowwise_data1
,
columnwise_data
,
rowwise_scale_inv
,
at
::
Tensor
rowwise_data1
,
columnwise_data
,
rowwise_scale_inv
,
columnwise_scale_inv
;
// TODO(pgadzinski) - change
columnwise_scale_inv
;
// TODO(pgadzinski) - change
opts
=
opts
.
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCUDA
);
opts
=
opts
.
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCUDA
);
auto
last_dim
=
static_cast
<
size_t
>
(
torch_shape
.
back
());
NVTE_CHECK
(
last_dim
%
MXFP8_BLOCK_SIZE
==
0
&&
(
numel
/
last_dim
)
%
MXFP8_BLOCK_SIZE
==
0
,
"MXFP8 requires tensor dims that are divisble by "
,
MXFP8_BLOCK_SIZE
,
" (got shape="
,
torch_shape
,
")"
);
at
::
Tensor
data
;
at
::
Tensor
data
;
if
(
rowwise_usage
)
{
if
(
rowwise_usage
)
{
...
@@ -463,9 +488,10 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
...
@@ -463,9 +488,10 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
}
else
{
}
else
{
data
=
at
::
empty
(
torch_shape
,
opts
);
data
=
at
::
empty
(
torch_shape
,
opts
);
}
}
auto
sinv0
=
roundup
(
numel
/
last_dim
,
128
);
auto
scale_shape
=
get_scale_shape
(
shape
,
false
);
auto
sinv1
=
roundup
(
last_dim
/
MXFP8_BLOCK_SIZE
,
4
);
size_t
sinv0
=
scale_shape
[
0
];
rowwise_scale_inv
=
at
::
zeros
({
sinv0
,
sinv1
},
opts
);
size_t
sinv1
=
scale_shape
[
1
];
rowwise_scale_inv
=
at
::
zeros
({
static_cast
<
int64_t
>
(
sinv0
),
static_cast
<
int64_t
>
(
sinv1
)},
opts
);
tensor
.
set_rowwise_data
(
data
.
data_ptr
(),
this
->
dtype
,
shape
);
tensor
.
set_rowwise_data
(
data
.
data_ptr
(),
this
->
dtype
,
shape
);
tensor
.
set_rowwise_scale_inv
(
tensor
.
set_rowwise_scale_inv
(
rowwise_scale_inv
.
data_ptr
(),
DType
::
kFloat8E8M0
,
rowwise_scale_inv
.
data_ptr
(),
DType
::
kFloat8E8M0
,
...
@@ -473,10 +499,12 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
...
@@ -473,10 +499,12 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
}
}
if
(
columnwise_usage
)
{
if
(
columnwise_usage
)
{
auto
sinv0
=
roundup
(
numel
/
(
last_dim
*
MXFP8_BLOCK_SIZE
),
4
);
auto
scale_shape
=
get_scale_shape
(
shape
,
true
);
auto
sinv1
=
roundup
(
last_dim
,
128
);
size_t
sinv0
=
scale_shape
[
0
];
size_t
sinv1
=
scale_shape
[
1
];
columnwise_data
=
at
::
empty
(
torch_shape
,
opts
);
columnwise_data
=
at
::
empty
(
torch_shape
,
opts
);
columnwise_scale_inv
=
at
::
zeros
({
sinv0
,
sinv1
},
opts
);
columnwise_scale_inv
=
at
::
zeros
({
static_cast
<
int64_t
>
(
sinv0
),
static_cast
<
int64_t
>
(
sinv1
)},
opts
);
tensor
.
set_columnwise_data
(
columnwise_data
.
data_ptr
(),
this
->
dtype
,
shape
);
tensor
.
set_columnwise_data
(
columnwise_data
.
data_ptr
(),
this
->
dtype
,
shape
);
tensor
.
set_columnwise_scale_inv
(
tensor
.
set_columnwise_scale_inv
(
...
@@ -504,4 +532,35 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
...
@@ -504,4 +532,35 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
return
{
std
::
move
(
tensor
),
std
::
move
(
ret
)};
return
{
std
::
move
(
tensor
),
std
::
move
(
ret
)};
}
}
std
::
vector
<
size_t
>
MXFP8Quantizer
::
get_scale_shape
(
const
std
::
vector
<
size_t
>&
shape
,
bool
columnwise
)
const
{
size_t
numel
=
1
;
for
(
auto
s
:
shape
)
{
numel
*=
s
;
}
auto
last_dim
=
shape
.
back
();
NVTE_CHECK
(
last_dim
%
MXFP8_BLOCK_SIZE
==
0
&&
(
numel
/
last_dim
)
%
MXFP8_BLOCK_SIZE
==
0
,
"MXFP8 requires tensor dims that are divisble by "
,
MXFP8_BLOCK_SIZE
,
" (got shape="
,
shape
,
")"
);
std
::
vector
<
size_t
>
scale_shape
;
bool
rowwise_usage
=
!
columnwise
;
if
(
rowwise_usage
)
{
// rowwise scaling factor shape
size_t
sinv0
=
roundup
(
numel
/
last_dim
,
128
);
size_t
sinv1
=
roundup
(
last_dim
/
MXFP8_BLOCK_SIZE
,
4
);
scale_shape
=
{
sinv0
,
sinv1
};
}
else
{
// columnwise scaling factor shape
size_t
sinv0
=
roundup
(
numel
/
(
last_dim
*
MXFP8_BLOCK_SIZE
),
4
);
size_t
sinv1
=
roundup
(
last_dim
,
128
);
scale_shape
=
{
sinv0
,
sinv1
};
}
return
scale_shape
;
}
}
// namespace transformer_engine::pytorch
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/export.py
0 → 100644
View file @
44740c6c
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Export utilities for TransformerEngine"""
from
contextlib
import
contextmanager
from
typing
import
Generator
import
torch
_IN_ONNX_EXPORT_MODE
=
False
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
1
])
@
contextmanager
def
onnx_export
(
enabled
:
bool
=
False
)
->
Generator
[
None
,
None
,
None
]:
"""
Context manager for exporting to ONNX.
.. code-block:: python
from transformer_engine.pytorch.export import onnx_export, te_translation_table
with onnx_export(enabled=True):
torch.onnx.export(model, dynamo=True, custom_translation_table=te_translation_table)
Parameters
----------
enabled: bool, default = `False`
whether or not to enable export
"""
global
_IN_ONNX_EXPORT_MODE
onnx_export_state
=
_IN_ONNX_EXPORT_MODE
if
(
TORCH_MAJOR
,
TORCH_MINOR
)
<
(
2
,
4
):
raise
RuntimeError
(
"ONNX export is not supported for PyTorch versions less than 2.4"
)
try
:
_IN_ONNX_EXPORT_MODE
=
enabled
yield
finally
:
_IN_ONNX_EXPORT_MODE
=
onnx_export_state
def
is_in_onnx_export_mode
()
->
bool
:
"""Returns True if onnx export mode is enabled, False otherwise."""
return
_IN_ONNX_EXPORT_MODE
def
assert_warmed_up
(
module
:
torch
.
nn
.
Module
)
->
None
:
"""Assert that the model has been warmed up before exporting to ONNX."""
assert
hasattr
(
module
,
"forwarded_at_least_once"
),
(
"Model must be warmed up before exporting to ONNX, please run model with the"
" same recipe before exporting."
)
if
TORCH_MAJOR
==
2
and
TORCH_MINOR
>=
4
or
TORCH_MAJOR
>
2
:
# pylint: disable=unused-import
from
.onnx_extensions
import
(
torch_onnx_gemm_inf_op
,
onnx_quantize_fp8_op
,
onnx_dequantize_fp8_op
,
onnx_quantize_mxfp8_op
,
onnx_dequantize_mxfp8_op
,
onnx_layernorm
,
onnx_attention_mask_func
,
onnx_gemm
,
te_translation_table
,
)
transformer_engine/pytorch/fp8.py
View file @
44740c6c
...
@@ -56,6 +56,8 @@ def check_fp8_support() -> Tuple[bool, str]:
...
@@ -56,6 +56,8 @@ def check_fp8_support() -> Tuple[bool, str]:
def
check_mxfp8_support
()
->
Tuple
[
bool
,
str
]:
def
check_mxfp8_support
()
->
Tuple
[
bool
,
str
]:
"""Return if fp8 support is available"""
"""Return if fp8 support is available"""
if
get_device_compute_capability
()
>=
(
12
,
0
):
return
False
,
"MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet."
if
get_device_compute_capability
()
>=
(
10
,
0
):
# blackwell and above
if
get_device_compute_capability
()
>=
(
10
,
0
):
# blackwell and above
return
True
,
""
return
True
,
""
return
False
,
"Device compute capability 10.0 or higher required for MXFP8 execution."
return
False
,
"Device compute capability 10.0 or higher required for MXFP8 execution."
...
@@ -79,7 +81,11 @@ def check_fp8_block_scaling_support() -> Tuple[bool, str]:
...
@@ -79,7 +81,11 @@ def check_fp8_block_scaling_support() -> Tuple[bool, str]:
def
get_default_fp8_recipe
()
->
Recipe
:
def
get_default_fp8_recipe
()
->
Recipe
:
"""FP8 recipe with default args."""
"""FP8 recipe with default args."""
if
get_device_compute_capability
()
>=
(
10
,
0
):
# blackwell and above
if
check_mxfp8_support
()[
0
]:
# This is a temporary restriction until MXFP8 is supported for all
# gemm layouts.
if
get_device_compute_capability
()
>=
(
12
,
0
):
return
Float8BlockScaling
()
return
MXFP8BlockScaling
()
return
MXFP8BlockScaling
()
return
DelayedScaling
()
return
DelayedScaling
()
...
...
transformer_engine/pytorch/graph.py
View file @
44740c6c
...
@@ -21,6 +21,7 @@ from .fp8 import (
...
@@ -21,6 +21,7 @@ from .fp8 import (
from
.distributed
import
get_all_rng_states
,
graph_safe_rng_available
from
.distributed
import
get_all_rng_states
,
graph_safe_rng_available
from
.module.base
import
TransformerEngineBaseModule
from
.module.base
import
TransformerEngineBaseModule
from
.ops.op
import
BasicOperation
from
.ops.op
import
BasicOperation
from
.utils
import
make_weak_ref
__all__
=
[
"make_graphed_callables"
]
__all__
=
[
"make_graphed_callables"
]
...
@@ -63,8 +64,10 @@ def _make_graphed_callables(
...
@@ -63,8 +64,10 @@ def _make_graphed_callables(
fp8_weight_caching
:
bool
=
False
,
fp8_weight_caching
:
bool
=
False
,
sample_kwargs
:
Optional
[
SingleOrTuple
[
Dict
[
str
,
Any
]]]
=
None
,
sample_kwargs
:
Optional
[
SingleOrTuple
[
Dict
[
str
,
Any
]]]
=
None
,
_order
:
Optional
[
List
[
int
]]
=
None
,
_order
:
Optional
[
List
[
int
]]
=
None
,
_num_layers_per_chunk
:
Optional
[
List
[
int
]]
=
None
,
pool
:
Optional
[
Tuple
[
int
,
...]]
=
None
,
pool
:
Optional
[
Tuple
[
int
,
...]]
=
None
,
retain_graph_in_backward
:
bool
=
False
,
retain_graph_in_backward
:
bool
=
False
,
_reuse_graph_input_output_buffers
:
bool
=
False
,
)
->
SingleOrTuple
[
Callable
]:
)
->
SingleOrTuple
[
Callable
]:
"""
"""
Helper method for `make_graphed_callables`
Helper method for `make_graphed_callables`
...
@@ -110,29 +113,113 @@ def _make_graphed_callables(
...
@@ -110,29 +113,113 @@ def _make_graphed_callables(
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py.
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py.
# Note: The model is assumed to consist of layers
# Note: The model is assumed to consist of layers
# (corresponding to callables) that are grouped into
# (corresponding to callables) that are grouped into
# equally-sized model chunks. _order is a list of chunk
# model chunks. _num_layers_per_chunk is a list of integers
# indices (1-indexed) that indicates the order in which the
# that indicates the number of layers in each model chunk.
# layers are evaluated. Positive values indicate forward
# _order is a list of chunk indices (1-indexed) that
# passes and negative values indicate backward passes. Each
# indicates the order in which the layers are evaluated.
# Positive values indicate forward passes and negative
# values indicate backward passes. Each
# entry in sample_args corresponds to one of the forward
# entry in sample_args corresponds to one of the forward
# passes.
# passes.
num_model_chunks
=
max
(
_order
)
num_model_chunks
=
max
(
_order
)
num_microbatches
=
len
(
_order
)
//
num_model_chunks
//
2
num_microbatches
=
len
(
_order
)
//
num_model_chunks
//
2
assert
num_model_chunks
*
num_microbatches
*
2
==
len
(
_order
)
assert
num_model_chunks
*
num_microbatches
*
2
==
len
(
_order
)
# Determine number of layers in each model chunk.
if
_num_layers_per_chunk
is
None
:
assert
len
(
sample_args
)
*
2
>=
len
(
_order
)
and
(
assert
len
(
sample_args
)
*
2
>=
len
(
_order
)
and
(
len
(
sample_args
)
*
2
%
len
(
_order
)
==
0
len
(
sample_args
)
*
2
%
len
(
_order
)
==
0
),
f
"
{
len
(
sample_args
)
}
>=
{
len
(
_order
)
}
and
{
len
(
sample_args
)
}
%
{
len
(
_order
)
}
== 0"
),
(
f
"
{
len
(
sample_args
)
}
* 2 >=
{
len
(
_order
)
}
and
{
len
(
sample_args
)
}
* 2 %"
f
"
{
len
(
_order
)
}
== 0"
)
num_layers
=
len
(
sample_args
)
//
num_model_chunks
//
num_microbatches
num_layers
=
len
(
sample_args
)
//
num_model_chunks
//
num_microbatches
assert
len
(
callables
)
==
num_model_chunks
*
num_layers
,
(
_num_layers_per_chunk
=
[
num_layers
]
*
num_model_chunks
f
"Callables should have (
{
num_model_chunks
*
num_layers
}
) "
else
:
assert
(
isinstance
(
_num_layers_per_chunk
,
int
)
or
len
(
_num_layers_per_chunk
)
==
num_model_chunks
),
(
"If _num_layers_per_chunk is provided, it must be an integer or a list of"
f
"
{
num_model_chunks
}
integers, but got
{
_num_layers_per_chunk
}
."
)
if
isinstance
(
_num_layers_per_chunk
,
int
):
_num_layers_per_chunk
=
[
_num_layers_per_chunk
]
*
num_model_chunks
total_num_layers
=
sum
(
_num_layers_per_chunk
)
assert
len
(
callables
)
==
total_num_layers
,
(
f
"Callables should have (
{
total_num_layers
}
) "
+
f
"entries when order input is provided but got
{
len
(
callables
)
}
."
+
f
"entries when order input is provided but got
{
len
(
callables
)
}
."
)
)
assert
len
(
sample_args
)
==
num_model_chunk
s
*
num_microbatches
*
num_layers
,
(
assert
len
(
sample_args
)
==
total_num_layer
s
*
num_microbatches
,
(
f
"Expected
{
num_model_chunk
s
*
num_microbatches
}
"
f
"Expected
{
total_num_layer
s
*
num_microbatches
}
"
+
f
"args tuple, but got
{
len
(
sample_args
)
}
."
+
f
"args tuple, but got
{
len
(
sample_args
)
}
."
)
)
# Calculate the starting index of each chunk in callables for future use.
_prefix_num_layers
=
[
0
]
for
m_chunk
in
range
(
num_model_chunks
):
num_layers
=
_num_layers_per_chunk
[
m_chunk
]
_prefix_num_layers
.
append
(
_prefix_num_layers
[
-
1
]
+
num_layers
)
assert
len
(
sample_kwargs
)
==
len
(
sample_args
)
assert
len
(
sample_kwargs
)
==
len
(
sample_args
)
# Check reuse graph conditions and reorganize sample_args and sample_kwargs.
# Note: When capturing a graph, we hold onto the args and kwargs so we have static buffers
# when the graph is replayed. If two model chunk microbatches have no overlap between their
# forward and backward, then we can reduce memory usage by reusing the same static buffers.
if
_reuse_graph_input_output_buffers
:
assert
(
_order
is
not
None
),
"`_order` must be provided when `_reuse_graph_input_output_buffers` is True."
assert
(
is_training
),
"`_reuse_graph_input_output_buffers` is only available in training mode."
assert
isinstance
(
sample_args
,
list
),
"sample_args must be a list for _reuse_graph_input_output_buffers."
len_args
=
len
(
sample_args
[
0
])
for
i
,
arg
in
enumerate
(
sample_args
):
assert
len_args
==
len
(
arg
),
"Arguments must have same length and shape for `_reuse_graph_input_output_buffers`."
len_kwargs
=
len
(
sample_kwargs
[
0
])
assert
isinstance
(
sample_kwargs
,
list
),
"sample_kwargs must be a list for _reuse_graph_input_output_buffers."
for
i
,
kwarg
in
enumerate
(
sample_kwargs
):
assert
len_kwargs
==
len
(
kwarg
),
(
"Keyword arguments must have same length and shape for"
" `_reuse_graph_input_output_buffers`."
)
# Reorganize args and kwargs for input tensor reuse.
fwd_sample_qs
=
{}
consumed_sample_q
=
[]
fwd_idx
=
[
0
]
*
num_model_chunks
for
c_id
in
_order
:
m_chunk
=
abs
(
c_id
)
-
1
if
c_id
>
0
:
sample_start_idx
=
(
_prefix_num_layers
[
m_chunk
]
*
num_microbatches
)
+
(
fwd_idx
[
m_chunk
]
*
_num_layers_per_chunk
[
m_chunk
]
)
fwd_sample_idx
=
[
sample_start_idx
+
i
for
i
in
range
(
_num_layers_per_chunk
[
m_chunk
])
]
fwd_sample_qs
[
m_chunk
]
=
fwd_sample_qs
.
get
(
m_chunk
,
[])
+
fwd_sample_idx
for
per_callable_fwd_idx
in
fwd_sample_idx
:
if
consumed_sample_q
:
reuse_fwd_idx
=
consumed_sample_q
.
pop
(
0
)
sample_args
[
per_callable_fwd_idx
]
=
sample_args
[
reuse_fwd_idx
]
sample_kwargs
[
per_callable_fwd_idx
]
=
sample_kwargs
[
reuse_fwd_idx
]
fwd_idx
[
m_chunk
]
+=
1
else
:
num_consumed_samples
=
min
(
len
(
fwd_sample_qs
[
m_chunk
]),
_num_layers_per_chunk
[
m_chunk
]
)
consumed_sample_q
+=
fwd_sample_qs
[
m_chunk
][:
num_consumed_samples
]
fwd_sample_qs
[
m_chunk
]
=
fwd_sample_qs
[
m_chunk
][
num_consumed_samples
:]
if
fp8_weight_caching
:
if
fp8_weight_caching
:
# Initialize flag that controls FP8 weight updates
# Initialize flag that controls FP8 weight updates
FP8GlobalStateManager
.
set_skip_fp8_weight_update_tensor
(
False
)
FP8GlobalStateManager
.
set_skip_fp8_weight_update_tensor
(
False
)
...
@@ -185,10 +272,13 @@ def _make_graphed_callables(
...
@@ -185,10 +272,13 @@ def _make_graphed_callables(
per_callable_module_params
=
[]
per_callable_module_params
=
[]
for
m_chunk
in
range
(
num_model_chunks
):
for
m_chunk
in
range
(
num_model_chunks
):
for
_
in
range
(
num_microbatches
):
for
_
in
range
(
num_microbatches
):
for
l_no
in
range
(
num_layers
):
for
l_no
in
range
(
_
num_layers
_per_chunk
[
m_chunk
]
):
per_callable_module_params
.
append
(
per_callable_module_params
.
append
(
tuple
(
callables
[
m_chunk
*
num_layers
+
l_no
].
parameters
())
tuple
(
callables
[
_prefix_num_layers
[
m_chunk
]
+
l_no
].
parameters
())
if
isinstance
(
callables
[
m_chunk
*
num_layers
+
l_no
],
torch
.
nn
.
Module
)
if
isinstance
(
callables
[
_prefix_num_layers
[
m_chunk
]
+
l_no
],
torch
.
nn
.
Module
,
)
else
()
else
()
)
)
assert
len
(
per_callable_module_params
)
==
len
(
flatten_sample_args
)
assert
len
(
per_callable_module_params
)
==
len
(
flatten_sample_args
)
...
@@ -227,10 +317,10 @@ def _make_graphed_callables(
...
@@ -227,10 +317,10 @@ def _make_graphed_callables(
for
c_id
in
_order
:
for
c_id
in
_order
:
if
c_id
>
0
:
if
c_id
>
0
:
m_chunk
=
c_id
-
1
m_chunk
=
c_id
-
1
for
l_no
in
range
(
num_layers
):
for
l_no
in
range
(
_
num_layers
_per_chunk
[
m_chunk
]
):
func
=
callables
[
m_chunk
*
num_layers
+
l_no
]
func
=
callables
[
_prefix_
num_layers
[
m_chunk
]
+
l_no
]
func_idx
=
(
m_chunk
*
num_microbatches
*
num_layers
)
+
(
func_idx
=
(
_prefix_num_layers
[
m_chunk
]
*
num_microbatches
)
+
(
fwd_idx
[
m_chunk
]
*
num_layers
+
l_no
fwd_idx
[
m_chunk
]
*
_
num_layers
_per_chunk
[
m_chunk
]
+
l_no
)
)
warmup_func_idx
.
append
(
func_idx
)
warmup_func_idx
.
append
(
func_idx
)
warmup_func
.
append
(
func
)
warmup_func
.
append
(
func
)
...
@@ -255,7 +345,7 @@ def _make_graphed_callables(
...
@@ -255,7 +345,7 @@ def _make_graphed_callables(
args
=
sample_args
[
func_idx
]
args
=
sample_args
[
func_idx
]
kwargs
=
sample_kwargs
[
func_idx
]
kwargs
=
sample_kwargs
[
func_idx
]
static_input_surface
=
per_callable_static_input_surfaces
[
func_idx
]
static_input_surface
=
per_callable_static_input_surfaces
[
func_idx
]
for
_
in
range
(
num_warmup_iters
):
for
warmup_iter
in
range
(
num_warmup_iters
):
hooks
=
[]
hooks
=
[]
for
module
in
func
.
modules
():
for
module
in
func
.
modules
():
hook
=
module
.
register_forward_hook
(
hook_fn
)
hook
=
module
.
register_forward_hook
(
hook_fn
)
...
@@ -271,6 +361,34 @@ def _make_graphed_callables(
...
@@ -271,6 +361,34 @@ def _make_graphed_callables(
only_inputs
=
True
,
only_inputs
=
True
,
allow_unused
=
allow_unused_input
,
allow_unused
=
allow_unused_input
,
)
)
# Filter module params that get None grad from grad_inputs and remove them
# from static_input_surface. This is to ensure that the backward hooks
# registered to these params are not wrongly triggered.
num_required_grad_sample_args
=
sum
(
arg
.
requires_grad
for
arg
in
flatten_sample_args
[
func_idx
]
)
required_grad_input_idx
=
[]
for
i
,
arg
in
enumerate
(
static_input_surface
):
if
arg
.
requires_grad
:
required_grad_input_idx
.
append
(
i
)
module_params_with_grad
=
[]
for
grad_inputs_idx
,
inputs_idx
in
enumerate
(
required_grad_input_idx
):
if
(
grad_inputs
[
grad_inputs_idx
]
is
not
None
and
grad_inputs_idx
>=
num_required_grad_sample_args
):
module_params_with_grad
.
append
(
static_input_surface
[
inputs_idx
])
if
len
(
module_params_with_grad
)
!=
len
(
per_callable_module_params
[
func_idx
]):
assert
warmup_iter
==
0
,
(
"no-grad params should only be used as inputs in the first warmup"
" iteration"
)
per_callable_module_params
[
func_idx
]
=
tuple
(
module_params_with_grad
)
static_input_surface
=
flatten_sample_args
[
func_idx
]
+
tuple
(
module_params_with_grad
)
per_callable_static_input_surfaces
[
func_idx
]
=
static_input_surface
else
:
else
:
grad_inputs
=
None
grad_inputs
=
None
del
outputs
,
grad_inputs
del
outputs
,
grad_inputs
...
@@ -292,14 +410,16 @@ def _make_graphed_callables(
...
@@ -292,14 +410,16 @@ def _make_graphed_callables(
per_callable_static_grad_inputs
=
[
None
]
*
len
(
flatten_sample_args
)
per_callable_static_grad_inputs
=
[
None
]
*
len
(
flatten_sample_args
)
fwd_idx
=
[
0
]
*
num_model_chunks
fwd_idx
=
[
0
]
*
num_model_chunks
bwd_idx
=
[
0
]
*
num_model_chunks
bwd_idx
=
[
0
]
*
num_model_chunks
static_grad_outputs
=
None
previous_per_callable_bwd_idx
=
None
for
c_id
in
_order
:
for
c_id
in
_order
:
if
c_id
>
0
:
if
c_id
>
0
:
# Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1]
# Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1]
m_chunk
=
c_id
-
1
m_chunk
=
c_id
-
1
for
l_no
in
range
(
num_layers
):
for
l_no
in
range
(
_
num_layers
_per_chunk
[
m_chunk
]
):
func
=
callables
[
m_chunk
*
num_layers
+
l_no
]
func
=
callables
[
_prefix_
num_layers
[
m_chunk
]
+
l_no
]
per_callable_fwd_idx
=
(
m_chunk
*
num_microbatches
*
num_layers
)
+
(
per_callable_fwd_idx
=
(
_prefix_num_layers
[
m_chunk
]
*
num_microbatches
)
+
(
fwd_idx
[
m_chunk
]
*
num_layers
+
l_no
fwd_idx
[
m_chunk
]
*
_
num_layers
_per_chunk
[
m_chunk
]
+
l_no
)
)
args
=
sample_args
[
per_callable_fwd_idx
]
args
=
sample_args
[
per_callable_fwd_idx
]
kwargs
=
sample_kwargs
[
per_callable_fwd_idx
]
kwargs
=
sample_kwargs
[
per_callable_fwd_idx
]
...
@@ -314,14 +434,17 @@ def _make_graphed_callables(
...
@@ -314,14 +434,17 @@ def _make_graphed_callables(
else
:
else
:
# Capture backward graph for model chunk c_id, microbatch bwd_idx[-c_id-1]
# Capture backward graph for model chunk c_id, microbatch bwd_idx[-c_id-1]
m_chunk
=
-
c_id
-
1
m_chunk
=
-
c_id
-
1
for
l_no
in
list
(
reversed
(
range
(
num_layers
))):
for
l_no
in
list
(
reversed
(
range
(
_
num_layers
_per_chunk
[
m_chunk
]
))):
per_callable_bwd_idx
=
(
m_chunk
*
num_microbatches
*
num_layers
)
+
(
per_callable_bwd_idx
=
(
_prefix_num_layers
[
m_chunk
]
*
num_microbatches
)
+
(
bwd_idx
[
m_chunk
]
*
num_layers
+
l_no
bwd_idx
[
m_chunk
]
*
_
num_layers
_per_chunk
[
m_chunk
]
+
l_no
)
)
static_input_surface
=
per_callable_static_input_surfaces
[
per_callable_bwd_idx
]
static_input_surface
=
per_callable_static_input_surfaces
[
per_callable_bwd_idx
]
static_outputs
=
per_callable_static_outputs
[
per_callable_bwd_idx
]
static_outputs
=
per_callable_static_outputs
[
per_callable_bwd_idx
]
bwd_graph
=
bwd_graphs
[
per_callable_bwd_idx
]
bwd_graph
=
bwd_graphs
[
per_callable_bwd_idx
]
# For now, assumes all static_outputs require grad
# For now, assumes all static_outputs require grad
if
not
_reuse_graph_input_output_buffers
or
static_grad_outputs
is
None
:
# Note for _reuse_graph_input_output_buffers: grad output is only used
# within backward, so we can reuse the same static buffers every time.
static_grad_outputs
=
tuple
(
static_grad_outputs
=
tuple
(
torch
.
empty_like
(
o
)
if
o
.
requires_grad
else
None
for
o
in
static_outputs
torch
.
empty_like
(
o
)
if
o
.
requires_grad
else
None
for
o
in
static_outputs
)
)
...
@@ -350,6 +473,30 @@ def _make_graphed_callables(
...
@@ -350,6 +473,30 @@ def _make_graphed_callables(
per_callable_static_grad_outputs
[
per_callable_bwd_idx
]
=
static_grad_outputs
per_callable_static_grad_outputs
[
per_callable_bwd_idx
]
=
static_grad_outputs
per_callable_static_grad_inputs
[
per_callable_bwd_idx
]
=
static_grad_inputs
per_callable_static_grad_inputs
[
per_callable_bwd_idx
]
=
static_grad_inputs
# Weak ref the static outputs and static grad inputs that are no longer needed
# in the following steps. These two type of tensors are both in cudagraph
# mempool, so we just deallocate them and let PyTorch's memory allocator
# reuse them elsewhere.
if
_reuse_graph_input_output_buffers
:
# Weak ref the static outputs of the forward pass of this backward. It's
# no longer needed after the corresponding backward graph is built up.
per_callable_static_outputs
[
per_callable_bwd_idx
]
=
make_weak_ref
(
static_outputs
)
# Weak ref the static grad inputs of the previous backward pass.
# Note: After a backward pass, we assume Mcore will send the
# grad input to another pipeline parallel rank and that the
# communication is finished before the end of the next backward
# pass.
if
previous_per_callable_bwd_idx
is
not
None
:
per_callable_static_grad_inputs
[
previous_per_callable_bwd_idx
]
=
(
make_weak_ref
(
per_callable_static_grad_inputs
[
previous_per_callable_bwd_idx
]
)
)
previous_per_callable_bwd_idx
=
per_callable_bwd_idx
bwd_idx
[
m_chunk
]
+=
1
bwd_idx
[
m_chunk
]
+=
1
else
:
else
:
# Capture forward graphs
# Capture forward graphs
...
@@ -593,7 +740,7 @@ def save_fp8_tensors(
...
@@ -593,7 +740,7 @@ def save_fp8_tensors(
m
.
adjust_amax_history_length
(
fp8_recipe
.
amax_history_len
)
m
.
adjust_amax_history_length
(
fp8_recipe
.
amax_history_len
)
module_tensors
=
m
.
get_fp8_meta_tensors
()
module_tensors
=
m
.
get_fp8_meta_tensors
()
elif
isinstance
(
m
,
BasicOperation
):
elif
isinstance
(
m
,
BasicOperation
):
m
.
pre_forward
(
fp8_enabled
=
True
,
fp8_
recipe
=
fp8_recipe
)
m
.
pre_
first_
forward
(
recipe
=
fp8_recipe
)
module_tensors
=
m
.
_save_fp8_metas
()
module_tensors
=
m
.
_save_fp8_metas
()
fp8_tensors
.
append
(
module_tensors
)
fp8_tensors
.
append
(
module_tensors
)
return
fp8_tensors
return
fp8_tensors
...
@@ -634,8 +781,10 @@ def make_graphed_callables(
...
@@ -634,8 +781,10 @@ def make_graphed_callables(
fp8_group
:
Optional
[
dist_group_type
]
=
None
,
fp8_group
:
Optional
[
dist_group_type
]
=
None
,
fp8_weight_caching
:
bool
=
False
,
fp8_weight_caching
:
bool
=
False
,
_order
:
Optional
[
List
[
int
]]
=
None
,
_order
:
Optional
[
List
[
int
]]
=
None
,
_num_layers_per_chunk
:
Optional
[
List
[
int
]]
=
None
,
pool
:
Optional
[
Tuple
[
int
,
...]]
=
None
,
pool
:
Optional
[
Tuple
[
int
,
...]]
=
None
,
retain_graph_in_backward
:
bool
=
False
,
retain_graph_in_backward
:
bool
=
False
,
_reuse_graph_input_output_buffers
:
bool
=
False
,
)
->
Union
[
Callable
,
Tuple
[
Callable
,
...]]:
)
->
Union
[
Callable
,
Tuple
[
Callable
,
...]]:
"""
"""
Make CUDA graph version of Transformer Engine modules
Make CUDA graph version of Transformer Engine modules
...
@@ -664,6 +813,11 @@ def make_graphed_callables(
...
@@ -664,6 +813,11 @@ def make_graphed_callables(
this graph may share memory with the indicated pool.
this graph may share memory with the indicated pool.
retain_graph_in_backward: bool, default = `False`
retain_graph_in_backward: bool, default = `False`
Whether to set retain_graph=True in backward graph capture.
Whether to set retain_graph=True in backward graph capture.
_reuse_graph_input_output_buffers: bool, default = `False`
Reduce memory usage by reusing input/output data buffers between
graphs. Only supported with Mcore interleaved pipeline parallelism, i.e.
when `_order` is provided. All callables in `modules` are assumed to have
inputs and outputs with the same dtype and shape.
FP8-related parameters
FP8-related parameters
----------------------
----------------------
...
@@ -702,10 +856,17 @@ def make_graphed_callables(
...
@@ -702,10 +856,17 @@ def make_graphed_callables(
saved_fp8_tensors
=
save_fp8_tensors
(
modules
,
fp8_recipe
=
fp8_recipe
)
saved_fp8_tensors
=
save_fp8_tensors
(
modules
,
fp8_recipe
=
fp8_recipe
)
# FP8 wrapper.
# FP8 wrapper.
old_call_funcs
=
{}
def
wrap_autocast
(
block
):
def
wrap_autocast
(
block
):
old_forward
=
block
.
forward
block_cls
=
type
(
block
)
if
block_cls
in
old_call_funcs
:
return
def
forward_func
(
*
args
,
**
kwargs
):
old_call_funcs
[
block_cls
]
=
block_cls
.
__call__
# Wrap the original call function of the module class.
def
call_func
(
*
args
,
**
kwargs
):
with
fp8_autocast
(
with
fp8_autocast
(
enabled
=
fp8_enabled
,
enabled
=
fp8_enabled
,
calibrating
=
fp8_calibrating
,
calibrating
=
fp8_calibrating
,
...
@@ -713,10 +874,10 @@ def make_graphed_callables(
...
@@ -713,10 +874,10 @@ def make_graphed_callables(
fp8_group
=
fp8_group
,
fp8_group
=
fp8_group
,
_graph
=
True
,
_graph
=
True
,
):
):
outputs
=
old_
forward
(
*
args
,
**
kwargs
)
outputs
=
old_
call_funcs
[
block_cls
]
(
*
args
,
**
kwargs
)
return
outputs
return
outputs
block
.
forward
=
forward
_func
block
_cls
.
__call__
=
call
_func
forward_funcs
=
[]
forward_funcs
=
[]
for
module
in
modules
:
for
module
in
modules
:
...
@@ -747,8 +908,10 @@ def make_graphed_callables(
...
@@ -747,8 +908,10 @@ def make_graphed_callables(
fp8_weight_caching
=
fp8_weight_caching
,
fp8_weight_caching
=
fp8_weight_caching
,
sample_kwargs
=
sample_kwargs
,
sample_kwargs
=
sample_kwargs
,
_order
=
_order
,
_order
=
_order
,
_num_layers_per_chunk
=
_num_layers_per_chunk
,
pool
=
pool
,
pool
=
pool
,
retain_graph_in_backward
=
retain_graph_in_backward
,
retain_graph_in_backward
=
retain_graph_in_backward
,
_reuse_graph_input_output_buffers
=
_reuse_graph_input_output_buffers
,
)
)
# Ensures warmup does not affect numerics for ops such as dropout.
# Ensures warmup does not affect numerics for ops such as dropout.
...
@@ -758,6 +921,10 @@ def make_graphed_callables(
...
@@ -758,6 +921,10 @@ def make_graphed_callables(
else
:
else
:
torch
.
cuda
.
set_rng_state
(
original_rng_states
)
torch
.
cuda
.
set_rng_state
(
original_rng_states
)
# Remove FP8 wrapper.
for
module_cls
,
old_call
in
old_call_funcs
.
items
():
module_cls
.
__call__
=
old_call
# Restore FP8 state.
# Restore FP8 state.
restore_fp8_tensors
(
modules
,
saved_fp8_tensors
)
restore_fp8_tensors
(
modules
,
saved_fp8_tensors
)
...
...
transformer_engine/pytorch/jit.py
View file @
44740c6c
...
@@ -6,10 +6,10 @@
...
@@ -6,10 +6,10 @@
import
os
import
os
from
functools
import
wraps
from
functools
import
wraps
from
typing
import
Callable
,
Optional
,
Tuple
from
typing
import
Callable
,
Optional
,
Tuple
import
torch
import
torch
from
.
import
torch_version
from
.
import
torch_version
from
.export
import
is_in_onnx_export_mode
from
.utils
import
gpu_autocast_ctx
from
.utils
import
gpu_autocast_ctx
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
...
@@ -47,7 +47,17 @@ if torch_version() >= (2, 2, 0) and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"
...
@@ -47,7 +47,17 @@ if torch_version() >= (2, 2, 0) and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"
# Decorator to disable Torch Dynamo
# Decorator to disable Torch Dynamo
# See: https://github.com/NVIDIA/TransformerEngine/issues/308
# See: https://github.com/NVIDIA/TransformerEngine/issues/308
no_torch_dynamo
=
lambda
recursive
=
True
:
lambda
f
:
torch
.
_dynamo
.
disable
(
f
,
recursive
=
recursive
)
no_torch_dynamo
=
lambda
recursive
=
True
:
lambda
func
:
func
if
torch
.
__version__
>=
"2"
:
import
torch._dynamo
if
torch
.
__version__
>=
"2.1"
:
no_torch_dynamo
=
lambda
recursive
=
True
:
lambda
f
:
(
f
if
is_in_onnx_export_mode
()
else
torch
.
_dynamo
.
disable
(
f
,
recursive
=
recursive
)
)
else
:
# no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
no_torch_dynamo
=
lambda
recursive
=
True
:
torch
.
_dynamo
.
disable
def
set_jit_fusion_options
()
->
None
:
def
set_jit_fusion_options
()
->
None
:
...
...
transformer_engine/pytorch/module/_common.py
View file @
44740c6c
...
@@ -13,6 +13,7 @@ import torch
...
@@ -13,6 +13,7 @@ import torch
from
..
import
cpp_extensions
as
tex
from
..
import
cpp_extensions
as
tex
from
..constants
import
TE_DType
from
..constants
import
TE_DType
from
..utils
import
get_default_init_method
from
..utils
import
get_default_init_method
from
..export
import
is_in_onnx_export_mode
import
warnings
import
warnings
try
:
try
:
from
lightop
import
rmsnorm_forward
,
rmsnorm_backward
from
lightop
import
rmsnorm_forward
,
rmsnorm_backward
...
@@ -173,6 +174,8 @@ def noop_cat(
...
@@ -173,6 +174,8 @@ def noop_cat(
raise
ValueError
(
"Attempted to concatenate 0 tensors"
)
raise
ValueError
(
"Attempted to concatenate 0 tensors"
)
if
len
(
tensors
)
==
1
:
if
len
(
tensors
)
==
1
:
return
tensors
[
0
]
return
tensors
[
0
]
if
is_in_onnx_export_mode
():
return
torch
.
cat
(
tensors
,
dim
=
dim
)
return
_NoopCatFunc
.
apply
(
dim
,
*
tensors
)
return
_NoopCatFunc
.
apply
(
dim
,
*
tensors
)
...
...
transformer_engine/pytorch/module/base.py
View file @
44740c6c
...
@@ -1035,6 +1035,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -1035,6 +1035,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
to setup the forward aggregated amax reduction for every module
to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent one.
just in case. The autocast exit will pick up the most recent one.
"""
"""
self
.
forwarded_at_least_once
=
True
# Activation recomputation is used and this is the second forward phase.
# Activation recomputation is used and this is the second forward phase.
if
self
.
fp8
and
in_fp8_activation_recompute_phase
():
if
self
.
fp8
and
in_fp8_activation_recompute_phase
():
FP8GlobalStateManager
.
get_old_fp8_meta_tensors_for_recompute
(
self
.
fp8_meta
)
FP8GlobalStateManager
.
get_old_fp8_meta_tensors_for_recompute
(
self
.
fp8_meta
)
...
...
transformer_engine/pytorch/module/fp8_padding.py
View file @
44740c6c
...
@@ -53,15 +53,16 @@ class _Fp8Padding(torch.autograd.Function):
...
@@ -53,15 +53,16 @@ class _Fp8Padding(torch.autograd.Function):
if
ctx
.
requires_dgrad
:
if
ctx
.
requires_dgrad
:
grad_output
=
grad_output
.
contiguous
()
grad_output
=
grad_output
.
contiguous
()
grad_output_mats
=
torch
.
split
(
in_features
=
grad_output
.
shape
[
-
1
]
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
-
1
]),
ctx
.
padded_m_splits
# Allocate cast and transpose output tensor
total_row
=
sum
(
ctx
.
m_splits
)
grad_input
=
torch
.
empty
(
[
total_row
,
in_features
],
dtype
=
grad_output
.
dtype
,
device
=
grad_output
.
device
)
)
grad_input
=
torch
.
cat
(
[
tex
.
fused_multi_row_unpadding
(
grad_output_mat
[:
ctx
.
m_splits
[
i
]]
grad_output
.
view
(
-
1
,
in_features
),
grad_input
,
ctx
.
padded_m_splits
,
ctx
.
m_splits
for
i
,
grad_output_mat
in
enumerate
(
grad_output_mats
)
],
dim
=
0
,
)
)
return
(
grad_input
,
None
,
None
,
None
)
return
(
grad_input
,
None
,
None
,
None
)
...
@@ -73,11 +74,12 @@ class Fp8Padding(torch.nn.Module):
...
@@ -73,11 +74,12 @@ class Fp8Padding(torch.nn.Module):
Parameters
Parameters
----------
----------
num_gemms: int
num_gemms
: int
number of GEMMs to be performed simutaneously.
number of GEMMs to be performed simu
l
taneously.
align_size: int, optional
align_size
: int, optional
the alignment size for the input tensor. If not provided, the alignment size will
the alignment size for the input tensor. If not provided, the alignment size will
be determined by the FP8 recipe, 32 for MXFP8 and 16 for others.
be determined by the FP8 recipe (32 for MXFP8 and 16 for others) in the first
forward pass.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -88,9 +90,6 @@ class Fp8Padding(torch.nn.Module):
...
@@ -88,9 +90,6 @@ class Fp8Padding(torch.nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
num_gemms
=
num_gemms
self
.
num_gemms
=
num_gemms
if
align_size
is
None
:
self
.
align_size
=
32
if
FP8GlobalStateManager
.
get_fp8_recipe
().
mxfp8
()
else
16
else
:
self
.
align_size
=
align_size
self
.
align_size
=
align_size
@
no_torch_dynamo
()
@
no_torch_dynamo
()
...
@@ -111,6 +110,8 @@ class Fp8Padding(torch.nn.Module):
...
@@ -111,6 +110,8 @@ class Fp8Padding(torch.nn.Module):
"""
"""
assert
len
(
m_splits
)
==
self
.
num_gemms
,
"Number of splits should match number of GEMMs."
assert
len
(
m_splits
)
==
self
.
num_gemms
,
"Number of splits should match number of GEMMs."
if
self
.
align_size
is
None
:
self
.
align_size
=
32
if
FP8GlobalStateManager
.
get_fp8_recipe
().
mxfp8
()
else
16
# FP8 padding calculate
# FP8 padding calculate
padded_m_splits
=
[
padded_m_splits
=
[
...
...
transformer_engine/pytorch/module/fp8_unpadding.py
View file @
44740c6c
...
@@ -29,10 +29,13 @@ class _Fp8Unpadding(torch.autograd.Function):
...
@@ -29,10 +29,13 @@ class _Fp8Unpadding(torch.autograd.Function):
is_grad_enabled
:
bool
,
is_grad_enabled
:
bool
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# pylint: disable=missing-function-docstring
# pylint: disable=missing-function-docstring
inputmats
=
torch
.
split
(
inp
.
view
(
-
1
,
inp
.
shape
[
-
1
]),
padded_m_splits
)
in_features
=
inp
.
shape
[
-
1
]
out_ret
=
torch
.
cat
(
[
grad_output_mat
[:
m_splits
[
i
]]
for
i
,
grad_output_mat
in
enumerate
(
inputmats
)],
dim
=
0
# Allocate cast and transpose output tensor
)
total_row
=
sum
(
m_splits
)
out_ret
=
torch
.
empty
([
total_row
,
in_features
],
dtype
=
inp
.
dtype
,
device
=
inp
.
device
)
tex
.
fused_multi_row_unpadding
(
inp
.
view
(
-
1
,
in_features
),
out_ret
,
padded_m_splits
,
m_splits
)
if
is_grad_enabled
:
if
is_grad_enabled
:
ctx
.
m_splits
=
m_splits
ctx
.
m_splits
=
m_splits
...
@@ -69,11 +72,12 @@ class Fp8Unpadding(torch.nn.Module):
...
@@ -69,11 +72,12 @@ class Fp8Unpadding(torch.nn.Module):
Parameters
Parameters
----------
----------
num_gemms: int
num_gemms
: int
number of GEMMs to be performed simutaneously.
number of GEMMs to be performed simu
l
taneously.
align_size: int, optional
align_size
: int, optional
the alignment size for the input tensor. If not provided, the alignment size will
the alignment size for the input tensor. If not provided, the alignment size will
be determined by the FP8 recipe, 32 for MXFP8 and 16 for others.
be determined by the FP8 recipe (32 for MXFP8 and 16 for others) in the first
forward pass.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -84,9 +88,6 @@ class Fp8Unpadding(torch.nn.Module):
...
@@ -84,9 +88,6 @@ class Fp8Unpadding(torch.nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
num_gemms
=
num_gemms
self
.
num_gemms
=
num_gemms
if
align_size
is
None
:
self
.
align_size
=
32
if
FP8GlobalStateManager
.
get_fp8_recipe
().
mxfp8
()
else
16
else
:
self
.
align_size
=
align_size
self
.
align_size
=
align_size
@
no_torch_dynamo
()
@
no_torch_dynamo
()
...
@@ -107,6 +108,8 @@ class Fp8Unpadding(torch.nn.Module):
...
@@ -107,6 +108,8 @@ class Fp8Unpadding(torch.nn.Module):
"""
"""
assert
len
(
m_splits
)
==
self
.
num_gemms
,
"Number of splits should match number of GEMMs."
assert
len
(
m_splits
)
==
self
.
num_gemms
,
"Number of splits should match number of GEMMs."
if
self
.
align_size
is
None
:
self
.
align_size
=
32
if
FP8GlobalStateManager
.
get_fp8_recipe
().
mxfp8
()
else
16
# FP8 padding calculate
# FP8 padding calculate
padded_m_splits
=
[
padded_m_splits
=
[
...
...
transformer_engine/pytorch/module/grouped_linear.py
View file @
44740c6c
...
@@ -25,7 +25,6 @@ from ..fp8 import FP8GlobalStateManager
...
@@ -25,7 +25,6 @@ from ..fp8 import FP8GlobalStateManager
from
..utils
import
(
from
..utils
import
(
divide
,
divide
,
cast_if_needed
,
cast_if_needed
,
assert_dim_for_fp8_exec
,
clear_tensor_data
,
clear_tensor_data
,
init_method_constant
,
init_method_constant
,
requires_grad
,
requires_grad
,
...
@@ -39,11 +38,12 @@ from ..distributed import (
...
@@ -39,11 +38,12 @@ from ..distributed import (
from
..cpp_extensions
import
(
from
..cpp_extensions
import
(
general_grouped_gemm
,
general_grouped_gemm
,
)
)
from
..constants
import
GemmParallelModes
,
dist_group_type
,
TE_DType
from
..constants
import
GemmParallelModes
,
dist_group_type
from
..jit
import
no_torch_dynamo
from
..jit
import
no_torch_dynamo
from
..graph
import
is_graph_capturing
from
..graph
import
is_graph_capturing
from
..cpu_offload
import
is_cpu_offload_enabled
from
..cpu_offload
import
is_cpu_offload_enabled
from
..tensor.float8_tensor
import
Float8CurrentScalingQuantizer
,
Float8Quantizer
from
..tensor.quantized_tensor
import
(
from
..tensor.quantized_tensor
import
(
QuantizedTensorBase
,
QuantizedTensorBase
,
Quantizer
,
Quantizer
,
...
@@ -80,6 +80,7 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -80,6 +80,7 @@ class _GroupedLinear(torch.autograd.Function):
is_grad_enabled
:
bool
,
is_grad_enabled
:
bool
,
module
,
module
,
skip_fp8_weight_update
,
skip_fp8_weight_update
,
save_original_input
,
*
weights_and_biases
,
*
weights_and_biases
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# pylint: disable=missing-function-docstring
# pylint: disable=missing-function-docstring
...
@@ -88,25 +89,18 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -88,25 +89,18 @@ class _GroupedLinear(torch.autograd.Function):
weights
=
weights_and_biases
[:
num_gemms
]
weights
=
weights_and_biases
[:
num_gemms
]
biases
=
weights_and_biases
[
num_gemms
:]
biases
=
weights_and_biases
[
num_gemms
:]
device
=
inp
.
device
device
=
inp
.
device
# Make sure input dimensions are compatible
in_features
=
weights
[
0
].
shape
[
-
1
]
assert
inp
.
shape
[
-
1
]
==
in_features
,
"GEMM not possible"
inputmats
=
torch
.
split
(
inp
.
view
(
-
1
,
in_features
),
m_splits
)
if
fp8
:
assert_dim_for_fp8_exec
(
*
inputmats
,
*
weights
)
# Cast input to expected dtype
inputmats_no_fp8
=
[
cast_if_needed
(
mat
,
activation_dtype
)
for
mat
in
inputmats
]
inputmats
=
[]
weight_requires_grad
=
weights
[
0
].
requires_grad
weight_requires_grad
=
weights
[
0
].
requires_grad
# Configure quantizers
if
save_original_input
and
isinstance
(
input_quantizers
[
0
],
Float8Quantizer
):
raise
ValueError
(
"DelayedScaling recipe is not supported with save_original_input"
)
if
input_quantizers
[
0
]
is
not
None
:
if
input_quantizers
[
0
]
is
not
None
:
for
input_quantizer
in
input_quantizers
:
for
input_quantizer
in
input_quantizers
:
input_quantizer
.
set_usage
(
input_quantizer
.
set_usage
(
rowwise
=
True
,
rowwise
=
True
,
columnwise
=
(
is_grad_enabled
and
weight_requires_grad
),
columnwise
=
(
is_grad_enabled
and
weight_requires_grad
and
not
save_original_input
),
)
)
columnwise_usage
=
is_grad_enabled
and
inp
.
requires_grad
columnwise_usage
=
is_grad_enabled
and
inp
.
requires_grad
if
not
columnwise_usage
:
if
not
columnwise_usage
:
...
@@ -121,17 +115,25 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -121,17 +115,25 @@ class _GroupedLinear(torch.autograd.Function):
for
output_quantizer
in
output_quantizers
:
for
output_quantizer
in
output_quantizers
:
output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
fprop_gemm_use_split_accumulator
=
_2X_ACC_FPROP
# Initialize input tensors
if
fp8
:
in_features
=
weights
[
0
].
size
(
-
1
)
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
inp
.
size
(
-
1
)
!=
in_features
:
if
hasattr
(
recipe
,
"fp8_gemm_fprop"
):
raise
ValueError
(
fprop_gemm_use_split_accumulator
=
recipe
.
fp8_gemm_fprop
.
use_split_accumulator
f
"Input tensor (shape=
{
tuple
(
inp
.
size
())
}
) is not compatible with "
inputmats
=
tex
.
fused_multi_quantize
(
f
"weight tensor (shape=
{
tuple
(
weights
[
0
].
size
())
}
)"
inputmats_no_fp8
,
None
,
input_quantizers
,
TE_DType
[
activation_dtype
]
)
)
weights_fp8
=
[]
inp_view
=
inp
.
reshape
(
-
1
,
in_features
)
bias_dtype
=
torch
.
bfloat16
if
activation_dtype
==
torch
.
float32
else
activation_dtype
inputmats
:
list
if
fp8
:
inputmats
=
tex
.
split_quantize
(
inp_view
,
m_splits
,
input_quantizers
)
else
:
inputmats
=
torch
.
split
(
cast_if_needed
(
inp_view
,
activation_dtype
),
m_splits
)
# Initialize weights
weights_fp8
:
list
if
fp8
:
# FP8 cast to workspace buffer
# FP8 cast to workspace buffer
weights_fp8
=
[]
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
for
i
in
range
(
num_gemms
):
for
i
in
range
(
num_gemms
):
weight_fp8
=
module
.
get_weight_workspace
(
weight_fp8
=
module
.
get_weight_workspace
(
...
@@ -144,18 +146,29 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -144,18 +146,29 @@ class _GroupedLinear(torch.autograd.Function):
weights_fp8
.
append
(
weight_fp8
)
weights_fp8
.
append
(
weight_fp8
)
else
:
else
:
inputmats
=
inputmats_no_fp8
bias_dtype
=
activation_dtype
weights_fp8
=
[
cast_if_needed
(
weight
,
activation_dtype
)
for
weight
in
weights
]
weights_fp8
=
[
cast_if_needed
(
weight
,
activation_dtype
)
for
weight
in
weights
]
# Initialize biases
bias_dtype
=
activation_dtype
if
fp8
and
activation_dtype
==
torch
.
float32
:
bias_dtype
=
torch
.
bfloat16
# FP8 GEMM only supports BF16/FP16 bias
biases
=
[
cast_if_needed
(
bias
,
bias_dtype
)
for
bias
in
biases
]
if
use_bias
else
biases
biases
=
[
cast_if_needed
(
bias
,
bias_dtype
)
for
bias
in
biases
]
if
use_bias
else
biases
# Initialize output tensor
out
=
torch
.
empty
(
out
=
torch
.
empty
(
[
sum
(
m_splits
),
weights_fp8
[
0
].
size
(
0
)],
[
sum
(
m_splits
),
weights_fp8
[
0
].
size
(
0
)],
dtype
=
activation_dtype
,
dtype
=
activation_dtype
,
device
=
device
,
device
=
device
,
)
)
# Choose whether to use split accumulator
use_split_accumulator
=
_2X_ACC_FPROP
if
fp8
:
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
hasattr
(
recipe
,
"fp8_gemm_fprop"
):
use_split_accumulator
=
recipe
.
fp8_gemm_fprop
.
use_split_accumulator
# Perform GEMM
_
=
general_grouped_gemm
(
_
=
general_grouped_gemm
(
weights_fp8
,
weights_fp8
,
inputmats
,
inputmats
,
...
@@ -166,7 +179,7 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -166,7 +179,7 @@ class _GroupedLinear(torch.autograd.Function):
m_splits
=
m_splits
,
m_splits
=
m_splits
,
bias
=
biases
,
bias
=
biases
,
use_bias
=
use_bias
,
use_bias
=
use_bias
,
use_split_accumulator
=
fprop_gemm_
use_split_accumulator
,
use_split_accumulator
=
use_split_accumulator
,
)
)
if
fp8_calibration
:
if
fp8_calibration
:
...
@@ -183,9 +196,15 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -183,9 +196,15 @@ class _GroupedLinear(torch.autograd.Function):
# TODO: update after #1638 is merged. # pylint: disable=fixme
# TODO: update after #1638 is merged. # pylint: disable=fixme
if
weight_requires_grad
:
if
weight_requires_grad
:
if
save_original_input
:
inputmats
=
[
None
]
*
num_gemms
inputmats
[
0
]
=
inp
else
:
for
inputmat
in
inputmats
:
for
inputmat
in
inputmats
:
if
isinstance
(
inputmat
,
QuantizedTensorBase
):
if
isinstance
(
inputmat
,
QuantizedTensorBase
):
inputmat
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
inputmat
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
else
:
inputmats
=
[
None
]
*
num_gemms
if
inp
.
requires_grad
:
if
inp
.
requires_grad
:
for
weight
in
weights_fp8
:
for
weight
in
weights_fp8
:
if
isinstance
(
weight
,
QuantizedTensorBase
):
if
isinstance
(
weight
,
QuantizedTensorBase
):
...
@@ -202,9 +221,18 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -202,9 +221,18 @@ class _GroupedLinear(torch.autograd.Function):
ctx
.
weights_requires_grad
=
weights
[
0
].
requires_grad
ctx
.
weights_requires_grad
=
weights
[
0
].
requires_grad
if
fuse_wgrad_accumulation
and
ctx
.
weights_requires_grad
:
if
fuse_wgrad_accumulation
and
ctx
.
weights_requires_grad
:
ctx
.
main_grads
=
[
weights
[
i
].
main_grad
for
i
in
range
(
num_gemms
)]
# This check is needed to ensure that main_grad is not created
# during the forward pass when using MCore FSDP as it creates
# the main_grad buffer lazily before backprop
if
hasattr
(
weights
[
0
],
"__fsdp_param__"
):
# MCore FSDP creates main_grad lazily before backward
ctx
.
main_grad_funcs
=
[
weights
[
i
].
get_main_grad
for
i
in
range
(
num_gemms
)]
else
:
ctx
.
main_grad_funcs
=
[
lambda
j
=
i
:
weights
[
j
].
main_grad
for
i
in
range
(
num_gemms
)
]
else
:
else
:
ctx
.
main_grad
s
=
[
None
]
*
num_gemms
ctx
.
main_grad
_funcs
=
[
lambda
:
None
for
i
in
range
(
num_gemms
)]
ctx
.
device
=
device
ctx
.
device
=
device
ctx
.
grad_output_quantizers
=
grad_output_quantizers
ctx
.
grad_output_quantizers
=
grad_output_quantizers
ctx
.
m_splits
=
m_splits
ctx
.
m_splits
=
m_splits
...
@@ -226,6 +254,8 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -226,6 +254,8 @@ class _GroupedLinear(torch.autograd.Function):
or
FP8GlobalStateManager
.
is_first_fp8_module
()
or
FP8GlobalStateManager
.
is_first_fp8_module
()
)
)
ctx
.
wgrad_store
=
wgrad_store
ctx
.
wgrad_store
=
wgrad_store
ctx
.
save_original_input
=
save_original_input
ctx
.
input_quantizers
=
input_quantizers
# [*, in_features] -> [*, out_features] except first dimension changes for SP
# [*, in_features] -> [*, out_features] except first dimension changes for SP
return
out
.
view
(
-
1
,
*
inp
.
shape
[
1
:
-
1
],
out
.
shape
[
-
1
])
return
out
.
view
(
-
1
,
*
inp
.
shape
[
1
:
-
1
],
out
.
shape
[
-
1
])
...
@@ -240,7 +270,7 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -240,7 +270,7 @@ class _GroupedLinear(torch.autograd.Function):
weights
=
saved_tensors
[
N
:
2
*
N
]
weights
=
saved_tensors
[
N
:
2
*
N
]
origin_weights
=
saved_tensors
[
2
*
N
:
3
*
N
]
origin_weights
=
saved_tensors
[
2
*
N
:
3
*
N
]
biases
=
saved_tensors
[
3
*
N
:
4
*
N
]
biases
=
saved_tensors
[
3
*
N
:
4
*
N
]
main_grads
=
ctx
.
main_grad
s
main_grads
=
[
main_grad_func
()
for
main_grad_func
in
ctx
.
main_grad
_funcs
]
if
ctx
.
cpu_offloading
and
ctx
.
fuse_wgrad_accumulation
:
if
ctx
.
cpu_offloading
and
ctx
.
fuse_wgrad_accumulation
:
for
i
in
range
(
ctx
.
num_gemms
):
for
i
in
range
(
ctx
.
num_gemms
):
...
@@ -248,36 +278,44 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -248,36 +278,44 @@ class _GroupedLinear(torch.autograd.Function):
w
.
main_grad
=
main_grads
[
i
]
w
.
main_grad
=
main_grads
[
i
]
weights
[
i
]
=
w
weights
[
i
]
=
w
# preprocess grad_output
# Preprocess grad output
grad_output_view
=
grad_output
.
contiguous
().
view
(
-
1
,
grad_output
.
shape
[
-
1
])
grad_output
=
grad_output
.
contiguous
()
grad_output_mats
=
torch
.
split
(
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
-
1
]),
ctx
.
m_splits
)
grad_output
=
[
None
]
*
ctx
.
num_gemms
grad_output
=
[
None
]
*
ctx
.
num_gemms
grad_biases
=
[
None
]
*
ctx
.
num_gemms
grad_biases
=
[
None
]
*
ctx
.
num_gemms
if
ctx
.
fp8
:
if
ctx
.
fp8
:
if
ctx
.
use_bias
:
if
ctx
.
use_bias
:
# unfuse bgrad for now until cast_transpose + dgrad calculation is ready
grad_output_mats
=
torch
.
split
(
grad_output_view
,
ctx
.
m_splits
)
# for Float8BlockQuantizer.
recipe
=
ctx
.
fp8_recipe
if
ctx
.
fp8_recipe
.
float8_block_scaling
():
if
recipe
.
delayed
()
or
recipe
.
float8_current_scaling
()
or
recipe
.
mxfp8
():
# Fused bias grad + quantize kernel
for
i
in
range
(
ctx
.
num_gemms
):
for
i
in
range
(
ctx
.
num_gemms
):
grad_biases
[
i
]
=
grad_output_mats
[
i
].
sum
(
dim
=
0
)
grad_biases
[
i
],
grad_output
[
i
]
=
tex
.
bgrad_quantize
(
grad_output
[
i
]
=
ctx
.
grad_output_quantizers
[
i
](
grad_output_mats
[
i
])
grad_output_mats
[
i
],
ctx
.
grad_output_quantizers
[
i
],
)
else
:
else
:
# Unfused bias grad and multi-tensor quantize
for
i
in
range
(
ctx
.
num_gemms
):
for
i
in
range
(
ctx
.
num_gemms
):
grad_biases
[
i
],
grad_output
[
i
]
=
tex
.
bgrad_quantize
(
grad_biases
[
i
]
=
grad_output_mats
[
i
].
sum
(
dim
=
0
)
grad_output_mats
[
i
],
ctx
.
grad_output_quantizers
[
i
]
grad_output
=
tex
.
split_quantize
(
grad_output_view
,
ctx
.
m_splits
,
ctx
.
grad_output_quantizers
,
)
)
else
:
else
:
grad_output
=
tex
.
fused_multi_quantize
(
# Multi-tensor quantize
grad_output_mats
,
grad_output
=
tex
.
split_quantize
(
None
,
grad_output_view
,
ctx
.
m_splits
,
ctx
.
grad_output_quantizers
,
ctx
.
grad_output_quantizers
,
TE_DType
[
ctx
.
activation_dtype
],
)
)
else
:
else
:
grad_output
=
grad_output_mats
# Only split grad output. Grad bias is fused with
# wgrad GEMM.
grad_output
=
torch
.
split
(
cast_if_needed
(
grad_output_view
,
ctx
.
activation_dtype
),
ctx
.
m_splits
,
)
if
ctx
.
is_first_microbatch
is
not
None
:
if
ctx
.
is_first_microbatch
is
not
None
:
accumulate_wgrad_into_param_main_grad
=
(
accumulate_wgrad_into_param_main_grad
=
(
...
@@ -334,6 +372,27 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -334,6 +372,27 @@ class _GroupedLinear(torch.autograd.Function):
torch
.
empty
(
w
.
size
(),
dtype
=
ctx
.
activation_dtype
,
device
=
ctx
.
device
)
torch
.
empty
(
w
.
size
(),
dtype
=
ctx
.
activation_dtype
,
device
=
ctx
.
device
)
for
w
in
weights
for
w
in
weights
]
]
if
ctx
.
save_original_input
:
inp
=
inputmats
[
0
]
in_features
=
inp
.
shape
[
-
1
]
inp_view
=
inp
.
reshape
(
-
1
,
in_features
)
if
ctx
.
input_quantizers
[
0
]
is
not
None
:
for
input_quantizer
in
ctx
.
input_quantizers
:
if
isinstance
(
input_quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)
):
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
else
:
input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
inputmats
:
list
if
ctx
.
fp8
:
inputmats
=
tex
.
split_quantize
(
inp_view
,
ctx
.
m_splits
,
ctx
.
input_quantizers
)
else
:
inputmats
=
torch
.
split
(
cast_if_needed
(
inp_view
,
ctx
.
activation_dtype
),
ctx
.
m_splits
)
grouped_gemm_wgrad
=
functools
.
partial
(
grouped_gemm_wgrad
=
functools
.
partial
(
general_grouped_gemm
,
general_grouped_gemm
,
out_dtype
=
ctx
.
activation_dtype
,
out_dtype
=
ctx
.
activation_dtype
,
...
@@ -429,6 +488,7 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -429,6 +488,7 @@ class _GroupedLinear(torch.autograd.Function):
None
,
None
,
None
,
None
,
None
,
None
,
None
,
*
wgrad_list
,
*
wgrad_list
,
*
grad_biases
,
*
grad_biases
,
)
)
...
@@ -479,6 +539,11 @@ class GroupedLinear(TransformerEngineBaseModule):
...
@@ -479,6 +539,11 @@ class GroupedLinear(TransformerEngineBaseModule):
would not fit in GPU memory.
would not fit in GPU memory.
delay_wgrad_compute : bool, default = `False`
delay_wgrad_compute : bool, default = `False`
Whether to delay weight gradient computation
Whether to delay weight gradient computation
save_original_input : bool, default = `False`
If set to `True`, always saves the original input tensor rather than the
cast tensor. In some scenarios, the input tensor is used by multiple modules,
and saving the original input tensor may reduce the memory usage.
Cannot work with FP8 DelayedScaling recipe.
Note: GroupedLinear doesn't really handle the TP communications inside. The `tp_size` and
Note: GroupedLinear doesn't really handle the TP communications inside. The `tp_size` and
`parallel_mode` are used to determine the shapes of weights and biases.
`parallel_mode` are used to determine the shapes of weights and biases.
...
@@ -506,6 +571,7 @@ class GroupedLinear(TransformerEngineBaseModule):
...
@@ -506,6 +571,7 @@ class GroupedLinear(TransformerEngineBaseModule):
ub_overlap_ag
:
bool
=
False
,
ub_overlap_ag
:
bool
=
False
,
ub_name
:
Optional
[
str
]
=
None
,
ub_name
:
Optional
[
str
]
=
None
,
delay_wgrad_compute
:
bool
=
False
,
delay_wgrad_compute
:
bool
=
False
,
save_original_input
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -520,6 +586,7 @@ class GroupedLinear(TransformerEngineBaseModule):
...
@@ -520,6 +586,7 @@ class GroupedLinear(TransformerEngineBaseModule):
self
.
ub_overlap_rs
=
ub_overlap_rs
self
.
ub_overlap_rs
=
ub_overlap_rs
self
.
ub_overlap_ag
=
ub_overlap_ag
self
.
ub_overlap_ag
=
ub_overlap_ag
self
.
ub_name
=
ub_name
self
.
ub_name
=
ub_name
self
.
save_original_input
=
save_original_input
assert
(
assert
(
not
ub_overlap_rs
and
not
ub_overlap_ag
not
ub_overlap_rs
and
not
ub_overlap_ag
),
"GroupedLinear doesn't support Userbuffer overlap."
),
"GroupedLinear doesn't support Userbuffer overlap."
...
@@ -735,6 +802,7 @@ class GroupedLinear(TransformerEngineBaseModule):
...
@@ -735,6 +802,7 @@ class GroupedLinear(TransformerEngineBaseModule):
torch
.
is_grad_enabled
(),
torch
.
is_grad_enabled
(),
self
,
self
,
skip_fp8_weight_update
,
skip_fp8_weight_update
,
self
.
save_original_input
,
*
weight_tensors
,
*
weight_tensors
,
*
bias_tensors
,
*
bias_tensors
,
)
)
...
...
transformer_engine/pytorch/module/layernorm_linear.py
View file @
44740c6c
...
@@ -68,6 +68,7 @@ from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantize
...
@@ -68,6 +68,7 @@ from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantize
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor._internal.mxfp8_tensor_base
import
MXFP8TensorBase
from
..tensor._internal.mxfp8_tensor_base
import
MXFP8TensorBase
from
..tensor._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
..tensor._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
..export
import
is_in_onnx_export_mode
,
assert_warmed_up
from
..cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
..cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
..cpp_extensions
import
(
from
..cpp_extensions
import
(
...
@@ -454,7 +455,14 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -454,7 +455,14 @@ class _LayerNormLinear(torch.autograd.Function):
ctx
.
requires_wgrad
=
weight
.
requires_grad
ctx
.
requires_wgrad
=
weight
.
requires_grad
ctx
.
quantized_weight
=
quantized_weight
ctx
.
quantized_weight
=
quantized_weight
if
fuse_wgrad_accumulation
and
weight
.
requires_grad
:
if
fuse_wgrad_accumulation
and
weight
.
requires_grad
:
ctx
.
main_grad
=
weight
.
main_grad
# This check is needed to ensure that main_grad is not created
# during the forward pass when using MCore FSDP as it creates
# the main_grad buffer lazily before backprop
if
hasattr
(
weight
,
"__fsdp_param__"
):
# MCore FSDP creates main_grad lazily before backward
ctx
.
main_grad_func
=
weight
.
get_main_grad
else
:
ctx
.
main_grad_func
=
lambda
:
weight
.
main_grad
ctx
.
grad_input_quantizer
=
grad_input_quantizer
ctx
.
grad_input_quantizer
=
grad_input_quantizer
ctx
.
grad_weight_quantizer
=
grad_weight_quantizer
ctx
.
grad_weight_quantizer
=
grad_weight_quantizer
ctx
.
grad_output_quantizer
=
grad_output_quantizer
ctx
.
grad_output_quantizer
=
grad_output_quantizer
...
@@ -500,7 +508,7 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -500,7 +508,7 @@ class _LayerNormLinear(torch.autograd.Function):
if
return_layernorm_output
:
if
return_layernorm_output
:
if
return_layernorm_output_gathered
:
if
return_layernorm_output_gathered
:
shape
=
list
(
inp
.
shape
)
shape
=
list
(
inp
_
shape
)
shape
[
0
]
*=
tp_size
if
with_input_all_gather
else
1
shape
[
0
]
*=
tp_size
if
with_input_all_gather
else
1
return
out
,
ln_out_return
.
view
(
shape
)
return
out
,
ln_out_return
.
view
(
shape
)
return
out
,
ln_out_return
.
view
(
inp_shape
)
return
out
,
ln_out_return
.
view
(
inp_shape
)
...
@@ -535,7 +543,7 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -535,7 +543,7 @@ class _LayerNormLinear(torch.autograd.Function):
# Since main_grad can be modified inplace, it should not be a part of saved_tensors
# Since main_grad can be modified inplace, it should not be a part of saved_tensors
main_grad
=
(
main_grad
=
(
ctx
.
main_grad
ctx
.
main_grad
_func
()
if
weight
is
not
None
and
ctx
.
fuse_wgrad_accumulation
and
ctx
.
requires_wgrad
if
weight
is
not
None
and
ctx
.
fuse_wgrad_accumulation
and
ctx
.
requires_wgrad
else
None
else
None
)
)
...
@@ -1470,6 +1478,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1470,6 +1478,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
first microbatch (since it is the first gradient being
produced)
produced)
"""
"""
if
is_in_onnx_export_mode
():
return
self
.
onnx_forward
(
inp
,
fp8_output
)
debug
=
TEDebugState
.
debug_enabled
debug
=
TEDebugState
.
debug_enabled
if
debug
:
if
debug
:
self
.
_validate_name
()
self
.
_validate_name
()
...
@@ -1493,12 +1503,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1493,12 +1503,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
)
as
inp
:
)
as
inp
:
# Get concatenated weight and bias tensors
# Get concatenated weight and bias tensors
unfused_weights
=
self
.
_get_weight_tensors
()
weight_tensor
,
bias_tensor
=
self
.
_get_weight_and_bias_tensors
()
weight_tensor
=
noop_cat
(
unfused_weights
)
if
self
.
use_bias
:
bias_tensor
=
noop_cat
([
getattr
(
self
,
name
)
for
name
in
self
.
bias_names
])
else
:
bias_tensor
=
getattr
(
self
,
self
.
bias_names
[
0
])
# Unused
quantizers
=
(
quantizers
=
(
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
)
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
)
...
@@ -1628,6 +1633,72 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1628,6 +1633,72 @@ class LayerNormLinear(TransformerEngineBaseModule):
for
name
,
q
in
zip
(
names
,
original_quantizers
)
for
name
,
q
in
zip
(
names
,
original_quantizers
)
)
)
def
_get_weight_and_bias_tensors
(
self
):
# Get concatenated weight and bias tensors
unfused_weights
=
self
.
_get_weight_tensors
()
weight_tensor
=
noop_cat
(
unfused_weights
)
if
self
.
use_bias
:
bias_tensor
=
noop_cat
([
getattr
(
self
,
name
)
for
name
in
self
.
bias_names
])
else
:
bias_tensor
=
getattr
(
self
,
self
.
bias_names
[
0
])
# Unused
return
weight_tensor
,
bias_tensor
def
onnx_forward
(
self
,
inp
:
torch
.
Tensor
,
fp8_output
:
bool
,
)
->
torch
.
Tensor
:
"""
ONNX-compatible version of the forward function that provides numerical equivalence
while only using operations that have defined ONNX symbolic translations.
This simplified implementation is designed specifically for inference scenarios.
"""
from
..export
import
onnx_layernorm
,
onnx_gemm
assert
not
TEDebugState
.
debug_enabled
,
"Debug mode is not supported in ONNX export"
assert_warmed_up
(
self
)
(
input_quantizer
,
weight_quantizer
,
output_quantizer
,
*
_
,
)
=
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
=
False
)
inp_dtype
=
inp
.
dtype
weight_tensor
,
bias_tensor
=
self
.
_get_weight_and_bias_tensors
()
ln_out
,
ln_out_return
=
onnx_layernorm
(
inp
,
self
.
layer_norm_weight
,
self
.
layer_norm_bias
,
self
.
eps
,
self
.
normalization
,
self
.
zero_centered_gamma
,
inp_dtype
,
self
.
return_layernorm_output
,
input_quantizer
,
)
if
weight_quantizer
is
not
None
:
weight_tensor_quantized
=
weight_quantizer
.
onnx_quantize
(
weight_tensor
)
weight_tensor
=
weight_quantizer
.
onnx_dequantize
(
weight_tensor_quantized
)
weight_tensor
=
weight_tensor
.
to
(
inp_dtype
)
if
bias_tensor
is
not
None
:
bias_tensor
=
bias_tensor
.
to
(
inp_dtype
)
output
=
onnx_gemm
(
weight_tensor
,
ln_out
,
bias_tensor
if
self
.
apply_bias
else
None
)
if
output_quantizer
is
not
None
:
raise
NotImplementedError
(
"ONNX export of quantized output is not supported"
)
if
self
.
return_layernorm_output
and
self
.
return_bias
:
return
output
,
bias_tensor
.
to
(
inp_dtype
),
ln_out_return
if
self
.
return_layernorm_output
:
return
output
,
ln_out_return
if
self
.
return_bias
:
return
output
,
bias_tensor
.
to
(
inp_dtype
)
return
output
def
_customize_quantizers_float8_current_scaling
(
self
,
fwd
:
bool
,
recipe
:
Recipe
)
->
None
:
def
_customize_quantizers_float8_current_scaling
(
self
,
fwd
:
bool
,
recipe
:
Recipe
)
->
None
:
"""Customize quantizers based on current scaling recipe + layernorm_linear."""
"""Customize quantizers based on current scaling recipe + layernorm_linear."""
assert
(
assert
(
...
...
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment