Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
9df0c4a3
Commit
9df0c4a3
authored
Mar 03, 2026
by
wenjh
Browse files
Merge branch 'nv_main'
parents
0d874a4e
f122b07d
Changes
221
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1700 additions
and
381 deletions
+1700
-381
transformer_engine/pytorch/csrc/extensions.h
transformer_engine/pytorch/csrc/extensions.h
+16
-8
transformer_engine/pytorch/csrc/extensions/activation.cpp
transformer_engine/pytorch/csrc/extensions/activation.cpp
+8
-0
transformer_engine/pytorch/csrc/extensions/attention.cpp
transformer_engine/pytorch/csrc/extensions/attention.cpp
+28
-27
transformer_engine/pytorch/csrc/extensions/pybind.cpp
transformer_engine/pytorch/csrc/extensions/pybind.cpp
+14
-4
transformer_engine/pytorch/distributed.py
transformer_engine/pytorch/distributed.py
+4
-4
transformer_engine/pytorch/graph.py
transformer_engine/pytorch/graph.py
+28
-11
transformer_engine/pytorch/jit.py
transformer_engine/pytorch/jit.py
+26
-8
transformer_engine/pytorch/module/_common.py
transformer_engine/pytorch/module/_common.py
+17
-8
transformer_engine/pytorch/module/base.py
transformer_engine/pytorch/module/base.py
+109
-91
transformer_engine/pytorch/module/grouped_linear.py
transformer_engine/pytorch/module/grouped_linear.py
+91
-16
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+8
-5
transformer_engine/pytorch/module/layernorm_mlp.py
transformer_engine/pytorch/module/layernorm_mlp.py
+23
-9
transformer_engine/pytorch/module/linear.py
transformer_engine/pytorch/module/linear.py
+6
-8
transformer_engine/pytorch/ops/__init__.py
transformer_engine/pytorch/ops/__init__.py
+6
-4
transformer_engine/pytorch/ops/basic/__init__.py
transformer_engine/pytorch/ops/basic/__init__.py
+3
-2
transformer_engine/pytorch/ops/basic/activation.py
transformer_engine/pytorch/ops/basic/activation.py
+33
-75
transformer_engine/pytorch/ops/basic/grouped_linear.py
transformer_engine/pytorch/ops/basic/grouped_linear.py
+702
-0
transformer_engine/pytorch/ops/basic/swiglu.py
transformer_engine/pytorch/ops/basic/swiglu.py
+498
-0
transformer_engine/pytorch/ops/fused/__init__.py
transformer_engine/pytorch/ops/fused/__init__.py
+24
-36
transformer_engine/pytorch/ops/fused/backward_activation_bias.py
...rmer_engine/pytorch/ops/fused/backward_activation_bias.py
+56
-65
No files found.
transformer_engine/pytorch/csrc/extensions.h
View file @
9df0c4a3
...
...
@@ -81,15 +81,16 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
float
p_dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
return_max_logit
,
bool
cuda_graph
);
int64_t
window_size_right
,
bool
return_max_logit
,
bool
cuda_graph
,
bool
deterministic
);
std
::
vector
<
py
::
object
>
fused_attn_fwd
(
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
float
attn_scale
,
float
p_dropout
,
bool
set_zero
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
const
std
::
vector
<
int64_t
>
window_size
,
const
at
::
Tensor
cu_seqlens_q
,
const
at
::
Tensor
cu_seqlens_kv
,
const
py
::
handle
Q
,
const
py
::
handle
K
,
const
py
::
handle
V
,
const
at
::
ScalarType
fake_dtype
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
std
::
vector
<
int64_t
>
window_size
,
bool
bottom_right_diagonal
,
const
at
::
Tensor
cu_seqlens_q
,
const
at
::
Tensor
cu_seqlens_kv
,
const
py
::
handle
Q
,
const
py
::
handle
K
,
const
py
::
handle
V
,
const
at
::
ScalarType
fake_dtype
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
const
std
::
optional
<
at
::
Tensor
>
page_table_k
,
const
std
::
optional
<
at
::
Tensor
>
page_table_v
,
py
::
handle
s_quantizer
,
py
::
handle
o_quantizer
,
const
std
::
optional
<
at
::
Tensor
>
Bias
,
...
...
@@ -99,10 +100,10 @@ std::vector<py::object> fused_attn_fwd(
std
::
vector
<
py
::
object
>
fused_attn_bwd
(
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
float
attn_scale
,
float
p_dropout
,
bool
set_zero
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
const
std
::
vector
<
int64_t
>
window_size
,
bool
deterministic
,
const
at
::
Tensor
cu_seqlens_q
,
const
at
::
Tensor
cu_seqlens_
kv
,
const
py
::
handle
Q
,
const
py
::
handle
K
,
const
py
::
handle
V
,
const
py
::
handle
O
,
const
py
::
handle
dO
,
const
at
::
ScalarType
fake_dtype
,
const
DType
dqkv_type
,
NVTE_Softmax_Type
softmax_type
,
const
std
::
vector
<
int64_t
>
window_size
,
bool
bottom_right_diagonal
,
bool
deterministic
,
const
at
::
Tensor
cu_seqlens_
q
,
const
at
::
Tensor
cu_seqlens_kv
,
const
py
::
handle
Q
,
const
py
::
handle
K
,
const
py
::
handle
V
,
const
py
::
handle
O
,
const
py
::
handle
dO
,
const
at
::
ScalarType
fake_dtype
,
const
DType
dqkv_type
,
const
std
::
vector
<
at
::
Tensor
>
Aux_CTX_Tensors
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
py
::
handle
s_quantizer
,
...
...
@@ -198,6 +199,11 @@ at::Tensor swap_first_dims(at::Tensor tensor, std::optional<at::Tensor> out = st
* Activations
**************************************************************************************************/
/* GLU (sigmoid gate) */
py
::
object
glu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
py
::
object
dglu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
/* GELU and variants*/
py
::
object
gelu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
...
...
@@ -585,6 +591,7 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve
~
CommOverlap
()
{}
using
transformer_engine
::
CommOverlapCore
::
copy_into_buffer
;
void
copy_into_buffer
(
const
at
::
Tensor
&
input
,
bool
local_chunk
=
false
);
at
::
Tensor
get_buffer
(
bool
local_chunk
=
false
,
...
...
@@ -606,6 +613,7 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm
~
CommOverlapP2P
()
{}
using
transformer_engine
::
CommOverlapP2PBase
::
copy_into_buffer
;
void
copy_into_buffer
(
const
at
::
Tensor
&
input
,
bool
local_chunk
=
false
);
at
::
Tensor
get_buffer
(
bool
local_chunk
=
false
,
...
...
transformer_engine/pytorch/csrc/extensions/activation.cpp
View file @
9df0c4a3
...
...
@@ -246,6 +246,14 @@ py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle qua
return
dactivation_helper
<
nvte_dgelu
,
nullptr
>
(
grad
,
input
,
quantizer
);
}
py
::
object
glu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
activation_helper
<
nvte_glu
,
nullptr
>
(
input
,
quantizer
,
2
);
}
py
::
object
dglu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
dactivation_helper
<
nvte_dglu
,
nullptr
>
(
grad
,
input
,
quantizer
);
}
py
::
object
geglu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
activation_helper
<
nvte_geglu
,
nullptr
>
(
input
,
quantizer
,
2
);
}
...
...
transformer_engine/pytorch/csrc/extensions/attention.cpp
View file @
9df0c4a3
...
...
@@ -45,7 +45,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
float
p_dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
return_max_logit
,
bool
cuda_graph
)
{
int64_t
window_size_right
,
bool
return_max_logit
,
bool
cuda_graph
,
bool
deterministic
)
{
#ifdef __HIP_PLATFORM_AMD__
return
NVTE_Fused_Attn_Backend
::
NVTE_No_Backend
;
#else
...
...
@@ -53,7 +53,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
is_training
,
static_cast
<
NVTEDType
>
(
q_dtype
),
static_cast
<
NVTEDType
>
(
kv_dtype
),
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
p_dropout
,
num_attn_heads
,
num_gqa_groups
,
max_seqlen_q
,
max_seqlen_kv
,
head_dim_qk
,
head_dim_v
,
window_size_left
,
window_size_right
,
return_max_logit
,
cuda_graph
);
return_max_logit
,
cuda_graph
,
deterministic
);
return
fused_attention_backend
;
#endif
}
...
...
@@ -104,9 +104,10 @@ std::vector<py::object> fused_attn_fwd(
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
float
attn_scale
,
float
p_dropout
,
bool
set_zero
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
const
std
::
vector
<
int64_t
>
window_size
,
const
at
::
Tensor
cu_seqlens_q
,
const
at
::
Tensor
cu_seqlens_kv
,
const
py
::
handle
Q
,
const
py
::
handle
K
,
const
py
::
handle
V
,
const
at
::
ScalarType
fake_dtype
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
std
::
vector
<
int64_t
>
window_size
,
bool
bottom_right_diagonal
,
const
at
::
Tensor
cu_seqlens_q
,
const
at
::
Tensor
cu_seqlens_kv
,
const
py
::
handle
Q
,
const
py
::
handle
K
,
const
py
::
handle
V
,
const
at
::
ScalarType
fake_dtype
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
const
std
::
optional
<
at
::
Tensor
>
page_table_k
,
const
std
::
optional
<
at
::
Tensor
>
page_table_v
,
py
::
handle
s_quantizer
,
py
::
handle
o_quantizer
,
const
std
::
optional
<
at
::
Tensor
>
Bias
,
...
...
@@ -242,7 +243,7 @@ std::vector<py::object> fused_attn_fwd(
te_cu_seqlens_q_padded
.
data
(),
te_cu_seqlens_kv_padded
.
data
(),
te_page_table_k
.
data
(),
te_page_table_v
.
data
(),
te_rng_state
.
data
(),
max_seqlen_q
,
max_seqlen_kv
,
is_training
,
return_max_logit
,
cuda_graph
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size
[
0
],
window_size
[
1
],
workspace
.
data
(),
softmax_type
,
window_size
[
0
],
window_size
[
1
],
bottom_right_diagonal
,
workspace
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
...
...
@@ -302,7 +303,7 @@ std::vector<py::object> fused_attn_fwd(
te_cu_seqlens_q_padded
.
data
(),
te_cu_seqlens_kv_padded
.
data
(),
te_page_table_k
.
data
(),
te_page_table_v
.
data
(),
te_rng_state
.
data
(),
max_seqlen_q
,
max_seqlen_kv
,
is_training
,
return_max_logit
,
cuda_graph
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size
[
0
],
window_size
[
1
],
workspace
.
data
(),
softmax_type
,
window_size
[
0
],
window_size
[
1
],
bottom_right_diagonal
,
workspace
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
...
...
@@ -318,10 +319,10 @@ std::vector<py::object> fused_attn_fwd(
std
::
vector
<
py
::
object
>
fused_attn_bwd
(
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
float
attn_scale
,
float
p_dropout
,
bool
set_zero
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
const
std
::
vector
<
int64_t
>
window_size
,
bool
deterministic
,
const
at
::
Tensor
cu_seqlens_q
,
const
at
::
Tensor
cu_seqlens_
kv
,
const
py
::
handle
Q
,
const
py
::
handle
K
,
const
py
::
handle
V
,
const
py
::
handle
O
,
const
py
::
handle
dO
,
const
at
::
ScalarType
fake_dtype
,
const
DType
dqkv_type
,
NVTE_Softmax_Type
softmax_type
,
const
std
::
vector
<
int64_t
>
window_size
,
bool
bottom_right_diagonal
,
bool
deterministic
,
const
at
::
Tensor
cu_seqlens_
q
,
const
at
::
Tensor
cu_seqlens_kv
,
const
py
::
handle
Q
,
const
py
::
handle
K
,
const
py
::
handle
V
,
const
py
::
handle
O
,
const
py
::
handle
dO
,
const
at
::
ScalarType
fake_dtype
,
const
DType
dqkv_type
,
const
std
::
vector
<
at
::
Tensor
>
Aux_CTX_Tensors
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
py
::
handle
s_quantizer
,
...
...
@@ -543,14 +544,14 @@ std::vector<py::object> fused_attn_bwd(
// populate tensors with appropriate shapes and dtypes
NVTE_SCOPED_GIL_RELEASE
({
nvte_fused_attn_bwd
(
te_Q
.
data
(),
te_K
.
data
(),
te_V
.
data
(),
te_O
.
data
(),
te_dO
.
data
(),
te_
S
.
data
(),
te_d
P
.
data
(),
&
nvte_aux_tensor_pack
,
te_d
Q
.
data
(),
te_dK
.
data
(),
te_dV
.
data
(),
te_dBias
.
data
(),
te_dSoftmaxOffset
.
data
(),
te_cu_seqlens_q
.
data
(),
te_cu_seqlens_kv
.
data
(),
te_cu_seqlens_q_padded
.
data
(),
te_cu_seqlens_kv_padded
.
data
(),
max_seqlen_q
,
max_seqlen_kv
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size
[
0
],
window_size
[
1
]
,
deterministic
,
cuda_graph
,
workspace
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
nvte_fused_attn_bwd
(
te_Q
.
data
(),
te_K
.
data
(),
te_V
.
data
(),
te_
O
.
data
(),
te_d
O
.
data
(),
te_S
.
data
()
,
te_d
P
.
data
(),
&
nvte_aux_tensor_pack
,
te_dQ
.
data
(),
te_dK
.
data
(),
te_dV
.
data
(),
te_dBias
.
data
(),
te_dSoftmaxOffset
.
data
(),
te_cu_seqlens_q
.
data
(),
te_cu_seqlens_kv
.
data
(),
te_cu_seqlens_q_padded
.
data
(),
te_cu_seqlens_kv_padded
.
data
(),
max_seqlen_q
,
max_seqlen_kv
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size
[
0
],
window_size
[
1
],
bottom_right_diagonal
,
deterministic
,
cuda_graph
,
workspace
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
// allocate memory for workspace
...
...
@@ -560,14 +561,14 @@ std::vector<py::object> fused_attn_bwd(
// execute kernel
NVTE_SCOPED_GIL_RELEASE
({
nvte_fused_attn_bwd
(
te_Q
.
data
(),
te_K
.
data
(),
te_V
.
data
(),
te_O
.
data
(),
te_dO
.
data
(),
te_
S
.
data
(),
te_d
P
.
data
(),
&
nvte_aux_tensor_pack
,
te_d
Q
.
data
(),
te_dK
.
data
(),
te_dV
.
data
(),
te_dBias
.
data
(),
te_dSoftmaxOffset
.
data
(),
te_cu_seqlens_q
.
data
(),
te_cu_seqlens_kv
.
data
(),
te_cu_seqlens_q_padded
.
data
(),
te_cu_seqlens_kv_padded
.
data
(),
max_seqlen_q
,
max_seqlen_kv
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size
[
0
],
window_size
[
1
]
,
deterministic
,
cuda_graph
,
workspace
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
nvte_fused_attn_bwd
(
te_Q
.
data
(),
te_K
.
data
(),
te_V
.
data
(),
te_
O
.
data
(),
te_d
O
.
data
(),
te_S
.
data
()
,
te_d
P
.
data
(),
&
nvte_aux_tensor_pack
,
te_dQ
.
data
(),
te_dK
.
data
(),
te_dV
.
data
(),
te_dBias
.
data
(),
te_dSoftmaxOffset
.
data
(),
te_cu_seqlens_q
.
data
(),
te_cu_seqlens_kv
.
data
(),
te_cu_seqlens_q_padded
.
data
(),
te_cu_seqlens_kv_padded
.
data
(),
max_seqlen_q
,
max_seqlen_kv
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size
[
0
],
window_size
[
1
],
bottom_right_diagonal
,
deterministic
,
cuda_graph
,
workspace
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
// destroy tensor wrappers
...
...
transformer_engine/pytorch/csrc/extensions/pybind.cpp
View file @
9df0c4a3
...
...
@@ -132,6 +132,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"comm_overlap"
)
=
nullptr
,
py
::
arg
(
"comm_type"
)
=
std
::
nullopt
,
py
::
arg
(
"extra_output"
)
=
std
::
nullopt
,
py
::
arg
(
"bulk_overlap"
)
=
false
,
py
::
arg
(
"alpha"
)
=
1.0
f
,
py
::
arg
(
"beta"
)
=
std
::
nullopt
);
/* GLU (sigmoid gate) */
m
.
def
(
"glu"
,
transformer_engine
::
pytorch
::
glu
,
"GLU activation"
,
py
::
arg
(
"input"
),
py
::
arg
(
"quantizer"
));
/* GELU and variants*/
m
.
def
(
"gelu"
,
transformer_engine
::
pytorch
::
gelu
,
"GeLU activation"
,
py
::
arg
(
"input"
),
py
::
arg
(
"quantizer"
));
...
...
@@ -158,6 +161,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"clamped_swiglu"
,
transformer_engine
::
pytorch
::
clamped_swiglu
,
"SwiGLU activation used in GPT OSS"
,
py
::
arg
(
"input"
),
py
::
arg
(
"quantizer"
),
py
::
arg
(
"limit"
)
=
7.0
f
,
py
::
arg
(
"alpha"
)
=
1.702
f
);
/* Backward of GLU */
m
.
def
(
"dglu"
,
transformer_engine
::
pytorch
::
dglu
,
"Backward of GLU"
,
py
::
arg
(
"grad"
),
py
::
arg
(
"fwd_input"
),
py
::
arg
(
"quantizer"
));
/* Backward of GELU and variants */
m
.
def
(
"dgelu"
,
transformer_engine
::
pytorch
::
dgelu
,
"Backward of GeLU"
,
py
::
arg
(
"grad"
),
py
::
arg
(
"fwd_input"
),
py
::
arg
(
"quantizer"
));
...
...
@@ -515,8 +521,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"comm_cga_size"
)
=
2
,
py
::
arg
(
"gemm_priority"
)
=
0
,
py
::
arg
(
"comm_priority"
)
=
0
,
py
::
arg
(
"num_comm_sm"
)
=
16
,
py
::
arg
(
"set_sm_margin"
)
=
true
,
py
::
arg
(
"atomic_gemm"
)
=
false
,
py
::
arg
(
"rs_overlap_first_gemm"
)
=
false
)
.
def
(
"copy_into_buffer"
,
&
CommOverlap
::
copy_into_buffer
,
py
::
arg
(
"input"
),
py
::
arg
(
"local_chunk"
)
=
false
)
.
def
(
"copy_into_buffer"
,
static_cast
<
void
(
CommOverlap
::*
)(
const
at
::
Tensor
&
,
bool
)
>
(
&
CommOverlap
::
copy_into_buffer
),
py
::
arg
(
"input"
),
py
::
arg
(
"local_chunk"
)
=
false
)
.
def
(
"get_buffer"
,
&
CommOverlap
::
get_buffer
,
py
::
arg
(
"local_chunk"
)
=
false
,
py
::
arg
(
"shape"
)
=
std
::
nullopt
)
.
def
(
"get_communication_stream"
,
&
CommOverlap
::
get_communication_stream
);
...
...
@@ -533,8 +541,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"gemm_priority"
)
=
0
,
py
::
arg
(
"comm_priority"
)
=
0
,
py
::
arg
(
"num_comm_sm"
)
=
1
,
py
::
arg
(
"set_sm_margin"
)
=
false
,
py
::
arg
(
"atomic_gemm"
)
=
false
,
py
::
arg
(
"use_ce"
)
=
true
,
py
::
arg
(
"aggregate"
)
=
false
)
.
def
(
"copy_into_buffer"
,
&
CommOverlapP2P
::
copy_into_buffer
,
py
::
arg
(
"input"
),
py
::
arg
(
"local_chunk"
)
=
false
)
.
def
(
"copy_into_buffer"
,
static_cast
<
void
(
CommOverlapP2P
::*
)(
const
at
::
Tensor
&
,
bool
)
>
(
&
CommOverlapP2P
::
copy_into_buffer
),
py
::
arg
(
"input"
),
py
::
arg
(
"local_chunk"
)
=
false
)
.
def
(
"get_buffer"
,
&
CommOverlapP2P
::
get_buffer
,
py
::
arg
(
"local_chunk"
)
=
false
,
py
::
arg
(
"shape"
)
=
std
::
nullopt
)
.
def
(
"get_communication_stream"
,
&
CommOverlapP2P
::
get_communication_stream
);
...
...
transformer_engine/pytorch/distributed.py
View file @
9df0c4a3
...
...
@@ -729,8 +729,8 @@ def checkpoint(
if
isinstance
(
function
,
TransformerEngineBaseModule
):
# If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need
# to scatter/gather activations that we will recompute anyway.
setattr
(
function
,
"fsdp_wrapped"
,
False
)
setattr
(
function
,
"fsdp_group"
,
None
)
function
.
fast_setattr
(
"fsdp_wrapped"
,
False
)
function
.
fast_setattr
(
"fsdp_group"
,
None
)
# Otherwise discard unused te.utils.checkpoint.checkpoint() arguments
# and execute TE's own checkpointing
...
...
@@ -2022,7 +2022,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
)
root_state
=
_get_module_fsdp_state
(
fsdp_root
)
assert
root_state
is
not
None
,
"Root module does not have a valid _FSDPState."
setattr
(
fsdp_root
.
module
,
"fsdp_group"
,
root_state
.
process_group
)
fsdp_root
.
module
.
fast_setattr
(
"fsdp_group"
,
root_state
.
process_group
)
# Iterate through all FSDP-wrapped submodules and inject FSDP information into TE modules
fsdp_states
,
fsdp_modules
=
_get_fsdp_states_with_modules
(
fsdp_root
)
...
...
@@ -2033,7 +2033,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
"TE modules with primary weights in FP8 cannot be FSDP-wrapped. "
"Please initialize your model without the te.quantized_model_init(...) context."
)
setattr
(
fsdp_module
.
module
,
"fsdp_group"
,
state
.
process_group
)
fsdp_module
.
module
.
fast_setattr
(
"fsdp_group"
,
state
.
process_group
)
class
FullyShardedDataParallel
(
FSDP
):
...
...
transformer_engine/pytorch/graph.py
View file @
9df0c4a3
...
...
@@ -451,11 +451,12 @@ def _make_graphed_callables(
if
is_training
:
inputs
=
tuple
(
i
for
i
in
static_input_surface
if
i
.
requires_grad
)
with
_none_grad_context_wrapper
(
inputs
):
outputs_requiring_grad
=
tuple
(
o
for
o
in
outputs
if
o
is
not
None
and
o
.
requires_grad
)
torch
.
autograd
.
backward
(
tuple
(
o
for
o
in
outputs
if
o
.
requires_grad
),
grad_tensors
=
tuple
(
torch
.
empty_like
(
o
)
for
o
in
outputs
if
o
.
requires_grad
),
outputs_requiring_grad
,
grad_tensors
=
tuple
(
torch
.
empty_like
(
o
)
for
o
in
outputs_requiring_grad
),
)
grad_inputs
=
tuple
(
input
.
grad
for
input
in
inputs
)
...
...
@@ -616,19 +617,22 @@ def _make_graphed_callables(
# Note for _reuse_graph_input_output_buffers: grad output is only used
# within backward, so we can reuse the same static buffers every time.
static_grad_outputs_keys
=
tuple
(
(
o
.
shape
,
o
.
dtype
,
o
.
layout
)
for
o
in
static_outputs
if
o
.
requires_grad
(
o
.
shape
,
o
.
dtype
,
o
.
layout
)
for
o
in
static_outputs
if
o
is
not
None
and
o
.
requires_grad
)
if
static_grad_outputs_keys
in
static_grad_outputs_dict
:
static_grad_outputs
=
static_grad_outputs_dict
[
static_grad_outputs_keys
]
else
:
static_grad_outputs
=
tuple
(
torch
.
empty_like
(
o
)
if
o
.
requires_grad
else
None
torch
.
empty_like
(
o
)
if
o
is
not
None
and
o
.
requires_grad
else
None
for
o
in
static_outputs
)
static_grad_outputs_dict
[
static_grad_outputs_keys
]
=
static_grad_outputs
else
:
static_grad_outputs
=
tuple
(
torch
.
empty_like
(
o
)
if
o
.
requires_grad
else
None
for
o
in
static_outputs
torch
.
empty_like
(
o
)
if
o
is
not
None
and
o
.
requires_grad
else
None
for
o
in
static_outputs
)
if
is_training
:
inputs
=
tuple
(
i
for
i
in
static_input_surface
if
i
.
requires_grad
)
...
...
@@ -636,7 +640,9 @@ def _make_graphed_callables(
bwd_graph
,
pool
=
mempool
):
torch
.
autograd
.
backward
(
tuple
(
o
for
o
in
static_outputs
if
o
.
requires_grad
),
tuple
(
o
for
o
in
static_outputs
if
o
is
not
None
and
o
.
requires_grad
),
grad_tensors
=
tuple
(
o
for
o
in
static_grad_outputs
if
o
is
not
None
),
retain_graph
=
retain_graph_in_backward
,
)
...
...
@@ -719,7 +725,8 @@ def _make_graphed_callables(
):
# For now, assumes all static_outputs require grad
static_grad_outputs
=
tuple
(
torch
.
empty_like
(
o
)
if
o
.
requires_grad
else
None
for
o
in
static_outputs
torch
.
empty_like
(
o
)
if
o
is
not
None
and
o
.
requires_grad
else
None
for
o
in
static_outputs
)
if
is_training
:
inputs
=
tuple
(
i
for
i
in
static_input_surface
if
i
.
requires_grad
)
...
...
@@ -727,7 +734,7 @@ def _make_graphed_callables(
bwd_graph
,
pool
=
mempool
):
torch
.
autograd
.
backward
(
tuple
(
o
for
o
in
static_outputs
if
o
.
requires_grad
),
tuple
(
o
for
o
in
static_outputs
if
o
is
not
None
and
o
.
requires_grad
),
grad_tensors
=
tuple
(
o
for
o
in
static_grad_outputs
if
o
is
not
None
),
retain_graph
=
retain_graph_in_backward
,
)
...
...
@@ -794,7 +801,7 @@ def _make_graphed_callables(
# Replay forward graph
fwd_graph
.
replay
()
assert
isinstance
(
static_outputs
,
tuple
)
return
tuple
(
o
.
detach
()
for
o
in
static_outputs
)
return
tuple
(
o
.
detach
()
if
o
is
not
None
else
o
for
o
in
static_outputs
)
@
staticmethod
@
torch
.
autograd
.
function
.
once_differentiable
...
...
@@ -853,12 +860,22 @@ def _make_graphed_callables(
return
functionalized
def
make_graphed_attribute_functions
(
graph_idx
):
# Get te modules for current graph
te_modules
=
visited_te_modules
.
get
(
graph_idx
,
set
())
# Attach backward_dw as an attribute to the graphed callable.
def
backward_dw
():
if
need_bwd_dw_graph
.
get
(
graph_idx
,
False
):
bwd_dw_graphs
[
graph_idx
].
replay
()
# Trigger the grad accumulation hook for wgrad graphs.
for
module
in
te_modules
:
if
(
isinstance
(
module
,
TransformerEngineBaseModule
)
and
module
.
need_backward_dw
()
):
module
.
_trigger_wgrad_accumulation_and_reduce_hooks
()
# Attach reset as an attribute to the graphed callable.
def
reset
():
fwd_graphs
[
graph_idx
].
reset
()
...
...
transformer_engine/pytorch/jit.py
View file @
9df0c4a3
...
...
@@ -47,17 +47,35 @@ if torch_version() >= (2, 2, 0) and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"
# Decorator to disable Torch Dynamo
# See: https://github.com/NVIDIA/TransformerEngine/issues/308
no_torch_dynamo
=
lambda
recursive
=
True
:
lambda
func
:
func
if
torch
.
__version__
>=
"2"
:
import
torch._dynamo
if
torch
.
__version__
>=
"2.1"
:
no_torch_dynamo
=
lambda
recursive
=
True
:
lambda
f
:
(
f
if
is_in_onnx_export_mode
()
else
torch
.
_dynamo
.
disable
(
f
,
recursive
=
recursive
)
)
else
:
# no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
no_torch_dynamo
=
lambda
recursive
=
True
:
torch
.
_dynamo
.
disable
def
no_torch_dynamo
(
recursive
=
True
):
"""Decorator to disable Torch Dynamo, except during ONNX export."""
def
decorator
(
f
):
# no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
disabled_f
=
(
torch
.
_dynamo
.
disable
(
f
,
recursive
=
recursive
)
if
torch
.
__version__
>=
"2.1"
else
torch
.
_dynamo
.
disable
(
f
)
)
@
wraps
(
f
)
def
wrapper
(
*
args
,
**
kwargs
):
if
is_in_onnx_export_mode
():
return
f
(
*
args
,
**
kwargs
)
return
disabled_f
(
*
args
,
**
kwargs
)
return
wrapper
return
decorator
else
:
# Fallback for PyTorch < 2.0: no-op decorator
def
no_torch_dynamo
(
recursive
=
True
):
# pylint: disable=unused-argument
"""No-op decorator for PyTorch < 2.0."""
return
lambda
func
:
func
def
set_jit_fusion_options
()
->
None
:
...
...
transformer_engine/pytorch/module/_common.py
View file @
9df0c4a3
...
...
@@ -90,6 +90,8 @@ class _NoopCatFunc(torch.autograd.Function):
# Check first tensor
if
not
tensors
:
raise
ValueError
(
"Attempted to concatenate 0 tensors"
)
# Check concat dim
num_dims
=
tensors
[
0
].
dim
()
if
not
-
num_dims
<=
dim
<
num_dims
:
raise
ValueError
(
...
...
@@ -122,11 +124,24 @@ class _NoopCatFunc(torch.autograd.Function):
ctx
.
dim
=
dim
ctx
.
split_ranges
=
split_ranges
#
Out-of-place concatenation if needed
#
Tensor properties from first tensor
dtype
=
tensors
[
0
].
dtype
device
=
tensors
[
0
].
device
strides
=
tensors
[
0
].
stride
()
data_ptr_stride
=
strides
[
dim
]
*
tensors
[
0
].
element_size
()
# Out-of-place concatenation when view tensors have different storage
# Note: This works around an edge case with the split_quantize
# function, which might allocate a buffer and construct
# subviews. However, in order to reduce CPU overheads, these
# views are configured manually outside of PyTorch. PyTorch
# doesn't know these views share the same memory, and it
# blocks us from reconstructing the full tensor because it
# thinks we are accessing out-of-bounds memory.
if
tensors
[
0
].
untyped_storage
().
nbytes
()
<
out_shape
[
dim
]
*
data_ptr_stride
:
return
torch
.
cat
(
tensors
,
dim
=
dim
)
# Out-of-place concatenation if tensor properties do not match
data_ptr
=
tensors
[
0
].
data_ptr
()
+
tensors
[
0
].
size
(
dim
)
*
data_ptr_stride
for
tensor
in
tensors
[
1
:]:
if
(
...
...
@@ -139,13 +154,7 @@ class _NoopCatFunc(torch.autograd.Function):
data_ptr
+=
tensor
.
size
(
dim
)
*
data_ptr_stride
# No-op concatenation
out
=
tensors
[
0
].
new
()
out
.
set_
(
tensors
[
0
].
untyped_storage
(),
tensors
[
0
].
storage_offset
(),
out_shape
,
strides
,
)
out
=
tensors
[
0
].
as_strided
(
out_shape
,
strides
)
out
.
requires_grad
=
any
(
tensor
.
requires_grad
for
tensor
in
tensors
)
return
out
...
...
transformer_engine/pytorch/module/base.py
View file @
9df0c4a3
...
...
@@ -10,9 +10,8 @@ import pickle
import
warnings
from
enum
import
Enum
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
,
Union
from
contextlib
import
contextmanager
import
logging
from
types
import
MethodType
import
torch
...
...
@@ -50,6 +49,8 @@ from ..utils import (
is_non_tn_fp8_gemm_supported
,
torch_get_autocast_gpu_dtype
,
get_nvtx_range_context
,
nvtx_range_push
,
nvtx_range_pop
,
)
from
..tensor.storage.float8_blockwise_tensor_storage
import
Float8BlockwiseQTensorStorage
from
...common.recipe
import
DelayedScaling
,
Recipe
...
...
@@ -644,10 +645,10 @@ def fill_userbuffers_buffer_for_all_gather(
class
TransformerEngineBaseModule
(
torch
.
nn
.
Module
,
ABC
):
"""Base TE module."""
def
__init__
(
self
)
->
None
:
def
__init__
(
self
,
name
:
Optional
[
str
]
=
None
)
->
None
:
super
().
__init__
()
assert
torch
.
cuda
.
is_available
(),
"TransformerEngine needs CUDA."
self
.
name
=
Non
e
self
.
name
=
nam
e
self
.
next_iter_when_debug_should_be_run
=
0
self
.
fp8_initialized
=
False
self
.
fp8
=
False
...
...
@@ -672,26 +673,22 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if
not
TEDebugState
.
debug_enabled
:
TEDebugState
.
initialize
()
self
.
_validate_name
()
# Names of attributes that can be set quickly (see __setattr__
# method)
_fast_setattr_names
:
Set
[
str
]
=
{
"activation_dtype"
,
"fp8"
,
"fp8_initialized"
,
"fp8_calibration"
,
"fp8_parameters"
,
}
def
__setattr__
(
self
,
name
:
str
,
value
:
Any
)
->
None
:
if
name
in
TransformerEngineBaseModule
.
_fast_setattr_names
:
# torch.nn.Module has a custom __setattr__ that handles
# modules, parameters, and buffers. This is unnecessary
# overhead when setting plain attrs.
self
.
__dict__
[
name
]
=
value
else
:
# Default case
super
().
__setattr__
(
name
,
value
)
def
fast_setattr
(
self
,
name
:
str
,
value
:
Any
)
->
None
:
"""
Fast version of the Module's set attribute function.
Should be used for regular attributes, but not properties nor parameters/buffers.
"""
self
.
__dict__
[
name
]
=
value
def
module_setattr
(
self
,
name
:
str
,
value
:
Any
)
->
None
:
"""
Regular version of the Module's set attribute function.
Should be used only when the fast version cannot be used - for the properties,
parameters and buffers.
"""
super
().
__setattr__
(
name
,
value
)
def
adjust_amax_history_length
(
self
,
length
:
int
,
fwd
:
Optional
[
bool
]
=
None
)
->
None
:
"""
...
...
@@ -812,7 +809,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self
.
set_meta_tensor
(
True
,
recipe
)
self
.
set_meta_tensor
(
False
,
recipe
)
self
.
fp8_meta_tensors_initialized
=
True
self
.
fast_setattr
(
"
fp8_meta_tensors_initialized
"
,
True
)
def
get_fp8_meta_tensors
(
self
)
->
None
:
"""Get scales and amaxes."""
...
...
@@ -969,7 +966,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Get activation data type for AMP."""
# Native AMP (`torch.autocast`) gets highest priority
if
torch
.
is_autocast_enabled
():
self
.
activation_dtype
=
torch_get_autocast_gpu_dtype
()
self
.
fast_setattr
(
"
activation_dtype
"
,
torch_get_autocast_gpu_dtype
()
)
return
# All checks after this have already been performed once, thus skip
...
...
@@ -984,7 +981,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"Data types for parameters must match when outside of autocasted region. "
f
" Found input dtype:
{
dtype
}
and
{
name
!
r
}
dtype:
{
param
.
dtype
}
"
)
self
.
activation_dtype
=
dtype
self
.
fast_setattr
(
"
activation_dtype
"
,
dtype
)
def
set_tensor_parallel_group
(
self
,
tp_group
:
Union
[
dist_group_type
,
None
])
->
None
:
"""
...
...
@@ -996,8 +993,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
tp_group : ProcessGroup, default = None
tensor parallel process group.
"""
self
.
tp_group
=
tp_group
self
.
tp_group_initialized
=
True
self
.
fast_setattr
(
"
tp_group
"
,
tp_group
)
self
.
fast_setattr
(
"
tp_group_initialized
"
,
True
)
def
_get_fp8_params
(
self
)
->
Union
[
List
[
torch
.
Tensor
],
None
]:
"""returns the FP8 weights."""
...
...
@@ -1013,48 +1010,51 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# assume FP8 execution.
def
init_fp8_metadata
(
self
,
num_gemms
:
int
=
1
)
->
None
:
"""Initialize fp8 related metadata and tensors during fprop."""
_original_recipe
=
self
.
fp8_meta
.
get
(
"recipe"
,
None
)
self
.
fp8_parameters
=
FP8GlobalStateManager
.
with_fp8_parameters
()
self
.
fp8
=
FP8GlobalStateManager
.
is_fp8_enabled
()
self
.
fp8_calibration
=
FP8GlobalStateManager
.
is_fp8_calibration
()
fp8_enabled
=
self
.
fp8
or
self
.
fp8_calibration
self
.
fp8_meta
[
"fp8_checkpoint"
]
=
self
.
fp8
or
self
.
fp8_calibration
if
self
.
fp8_parameters
or
fp8_enabled
:
if
(
self
.
fp8_initialized
and
FP8GlobalStateManager
.
get_fp8_recipe
()
==
self
.
fp8_meta
[
"recipe"
]
):
meta
=
self
.
fp8_meta
fp8
=
FP8GlobalStateManager
.
is_fp8_enabled
()
fp8_parameters
=
FP8GlobalStateManager
.
with_fp8_parameters
()
fp8_calibration
=
FP8GlobalStateManager
.
is_fp8_calibration
()
self
.
fast_setattr
(
"fp8_parameters"
,
fp8_parameters
)
self
.
fast_setattr
(
"fp8"
,
fp8
)
self
.
fast_setattr
(
"fp8_calibration"
,
fp8_calibration
)
fp8_enabled
=
fp8
or
fp8_calibration
meta
[
"fp8_checkpoint"
]
=
fp8_enabled
_original_recipe
=
None
if
fp8_parameters
or
fp8_enabled
:
_original_recipe
=
meta
.
get
(
"recipe"
,
None
)
if
self
.
fp8_initialized
and
FP8GlobalStateManager
.
get_fp8_recipe
()
==
_original_recipe
:
# FP8 init has already been run and recipe is the same, don't do anything.
return
self
.
fp8_
meta
[
"recipe"
]
=
FP8GlobalStateManager
.
get_fp8_recipe
()
meta
[
"recipe"
]
=
FP8GlobalStateManager
.
get_fp8_recipe
()
else
:
# If fp8 isn't enabled, turn off and return.
self
.
fp8_initialized
=
False
self
.
fast_setattr
(
"
fp8_initialized
"
,
False
)
return
if
self
.
fp8_parameters
and
not
self
.
fp8_initialized
:
self
.
fp8_
meta
[
"num_gemms"
]
=
num_gemms
self
.
init_fp8_meta_tensors
(
self
.
fp8_
meta
[
"recipe"
])
if
fp8_parameters
and
not
self
.
fp8_initialized
:
meta
[
"num_gemms"
]
=
num_gemms
self
.
init_fp8_meta_tensors
(
meta
[
"recipe"
])
if
fp8_enabled
:
# Set FP8 and other FP8 metadata
self
.
fp8_
meta
[
"num_gemms"
]
=
num_gemms
self
.
fp8_
meta
[
"fp8_group"
]
=
FP8GlobalStateManager
.
get_fp8_group
()
meta
[
"num_gemms"
]
=
num_gemms
meta
[
"fp8_group"
]
=
FP8GlobalStateManager
.
get_fp8_group
()
# Set FP8_MAX per tensor according to recipe
if
hasattr
(
self
.
fp8_
meta
[
"recipe"
],
"fp8_format"
):
self
.
fp8_
meta
[
"fp8_max_fwd"
]
=
self
.
fp8_
meta
[
"recipe"
].
fp8_format
.
value
.
max_fwd
self
.
fp8_
meta
[
"fp8_max_bwd"
]
=
self
.
fp8_
meta
[
"recipe"
].
fp8_format
.
value
.
max_bwd
if
hasattr
(
meta
[
"recipe"
],
"fp8_format"
):
meta
[
"fp8_max_fwd"
]
=
meta
[
"recipe"
].
fp8_format
.
value
.
max_fwd
meta
[
"fp8_max_bwd"
]
=
meta
[
"recipe"
].
fp8_format
.
value
.
max_bwd
# Allocate scales and amaxes
self
.
init_fp8_meta_tensors
(
self
.
fp8_
meta
[
"recipe"
])
self
.
fp8_initialized
=
True
self
.
init_fp8_meta_tensors
(
meta
[
"recipe"
])
self
.
fast_setattr
(
"
fp8_initialized
"
,
True
)
self
.
fp8_
meta
[
"recipe"
]
=
FP8GlobalStateManager
.
get_fp8_recipe
()
meta
[
"recipe"
]
=
FP8GlobalStateManager
.
get_fp8_recipe
()
_current_recipe
=
self
.
fp8_
meta
[
"recipe"
]
_current_recipe
=
meta
[
"recipe"
]
if
_original_recipe
is
not
None
and
not
(
issubclass
(
_current_recipe
.
__class__
,
_original_recipe
.
__class__
)
or
issubclass
(
_original_recipe
.
__class__
,
_current_recipe
.
__class__
)
...
...
@@ -1067,22 +1067,18 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Clear cached workspaces as they were created with the old recipe/quantizer type
self
.
_fp8_workspaces
.
clear
()
@
contextmanager
def
prepare_forward
(
self
,
inp
:
torch
.
Tensor
,
num_gemms
:
int
=
1
,
allow_non_contiguous
:
bool
=
False
,
allow_different_data_and_param_types
:
bool
=
False
,
)
->
Generator
[
torch
.
Tensor
,
None
,
None
]:
"""Checks and prep for FWD.
The context manager is needed because there isn't a way for a module to know
if it's the last FP8 module in the forward autocast. It is useful
to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent one.
"""
self
.
allow_different_data_and_param_types
=
allow_different_data_and_param_types
self
.
forwarded_at_least_once
=
True
)
->
torch
.
Tensor
:
"""Checks and prepares for FWD execution."""
self
.
fast_setattr
(
"allow_different_data_and_param_types"
,
allow_different_data_and_param_types
)
self
.
fast_setattr
(
"forwarded_at_least_once"
,
True
)
# Activation recomputation is used and this is the second forward phase.
if
self
.
fp8
and
in_fp8_activation_recompute_phase
():
...
...
@@ -1113,13 +1109,37 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if
self
.
training
and
is_fp8_activation_recompute_enabled
():
FP8GlobalStateManager
.
copy_forward_fp8_meta_tensors_for_recompute
(
self
.
fp8_meta
)
with
get_
nvtx_range_
context
(
self
.
__class__
.
__name__
+
" forward"
)
:
if
not
allow_non_contiguous
and
not
inp
.
is_contiguous
():
inp
=
inp
.
contiguous
()
yield
inp
nvtx_range_
push
(
self
.
__class__
.
__name__
+
" forward"
)
if
not
allow_non_contiguous
and
not
inp
.
is_contiguous
():
inp
=
inp
.
contiguous
()
return
inp
def
end_forward
(
self
):
"""
Required to be called at the end of the forward function to properly handle
DelayedScaling metadata handling and the NVTX ranges.
"""
delayed_scaling_recipe
=
self
.
fp8
and
self
.
fp8_meta
[
"recipe"
].
delayed
()
if
delayed_scaling_recipe
and
self
.
fp8
and
in_fp8_activation_recompute_phase
():
FP8GlobalStateManager
.
restore_fp8_meta_tensors
(
self
.
fp8_meta
)
nvtx_range_pop
()
@
contextmanager
def
prepare_forward_ctx
(
self
,
inp
:
torch
.
Tensor
,
num_gemms
:
int
=
1
,
allow_non_contiguous
:
bool
=
False
,
allow_different_data_and_param_types
:
bool
=
False
,
)
->
Generator
[
torch
.
Tensor
,
None
,
None
]:
"""Checks and prepares for FWD execution."""
inp
=
self
.
prepare_forward
(
inp
,
num_gemms
,
allow_non_contiguous
,
allow_different_data_and_param_types
)
try
:
yield
inp
finally
:
self
.
end_forward
()
def
set_nccl_overlap_warning_if_tp
(
self
)
->
None
:
"""When using TP, the NCCL communication needs to be scheduled
...
...
@@ -1354,9 +1374,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Update the parameter based on its type
if
not
is_dtensor
:
setattr
(
self
,
name
,
param
)
se
lf
.
module_se
tattr
(
name
,
param
)
else
:
setattr
(
self
,
name
,
dtensor_param
)
se
lf
.
module_se
tattr
(
name
,
dtensor_param
)
@
abstractmethod
def
forward
(
self
):
...
...
@@ -1545,8 +1565,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
bias_tensor
.
grad
=
bgrad
.
to
(
bias_tensor
.
dtype
)
del
wgrad
del
bgrad
for
wgrad_accumulation_and_reduce_hook
in
self
.
wgrad_accumulation_and_reduce_hooks
:
wgrad_accumulation_and_reduce_hook
()
self
.
_trigger_wgrad_accumulation_and_reduce_hooks
()
def
_trigger_wgrad_accumulation_and_reduce_hooks
(
self
):
"""
Trigger the wgrad accumulation and reduce hooks.
"""
for
wgrad_accumulation_and_reduce_hook
in
self
.
wgrad_accumulation_and_reduce_hooks
:
wgrad_accumulation_and_reduce_hook
()
def
is_debug_iter
(
self
)
->
bool
:
"""
...
...
@@ -1555,7 +1581,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
debug
=
TEDebugState
.
debug_enabled
if
not
debug
:
return
False
self
.
_validate_name
()
# If layer is run first time in new iteration,
# we need to check if the debug should be enabled for this layer -
...
...
@@ -1569,14 +1594,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
debug
=
False
else
:
debug
=
TEDebugState
.
get_iteration
()
>=
self
.
next_iter_when_debug_should_be_run
self
.
debug_last_iteration
=
TEDebugState
.
get_iteration
()
self
.
debug_enabled_in_this_iteration
=
debug
self
.
fast_setattr
(
"
debug_last_iteration
"
,
TEDebugState
.
get_iteration
()
)
self
.
fast_setattr
(
"
debug_enabled_in_this_iteration
"
,
debug
)
else
:
# If this is the same iteration as previous invocation of the module,
# we use the debug value from the first invocation in the iteration.
debug
=
self
.
debug_enabled_in_this_iteration
self
.
debug_last_iteration
=
TEDebugState
.
get_iteration
()
self
.
fast_setattr
(
"
debug_last_iteration
"
,
TEDebugState
.
get_iteration
()
)
if
self
.
wgrad_store
is
not
None
:
if
debug
and
self
.
wgrad_store
.
delay_wgrad_compute
():
...
...
@@ -1592,7 +1617,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Sometimes features inform that they will not be enabled for particular layer
# for multiple next iterations.
self
.
next_iter_when_debug_should_be_run
=
next_iter_when_debug_should_be_run
(
quantizers
)
self
.
fast_setattr
(
"next_iter_when_debug_should_be_run"
,
next_iter_when_debug_should_be_run
(
quantizers
)
)
if
not
run_current
:
return
True
...
...
@@ -1604,22 +1631,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def
_validate_name
(
self
):
"""
Validate name passed to the module.
This is invoked in the forward() method as mod
ul
e
name
s are assigned after Model is initialized in Megatron-LM.
If no name is assigned, it creates a default name with layer count as the variab
le.
It creates a defa
ul
t
name
with layer count as the variable
which may be changed by the user of the modu
le.
"""
if
self
.
name
is
not
None
:
return
assert
TEDebugState
.
debug_enabled
import
nvdlfw_inspect.api
as
debug_api
if
self
.
name
is
None
:
debug_api
.
log_message
(
"Names are not provided to debug modules. "
,
"Creating and using generic names. Pass names to debug modules for better"
" insight. "
,
level
=
logging
.
WARNING
,
)
self
.
name
=
f
"Layer_
{
TEDebugState
.
get_layer_count
()
}
"
self
.
name
=
f
"Layer_
{
TEDebugState
.
get_layer_count
()
}
"
def
_check_weight_tensor_recipe_correspondence
(
self
)
->
None
:
"""
...
...
transformer_engine/pytorch/module/grouped_linear.py
View file @
9df0c4a3
...
...
@@ -15,6 +15,7 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
import
transformer_engine_torch
as
tex
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.pytorch.tensor.storage.grouped_tensor
import
GroupedTensor
from
.base
import
(
get_dummy_wgrad
,
TransformerEngineBaseModule
,
...
...
@@ -149,7 +150,10 @@ class _GroupedLinear(torch.autograd.Function):
# tensors (like scales), but bulk allocation shares storage across all tensors,
# so if scales can't be offloaded, nothing in the group can be offloaded.
inputmats
=
tex
.
split_quantize
(
inp_view
,
m_splits
,
input_quantizers
,
disable_bulk_allocation
=
cpu_offloading
inp_view
,
m_splits
,
input_quantizers
,
disable_bulk_allocation
=
cpu_offloading
,
)
elif
debug
:
inputmats
=
DebugQuantizer
.
multi_tensor_quantize
(
...
...
@@ -367,7 +371,10 @@ class _GroupedLinear(torch.autograd.Function):
for
i
in
range
(
ctx
.
num_gemms
):
grad_biases
[
i
]
=
grad_output_mats
[
i
].
sum
(
dim
=
0
)
grad_output
=
DebugQuantizer
.
multi_tensor_quantize
(
grad_output_view
,
ctx
.
grad_output_quantizers
,
ctx
.
m_splits
,
ctx
.
activation_dtype
grad_output_view
,
ctx
.
grad_output_quantizers
,
ctx
.
m_splits
,
ctx
.
activation_dtype
,
)
else
:
# Only split grad output. Grad bias is fused with
...
...
@@ -438,7 +445,8 @@ class _GroupedLinear(torch.autograd.Function):
if
ctx
.
input_quantizers
[
0
]
is
not
None
:
for
input_quantizer
in
ctx
.
input_quantizers
:
if
isinstance
(
input_quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)
input_quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
),
):
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
else
:
...
...
@@ -448,7 +456,10 @@ class _GroupedLinear(torch.autograd.Function):
inputmats
=
tex
.
split_quantize
(
inp_view
,
ctx
.
m_splits
,
ctx
.
input_quantizers
)
elif
ctx
.
debug
:
inputmats
=
DebugQuantizer
.
multi_tensor_quantize
(
inp_view
,
ctx
.
input_quantizers
,
ctx
.
m_splits
,
ctx
.
activation_dtype
inp_view
,
ctx
.
input_quantizers
,
ctx
.
m_splits
,
ctx
.
activation_dtype
,
)
else
:
inputmats
=
torch
.
split
(
...
...
@@ -623,9 +634,9 @@ class GroupedLinear(TransformerEngineBaseModule):
save_original_input
:
bool
=
False
,
name
:
Optional
[
str
]
=
None
,
)
->
None
:
super
().
__init__
()
super
().
__init__
(
name
)
params_dtype
=
torch
.
get_default_dtype
()
if
params_dtype
is
None
else
params_dtype
self
.
params_dtype
=
torch
.
get_default_dtype
()
if
params_dtype
is
None
else
params_dtype
self
.
num_gemms
=
num_gemms
self
.
in_features
=
in_features
self
.
out_features
=
out_features
...
...
@@ -640,13 +651,19 @@ class GroupedLinear(TransformerEngineBaseModule):
assert
(
not
ub_overlap_rs
and
not
ub_overlap_ag
),
"GroupedLinear doesn't support Userbuffer overlap."
self
.
init_method
=
init_method
self
.
get_rng_state_tracker
=
get_rng_state_tracker
self
.
rng_tracker_name
=
rng_tracker_name
self
.
name
=
name
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
)
self
.
_offsets
=
{
"input"
:
0
,
"weight"
:
1
,
"output"
:
2
,
"grad_output"
:
0
,
"grad_input"
:
1
}
self
.
_offsets
=
{
"input"
:
0
,
"weight"
:
1
,
"output"
:
2
,
"grad_output"
:
0
,
"grad_input"
:
1
,
}
self
.
_num_fp8_tensors_per_gemm
=
{
"fwd"
:
3
,
"bwd"
:
2
,
...
...
@@ -688,7 +705,7 @@ class GroupedLinear(TransformerEngineBaseModule):
self
.
out_features
,
self
.
in_features
,
device
=
device
,
dtype
=
params_dtype
,
dtype
=
self
.
params_dtype
,
),
),
init_fn
=
init_method
,
...
...
@@ -704,13 +721,13 @@ class GroupedLinear(TransformerEngineBaseModule):
torch
.
empty
(
self
.
out_features
,
device
=
device
,
dtype
=
params_dtype
,
dtype
=
self
.
params_dtype
,
),
),
init_fn
=
init_method_constant
(
0.0
),
)
else
:
bias
=
torch
.
Tensor
().
to
(
dtype
=
params_dtype
,
device
=
device
)
bias
=
torch
.
Tensor
().
to
(
dtype
=
self
.
params_dtype
,
device
=
device
)
setattr
(
self
,
f
"bias
{
i
}
"
,
bias
)
if
self
.
primary_weights_in_fp8
:
...
...
@@ -734,8 +751,63 @@ class GroupedLinear(TransformerEngineBaseModule):
if
recipe
.
float8_current_scaling
():
self
.
_customize_quantizers_float8_current_scaling
(
fwd
,
recipe
)
def
make_grouped_weights
(
self
,
defer_init
=
False
)
->
None
:
"""
Convert parameters into a GroupedTensor and re-register them as parameters.
"""
if
defer_init
:
return
weight_quantizers
=
self
.
_get_weight_quantizers
()
recipe
=
(
weight_quantizers
[
0
].
_get_compatible_recipe
()
if
weight_quantizers
and
weight_quantizers
[
0
]
is
not
None
else
None
)
if
recipe
is
not
None
and
(
recipe
.
delayed
()
or
recipe
.
float8_current_scaling
()):
self
.
set_tensor_parallel_attributes
(
defer_init
=
defer_init
)
return
weights
=
[
getattr
(
self
,
f
"weight
{
i
}
"
)
for
i
in
range
(
self
.
num_gemms
)]
# Create the weight storage.
grouped_weights
=
GroupedTensor
.
make_grouped_tensor_with_shapes
(
num_tensors
=
self
.
num_gemms
,
shape
=
[(
self
.
out_features
,
self
.
in_features
)]
*
self
.
num_gemms
,
quantizer
=
weight_quantizers
[
0
],
dtype
=
self
.
params_dtype
,
device
=
weights
[
0
].
device
,
)
# Copy existing params into storage.
with
torch
.
no_grad
():
for
i
in
range
(
self
.
num_gemms
):
if
self
.
primary_weights_in_fp8
:
grouped_weights
.
quantized_tensors
[
i
].
copy_from_storage
(
weights
[
i
])
else
:
grouped_weights
.
quantized_tensors
[
i
].
copy_
(
weights
[
i
])
# Re-register the grouped weights as parameters.
for
i
in
range
(
self
.
num_gemms
):
self
.
register_parameter
(
f
"weight
{
i
}
"
,
torch
.
nn
.
Parameter
(
grouped_weights
.
quantized_tensors
[
i
]),
init_fn
=
self
.
init_method
,
get_rng_state_tracker
=
self
.
get_rng_state_tracker
,
fp8_meta_index
=
self
.
_offsets
[
"weight"
]
+
i
*
self
.
_num_fp8_tensors_per_gemm
[
"fwd"
],
)
self
.
set_tensor_parallel_attributes
(
defer_init
=
defer_init
)
def
reset_parameters
(
self
,
defer_init
=
False
):
super
().
reset_parameters
(
defer_init
=
defer_init
)
# Grouped tensor weights is an opt-in feature.
if
bool
(
int
(
os
.
getenv
(
"NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS"
,
"0"
))):
self
.
make_grouped_weights
(
defer_init
=
defer_init
)
def
set_tensor_parallel_attributes
(
self
,
defer_init
=
False
)
->
None
:
"""Set attributes needed for TP"""
if
not
defer_init
:
# Set parallelism attributes for linear weights
...
...
@@ -798,7 +870,8 @@ class GroupedLinear(TransformerEngineBaseModule):
is_grad_enabled
=
torch
.
is_grad_enabled
()
with
self
.
prepare_forward
(
inp
,
num_gemms
=
self
.
num_gemms
)
as
inp
:
inp
=
self
.
prepare_forward
(
inp
,
num_gemms
=
self
.
num_gemms
)
try
:
weight_tensors
=
self
.
_get_weight_tensors
()
bias_tensors
=
[
getattr
(
self
,
f
"bias
{
i
}
"
)
for
i
in
range
(
self
.
num_gemms
)]
...
...
@@ -853,6 +926,9 @@ class GroupedLinear(TransformerEngineBaseModule):
)
out
=
linear_fn
(
*
autograd_ctx
,
inp
,
non_tensor_args
,
*
weight_tensors
,
*
bias_tensors
)
finally
:
self
.
end_forward
()
if
self
.
return_bias
:
return
out
,
[
cast_if_needed
(
b
,
self
.
activation_dtype
)
for
b
in
bias_tensors
]
return
out
...
...
@@ -879,8 +955,7 @@ class GroupedLinear(TransformerEngineBaseModule):
del
grad_biases_
del
wgrad_list
del
tensor_list
for
wgrad_accumulation_and_reduce_hook
in
self
.
wgrad_accumulation_and_reduce_hooks
:
wgrad_accumulation_and_reduce_hook
()
self
.
_trigger_wgrad_accumulation_and_reduce_hooks
()
def
_customize_quantizers_float8_current_scaling
(
self
,
fwd
:
bool
,
recipe
:
Recipe
)
->
None
:
"""Customize quantizers based on current scaling recipe + linear."""
...
...
@@ -932,7 +1007,7 @@ class GroupedLinear(TransformerEngineBaseModule):
def
_get_weight_quantizers
(
self
)
->
List
[
Quantizer
]:
"""Get the weight quantizers of the module."""
if
not
self
.
fp8
and
not
self
.
fp8_calibration
:
if
not
self
.
fp8
and
not
self
.
fp8_calibration
and
not
self
.
primary_weights_in_fp8
:
return
[
None
]
*
self
.
num_gemms
weight_quantizers
=
[
self
.
quantizers
[
"scaling_fwd"
][
...
...
@@ -941,7 +1016,7 @@ class GroupedLinear(TransformerEngineBaseModule):
for
i
in
range
(
self
.
num_gemms
)
]
for
i
in
range
(
self
.
num_gemms
):
weight_quantizers
[
i
].
internal
=
True
weight_quantizers
[
i
].
internal
=
not
self
.
primary_weights_in_fp8
return
weight_quantizers
def
_get_quantizers
(
self
):
...
...
transformer_engine/pytorch/module/layernorm_linear.py
View file @
9df0c4a3
...
...
@@ -1177,9 +1177,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
ub_name
:
Optional
[
str
]
=
None
,
delay_wgrad_compute
:
bool
=
False
,
symmetric_ar_type
:
Optional
[
str
]
=
None
,
name
:
str
=
None
,
name
:
Optional
[
str
]
=
None
,
)
->
None
:
super
().
__init__
()
super
().
__init__
(
name
)
params_dtype
=
torch
.
get_default_dtype
()
if
params_dtype
is
None
else
params_dtype
self
.
in_features
=
in_features
...
...
@@ -1198,7 +1198,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
self
.
symmetric_ar_type
=
symmetric_ar_type
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
,
ub_bulk_wgrad
)
self
.
name
=
name
if
tp_group
is
None
:
self
.
tp_size
=
tp_size
...
...
@@ -1527,10 +1526,11 @@ class LayerNormLinear(TransformerEngineBaseModule):
).
is_fp8_ubuf
():
fp8_grad
=
True
with
self
.
prepare_forward
(
inp
=
self
.
prepare_forward
(
inp
,
allow_non_contiguous
=
False
# removed .contiguous from inside the layer
)
as
inp
:
)
try
:
# Get concatenated weight and bias tensors
weight_tensor
,
bias_tensor
=
self
.
_get_weight_and_bias_tensors
()
...
...
@@ -1609,6 +1609,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
non_tensor_args
,
)
finally
:
self
.
end_forward
()
if
self
.
return_layernorm_output
:
out
,
ln_out
=
out
...
...
transformer_engine/pytorch/module/layernorm_mlp.py
View file @
9df0c4a3
...
...
@@ -107,6 +107,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
return
{
"gelu"
:
(
tex
.
gelu
,
tex
.
dgelu
,
None
),
"geglu"
:
(
tex
.
geglu
,
tex
.
dgeglu
,
None
),
"glu"
:
(
tex
.
glu
,
tex
.
dglu
,
None
),
"qgelu"
:
(
tex
.
qgelu
,
tex
.
dqgelu
,
None
),
"qgeglu"
:
(
tex
.
qgeglu
,
tex
.
dqgeglu
,
None
),
"relu"
:
(
tex
.
relu
,
tex
.
drelu
,
None
),
...
...
@@ -123,6 +124,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
return
{
"gelu"
:
(
tex
.
gelu
,
tex
.
dgelu
,
tex
.
dbias_dgelu
),
"geglu"
:
(
tex
.
geglu
,
tex
.
dgeglu
,
None
),
"glu"
:
(
tex
.
glu
,
tex
.
dglu
,
None
),
"qgelu"
:
(
tex
.
qgelu
,
tex
.
dqgelu
,
tex
.
dbias_dqgelu
),
"qgeglu"
:
(
tex
.
qgeglu
,
tex
.
dqgeglu
,
None
),
"relu"
:
(
tex
.
relu
,
tex
.
drelu
,
tex
.
dbias_drelu
),
...
...
@@ -145,6 +147,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
return
{
"gelu"
:
(
tex
.
gelu
,
tex
.
dgelu
,
None
),
"geglu"
:
(
tex
.
geglu
,
tex
.
dgeglu
,
None
),
"glu"
:
(
tex
.
glu
,
tex
.
dglu
,
None
),
"qgelu"
:
(
tex
.
qgelu
,
tex
.
dqgelu
,
None
),
"qgeglu"
:
(
tex
.
qgeglu
,
tex
.
dqgeglu
,
None
),
"relu"
:
(
tex
.
relu
,
tex
.
drelu
,
None
),
...
...
@@ -1695,7 +1698,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
type of normalization applied.
activation : str, default = 'gelu'
activation function used.
Options: ``'gelu'``, ``'geglu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``,
Options: ``'gelu'``, ``'geglu'``,
``'glu'``,
``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``,
``'silu'``, ``'swiglu'``, and ``'clamped_swiglu'``.
activation_params : dict, default = None
Additional parameters for the activation function.
...
...
@@ -1817,7 +1820,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
zero_centered_gamma
:
bool
=
False
,
device
:
Union
[
torch
.
device
,
str
]
=
"cuda"
,
ub_overlap_ag
:
bool
=
False
,
name
:
str
=
None
,
name
:
Optional
[
str
]
=
None
,
ub_overlap_rs
:
bool
=
False
,
ub_overlap_rs_dgrad
:
bool
=
False
,
ub_bulk_dgrad
:
bool
=
False
,
...
...
@@ -1826,7 +1829,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
symmetric_ar_type
:
Optional
[
str
]
=
None
,
checkpoint
:
bool
=
False
,
)
->
None
:
super
().
__init__
()
super
().
__init__
(
name
)
params_dtype
=
torch
.
get_default_dtype
()
if
params_dtype
is
None
else
params_dtype
self
.
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
...
...
@@ -1857,7 +1860,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
for
use_fp8
in
[
False
,
True
]
)
)
self
.
name
=
name
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
,
ub_bulk_wgrad
)
...
...
@@ -1915,7 +1917,15 @@ class LayerNormMLP(TransformerEngineBaseModule):
self
.
layer_norm_bias
=
None
# FC1 init
if
self
.
activation
in
[
"geglu"
,
"qgeglu"
,
"reglu"
,
"sreglu"
,
"swiglu"
,
"clamped_swiglu"
]:
if
self
.
activation
in
[
"geglu"
,
"glu"
,
"qgeglu"
,
"reglu"
,
"sreglu"
,
"swiglu"
,
"clamped_swiglu"
,
]:
fc1_output_features
=
2
*
self
.
size_per_partition
else
:
fc1_output_features
=
self
.
size_per_partition
...
...
@@ -2077,8 +2087,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
if
get_ub
(
"fc2_fprop"
,
FP8GlobalStateManager
.
is_fp8_enabled
()).
is_fp8_ubuf
():
fp8_output
=
True
with
self
.
prepare_forward
(
inp
,
num_gemms
=
2
)
as
inp
:
inp
=
self
.
prepare_forward
(
inp
,
num_gemms
=
2
)
try
:
quantizers
=
(
self
.
_get_quantizers
(
fp8_output
,
is_grad_enabled
)
if
not
debug
...
...
@@ -2118,7 +2129,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
# Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode
if
(
not
IS_HIP_EXTENSION
and
self
.
bias_gelu_nvfusion
and
not
use_reentrant_activation_recompute
()
):
self
.
bias_gelu_nvfusion
=
False
self
.
fast_setattr
(
"
bias_gelu_nvfusion
"
,
False
)
if
is_grad_enabled
:
fwd_fn
=
_LayerNormMLP
.
apply
...
...
@@ -2188,6 +2199,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
non_tensor_args
,
)
finally
:
self
.
end_forward
()
if
self
.
return_layernorm_output
:
out
,
ln_out
=
out
...
...
@@ -2336,6 +2350,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
activation_map
=
{
"gelu"
:
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
"tanh"
),
"geglu"
:
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
.
chunk
(
2
,
-
1
)[
0
])
*
x
.
chunk
(
2
,
-
1
)[
1
],
"glu"
:
lambda
x
:
torch
.
sigmoid
(
x
.
chunk
(
2
,
-
1
)[
0
])
*
x
.
chunk
(
2
,
-
1
)[
1
],
"qgelu"
:
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
"tanh"
),
"qgeglu"
:
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
.
chunk
(
2
,
-
1
)[
0
],
approximate
=
"tanh"
)
*
x
.
chunk
(
2
,
-
1
)[
1
],
...
...
@@ -2534,5 +2549,4 @@ class LayerNormMLP(TransformerEngineBaseModule):
del
fc2_wgrad
del
fc1_wgrad
del
fc1_bias_grad
for
wgrad_accumulation_and_reduce_hook
in
self
.
wgrad_accumulation_and_reduce_hooks
:
wgrad_accumulation_and_reduce_hook
()
self
.
_trigger_wgrad_accumulation_and_reduce_hooks
()
transformer_engine/pytorch/module/linear.py
View file @
9df0c4a3
...
...
@@ -429,8 +429,8 @@ class _Linear(torch.autograd.Function):
# weights if weights are externally touched outside this module
ctx
.
weight_object
=
weight
if
cpu_offloading
:
mark_not_offload
(
weight
,
weightmat
,
bias
)
# TODO(ksivamani): Check memory usage
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
saved_inputmat
,
...
...
@@ -1103,7 +1103,7 @@ class Linear(TransformerEngineBaseModule):
save_original_input
:
bool
=
False
,
name
:
Optional
[
str
]
=
None
,
)
->
None
:
super
().
__init__
()
super
().
__init__
(
name
)
params_dtype
=
torch
.
get_default_dtype
()
if
params_dtype
is
None
else
params_dtype
self
.
in_features
=
in_features
...
...
@@ -1116,7 +1116,6 @@ class Linear(TransformerEngineBaseModule):
self
.
rng_tracker_name
=
rng_tracker_name
self
.
symmetric_ar_type
=
symmetric_ar_type
self
.
save_original_input
=
save_original_input
self
.
name
=
name
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
,
ub_bulk_wgrad
)
...
...
@@ -1400,11 +1399,8 @@ class Linear(TransformerEngineBaseModule):
).
is_fp8_ubuf
():
fp8_grad
=
True
with
self
.
prepare_forward
(
inp
,
allow_non_contiguous
=
isinstance
(
inp
,
QuantizedTensor
),
)
as
inp
:
inp
=
self
.
prepare_forward
(
inp
,
allow_non_contiguous
=
isinstance
(
inp
,
QuantizedTensor
))
try
:
weight_tensor
,
bias_tensor
=
self
.
_get_weight_and_bias_tensors
()
quantizers
=
(
...
...
@@ -1475,6 +1471,8 @@ class Linear(TransformerEngineBaseModule):
bias_tensor
if
(
self
.
apply_bias
and
not
self
.
gemm_bias_unfused_add
)
else
None
,
non_tensor_args
,
)
finally
:
self
.
end_forward
()
if
self
.
gemm_bias_unfused_add
:
out
=
out
+
cast_if_needed
(
bias_tensor
,
self
.
activation_dtype
)
...
...
transformer_engine/pytorch/ops/__init__.py
View file @
9df0c4a3
...
...
@@ -8,7 +8,9 @@ This operation-based API is experimental and subject to change.
"""
from
transformer_engine.pytorch.ops.basic
import
*
from
transformer_engine.pytorch.ops.linear
import
Linear
from
transformer_engine.pytorch.ops.op
import
FusibleOperation
from
transformer_engine.pytorch.ops.sequential
import
Sequential
from
.basic
import
*
from
.fuser
import
register_backward_fusion
,
register_forward_fusion
from
.linear
import
Linear
from
.op
import
BasicOperation
,
FusedOperation
,
FusibleOperation
from
.sequential
import
Sequential
from
.
import
fused
transformer_engine/pytorch/ops/basic/__init__.py
View file @
9df0c4a3
...
...
@@ -7,6 +7,7 @@
from
.activation
import
(
GELU
,
GEGLU
,
GLU
,
QGELU
,
QGEGLU
,
ReLU
,
...
...
@@ -14,8 +15,6 @@ from .activation import (
SReLU
,
SReGLU
,
SiLU
,
SwiGLU
,
ClampedSwiGLU
,
)
from
.add_extra_input
import
AddExtraInput
from
.all_gather
import
AllGather
...
...
@@ -24,6 +23,7 @@ from .basic_linear import BasicLinear
from
.bias
import
Bias
from
.constant_scale
import
ConstantScale
from
.dropout
import
Dropout
from
.grouped_linear
import
GroupedLinear
from
.identity
import
Identity
from
.l2normalization
import
L2Normalization
from
.layer_norm
import
LayerNorm
...
...
@@ -32,3 +32,4 @@ from .quantize import Quantize
from
.reduce_scatter
import
ReduceScatter
from
.reshape
import
Reshape
from
.rmsnorm
import
RMSNorm
from
.swiglu
import
ClampedSwiGLU
,
ScaledSwiGLU
,
SwiGLU
transformer_engine/pytorch/ops/basic/activation.py
View file @
9df0c4a3
...
...
@@ -20,6 +20,7 @@ from .._common import maybe_dequantize
__all__
=
[
"GELU"
,
"GEGLU"
,
"GLU"
,
"QGELU"
,
"QGEGLU"
,
"ReLU"
,
...
...
@@ -27,8 +28,6 @@ __all__ = [
"SReLU"
,
"SReGLU"
,
"SiLU"
,
"SwiGLU"
,
"ClampedSwiGLU"
,
]
...
...
@@ -164,6 +163,38 @@ class GELU(_ActivationOperation):
return
tex
.
dgelu
(
*
args
,
**
kwargs
)
class
GLU
(
_ActivationOperation
):
r
"""Gated Linear Unit
The input tensor is split into chunks :math:`a` and :math:`b`
along the last dimension and the following is computed:
.. math::
\text{GLU}(a,b) = \sigma(a) * b
where :math:`\sigma` is the sigmoid function.
.. warning::
Transformer Engine's gated activations and PyTorch's GLU
activation follow opposite conventions for :math:`a` and
:math:`b`. Transformer Engine applies the gating function to
the first half of the input tensor, while PyTorch applies it to
the second half.
See `Language Modeling with Gated Convolutional Networks<https://arxiv.org/abs/1612.08083>`__
and `GLU Variants Improve Transformer<https://arxiv.org/abs/2002.05202>`__.
"""
def
_activation_forward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
glu
(
*
args
,
**
kwargs
)
def
_activation_backward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
dglu
(
*
args
,
**
kwargs
)
class
GEGLU
(
_ActivationOperation
):
r
"""Gaussian Error Gated Linear Unit
...
...
@@ -355,76 +386,3 @@ class SiLU(_ActivationOperation):
def
_activation_backward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
dsilu
(
*
args
,
**
kwargs
)
class
SwiGLU
(
_ActivationOperation
):
r
"""Swish gated linear unit
The input tensor is split into chunks :math:`a` and :math:`b`
along the last dimension and the following is computed:
.. math::
\text{GEGLU}(a,b) = \text{SiLU}(a) * b
where
.. math::
\text{SiLU}(x) = x \sigma(x) = \frac{x}{1+\exp(-x)}
.. warning::
Transformer Engine's gated activations and PyTorch's GLU
activation follow opposite conventions for :math:`a` and
:math:`b`. Transformer Engine applies the gating function to
the first half of the input tensor, while PyTorch applies it to
the second half.
The Sigmoid Linear Unit (SiLU) gating function is also known as
the swish function. See
`GLU Variants Improve Transformer<https://arxiv.org/abs/2002.05202>`__
and `Gaussian Error Linear Units (GELUs)<https://arxiv.org/abs/1606.08415>`__.
"""
def
_activation_forward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
swiglu
(
*
args
,
**
kwargs
)
def
_activation_backward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
dswiglu
(
*
args
,
**
kwargs
)
class
ClampedSwiGLU
(
_ActivationOperation
):
r
"""GPT-OSS
Implementation based on `GPT-OSS<https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250>`__.
This activation has two differences compared to the original SwiGLU
1. Both gate and pre-activations are clipped based on parameter limit.
2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation.
.. warning:: The input tensor is chunked along the last dimension to get gates/pre-activations which is differnt
from GPT OSS implementation where the gates/pre-activations are assumed to be interleaved in the input tensor.
Parameters
----------
limit : float
The clamp limit.
alpha : float
The scaling factor for the sigmoid function used in the activation.
cache_quantized_input : bool, default = False
Quantize input tensor when caching for use in the backward pass.
"""
def
__init__
(
self
,
*
,
limit
:
float
=
7.0
,
alpha
:
float
=
1.702
,
cache_quantized_input
:
bool
=
False
):
super
().
__init__
(
cache_quantized_input
=
cache_quantized_input
)
self
.
limit
=
limit
self
.
alpha
=
alpha
def
_activation_forward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
clamped_swiglu
(
*
args
,
limit
=
self
.
limit
,
alpha
=
self
.
alpha
,
**
kwargs
)
def
_activation_backward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
clamped_dswiglu
(
*
args
,
limit
=
self
.
limit
,
alpha
=
self
.
alpha
,
**
kwargs
)
transformer_engine/pytorch/ops/basic/grouped_linear.py
0 → 100644
View file @
9df0c4a3
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operation for grouped linear layer."""
from
__future__
import
annotations
from
collections.abc
import
Callable
,
Iterable
,
Sequence
import
contextlib
import
math
from
typing
import
Any
,
Optional
import
torch
import
transformer_engine_torch
as
tex
from
...cpp_extensions
import
general_grouped_gemm
from
...distributed
import
CudaRNGStatesTracker
from
...module.base
import
(
_2X_ACC_FPROP
,
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
,
get_dummy_wgrad
,
)
from
...quantization
import
FP8GlobalStateManager
,
Recipe
from
...tensor
import
MXFP8Quantizer
,
MXFP8Tensor
,
Quantizer
from
...utils
import
(
canonicalize_device
,
canonicalize_dtype
,
clear_tensor_data
,
devices_match
,
round_up_to_nearest_multiple
,
)
from
.._common
import
is_quantized_tensor
,
maybe_dequantize
from
..op
import
BasicOperation
,
OperationContext
class
GroupedLinear
(
BasicOperation
):
r
"""Apply multiple linear transformations: :math:``y_i = x_i W_i^T + b_i``
This feature is experimental and subject to change.
This is equivalent to splitting the input tensor along its first
dimension, applying a separate ``torch.nn.Linear`` to each split,
and concatenating along the first dimension.
Parameters
----------
num_groups : int
Number of linear transformations.
in_features : int
Inner dimension of input tensor.
out_features : int
Inner dimension of output tensor.
bias : bool, default = ``True``
Apply additive bias.
device : torch.device, default = default CUDA device
Tensor device.
dtype : torch.dtype, default = default dtype
Tensor datatype.
rng_state_tracker_function : callable
Function that returns ``CudaRNGStatesTracker``, which is used
for model-parallel weight initialization.
accumulate_into_main_grad : bool, default = ``False``
Whether to directly accumulate weight gradients into the
weight's ``main_grad`` attribute instead of relying on PyTorch
autograd. The weight's ``main_grad`` must be set externally
and there is no guarantee that `grad` will be set or be
meaningful. This is primarily intented to integrate with
Megatron-LM. This argument along with weight tensor having
attribute ``overwrite_main_grad`` set to True will overwrite
``main_grad`` instead of accumulating.
"""
# Operation expects input split sizes
num_extra_inputs
:
int
=
1
def
__init__
(
self
,
num_groups
:
int
,
in_features
:
int
,
out_features
:
int
,
*
,
bias
:
bool
=
True
,
device
:
Optional
[
torch
.
device
|
str
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
rng_state_tracker_function
:
Optional
[
Callable
[[],
CudaRNGStatesTracker
]]
=
None
,
accumulate_into_main_grad
:
bool
=
False
,
)
->
None
:
super
().
__init__
()
# Weight tensor dimensions
self
.
num_groups
:
int
=
num_groups
self
.
in_features
:
int
=
in_features
self
.
out_features
:
int
=
out_features
if
self
.
num_groups
<=
0
:
raise
ValueError
(
f
"Invalid number of groups (
{
self
.
num_groups
}
)"
)
if
self
.
in_features
<=
0
:
raise
ValueError
(
f
"Invalid input size (
{
self
.
in_features
}
)"
)
if
self
.
out_features
<=
0
:
raise
ValueError
(
f
"Invalid output size (
{
self
.
out_features
}
)"
)
# Weight tensor attributes
device
=
canonicalize_device
(
device
)
dtype
=
canonicalize_dtype
(
dtype
)
if
dtype
not
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
):
raise
ValueError
(
f
"Supported dtypes are float32, float16, bfloat16 (got
{
dtype
}
)"
)
# Initialize recipe state if needed for natively quantized weight
self
.
_with_quantized_weight
:
bool
=
FP8GlobalStateManager
.
with_fp8_parameters
()
if
self
.
_with_quantized_weight
:
self
.
reset_recipe_state
(
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
())
# RNG state tracker
self
.
_rng_state_tracker_function
:
Optional
[
Callable
[[],
CudaRNGStatesTracker
]]
self
.
_rng_state_tracker_function
=
rng_state_tracker_function
# Register weights
self
.
weight0
:
torch
.
nn
.
Parameter
for
group_idx
in
range
(
self
.
num_groups
):
weight_tensor
=
torch
.
empty
(
self
.
out_features
,
self
.
in_features
,
device
=
"meta"
,
dtype
=
dtype
,
)
self
.
register_parameter
(
f
"weight
{
group_idx
}
"
,
torch
.
nn
.
Parameter
(
weight_tensor
),
)
# Register biases
self
.
bias0
:
Optional
[
torch
.
nn
.
Parameter
]
for
group_idx
in
range
(
self
.
num_groups
):
bias_tensor
=
None
if
bias
:
bias_tensor
=
torch
.
empty
(
self
.
out_features
,
device
=
"meta"
,
dtype
=
dtype
,
)
bias_tensor
=
torch
.
nn
.
Parameter
(
bias_tensor
)
self
.
register_parameter
(
f
"bias
{
group_idx
}
"
,
bias_tensor
)
# Initialize weights if needed
if
device
.
type
!=
"meta"
:
self
.
reset_parameters
()
# Whether to accumulate weight gradient into main_grad
self
.
_accumulate_into_main_grad
:
bool
=
accumulate_into_main_grad
def
num_quantizers
(
self
,
mode
:
str
)
->
int
:
if
mode
==
"forward"
:
return
2
*
self
.
num_groups
if
mode
==
"backward"
:
return
self
.
num_groups
return
0
@
property
def
has_bias
(
self
)
->
bool
:
"""Whether an additive bias is being applied"""
return
self
.
bias0
is
not
None
def
reset_parameters
(
self
)
->
None
:
"""Initialize parameter buffers and values"""
# Parameter device
device
=
self
.
weight0
.
device
if
device
.
type
==
"meta"
:
device
=
canonicalize_device
(
None
)
# Initialize weight values
# Note: Allocate a single buffer in order to support grouped
# GEMM kernels that expect a single weight buffer.
packed_weights
=
torch
.
empty
(
self
.
num_groups
,
self
.
out_features
,
self
.
in_features
,
dtype
=
self
.
weight0
.
dtype
,
device
=
device
,
)
weights
=
[
packed_weights
[
idx
]
for
idx
in
range
(
self
.
num_groups
)]
for
weight
in
weights
:
init_context
=
contextlib
.
nullcontext
()
if
self
.
_rng_state_tracker_function
is
not
None
:
init_context
=
self
.
_rng_state_tracker_function
().
fork
()
with
init_context
:
torch
.
nn
.
init
.
kaiming_uniform_
(
weight
,
a
=
math
.
sqrt
(
5
))
# Quantize weights if needed
if
self
.
_with_quantized_weight
:
# Configure quantizers
quantizers
=
[
self
.
get_quantizer
(
"forward"
,
2
*
idx
+
1
)
for
idx
in
range
(
self
.
num_groups
)
]
with_rowwise_usage
=
True
with_columnwise_usage
=
torch
.
is_grad_enabled
()
for
quantizer
in
quantizers
:
if
quantizer
is
None
:
raise
RuntimeError
(
"Tried to quantize weight with deferred initialization "
"due to meta device, but no quantizer was available. "
"This is most likely because the weight was initialized "
"within quantized_model_init, but the forward pass was not "
"performed within autocast."
)
quantizer
.
set_usage
(
rowwise
=
with_rowwise_usage
,
columnwise
=
with_columnwise_usage
,
)
quantizer
.
internal
=
False
# Quantize weights
weights
=
self
.
_quantize_weights
(
weights
,
quantizers
)
# Register weights
for
group_idx
,
weight
in
enumerate
(
weights
):
if
not
isinstance
(
weight
,
torch
.
nn
.
Parameter
):
weight
=
torch
.
nn
.
Parameter
(
weight
)
setattr
(
self
,
f
"weight
{
group_idx
}
"
,
weight
)
# Initialize biases if needed
if
self
.
bias0
is
not
None
:
packed_biases
=
torch
.
zeros
(
self
.
num_groups
,
self
.
out_features
,
dtype
=
self
.
bias0
.
dtype
,
device
=
device
,
)
for
group_idx
in
range
(
self
.
num_groups
):
bias
=
torch
.
nn
.
Parameter
(
packed_biases
[
group_idx
])
setattr
(
self
,
f
"bias
{
group_idx
}
"
,
bias
)
def
_quantize_weights
(
self
,
weights
:
Sequence
[
torch
.
Tensor
],
quantizers
:
Sequence
[
Quantizer
],
)
->
Sequence
[
torch
.
Tensor
]:
"""Construct quantized weight tensors."""
# Manually construct MXFP8 weights
if
isinstance
(
quantizers
[
0
],
MXFP8Quantizer
):
return
self
.
_quantize_weights_mxfp8
(
weights
,
quantizers
)
# Use quantizers to construct quantized weights
with
torch
.
no_grad
():
return
[
quantizer
(
weight
)
for
quantizer
,
weight
in
zip
(
quantizers
,
weights
)]
def
_quantize_weights_mxfp8
(
self
,
weights
:
Sequence
[
torch
.
Tensor
],
quantizers
:
Sequence
[
Quantizer
],
)
->
Sequence
[
MXFP8Tensor
]:
"""Construct MXFP8 weight tensors.
Instead of allocating separate buffers for each weight tensor,
this function constructs large buffers and assigns subviews to
each tensor. This is intended to support grouped GEMM kernels
that expect packed buffers.
"""
# Tensor dimensions
num_groups
=
len
(
weights
)
out_features
,
in_features
=
weights
[
0
].
size
()
packed_shape
=
(
num_groups
,
out_features
,
in_features
)
unpacked_shape
=
(
out_features
,
in_features
)
# Tensor attributes
device
=
weights
[
0
].
device
dtype
=
weights
[
0
].
dtype
requires_grad
=
torch
.
is_grad_enabled
()
with_rowwise_usage
=
quantizers
[
0
].
rowwise_usage
with_columnwise_usage
=
quantizers
[
0
].
columnwise_usage
# Construct packed buffers
rowwise_data
=
[
None
]
*
num_groups
rowwise_scales
=
[
None
]
*
num_groups
columnwise_data
=
[
None
]
*
num_groups
columnwise_scales
=
[
None
]
*
num_groups
if
with_rowwise_usage
:
scale_shape
=
(
num_groups
,
round_up_to_nearest_multiple
(
out_features
,
128
),
round_up_to_nearest_multiple
(
in_features
//
32
,
4
),
)
packed_data
=
torch
.
empty
(
packed_shape
,
dtype
=
torch
.
uint8
,
device
=
device
)
packed_scales
=
torch
.
empty
(
scale_shape
,
dtype
=
torch
.
uint8
,
device
=
device
)
rowwise_data
=
[
packed_data
[
idx
]
for
idx
in
range
(
num_groups
)]
rowwise_scales
=
[
packed_scales
[
idx
]
for
idx
in
range
(
num_groups
)]
if
with_columnwise_usage
:
scale_shape
=
(
num_groups
,
round_up_to_nearest_multiple
(
out_features
//
32
,
4
),
round_up_to_nearest_multiple
(
in_features
,
128
),
)
packed_data
=
torch
.
empty
(
packed_shape
,
dtype
=
torch
.
uint8
,
device
=
device
)
packed_scales
=
torch
.
empty
(
scale_shape
,
dtype
=
torch
.
uint8
,
device
=
device
)
columnwise_data
=
[
packed_data
[
idx
]
for
idx
in
range
(
num_groups
)]
columnwise_scales
=
[
packed_scales
[
idx
]
for
idx
in
range
(
num_groups
)]
# Construct MXFP8 tensors and cast to MXFP8
out
=
[]
with
torch
.
no_grad
():
for
group_idx
in
range
(
num_groups
):
weight
=
MXFP8Tensor
(
shape
=
unpacked_shape
,
dtype
=
dtype
,
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
rowwise_data
=
rowwise_data
[
group_idx
],
rowwise_scale_inv
=
rowwise_scales
[
group_idx
],
columnwise_data
=
columnwise_data
[
group_idx
],
columnwise_scale_inv
=
columnwise_scales
[
group_idx
],
quantizer
=
quantizers
[
group_idx
],
requires_grad
=
requires_grad
,
with_gemm_swizzled_scales
=
False
,
)
weight
.
copy_
(
weights
[
group_idx
])
out
.
append
(
weight
)
return
out
def
pre_first_fuser_forward
(
self
)
->
None
:
super
().
pre_first_fuser_forward
()
# Initialize params if needed
if
any
(
param
.
device
.
type
==
"meta"
for
param
in
self
.
parameters
()):
self
.
reset_parameters
()
# Check that weights are consistent
dtype
=
self
.
weight0
.
dtype
device
=
self
.
weight0
.
device
weight_requires_grad
=
self
.
weight0
.
requires_grad
weight_tensor_type
=
type
(
self
.
weight0
.
data
)
for
group_idx
in
range
(
self
.
num_groups
):
weight
=
getattr
(
self
,
f
"weight
{
group_idx
}
"
)
if
weight
.
dtype
!=
dtype
:
raise
RuntimeError
(
f
"Weight
{
group_idx
}
has invalid dtype (expected
{
dtype
}
, got
{
weight
.
dtype
}
)."
)
if
not
devices_match
(
weight
.
device
,
device
):
raise
RuntimeError
(
f
"Weight
{
group_idx
}
has invalid device "
f
"(expected
{
device
}
, got
{
weight
.
device
}
)."
)
if
weight
.
requires_grad
!=
weight_requires_grad
:
raise
RuntimeError
(
f
"Weight
{
group_idx
}
has requires_grad=
{
weight
.
requires_grad
}
, "
f
"but expected requires_grad=
{
weight_requires_grad
}
."
)
if
type
(
weight
.
data
)
!=
weight_tensor_type
:
# pylint: disable=unidiomatic-typecheck
raise
RuntimeError
(
f
"Weight
{
group_idx
}
has invalid tensor type "
f
"(expected
{
weight_tensor_type
.
__name__
}
, "
f
"got
{
type
(
weight
.
data
).
__name__
}
)."
)
# Check that biases are consistent
for
group_idx
in
range
(
self
.
num_groups
):
bias
=
getattr
(
self
,
f
"bias
{
group_idx
}
"
)
if
self
.
has_bias
:
if
bias
is
None
:
raise
RuntimeError
(
f
"Expected biases, but bias
{
group_idx
}
is uninitialized"
)
if
bias
.
dtype
!=
dtype
:
raise
RuntimeError
(
f
"Bias
{
group_idx
}
has invalid dtype (expected
{
dtype
}
, got
{
bias
.
dtype
}
)."
)
if
not
devices_match
(
bias
.
device
,
device
):
raise
RuntimeError
(
f
"Bias
{
group_idx
}
has invalid device "
f
"(expected
{
device
}
, got
{
bias
.
device
}
)."
)
if
bias
.
requires_grad
!=
weight_requires_grad
:
raise
RuntimeError
(
f
"Bias
{
group_idx
}
has requires_grad=
{
bias
.
requires_grad
}
, "
f
"but expected requires_grad=
{
weight_requires_grad
}
."
)
else
:
if
bias
is
not
None
:
raise
RuntimeError
(
f
"Expected no biases, but bias
{
group_idx
}
is initialized"
)
def
pre_fuser_forward
(
self
,
*
,
requires_grad
:
bool
)
->
None
:
super
().
pre_fuser_forward
(
requires_grad
=
requires_grad
)
if
FP8GlobalStateManager
.
is_fp8_enabled
():
# Assume weights have consistent grad requirement
weight_requires_grad
=
requires_grad
and
self
.
weight0
.
requires_grad
# Configure quantizer usages
# Note: We cache the quantized input for backward pass,
# but discard the quantized weights.
for
group_idx
in
range
(
self
.
num_groups
):
input_quantizer
=
self
.
get_quantizer
(
"forward"
,
2
*
group_idx
)
weight_quantizer
=
self
.
get_quantizer
(
"forward"
,
2
*
group_idx
+
1
)
grad_output_quantizer
=
self
.
get_quantizer
(
"backward"
,
group_idx
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
weight_requires_grad
)
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
grad_output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
weight_requires_grad
)
def
reset_recipe_state
(
self
,
*
,
recipe
:
Optional
[
Recipe
])
->
None
:
super
().
reset_recipe_state
(
recipe
=
recipe
)
for
group_idx
in
range
(
self
.
num_groups
):
# Input/grad output quantizers use internal tensors
input_quantizer
=
self
.
get_quantizer
(
"forward"
,
2
*
group_idx
)
grad_output_quantizer
=
self
.
get_quantizer
(
"backward"
,
group_idx
)
if
input_quantizer
is
not
None
:
input_quantizer
.
internal
=
True
if
grad_output_quantizer
is
not
None
:
grad_output_quantizer
.
internal
=
True
# Handle weight quantizer
# Note: This function may be called in base class constructor,
# before any basic linear attrs have been set.
weight_quantizer
=
self
.
get_quantizer
(
"forward"
,
2
*
group_idx
+
1
)
if
weight_quantizer
is
None
:
pass
elif
is_quantized_tensor
(
getattr
(
self
,
f
"weight
{
group_idx
}
"
,
None
)):
# Make sure weight param has correct quantizer
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
torch
.
is_grad_enabled
())
weight_quantizer
.
internal
=
False
getattr
(
self
,
f
"weight
{
group_idx
}
"
).
update_quantizer
(
weight_quantizer
.
copy
())
else
:
# Use internal tensors if quantized weights will not be
# exposed externally
weight_quantizer
.
internal
=
(
not
FP8GlobalStateManager
.
with_fp8_parameters
()
and
not
getattr
(
self
,
"_with_quantized_weight"
,
False
)
)
# Recipe-specific configuration
# Note: This function may be called in base class constructor,
# before any basic linear attrs have been set.
if
recipe
is
not
None
:
if
recipe
.
float8_current_scaling
():
input_quantizer
.
force_pow_2_scales
=
recipe
.
fp8_quant_fwd_inp
.
power_2_scale
input_quantizer
.
amax_epsilon_scales
=
recipe
.
fp8_quant_fwd_inp
.
amax_epsilon
weight_quantizer
.
force_pow_2_scales
=
recipe
.
fp8_quant_fwd_weight
.
power_2_scale
weight_quantizer
.
amax_epsilon_scales
=
recipe
.
fp8_quant_fwd_weight
.
amax_epsilon
grad_output_quantizer
.
force_pow_2_scales
=
(
recipe
.
fp8_quant_bwd_grad
.
power_2_scale
)
grad_output_quantizer
.
amax_epsilon_scales
=
(
recipe
.
fp8_quant_bwd_grad
.
amax_epsilon
)
def
op_forward
(
self
,
*
args
,
**
kwargs
):
raise
RuntimeError
(
f
"
{
self
.
__class__
.
__name__
}
operation has "
f
"
{
self
.
num_extra_inputs
}
extra tensor inputs "
f
"and
{
self
.
num_extra_outputs
}
extra tensor outputs. "
"It overrides `fuser_forward` instead of `op_forward`."
)
def
op_backward
(
self
,
*
args
,
**
kwargs
):
raise
RuntimeError
(
f
"
{
self
.
__class__
.
__name__
}
operation has "
f
"
{
self
.
num_extra_inputs
}
extra tensor inputs "
f
"and
{
self
.
num_extra_outputs
}
extra tensor outputs. "
"It overrides `fuser_backward` instead of `op_backward`."
)
def
fuser_forward
(
self
,
basic_op_ctxs
:
list
[
OperationContext
],
input_
:
torch
.
Tensor
,
*
,
basic_op_extra_inputs
:
list
[
tuple
[
torch
.
Tensor
,
...]],
prev_op_grad_output_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
basic_op_kwargs
:
list
[
dict
[
str
,
Any
]],
)
->
tuple
[
torch
.
Tensor
,
Iterable
[
Iterable
[
torch
.
Tensor
]]]:
num_groups
=
self
.
num_groups
has_bias
=
self
.
has_bias
device
=
self
.
weight0
.
device
# Check which grads are required
ctx
=
basic_op_ctxs
[
0
]
input_requires_grad
=
ctx
.
requires_grad
weight_requires_grad
=
ctx
.
requires_grad
and
self
.
weight0
.
requires_grad
# Quantizers
input_quantizers
=
[
None
]
*
num_groups
weight_quantizers
=
[
None
]
*
num_groups
grad_output_quantizers
=
[
None
]
*
num_groups
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
if
with_quantized_compute
:
for
group_idx
in
range
(
num_groups
):
input_quantizers
[
group_idx
]
=
self
.
get_quantizer
(
"forward"
,
2
*
group_idx
)
weight_quantizers
[
group_idx
]
=
self
.
get_quantizer
(
"forward"
,
2
*
group_idx
+
1
)
grad_output_quantizers
[
group_idx
]
=
self
.
get_quantizer
(
"backward"
,
group_idx
)
# Get autocast dtype if needed
if
torch
.
is_autocast_enabled
():
dtype
=
torch
.
get_autocast_dtype
(
"cuda"
)
else
:
dtype
=
self
.
weight0
.
dtype
# Extract split sizes from extra input
split_sizes
=
basic_op_extra_inputs
[
0
][
0
]
split_sizes_int
=
[
int
(
s
)
for
s
in
split_sizes
.
tolist
()]
if
len
(
split_sizes_int
)
!=
num_groups
:
raise
ValueError
(
f
"Expected
{
num_groups
}
splits, but got
{
len
(
split_sizes_int
)
}
."
)
# Extract params
weights
=
[
getattr
(
self
,
f
"weight
{
idx
}
"
)
for
idx
in
range
(
num_groups
)]
bs
=
None
if
has_bias
:
bs
=
[
maybe_dequantize
(
getattr
(
self
,
f
"bias
{
idx
}
"
),
dtype
)
for
idx
in
range
(
num_groups
)]
# Convert weight dtype if needed
ws
=
[]
for
w
,
quantizer
in
zip
(
weights
,
weight_quantizers
):
if
not
with_quantized_compute
:
w
=
maybe_dequantize
(
w
,
dtype
)
elif
with_quantized_compute
and
not
is_quantized_tensor
(
w
):
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
input_requires_grad
)
w
=
quantizer
(
w
)
ws
.
append
(
w
)
# Split input tensor and convert dtypes if needed
x
=
maybe_dequantize
(
input_
,
dtype
)
xs
=
None
if
with_quantized_compute
:
for
quantizer
in
input_quantizers
:
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
weight_requires_grad
)
xs
=
tex
.
split_quantize
(
x
,
split_sizes_int
,
input_quantizers
)
else
:
xs
=
torch
.
split
(
x
,
split_sizes_int
)
# Allocate output tensor
in_shape
=
list
(
input_
.
size
())
out_shape
=
in_shape
[:
-
1
]
+
[
self
.
out_features
]
out
=
torch
.
empty
(
out_shape
,
dtype
=
dtype
,
device
=
device
)
# Perform GEMMs
general_grouped_gemm
(
ws
,
xs
,
[
out
],
[
None
]
*
num_groups
,
# quantization_params
dtype
,
m_splits
=
split_sizes_int
,
bias
=
bs
,
use_bias
=
has_bias
,
use_split_accumulator
=
_2X_ACC_FPROP
,
single_output
=
True
,
)
# Prepare weight tensors for backward pass
if
not
input_requires_grad
:
ws
=
[
None
]
*
num_groups
elif
with_quantized_compute
:
for
w
,
weight_param
in
zip
(
ws
,
weights
):
if
w
is
not
weight_param
:
w
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
# Prepare input tensor for backward pass
if
not
weight_requires_grad
:
xs
=
[
None
]
*
num_groups
elif
with_quantized_compute
:
for
x
in
xs
:
x
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
# Save state for backward pass
if
ctx
.
requires_grad
:
ctx
.
save_for_backward
(
split_sizes
,
*
xs
,
*
ws
)
ctx
.
with_quantized_compute
=
with_quantized_compute
ctx
.
input_quantizers
=
input_quantizers
ctx
.
weight_quantizers
=
weight_quantizers
ctx
.
grad_output_quantizers
=
grad_output_quantizers
ctx
.
grad_input_quantizers
=
None
ctx
.
dtype
=
dtype
ctx
.
input_requires_grad
=
input_requires_grad
ctx
.
weight_requires_grad
=
weight_requires_grad
return
out
,
[()]
def
fuser_backward
(
self
,
basic_op_ctxs
:
list
[
OperationContext
],
grad_output
:
torch
.
Tensor
,
*
,
basic_op_grad_extra_outputs
:
list
[
tuple
[
torch
.
Tensor
,
...]],
)
->
tuple
[
torch
.
Tensor
,
Iterable
[
Iterable
[
Optional
[
torch
.
Tensor
]]],
Iterable
[
Iterable
[
Optional
[
torch
.
Tensor
]]],
]:
num_groups
=
self
.
num_groups
has_bias
=
self
.
has_bias
device
=
self
.
weight0
.
device
# Saved tensors from forward pass
ctx
=
basic_op_ctxs
[
0
]
saved_tensors
=
ctx
.
saved_tensors
split_sizes
,
saved_tensors
=
saved_tensors
[
0
],
saved_tensors
[
1
:]
xs
,
saved_tensors
=
saved_tensors
[:
num_groups
],
saved_tensors
[
num_groups
:]
ws
,
saved_tensors
=
saved_tensors
[:
num_groups
],
saved_tensors
[
num_groups
:]
# Split grad output tensor and convert dtypes if needed
split_sizes_int
=
[
int
(
s
)
for
s
in
split_sizes
.
tolist
()]
dy
=
maybe_dequantize
(
grad_output
,
ctx
.
dtype
)
dys
=
None
grad_biases
=
[
None
]
*
num_groups
if
ctx
.
with_quantized_compute
:
for
quantizer
in
ctx
.
grad_output_quantizers
:
quantizer
.
set_usage
(
rowwise
=
ctx
.
input_requires_grad
,
columnwise
=
ctx
.
weight_requires_grad
,
)
dys
=
tex
.
split_quantize
(
dy
,
split_sizes_int
,
ctx
.
grad_output_quantizers
)
if
has_bias
:
grad_biases
=
[
dy
.
reshape
(
-
1
,
dy
.
size
(
-
1
)).
sum
(
dim
=
0
)
for
dy
in
torch
.
split
(
grad_output
,
split_sizes_int
)
]
else
:
dys
=
torch
.
split
(
dy
,
split_sizes_int
)
if
has_bias
:
grad_biases
=
[
dy
.
reshape
(
-
1
,
dy
.
size
(
-
1
)).
sum
(
dim
=
0
)
for
dy
in
dys
]
# Initialize grad weight buffers
accumulate_into_main_grad
=
self
.
_accumulate_into_main_grad
grad_weights
=
[
None
]
*
num_groups
if
ctx
.
weight_requires_grad
:
if
accumulate_into_main_grad
:
# Megatron-LM wgrad fusion
# Note: Get grad tensors from params so we can
# accumulate directly into it.
for
group_idx
in
range
(
num_groups
):
weight_param
=
getattr
(
self
,
f
"weight
{
group_idx
}
"
)
if
hasattr
(
weight_param
,
"__fsdp_param__"
):
weight_param
.
main_grad
=
weight_param
.
get_main_grad
()
grad_weights
[
group_idx
]
=
weight_param
.
main_grad
accumulate_into_main_grad
=
not
getattr
(
self
.
weight0
,
"overwrite_main_grad"
,
False
)
else
:
weight_shape
=
ws
[
0
].
size
()
for
group_idx
in
range
(
num_groups
):
grad_weights
[
group_idx
]
=
torch
.
empty
(
weight_shape
,
dtype
=
ctx
.
dtype
,
device
=
device
,
)
else
:
accumulate_into_main_grad
=
False
# Perform dgrad GEMMs
grad_input
=
None
if
ctx
.
input_requires_grad
:
out_shape
=
list
(
grad_output
.
size
())
in_shape
=
out_shape
[:
-
1
]
+
[
self
.
in_features
]
grad_input
=
torch
.
empty
(
in_shape
,
dtype
=
ctx
.
dtype
,
device
=
device
,
)
general_grouped_gemm
(
ws
,
dys
,
[
grad_input
],
[
None
]
*
num_groups
,
# quantization_params
ctx
.
dtype
,
layout
=
"NN"
,
m_splits
=
split_sizes_int
,
use_split_accumulator
=
_2X_ACC_DGRAD
,
single_output
=
True
,
)
# Perform wgrad GEMMs
if
ctx
.
weight_requires_grad
:
general_grouped_gemm
(
xs
,
dys
,
grad_weights
,
[
None
]
*
num_groups
,
# quantization_params
ctx
.
dtype
,
layout
=
"NT"
,
m_splits
=
split_sizes_int
,
use_split_accumulator
=
_2X_ACC_WGRAD
,
accumulate
=
accumulate_into_main_grad
,
)
# Clear input tensors if possible
clear_tensor_data
(
*
xs
)
# Megatron-LM wgrad fusion
# Note: Return dummy tensor for grad weight if needed.
if
accumulate_into_main_grad
:
grad_weights
=
[
None
]
*
num_groups
for
group_idx
in
range
(
num_groups
):
weight_param
=
getattr
(
self
,
f
"weight
{
group_idx
}
"
)
if
hasattr
(
weight_param
,
"grad_added_to_main_grad"
):
weight_param
.
grad_added_to_main_grad
=
True
grad_weights
[
group_idx
]
=
get_dummy_wgrad
(
list
(
weight_param
.
size
()),
weight_param
.
dtype
,
zero
=
getattr
(
weight_param
,
"zero_out_wgrad"
,
False
),
)
grad_params
=
grad_weights
+
grad_biases
if
has_bias
else
grad_weights
return
grad_input
,
[
grad_params
],
[(
None
,)]
transformer_engine/pytorch/ops/basic/swiglu.py
0 → 100644
View file @
9df0c4a3
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operation for SwiGLU and variants."""
from
__future__
import
annotations
from
collections.abc
import
Iterable
from
typing
import
Any
,
Optional
import
torch
import
transformer_engine_torch
as
tex
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...tensor
import
Float8CurrentScalingQuantizer
,
Quantizer
from
...utils
import
clear_tensor_data
from
..op
import
BasicOperation
,
OperationContext
from
.._common
import
maybe_dequantize
__all__
=
[
"SwiGLU"
,
"ClampedSwiGLU"
,
"ScaledSwiGLU"
]
class
SwiGLU
(
BasicOperation
):
r
"""Swish gated linear unit
The input tensor is split into chunks :math:``a`` and :math:``b``
along the last dimension and the following is computed:
.. math::
\text{SwiGLU}(a,b) = \text{SiLU}(a) * b
where
.. math::
\text{SiLU}(x) = x \sigma(x) = \frac{x}{1+\exp(-x)}
.. warning::
Transformer Engine's gated activations and PyTorch's GLU
activation follow opposite conventions for :math:``a`` and
:math:``b``. Transformer Engine applies the gating function to
the first half of the input tensor, while PyTorch applies it to
the second half.
The Sigmoid Linear Unit (SiLU) gating function is also known as
the swish function. See
``GLU Variants Improve Transformer<https://arxiv.org/abs/2002.05202>``__.
Parameters
----------
cache_quantized_input : bool, default = False
Quantize input tensor when caching for use in the backward
pass. This will typically reduce memory usage but require
extra compute and increase numerical error. This feature is
highly experimental.
glu_interleave_size : int, optional
When set, the GLU activations will use a block interleaved
format. Instead of interpreting the input tensor as a
concatenation of gates and linear units (e.g.
:math:``[a_1, a_2, a_3, a_4, b_1, b_2, b_3, b_4]``
in the above notation), it will be interpreted
as alternating blocks of gates and linear units (e.g.
:math:``[a_1, a_2, b_1, b_2, a_3, a_4, b_3, b_4]``
when the interleave size is 2). This data format is highly
experiental and is primarily intended to support some advanced
fused kernels.
"""
def
__init__
(
self
,
*
,
cache_quantized_input
:
bool
=
False
,
glu_interleave_size
:
Optional
[
int
]
=
None
,
):
super
().
__init__
()
self
.
cache_quantized_input
:
bool
=
cache_quantized_input
self
.
glu_interleave_size
:
Optional
[
int
]
=
glu_interleave_size
def
op_forward
(
self
,
ctx
:
OperationContext
,
input_
:
torch
.
Tensor
,
prev_op_grad_output_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
)
->
torch
.
Tensor
:
# Compute dtype
dtype
:
torch
.
dtype
if
torch
.
is_autocast_enabled
():
dtype
=
torch
.
get_autocast_dtype
(
"cuda"
)
else
:
dtype
=
input_
.
dtype
if
dtype
not
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
):
raise
RuntimeError
(
f
"Unsupported dtype (
{
dtype
}
)"
)
# Check input tensor
input_
=
maybe_dequantize
(
input_
.
contiguous
(),
dtype
)
# Remove interleaving if needed
swiglu_in
=
input_
if
self
.
glu_interleave_size
is
not
None
:
shape
=
swiglu_in
.
size
()
swiglu_in
=
swiglu_in
.
reshape
(
-
1
,
shape
[
-
1
]
//
(
2
*
self
.
glu_interleave_size
),
2
,
self
.
glu_interleave_size
,
)
swiglu_in
=
swiglu_in
.
transpose
(
1
,
2
).
contiguous
()
swiglu_in
=
swiglu_in
.
view
(
shape
)
# Launch kernel
out
=
tex
.
swiglu
(
swiglu_in
,
next_op_input_quantizer
)
# Quantize input to FP8 before caching if needed
if
self
.
cache_quantized_input
:
input_quantizer
=
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E4M3
,
input_
.
device
,
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
input_
=
input_quantizer
(
input_
)
# Save state for backward pass
if
ctx
.
requires_grad
:
if
is_cpu_offload_enabled
():
mark_activation_offload
(
input_
)
ctx
.
save_for_backward
(
input_
)
ctx
.
dtype
=
dtype
ctx
.
prev_op_grad_output_quantizer
=
prev_op_grad_output_quantizer
return
out
def
op_backward
(
self
,
ctx
:
OperationContext
,
grad_output
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
tuple
[()]]:
# Saved tensors from forward pass
(
input_
,)
=
ctx
.
saved_tensors
# Make sure tensors have correct dtypes
x
=
maybe_dequantize
(
input_
.
contiguous
(),
ctx
.
dtype
)
dy
=
maybe_dequantize
(
grad_output
.
contiguous
(),
ctx
.
dtype
)
# Remove interleaving if needed
swiglu_in
=
x
if
self
.
glu_interleave_size
is
not
None
:
shape
=
swiglu_in
.
size
()
swiglu_in
=
swiglu_in
.
reshape
(
-
1
,
shape
[
-
1
]
//
(
2
*
self
.
glu_interleave_size
),
2
,
self
.
glu_interleave_size
,
)
swiglu_in
=
swiglu_in
.
transpose
(
1
,
2
).
contiguous
()
swiglu_in
=
swiglu_in
.
view
(
shape
)
# Quantizer for grad input
quantizer
=
ctx
.
prev_op_grad_output_quantizer
if
self
.
glu_interleave_size
is
not
None
:
quantizer
=
None
# Launch kernel
grad_swiglu_in
=
tex
.
dswiglu
(
dy
,
swiglu_in
,
quantizer
)
# Apply interleaving if needed
dx
=
grad_swiglu_in
if
self
.
glu_interleave_size
is
not
None
:
shape
=
dx
.
size
()
dx
=
dx
.
reshape
(
-
1
,
2
,
shape
[
-
1
]
//
(
2
*
self
.
glu_interleave_size
),
self
.
glu_interleave_size
,
)
dx
=
dx
.
transpose
(
1
,
2
).
contiguous
()
dx
=
dx
.
view
(
shape
)
# Clear input tensor if possible
clear_tensor_data
(
input_
)
return
dx
,
()
class
ClampedSwiGLU
(
BasicOperation
):
r
"""GPT-OSS
Implementation based on ``GPT-OSS<https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250>``__.
This activation has two differences compared to the original SwiGLU
1. Both gate and pre-activations are clipped based on parameter limit.
2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation.
.. warning:: The input tensor is chunked along the last dimension to get gates/pre-activations which is different
from GPT OSS implementation where the gates/pre-activations are assumed to be interleaved in the input tensor.
Parameters
----------
limit : float
The clamp limit.
alpha : float
The scaling factor for the sigmoid function used in the activation.
cache_quantized_input : bool, default = ``False``
Quantize input tensor when caching for use in the backward pass.
glu_interleave_size : int, optional
When set, the GLU activations will use an experimental block
interleaved format. See the corresponding option in the SwiGLU
operation for more details.
"""
def
__init__
(
self
,
*
,
limit
:
float
=
7.0
,
alpha
:
float
=
1.702
,
cache_quantized_input
:
bool
=
False
,
glu_interleave_size
:
Optional
[
int
]
=
None
,
):
super
().
__init__
()
self
.
limit
:
float
=
limit
self
.
alpha
:
float
=
alpha
self
.
cache_quantized_input
:
bool
=
cache_quantized_input
self
.
glu_interleave_size
:
Optional
[
int
]
=
glu_interleave_size
def
op_forward
(
self
,
ctx
:
OperationContext
,
input_
:
torch
.
Tensor
,
prev_op_grad_output_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
)
->
torch
.
Tensor
:
# Compute dtype
dtype
:
torch
.
dtype
if
torch
.
is_autocast_enabled
():
dtype
=
torch
.
get_autocast_dtype
(
"cuda"
)
else
:
dtype
=
input_
.
dtype
if
dtype
not
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
):
raise
RuntimeError
(
f
"Unsupported dtype (
{
dtype
}
)"
)
# Check input tensor
x
=
maybe_dequantize
(
input_
.
contiguous
(),
dtype
)
# Remove interleaving if needed
swiglu_in
=
input_
if
self
.
glu_interleave_size
is
not
None
:
shape
=
swiglu_in
.
size
()
swiglu_in
=
swiglu_in
.
reshape
(
-
1
,
shape
[
-
1
]
//
(
2
*
self
.
glu_interleave_size
),
2
,
self
.
glu_interleave_size
,
)
swiglu_in
=
swiglu_in
.
transpose
(
1
,
2
).
contiguous
()
swiglu_in
=
swiglu_in
.
view
(
shape
)
# Launch kernel
out
=
tex
.
clamped_swiglu
(
swiglu_in
,
next_op_input_quantizer
,
limit
=
self
.
limit
,
alpha
=
self
.
alpha
,
)
# Quantize input to FP8 before caching if needed
if
self
.
cache_quantized_input
:
input_quantizer
=
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E4M3
,
x
.
device
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
x
=
input_quantizer
(
x
)
# Save state for backward pass
if
ctx
.
requires_grad
:
if
is_cpu_offload_enabled
():
mark_activation_offload
(
x
)
ctx
.
save_for_backward
(
x
)
ctx
.
dtype
=
dtype
ctx
.
prev_op_grad_output_quantizer
=
prev_op_grad_output_quantizer
return
out
def
op_backward
(
self
,
ctx
:
OperationContext
,
grad_output
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
tuple
[()]]:
# Saved tensors from forward pass
(
input_
,)
=
ctx
.
saved_tensors
# Make sure tensors have correct dtypes
x
=
maybe_dequantize
(
input_
.
contiguous
(),
ctx
.
dtype
)
dy
=
maybe_dequantize
(
grad_output
.
contiguous
(),
ctx
.
dtype
)
# Remove interleaving if needed
swiglu_in
=
x
if
self
.
glu_interleave_size
is
not
None
:
shape
=
swiglu_in
.
size
()
swiglu_in
=
swiglu_in
.
reshape
(
-
1
,
shape
[
-
1
]
//
(
2
*
self
.
glu_interleave_size
),
2
,
self
.
glu_interleave_size
,
)
swiglu_in
=
swiglu_in
.
transpose
(
1
,
2
).
contiguous
()
swiglu_in
=
swiglu_in
.
view
(
shape
)
# Quantizer for grad input
quantizer
=
ctx
.
prev_op_grad_output_quantizer
if
self
.
glu_interleave_size
is
not
None
:
quantizer
=
None
# Launch kernel
grad_swiglu_in
=
tex
.
clamped_dswiglu
(
dy
,
swiglu_in
,
quantizer
,
limit
=
self
.
limit
,
alpha
=
self
.
alpha
,
)
# Apply interleaving if needed
dx
=
grad_swiglu_in
if
self
.
glu_interleave_size
is
not
None
:
shape
=
dx
.
size
()
dx
=
dx
.
reshape
(
-
1
,
2
,
shape
[
-
1
]
//
(
2
*
self
.
glu_interleave_size
),
self
.
glu_interleave_size
,
)
dx
=
dx
.
transpose
(
1
,
2
).
contiguous
()
dx
=
dx
.
view
(
shape
)
# Clear input tensor if possible
clear_tensor_data
(
input_
)
return
dx
,
()
class
ScaledSwiGLU
(
BasicOperation
):
r
"""SwiGLU with post-scaling.
If the SwiGLU output has shape ``(d_1, ..., d_n)``, it is
multiplied with an extra input tensor of shape
``(d_1, ..., d_{n-1})``.
Parameters
----------
glu_interleave_size : int, optional
When set, the GLU activations will use an experimental block
interleaved format. See the corresponding option in the SwiGLU
operation for more details.
"""
# Operation expects scales
num_extra_inputs
:
int
=
1
def
__init__
(
self
,
glu_interleave_size
:
Optional
[
int
]
=
None
):
super
().
__init__
()
self
.
glu_interleave_size
:
Optional
[
int
]
=
glu_interleave_size
def
op_forward
(
self
,
*
args
,
**
kwargs
)
->
None
:
raise
RuntimeError
(
f
"
{
self
.
__class__
.
__name__
}
operation has "
f
"
{
self
.
num_extra_inputs
}
extra tensor inputs "
f
"and
{
self
.
num_extra_outputs
}
extra tensor outputs. "
"It overrides `fuser_forward` instead of `op_forward`."
)
def
op_backward
(
self
,
*
args
,
**
kwargs
)
->
None
:
raise
RuntimeError
(
f
"
{
self
.
__class__
.
__name__
}
operation has "
f
"
{
self
.
num_extra_inputs
}
extra tensor inputs "
f
"and
{
self
.
num_extra_outputs
}
extra tensor outputs. "
"It overrides `fuser_backward` instead of `op_backward`."
)
def
fuser_forward
(
self
,
basic_op_ctxs
:
list
[
OperationContext
],
input_
:
torch
.
Tensor
,
*
,
basic_op_extra_inputs
:
list
[
tuple
[
torch
.
Tensor
,
...]],
prev_op_grad_output_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
basic_op_kwargs
:
list
[
dict
[
str
,
Any
]],
)
->
tuple
[
torch
.
Tensor
,
Iterable
[
Iterable
[
torch
.
Tensor
]]]:
extra_input
=
basic_op_extra_inputs
[
0
][
0
]
# Determine compute dtype
if
torch
.
is_autocast_enabled
():
dtype
=
torch
.
get_autocast_dtype
(
"cuda"
)
elif
isinstance
(
input_
,
torch
.
Tensor
):
dtype
=
input_
.
dtype
else
:
dtype
=
extra_input
.
dtype
# Make sure inputs are in correct dtype
input_
=
maybe_dequantize
(
input_
,
dtype
)
scales
=
maybe_dequantize
(
extra_input
,
dtype
)
# Remove gate interleaving if needed
swiglu_in
=
input_
if
self
.
glu_interleave_size
is
not
None
:
shape
=
swiglu_in
.
size
()
swiglu_in
=
swiglu_in
.
reshape
(
-
1
,
shape
[
-
1
]
//
(
2
*
self
.
glu_interleave_size
),
2
,
self
.
glu_interleave_size
,
)
swiglu_in
=
swiglu_in
.
transpose
(
1
,
2
).
contiguous
()
swiglu_in
=
swiglu_in
.
view
(
shape
)
# Compute scaled SwiGLU
swiglu_out
=
tex
.
swiglu
(
swiglu_in
,
None
)
out
=
swiglu_out
*
scales
.
unsqueeze
(
-
1
)
# Save state for backward pass
ctx
=
basic_op_ctxs
[
0
]
if
ctx
.
requires_grad
:
if
is_cpu_offload_enabled
():
mark_activation_offload
(
input_
)
ctx
.
input_requires_grad
=
True
ctx
.
extra_input_requires_grad
=
extra_input
.
requires_grad
ctx
.
dtype
=
dtype
ctx
.
save_for_backward
(
input_
,
scales
if
ctx
.
input_requires_grad
else
None
,
)
return
out
,
[()]
def
fuser_backward
(
self
,
basic_op_ctxs
:
list
[
OperationContext
],
grad_output
:
torch
.
Tensor
,
*
,
basic_op_grad_extra_outputs
:
list
[
tuple
[
torch
.
Tensor
,
...]],
)
->
tuple
[
torch
.
Tensor
,
Iterable
[
Iterable
[
Optional
[
torch
.
Tensor
]]],
Iterable
[
Iterable
[
Optional
[
torch
.
Tensor
]]],
]:
ctx
=
basic_op_ctxs
[
0
]
input_
,
scales
=
ctx
.
saved_tensors
input_
=
maybe_dequantize
(
input_
,
ctx
.
dtype
)
if
scales
is
not
None
:
scales
=
maybe_dequantize
(
scales
,
ctx
.
dtype
)
grad_output
=
maybe_dequantize
(
grad_output
,
ctx
.
dtype
)
# Remove gate interleaving if needed
swiglu_in
=
input_
if
self
.
glu_interleave_size
is
not
None
:
shape
=
swiglu_in
.
size
()
swiglu_in
=
swiglu_in
.
reshape
(
-
1
,
shape
[
-
1
]
//
(
2
*
self
.
glu_interleave_size
),
2
,
self
.
glu_interleave_size
,
)
swiglu_in
=
swiglu_in
.
transpose
(
1
,
2
).
contiguous
()
swiglu_in
=
swiglu_in
.
view
(
shape
)
# Compute input grad
grad_input
=
None
if
ctx
.
input_requires_grad
:
grad_swiglu_out
=
grad_output
*
scales
.
unsqueeze
(
-
1
)
grad_swiglu_in
=
tex
.
dswiglu
(
grad_swiglu_out
,
swiglu_in
,
None
)
grad_input
=
grad_swiglu_in
if
self
.
glu_interleave_size
is
not
None
:
shape
=
grad_input
.
size
()
grad_input
=
grad_input
.
reshape
(
-
1
,
2
,
shape
[
-
1
]
//
(
2
*
self
.
glu_interleave_size
),
self
.
glu_interleave_size
,
)
grad_input
=
grad_input
.
transpose
(
1
,
2
).
contiguous
()
grad_input
=
grad_input
.
view
(
shape
)
# Compute scales grad by recomputing SwiGLU
grad_extra_input
=
None
if
ctx
.
extra_input_requires_grad
:
swiglu_out
=
tex
.
swiglu
(
swiglu_in
,
None
)
grad_extra_input
=
torch
.
linalg
.
vecdot
(
swiglu_out
,
grad_output
)
# Clear input tensor if possible
clear_tensor_data
(
ctx
.
saved_tensors
[
0
])
# input_
return
grad_input
,
[()],
[(
grad_extra_input
,)]
transformer_engine/pytorch/ops/fused/__init__.py
View file @
9df0c4a3
...
...
@@ -4,39 +4,27 @@
"""Compound tensor operation supported by the operation fuser."""
from
.backward_activation_bias
import
(
BackwardActivationBias
,
fuse_backward_activation_bias
,
)
from
.backward_add_rmsnorm
import
(
BackwardAddRMSNorm
,
fuse_backward_add_rmsnorm
,
)
from
.backward_linear_add
import
(
BackwardLinearAdd
,
fuse_backward_linear_add
,
)
from
.backward_linear_scale
import
(
BackwardLinearScale
,
fuse_backward_linear_scale
,
)
from
.forward_linear_bias_activation
import
(
ForwardLinearBiasActivation
,
fuse_forward_linear_bias_activation
,
)
from
.forward_linear_bias_add
import
(
ForwardLinearBiasAdd
,
fuse_forward_linear_bias_add
,
)
from
.forward_linear_scale_add
import
(
ForwardLinearScaleAdd
,
fuse_forward_linear_scale_add
,
)
from
.userbuffers_backward_linear
import
(
UserbuffersBackwardLinear
,
fuse_userbuffers_backward_linear
,
)
from
.userbuffers_forward_linear
import
(
UserbuffersForwardLinear
,
fuse_userbuffers_forward_linear
,
)
from
..fuser
import
register_backward_fusion
,
register_forward_fusion
from
.backward_activation_bias
import
BackwardActivationBias
from
.backward_add_rmsnorm
import
BackwardAddRMSNorm
from
.backward_linear_add
import
BackwardLinearAdd
from
.backward_linear_scale
import
BackwardLinearScale
from
.forward_linear_bias_activation
import
ForwardLinearBiasActivation
from
.forward_linear_bias_add
import
ForwardLinearBiasAdd
from
.forward_linear_scale_add
import
ForwardLinearScaleAdd
from
.userbuffers_backward_linear
import
UserbuffersBackwardLinear
from
.userbuffers_forward_linear
import
UserbuffersForwardLinear
# Register forward fusions
register_forward_fusion
(
UserbuffersForwardLinear
.
fuse_forward_ops
)
register_forward_fusion
(
ForwardLinearBiasAdd
.
fuse_forward_ops
)
register_forward_fusion
(
ForwardLinearBiasActivation
.
fuse_forward_ops
)
register_forward_fusion
(
ForwardLinearScaleAdd
.
fuse_forward_ops
)
# Register backward fusions
register_backward_fusion
(
UserbuffersBackwardLinear
.
fuse_backward_ops
)
register_backward_fusion
(
BackwardLinearAdd
.
fuse_backward_ops
)
register_backward_fusion
(
BackwardLinearScale
.
fuse_backward_ops
)
register_backward_fusion
(
BackwardActivationBias
.
fuse_backward_ops
)
register_backward_fusion
(
BackwardAddRMSNorm
.
fuse_backward_ops
)
transformer_engine/pytorch/ops/fused/backward_activation_bias.py
View file @
9df0c4a3
...
...
@@ -53,8 +53,8 @@ class BackwardActivationBias(FusedOperation):
]:
# Get basic operation contexts
activation
_op_ctx
=
basic_op_ctxs
[
0
]
bias
_op_ctx
=
basic_op_ctxs
[
1
]
bias
_op_ctx
=
basic_op_ctxs
[
0
]
activation
_op_ctx
=
basic_op_ctxs
[
1
]
# Saved tensors from forward pass
(
act_input
,)
=
activation_op_ctx
.
saved_tensors
...
...
@@ -79,68 +79,59 @@ class BackwardActivationBias(FusedOperation):
# Clear activation input tensor
clear_tensor_data
(
act_input
)
return
dx
,
[(),
(
db
,
)],
[(),
()]
return
dx
,
[(
db
,
),
()],
[(),
()]
def
fuse_backward_activation_bias
(
ops
:
list
[
tuple
[
FusibleOperation
,
list
[
int
]]],
recipe
:
Optional
[
Recipe
],
)
->
list
[
tuple
[
FusibleOperation
,
list
[
int
]]]:
"""Fused backward dact + dbias + quantize
Parameters
----------
ops : list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
recipe : Recipe, optional
Used quantization recipe
Returns
-------
ops : list of tuples
Updated backward pass operations
"""
# Check if recipe supports bias activation fusion
if
recipe
is
None
:
return
ops
# Scan through ops, fusing if possible
out
=
[]
window
=
[]
while
len
(
ops
)
>=
3
:
@
staticmethod
def
fuse_backward_ops
(
ops
:
list
[
FusibleOperation
],
*
,
recipe
:
Optional
[
Recipe
]
=
None
,
**
unused
,
# pylint: disable=unused-argument
)
->
list
[
FusibleOperation
]:
"""Apply operation fusion for backward pass.
Parameters
----------
ops : list of FusibleOperation
Backward pass operations.
recipe : Recipe, optional
Quantization recipe.
Returns
-------
ops : list of FusibleOperation
Updated backward pass operations
"""
# Check if recipe supports bias activation fusion
if
recipe
is
None
:
return
ops
# Scan through ops, fusing if possible
out
=
[]
window
,
ops
=
ops
[:
3
],
ops
[
3
:]
while
len
(
window
)
==
3
:
if
(
isinstance
(
window
[
2
],
_fusible_activations
)
and
isinstance
(
window
[
1
],
Bias
)
and
window
[
0
].
get_grad_output_quantizer
()
is
not
None
):
# Construct fused op if window matches pattern
op
=
BackwardActivationBias
(
bias
=
window
[
1
],
activation
=
window
[
2
])
window
=
[
window
[
0
],
op
]
else
:
# Shift window if window doesn't match pattern
out
.
extend
(
window
[:
-
2
])
window
=
window
[
-
2
:]
# Adjust window to expected size
out
.
extend
(
window
[:
-
3
])
window
=
window
[
-
3
:]
while
ops
and
len
(
window
)
<
3
:
window
.
append
(
ops
[
0
])
ops
=
ops
[
1
:]
# Return list of ops
out
.
extend
(
window
)
# Check if first op is a supported activation
window
,
ops
=
ops
[:
1
],
ops
[
1
:]
op
,
_
=
window
[
0
]
if
not
isinstance
(
op
,
_fusible_activations
):
continue
# Check if second op is bias
op
,
_
=
ops
[
0
]
if
not
isinstance
(
op
,
Bias
):
continue
# Check if third op has a grad input quantizer
op
,
_
=
ops
[
1
]
if
not
op
.
num_quantizers
(
"backward"
)
>
0
:
continue
window
.
extend
(
ops
[:
1
])
ops
=
ops
[
1
:]
# Replace window with fused op
op
=
BackwardActivationBias
(
activation
=
window
[
0
][
0
],
bias
=
window
[
1
][
0
],
)
basic_op_idxs
=
[
basic_op_idxs
[
0
]
for
_
,
basic_op_idxs
in
window
]
window
=
[(
op
,
basic_op_idxs
)]
# Return list of ops
out
.
extend
(
window
)
out
.
extend
(
ops
)
return
out
return
out
Prev
1
…
6
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment