Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
27ddce40
Commit
27ddce40
authored
Oct 11, 2025
by
wenjh
Browse files
Merge branch 'nv_main'
parents
d262ef4c
5b3092a0
Changes
208
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1228 additions
and
352 deletions
+1228
-352
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py
...torch/attention/dot_product_attention/context_parallel.py
+365
-95
transformer_engine/pytorch/attention/dot_product_attention/utils.py
...r_engine/pytorch/attention/dot_product_attention/utils.py
+54
-16
transformer_engine/pytorch/attention/inference.py
transformer_engine/pytorch/attention/inference.py
+23
-5
transformer_engine/pytorch/attention/multi_head_attention.py
transformer_engine/pytorch/attention/multi_head_attention.py
+7
-17
transformer_engine/pytorch/attention/rope.py
transformer_engine/pytorch/attention/rope.py
+161
-2
transformer_engine/pytorch/cpu_offload.py
transformer_engine/pytorch/cpu_offload.py
+20
-10
transformer_engine/pytorch/csrc/extensions.h
transformer_engine/pytorch/csrc/extensions.h
+52
-13
transformer_engine/pytorch/csrc/extensions/activation.cpp
transformer_engine/pytorch/csrc/extensions/activation.cpp
+40
-22
transformer_engine/pytorch/csrc/extensions/apply_rope.cpp
transformer_engine/pytorch/csrc/extensions/apply_rope.cpp
+99
-1
transformer_engine/pytorch/csrc/extensions/cast.cpp
transformer_engine/pytorch/csrc/extensions/cast.cpp
+4
-10
transformer_engine/pytorch/csrc/extensions/dropout.cpp
transformer_engine/pytorch/csrc/extensions/dropout.cpp
+89
-0
transformer_engine/pytorch/csrc/extensions/gemm.cpp
transformer_engine/pytorch/csrc/extensions/gemm.cpp
+47
-30
transformer_engine/pytorch/csrc/extensions/normalization.cpp
transformer_engine/pytorch/csrc/extensions/normalization.cpp
+54
-4
transformer_engine/pytorch/csrc/extensions/pybind.cpp
transformer_engine/pytorch/csrc/extensions/pybind.cpp
+38
-10
transformer_engine/pytorch/csrc/quantizer.cpp
transformer_engine/pytorch/csrc/quantizer.cpp
+4
-48
transformer_engine/pytorch/graph.py
transformer_engine/pytorch/graph.py
+45
-15
transformer_engine/pytorch/module/__init__.py
transformer_engine/pytorch/module/__init__.py
+1
-1
transformer_engine/pytorch/module/base.py
transformer_engine/pytorch/module/base.py
+90
-36
transformer_engine/pytorch/module/grouped_linear.py
transformer_engine/pytorch/module/grouped_linear.py
+2
-3
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+33
-14
No files found.
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py
View file @
27ddce40
...
...
@@ -4,7 +4,7 @@
"""Context Parallelism."""
import
os
from
typing
import
List
,
Union
from
typing
import
List
,
Union
,
Tuple
import
torch
import
transformer_engine_torch
as
tex
...
...
@@ -358,7 +358,7 @@ def get_fa_args(
max_seqlen_q
,
max_seqlen_kv
,
*
[
None
]
*
8
,
# page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale
*
9
,
# page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin,
seqlens_rotary,
q_descale, k_descale, v_descale
]
return
[
*
[
None
]
...
...
@@ -366,7 +366,7 @@ def get_fa_args(
max_seqlen_q
,
max_seqlen_kv
,
*
[
None
]
*
8
,
# page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale
*
9
,
# page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin,
seqlens_rotary,
q_descale, k_descale, v_descale
]
if
qkv_format
==
"thd"
:
return
[
...
...
@@ -829,6 +829,19 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step
[
i
],
rng_states
[
i
],
*
rest
=
aux_ctx_tensors
attn_biases
[
i
]
=
rest
[
0
]
if
len
(
rest
)
>
0
else
None
else
:
if
not
enable_mla
:
# If MHA, then split the KV into k_part and v_part.
# Otherwise (MHA), k_part and v_part have already been split.
k_part
=
(
kv_inputs
[
i
%
2
][...,
0
,
:,
:]
if
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_inputs
[
i
%
2
][
0
]
)
v_part
=
(
kv_inputs
[
i
%
2
][...,
1
,
:,
:]
if
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_inputs
[
i
%
2
][
1
]
)
fa_forward_args_thd
=
get_fa_args
(
True
,
use_flash_attn_3
,
...
...
@@ -838,19 +851,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_kv
=
max_seqlen_kv
,
)
# Need to add MLA support once Flash Attention supports MLA
fa_outputs
=
flash_attn_fwd
(
q_inputs
[
i
%
2
],
(
kv_inputs
[
i
%
2
][...,
0
,
:,
:]
if
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_inputs
[
i
%
2
][
0
]
),
(
kv_inputs
[
i
%
2
][...,
1
,
:,
:]
if
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_inputs
[
i
%
2
][
1
]
),
k_part
,
v_part
,
*
fa_forward_args_thd
,
causal
=
True
,
**
fa_forward_kwargs
,
...
...
@@ -985,6 +989,22 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step
[
i
],
rng_states
[
i
],
*
rest
=
aux_ctx_tensors
attn_biases
[
i
]
=
rest
[
0
]
if
len
(
rest
)
>
0
else
None
else
:
if
enable_mla
:
k_part
=
k_part
.
contiguous
()
v_part
=
v_part
.
contiguous
()
else
:
# If MHA, then split the KV into k_part and v_part.
# Otherwise (MHA), k_part and v_part have already been split.
k_part
=
(
kv_inputs
[
i
%
2
][...,
0
,
:,
:]
if
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_inputs
[
i
%
2
][
0
]
)
v_part
=
(
kv_inputs
[
i
%
2
][...,
1
,
:,
:]
if
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_inputs
[
i
%
2
][
1
]
)
fa_forward_args_thd
=
get_fa_args
(
True
,
use_flash_attn_3
,
...
...
@@ -1001,19 +1021,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
elif
fa_utils
.
v2_7_0_plus
:
fa_forward_kwargs
[
"window_size_left"
]
=
-
1
fa_forward_kwargs
[
"window_size_right"
]
=
-
1
# Need to add MLA support once Flash Attention supports MLA
fa_outputs
=
flash_attn_fwd
(
q_inputs
[
i
%
2
],
(
kv_inputs
[
i
%
2
][...,
0
,
:,
:]
if
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_inputs
[
i
%
2
][
0
]
),
(
kv_inputs
[
i
%
2
][...,
1
,
:,
:]
if
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_inputs
[
i
%
2
][
1
]
),
k_part
,
v_part
,
*
fa_forward_args_thd
,
causal
=
False
,
**
fa_forward_kwargs
,
...
...
@@ -1144,6 +1155,19 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step
[
i
],
rng_states
[
i
],
*
rest
=
aux_ctx_tensors
attn_biases
[
i
]
=
rest
[
0
]
if
len
(
rest
)
>
0
else
None
else
:
if
not
enable_mla
:
# If MHA, then split the KV into k_part and v_part.
# Otherwise (MHA), k_part and v_part have already been split.
k_part
=
(
kv_inputs
[
i
%
2
][...,
0
,
:,
:]
if
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_inputs
[
i
%
2
][
0
]
)
v_part
=
(
kv_inputs
[
i
%
2
][...,
1
,
:,
:]
if
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_inputs
[
i
%
2
][
1
]
)
fa_forward_args_thd
=
get_fa_args
(
True
,
use_flash_attn_3
,
...
...
@@ -1160,19 +1184,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
elif
fa_utils
.
v2_7_0_plus
:
fa_forward_kwargs
[
"window_size_left"
]
=
-
1
fa_forward_kwargs
[
"window_size_right"
]
=
-
1
# Need to add MLA support once Flash Attention supports MLA
fa_outputs
=
flash_attn_fwd
(
q_inputs
[
i
%
2
],
(
kv_inputs
[
i
%
2
][...,
0
,
:,
:]
if
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_inputs
[
i
%
2
][
0
]
),
(
kv_inputs
[
i
%
2
][...,
1
,
:,
:]
if
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_inputs
[
i
%
2
][
1
]
),
k_part
,
v_part
,
*
fa_forward_args_thd
,
causal
=
False
,
**
fa_forward_kwargs
,
...
...
@@ -1269,6 +1284,19 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step
[
i
],
rng_states
[
i
],
*
rest
=
aux_ctx_tensors
attn_biases
[
i
]
=
rest
[
0
]
if
len
(
rest
)
>
0
else
None
else
:
if
not
enable_mla
:
# If MHA, then split the KV into k_part and v_part.
# Otherwise (MHA), k_part and v_part have already been split.
k_part
=
(
kv_inputs
[
i
%
2
][...,
0
,
:,
:]
if
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_inputs
[
i
%
2
][
0
]
)
v_part
=
(
kv_inputs
[
i
%
2
][...,
1
,
:,
:]
if
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_inputs
[
i
%
2
][
1
]
)
fa_forward_args_thd
=
get_fa_args
(
True
,
use_flash_attn_3
,
...
...
@@ -1278,19 +1306,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_kv
=
max_seqlen_kv
,
)
# Need to add MLA support once Flash Attention supports MLA
fa_outputs
=
flash_attn_fwd
(
q
,
(
kv_inputs
[
i
%
2
][...,
0
,
:,
:]
if
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_inputs
[
i
%
2
][
0
]
),
(
kv_inputs
[
i
%
2
][...,
1
,
:,
:]
if
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_inputs
[
i
%
2
][
1
]
),
k_part
,
v_part
,
*
fa_forward_args_thd
,
causal
=
False
,
**
fa_forward_kwargs
,
...
...
@@ -1865,7 +1884,27 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dv_
=
dv_
.
_data
else
:
dq_
=
torch
.
empty_like
(
q_
)
dkv_
=
torch
.
empty_like
(
kv_
)
if
ctx
.
enable_mla
:
dk_
=
torch
.
empty_like
(
k_part
)
dv_
=
torch
.
empty_like
(
v_part
)
else
:
k_part
=
(
kv_
[...,
0
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_
[
0
]
)
v_part
=
(
kv_
[...,
1
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_
[
1
]
)
dkv_
=
torch
.
empty_like
(
kv_
)
dk_
=
(
dkv_
[...,
0
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
dkv_
[
0
]
)
dv_
=
(
dkv_
[...,
1
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
dkv_
[
1
]
)
fa_backward_args_thd
=
get_fa_args
(
False
,
ctx
.
use_flash_attn_3
,
...
...
@@ -1875,16 +1914,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
max_seqlen_q
=
ctx
.
max_seqlen_q
,
max_seqlen_kv
=
ctx
.
max_seqlen_kv
,
dq
=
dq_
,
dk
=
(
dkv_
[...,
0
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
dkv_
[
0
]
),
dv
=
(
dkv_
[...,
1
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
dkv_
[
1
]
),
dk
=
dk_
,
dv
=
dv_
,
)
if
ctx
.
use_flash_attn_3
or
(
fa_utils
.
v2_3_plus
and
not
fa_utils
.
v2_7_0_plus
...
...
@@ -1895,12 +1926,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fa_backward_kwargs
[
"window_size_right"
]
=
0
if
not
ctx
.
use_flash_attn_3
:
fa_backward_kwargs
[
"rng_state"
]
=
rng_states
[
cp_size
-
i
-
1
]
# Need to add MLA support once Flash Attention supports MLA
flash_attn_bwd
(
dout_
,
q_
,
k
v_
[...,
0
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_
[
0
]
,
k
v_
[...,
1
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_
[
1
]
,
k
_part
,
v_
part
,
out_
,
softmax_lse
,
*
fa_backward_args_thd
,
...
...
@@ -2016,7 +2046,29 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dv_
=
dv_
.
_data
else
:
dq_
=
torch
.
empty_like
(
q_
)
dkv_
=
torch
.
empty_like
(
kv_
)
if
ctx
.
enable_mla
:
k_part
=
k_part
.
contiguous
()
v_part
=
v_part
.
contiguous
()
dk_
=
torch
.
empty_like
(
k_part
)
dv_
=
torch
.
empty_like
(
v_part
)
else
:
k_part
=
(
kv_
[...,
0
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_
[
0
]
)
v_part
=
(
kv_
[...,
1
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_
[
1
]
)
dkv_
=
torch
.
empty_like
(
kv_
)
dk_
=
(
dkv_
[...,
0
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
dkv_
[
0
]
)
dv_
=
(
dkv_
[...,
1
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
dkv_
[
1
]
)
fa_backward_args_thd
=
get_fa_args
(
False
,
ctx
.
use_flash_attn_3
,
...
...
@@ -2026,16 +2078,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
max_seqlen_q
=
ctx
.
max_seqlen_q
,
max_seqlen_kv
=
ctx
.
max_seqlen_kv
//
2
,
dq
=
dq_
,
dk
=
(
dkv_
[...,
0
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
dkv_
[
0
]
),
dv
=
(
dkv_
[...,
1
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
dkv_
[
1
]
),
dk
=
dk_
,
dv
=
dv_
,
)
if
ctx
.
use_flash_attn_3
or
(
fa_utils
.
v2_3_plus
and
not
fa_utils
.
v2_7_0_plus
...
...
@@ -2046,12 +2090,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fa_backward_kwargs
[
"window_size_right"
]
=
-
1
if
not
ctx
.
use_flash_attn_3
:
fa_backward_kwargs
[
"rng_state"
]
=
rng_states
[
cp_size
-
i
-
1
]
# Need to add MLA support once Flash Attention supports MLA
flash_attn_bwd
(
dout_
,
q_
,
k
v_
[...,
0
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_
[
0
]
,
k
v_
[...,
1
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_
[
1
]
,
k
_part
,
v_
part
,
out_
,
softmax_lse
,
*
fa_backward_args_thd
,
...
...
@@ -2160,7 +2203,27 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dv_
=
dv_
.
_data
else
:
dq_
=
torch
.
empty_like
(
q_
)
dkv_
=
torch
.
empty_like
(
kv_
)
if
ctx
.
enable_mla
:
dk_
=
torch
.
empty_like
(
k_part
)
dv_
=
torch
.
empty_like
(
v_part
)
else
:
k_part
=
(
kv_
[...,
0
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_
[
0
]
)
v_part
=
(
kv_
[...,
1
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_
[
1
]
)
dkv_
=
torch
.
empty_like
(
kv_
)
dk_
=
(
dkv_
[...,
0
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
dkv_
[
0
]
)
dv_
=
(
dkv_
[...,
1
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
dkv_
[
1
]
)
fa_backward_args_thd
=
get_fa_args
(
False
,
ctx
.
use_flash_attn_3
,
...
...
@@ -2170,16 +2233,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
max_seqlen_q
=
ctx
.
max_seqlen_q
//
2
,
max_seqlen_kv
=
ctx
.
max_seqlen_kv
,
dq
=
dq_
,
dk
=
(
dkv_
[...,
0
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
dkv_
[
0
]
),
dv
=
(
dkv_
[...,
1
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
dkv_
[
1
]
),
dk
=
dk_
,
dv
=
dv_
,
)
if
ctx
.
use_flash_attn_3
or
(
fa_utils
.
v2_3_plus
and
not
fa_utils
.
v2_7_0_plus
...
...
@@ -2190,12 +2245,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fa_backward_kwargs
[
"window_size_right"
]
=
-
1
if
not
ctx
.
use_flash_attn_3
:
fa_backward_kwargs
[
"rng_state"
]
=
rng_states
[
cp_size
-
i
-
1
]
# Need to add MLA support once Flash Attention supports MLA
flash_attn_bwd
(
dout_
,
q_
,
k
v_
[...,
0
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_
[
0
]
,
k
v_
[...,
1
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_
[
1
]
,
k
_part
,
v_
part
,
out_
,
softmax_lse_
,
*
fa_backward_args_thd
,
...
...
@@ -2267,7 +2321,15 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else
:
dq_
=
torch
.
empty_like
(
q
)
dkv_
=
torch
.
empty_like
(
kv
)
if
ctx
.
enable_mla
:
dk_
=
torch
.
empty_like
(
k_part
)
dv_
=
torch
.
empty_like
(
v_part
)
else
:
k_part
=
kv
[...,
0
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv
[
0
]
v_part
=
kv
[...,
1
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv
[
1
]
dkv_
=
torch
.
empty_like
(
kv
)
dk_
=
dkv_
[...,
0
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
dkv_
[
0
]
dv_
=
dkv_
[...,
1
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
dkv_
[
1
]
fa_backward_args_thd
=
get_fa_args
(
False
,
ctx
.
use_flash_attn_3
,
...
...
@@ -2277,8 +2339,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
max_seqlen_q
=
ctx
.
max_seqlen_q
,
max_seqlen_kv
=
ctx
.
max_seqlen_kv
,
dq
=
dq_
,
dk
=
dk
v_
[...,
0
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
dkv_
[
0
]
,
dv
=
d
k
v_
[...,
1
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
dkv_
[
1
]
,
dk
=
dk
_
,
dv
=
dv_
,
)
if
ctx
.
use_flash_attn_3
or
(
fa_utils
.
v2_3_plus
and
not
fa_utils
.
v2_7_0_plus
):
fa_backward_kwargs
[
"window_size"
]
=
(
-
1
,
-
1
)
...
...
@@ -2287,12 +2349,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fa_backward_kwargs
[
"window_size_right"
]
=
-
1
if
not
ctx
.
use_flash_attn_3
:
fa_backward_kwargs
[
"rng_state"
]
=
rng_states
[
cp_size
-
i
-
1
]
# Need to add MLA support once Flash Attention supports MLA
flash_attn_bwd
(
dout
,
q
,
k
v
[...,
0
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv
[
0
]
,
kv
[...,
1
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv
[
1
]
,
k
_part
,
v_part
,
out
,
softmax_lse
,
*
fa_backward_args_thd
,
...
...
@@ -3927,3 +3988,212 @@ def attn_forward_func_with_cp(
raise
ValueError
(
f
"Unsupported communication type:
{
cp_comm_type
}
!"
)
return
out
def
pad_thd_sequences_for_cp
(
input_ids
:
torch
.
Tensor
,
labels
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
divisibility_factor
:
int
,
padding_token_id
:
int
=
0
,
padding_label_id
:
int
=
-
100
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Pads sequences to be divisible by the divisibility factor.
Args:
input_ids: Tensor of shape (1, N) or (N,) containing concatenated sequences
labels: Tensor of shape (1, N) or (N,) containing labels for each token
cu_seqlens: Tensor of shape (M,) containing cumulative sequence lengths
divisibility_factor: Each sequence length must be divisible by this factor
padding_token_id: Token ID to use for padding (default: 0)
padding_label_id: Label ID to use for padding (default: -100)
Returns:
Tuple of:
- input_ids_padded: Padded input_ids tensor
- labels_padded: Padded labels tensor
- cu_seqlens_padded: Cumulative sequence lengths accounting for padding
"""
# Flatten input_ids and labels if needed
if
input_ids
.
dim
()
==
2
:
input_ids
=
input_ids
.
squeeze
(
0
)
if
labels
.
dim
()
==
2
:
labels
=
labels
.
squeeze
(
0
)
# Compute the sequence lengths from cu_seqlens
seqlens
=
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]
# List: amount of padding needed for each sequence (make length a multiple of divisibility_factor)
padding_amounts
=
[
((
l
.
item
()
+
divisibility_factor
-
1
)
//
divisibility_factor
)
*
divisibility_factor
-
l
.
item
()
for
l
in
seqlens
]
# Extract sequences and labels for each batch item
batch_sequences
=
[
input_ids
[
start
.
item
()
:
end
.
item
()]
for
start
,
end
in
zip
(
cu_seqlens
[:
-
1
],
cu_seqlens
[
1
:])
]
batch_labels
=
[
labels
[
start
.
item
()
:
end
.
item
()]
for
start
,
end
in
zip
(
cu_seqlens
[:
-
1
],
cu_seqlens
[
1
:])
]
# Pad sequences and labels to required length
input_ids_padded
=
torch
.
cat
(
[
(
torch
.
cat
([
seq
,
torch
.
full
((
pad
,),
padding_token_id
,
dtype
=
seq
.
dtype
)])
if
pad
>
0
else
seq
)
for
seq
,
pad
in
zip
(
batch_sequences
,
padding_amounts
)
]
)
labels_padded
=
torch
.
cat
(
[
(
torch
.
cat
([
seq
,
torch
.
full
((
pad
,),
padding_label_id
,
dtype
=
seq
.
dtype
)])
if
pad
>
0
else
seq
)
for
seq
,
pad
in
zip
(
batch_labels
,
padding_amounts
)
]
)
# Compute cumulative padded sequence lengths, starting from 0
padded_lengths
=
seqlens
+
torch
.
tensor
(
padding_amounts
,
dtype
=
seqlens
.
dtype
)
cu_seqlens_padded
=
torch
.
cumsum
(
torch
.
cat
([
torch
.
tensor
([
0
],
dtype
=
cu_seqlens
.
dtype
),
padded_lengths
]),
dim
=
0
)
return
input_ids_padded
,
labels_padded
,
cu_seqlens_padded
def
generate_positional_ids_for_cp
(
cu_seqlens
:
torch
.
Tensor
,
divisibility_factor
:
int
,
dtype
:
torch
.
dtype
=
torch
.
long
,
)
->
torch
.
Tensor
:
"""Generate positional IDs for sequences padded to be divisible by divisibility_factor.
Args:
cu_seqlens: Tensor of shape (M,) containing cumulative sequence lengths
divisibility_factor: Each sequence length must be divisible by this factor
dtype: Data type for the generated positional IDs (default: torch.long)
Returns:
Generated positional_ids tensor where each sequence starts from 0 and continues through padding
"""
# Compute the sequence lengths from cu_seqlens
seqlens
=
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]
# List: amount of padding needed for each sequence
padding_amounts
=
[
((
l
.
item
()
+
divisibility_factor
-
1
)
//
divisibility_factor
)
*
divisibility_factor
-
l
.
item
()
for
l
in
seqlens
]
# Generate positional IDs for each padded sequence (each starts from 0)
padded_lengths
=
seqlens
+
torch
.
tensor
(
padding_amounts
,
dtype
=
seqlens
.
dtype
)
positional_ids
=
torch
.
cat
(
[
torch
.
arange
(
0
,
int
(
length
),
dtype
=
dtype
)
for
length
in
padded_lengths
]
)
return
positional_ids
def
get_batch_on_this_cp_rank
(
cu_seqlens_padded
:
torch
.
Tensor
,
input_ids_padded
:
torch
.
Tensor
,
labels_padded
:
torch
.
Tensor
,
position_ids_padded
:
torch
.
Tensor
,
cp_group
:
torch
.
distributed
.
ProcessGroup
=
None
,
qvk_format
:
str
=
"thd"
,
):
"""Slice batch input along sequence dimension into multiple chunks for THD format.
This function is inteded for use in self attention. It will not work for cross attention because
it does not handle the case where the sequence length of the query and key are different.
Which are parallelized across GPUs in a context parallel group.
This version works with variable-length sequences using cumulative sequence lengths.
"""
if
qvk_format
not
in
[
"thd"
,
"bshd"
,
"sbhd"
]:
raise
ValueError
(
f
"Unsupported qvk_format:
{
qvk_format
}
!"
)
if
qvk_format
==
"thd"
:
# Get context parallel size and rank
cp_size
=
torch
.
distributed
.
get_world_size
(
group
=
cp_group
)
if
cp_size
>
1
:
cp_rank
=
torch
.
distributed
.
get_rank
(
group
=
cp_group
)
# Calculate the chunk sizes for each sequence
total_slices_of_any_sequence
=
2
*
cp_size
slice_sizes
=
(
cu_seqlens_padded
[
1
:]
-
cu_seqlens_padded
[:
-
1
]
)
//
total_slices_of_any_sequence
# Process each tensor directly instead of using keys_to_change loop
def
process_tensor
(
val
):
if
val
is
None
:
return
val
# Determine which dimension is the sequence dimension
# Ensure cu_seqlens_padded[-1] is a Python int, not a 0-dim tensor
if
isinstance
(
cu_seqlens_padded
[
-
1
],
torch
.
Tensor
):
seq_len_val
=
cu_seqlens_padded
[
-
1
].
item
()
else
:
seq_len_val
=
cu_seqlens_padded
[
-
1
]
# Handle 1D tensors (like position_ids that don't have batch dimension)
if
val
.
ndim
==
1
:
if
val
.
shape
[
0
]
==
seq_len_val
:
current_seq_dim
=
0
else
:
raise
ValueError
(
"1D tensor shape doesn't match expected sequence length. Make sure the"
" inputs are in THD format and padded correctly."
)
elif
val
.
ndim
>=
2
:
if
val
.
shape
[
1
]
==
seq_len_val
:
current_seq_dim
=
1
elif
val
.
shape
[
0
]
==
seq_len_val
:
current_seq_dim
=
0
else
:
raise
ValueError
(
"Make sure the inputs are in THD format and padded correctly."
)
else
:
raise
ValueError
(
"Tensor must be at least 1D"
)
# On this particular rank, for each sequence, get two slices, one from the beginning
# and one from the end.
cp_rank_slices
=
[]
for
slice_size
,
seq_start
in
zip
(
slice_sizes
,
cu_seqlens_padded
[:
-
1
]):
# 1st segment
cp_rank_slices
.
append
(
torch
.
arange
(
seq_start
+
(
cp_rank
*
slice_size
),
seq_start
+
((
cp_rank
+
1
)
*
slice_size
),
device
=
val
.
device
,
)
)
# 2nd segment
cp_rank_slices
.
append
(
torch
.
arange
(
seq_start
+
((
total_slices_of_any_sequence
-
cp_rank
-
1
)
*
slice_size
),
seq_start
+
((
total_slices_of_any_sequence
-
cp_rank
)
*
slice_size
),
device
=
val
.
device
,
)
)
return
val
.
index_select
(
current_seq_dim
,
torch
.
cat
(
cp_rank_slices
))
# Process each tensor directly
input_ids_padded
=
process_tensor
(
input_ids_padded
)
labels_padded
=
process_tensor
(
labels_padded
)
position_ids_padded
=
process_tensor
(
position_ids_padded
)
else
:
raise
ValueError
(
f
"Support not implemented yet for qvk_format:
{
qvk_format
}
!"
)
return
input_ids_padded
,
labels_padded
,
position_ids_padded
transformer_engine/pytorch/attention/dot_product_attention/utils.py
View file @
27ddce40
...
...
@@ -126,10 +126,10 @@ class FlashAttentionUtils:
# Please follow these instructions to install FA3
v3_installation_steps
=
"""
\
(1) git clone https://github.com/Dao-AILab/flash-attention.git
(2) cd flash-attention/ && git checkout
27f501d
&& cd hopper/ && python setup.py install
(2) cd flash-attention/ && git checkout
3ba6f82 && git submodule update --init
&& cd hopper/ && python setup.py install
(3) python_path=`python -c "import site; print(site.getsitepackages()[0])"`
(4) mkdir -p $python_path/flash_attn_3
(5)
wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab
/flash
-
att
ention/27f501dbe011f4371bff938fe7e09311ab3002fa/hopper
/flash_attn_interface.py"""
(5)
cp flash_attn_interface.py $python_path
/flash
_
att
n_3
/flash_attn_interface.py"""
v3_warning_printed
=
False
@
staticmethod
...
...
@@ -438,8 +438,10 @@ def get_attention_backend(
# | FP8 | non-paged/paged | sm90 | thd | >= 1
# Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1
if
inference_params
is
not
None
:
if
device_compute_capability
==
(
8
,
9
)
and
cudnn_version
<=
(
9
,
12
,
0
):
logger
.
debug
(
"Disabling FusedAttention for KV caching for sm89 and cuDNN <= 9.12"
)
# Temporarily disabling fused attention for kv caching for sm89 irrespective of cuDNN version
# until the cuDNN bug is resolved
if
device_compute_capability
==
(
8
,
9
):
logger
.
debug
(
"Disabling FusedAttention for KV caching for sm89"
)
use_fused_attention
=
False
if
context_parallel
:
logger
.
debug
(
"Disabling all backends for KV caching with context parallelism"
)
...
...
@@ -482,11 +484,10 @@ def get_attention_backend(
# Filter: Head dimension
if
not
IS_HIP_EXTENSION
:
if
head_dim_qk
!=
head_dim_v
:
if
(
use_flash_attention_2
and
FlashAttentionUtils
.
is_installed
)
or
(
use_flash_attention_3
and
FlashAttentionUtils
.
v3_is_installed
):
logger
.
debug
(
"Disabling FlashAttention as it does not support MLA."
)
use_flash_attention
=
False
if
use_flash_attention_2
and
FlashAttentionUtils
.
is_installed
:
logger
.
debug
(
"Disabling FlashAttention 2 as it does not support MLA."
)
use_flash_attention_2
=
False
qkv_layout_group
=
qkv_layout
.
replace
(
"b"
,
""
).
replace
(
"s"
,
""
).
replace
(
"t"
,
""
)
if
use_fused_attention
and
qkv_layout_group
!=
"hd_hd_hd"
:
logger
.
debug
(
...
...
@@ -518,10 +519,41 @@ def get_attention_backend(
"."
.
join
([
str
(
i
)
for
i
in
device_compute_capability
]),
)
use_flash_attention_2
=
False
if
use_flash_attention_3
and
(
head_dim_qk
>
128
or
head_dim_v
>
128
):
if
FlashAttentionUtils
.
v3_is_installed
:
logger
.
debug
(
"Disabling FlashAttention 3 for head_dim > 128"
)
use_flash_attention_3
=
False
if
use_flash_attention_3
:
def
_is_fa3_supported
(
num_heads
,
num_gqa_groups
,
head_dim_qk
,
head_dim_v
,
qkv_dtype
):
if
head_dim_qk
>
256
or
num_heads
%
num_gqa_groups
!=
0
:
return
False
if
head_dim_qk
!=
head_dim_v
:
cond1
=
128
<
head_dim_qk
<=
192
cond2
=
96
<
head_dim_v
<=
128
cond3
=
head_dim_qk
<=
64
and
head_dim_v
<=
512
if
not
((
cond1
and
cond2
)
or
cond3
):
return
False
if
head_dim_v
>
256
and
qkv_dtype
not
in
(
torch
.
bfloat16
,
torch
.
float16
):
return
False
return
True
if
not
_is_fa3_supported
(
num_heads
,
num_gqa_groups
,
head_dim_qk
,
head_dim_v
,
qkv_dtype
):
if
FlashAttentionUtils
.
v3_is_installed
:
logger
.
debug
(
"Disabling FlashAttention 3 due to unsupported num_heads, num_gqa_groups, "
"head_dim_qk, head_dim_v or qkv_dtype. "
"Supported: head_dim_qk <= 256, and num_heads %% num_gqa_groups = 0, and "
"if head_dim_qk is different from head_dim_v, then "
"(head_dim_qk must in (128, 192] and head_dim_v in (96, 128]) or "
"(head_dim_qk <= 64 and head_dim_v <= 512), and "
"if head_dim_qk is different from head_dim_v and head_dim_v > 256, then "
"qkv_dtype requires fp16 and bf16 data type. "
"Found: num_heads = %s, num_gqa_groups = %s, "
"head_dim_qk = %s, head_dim_v = %s and qkv_dtype = %s."
,
num_heads
,
num_gqa_groups
,
head_dim_qk
,
head_dim_v
,
qkv_dtype
,
)
use_flash_attention_3
=
False
# Filter: QKV layout
if
qkv_format
==
"thd"
:
...
...
@@ -838,7 +870,7 @@ def get_attention_backend(
# flash-attn >=2.4.1 | yes
# FusedAttention |
# sub-backend 0 | yes
# sub-backend 1 | workspace optimization path and sm90
+
: yes;
# sub-backend 1 | workspace optimization path and sm90: yes;
# | otherwise: no
# sub-backend 2 | no
# UnfusedDotProductAttention | yes
...
...
@@ -854,8 +886,9 @@ def get_attention_backend(
use_flash_attention_2
=
False
if
use_fused_attention
and
deterministic
:
if
fused_attention_backend
==
FusedAttnBackend
[
"FP8"
]
and
is_training
:
logger
.
debug
(
"Disabling FusedAttention for determinism reasons"
)
logger
.
debug
(
"Disabling FusedAttention for determinism reasons
with FP8
"
)
use_fused_attention
=
False
fused_attention_backend
=
None
if
(
fused_attention_backend
==
FusedAttnBackend
[
"F16_arbitrary_seqlen"
]
and
is_training
...
...
@@ -865,8 +898,13 @@ def get_attention_backend(
or
cudnn_version
<
(
8
,
9
,
5
)
)
):
logger
.
debug
(
"Disabling FusedAttention for determinism reasons"
)
logger
.
debug
(
"Disabling FusedAttention for determinism reasons
with post_scale_bias
"
)
use_fused_attention
=
False
fused_attention_backend
=
None
if
is_training
and
device_compute_capability
>=
(
10
,
0
)
and
cudnn_version
<=
(
9
,
14
,
0
):
logger
.
debug
(
"Disabling FusedAttention for determinism reasons on Blackwell"
)
use_fused_attention
=
False
fused_attention_backend
=
None
# use_flash_attention may have been set above
use_flash_attention_2
=
use_flash_attention
and
use_flash_attention_2
...
...
transformer_engine/pytorch/attention/inference.py
View file @
27ddce40
...
...
@@ -215,6 +215,17 @@ class InferenceParams:
device
=
torch
.
cuda
.
current_device
(),
)
# This internal buffer holds the running length of each
# unfinished sequence in the batch and is updated in `pre_step()`
# method. One use of this buffer is applying RoPE to q and k tensors
# during inference by slicing ROPE Embeddings according to the
# current sequence length window.
self
.
pre_step_seqlens
=
torch
.
zeros
(
self
.
max_batch_size
,
dtype
=
torch
.
int32
,
device
=
torch
.
cuda
.
current_device
(),
)
def
reset
(
self
):
"""Reset InferenceParams state"""
self
.
sequences
=
OrderedDict
()
...
...
@@ -266,6 +277,15 @@ class InferenceParams:
for
k
,
v
in
self
.
sequences
.
items
():
self
.
sequences_pre_step
[
k
]
=
v
-
step_dict
[
k
]
pre_step_seqlens_temp
=
torch
.
Tensor
(
list
(
self
.
sequences_pre_step
.
values
())).
to
(
dtype
=
torch
.
int32
,
device
=
"cpu"
)
# Copy the pre-step seqlens to the device in CUDA Graphs safe manner.
self
.
pre_step_seqlens
[:
len
(
pre_step_seqlens_temp
)].
copy_
(
pre_step_seqlens_temp
,
non_blocking
=
False
)
seqlens_q
=
list
(
step_dict
.
values
())
cu_seqlens_q
=
[
0
]
+
[
sum
(
seqlens_q
[:
i
])
for
i
in
range
(
1
,
self
.
batch_size
+
1
)]
cu_seqlens_q
=
cu_seqlens_q
+
[
cu_seqlens_q
[
-
1
]]
*
(
self
.
max_batch_size
-
self
.
batch_size
)
...
...
@@ -280,9 +300,7 @@ class InferenceParams:
def
get_seqlens_pre_step
(
self
):
"""Get cached sequence lengths before the stepping"""
return
torch
.
Tensor
(
list
(
self
.
sequences_pre_step
.
values
())).
to
(
dtype
=
torch
.
int32
,
device
=
"cpu"
)
return
self
.
pre_step_seqlens
def
convert_paged_to_nonpaged
(
self
,
layer_number
:
int
):
"""
...
...
@@ -458,14 +476,14 @@ class NonPagedKVCacheManager(KVCacheManager):
finished_seqs
=
self
.
sequences
.
keys
()
-
unfinished_seqs
unfinished_indices
=
[
i
for
i
,
j
in
enumerate
(
self
.
sequences
)
if
j
in
unfinished_seqs
]
finished_indices
=
[
i
for
i
,
j
in
enumerate
(
self
.
sequences
)
if
j
in
finished_seqs
]
self
.
batch_indices
.
copy_
(
self
.
batch_indices
.
data
[:].
copy_
(
torch
.
Tensor
(
(
unfinished_indices
+
finished_indices
+
list
(
range
(
prev_batch_size
,
self
.
max_batch_size
))
)
)
.
to
(
dtype
=
torch
.
int32
,
device
=
"cpu"
)
)
)
# Advance unfinished sequences
...
...
transformer_engine/pytorch/attention/multi_head_attention.py
View file @
27ddce40
...
...
@@ -889,23 +889,11 @@ class MultiheadAttention(torch.nn.Module):
q_pos_emb
,
k_pos_emb
=
rotary_pos_emb
# adjust key and value for inference
if
inference_params
is
not
None
:
if
self
.
qkv_format
==
"sbhd"
:
sequence_length
=
key_layer
.
size
(
0
)
elif
self
.
qkv_format
==
"bshd"
:
sequence_length
=
key_layer
.
size
(
1
)
else
:
raise
ValueError
(
f
"qkv_format=
{
self
.
qkv_format
}
not supported for KV caching and RoPE."
)
sequence_start
=
inference_params
.
get_seqlens_pre_step
()
# sequence_start = inference_params.seqlens[0]
sequence_end
=
sequence_start
+
sequence_length
q_pos_emb
=
q_pos_emb
[
sequence_start
:
sequence_end
,
...]
k_pos_emb
=
k_pos_emb
[
sequence_start
:
sequence_end
,
...]
# Applyig RoPE for inference needs start positions of sequences
# for each iteration.
sequence_start_positions
=
(
inference_params
.
get_seqlens_pre_step
()
if
inference_params
is
not
None
else
None
)
if
pad_between_seqs
:
rotary_pos_cu_seq_lens_q
=
cu_seqlens_q_padded
...
...
@@ -922,6 +910,7 @@ class MultiheadAttention(torch.nn.Module):
cu_seqlens
=
rotary_pos_cu_seq_lens_q
,
cp_size
=
self
.
cp_size
,
cp_rank
=
self
.
cp_rank
,
start_positions
=
sequence_start_positions
,
interleaved
=
self
.
rotary_pos_interleaved
,
)
key_layer
=
apply_rotary_pos_emb
(
...
...
@@ -932,6 +921,7 @@ class MultiheadAttention(torch.nn.Module):
cu_seqlens
=
rotary_pos_cu_seq_lens_kv
,
cp_size
=
self
.
cp_size
,
cp_rank
=
self
.
cp_rank
,
start_positions
=
sequence_start_positions
,
interleaved
=
self
.
rotary_pos_interleaved
,
)
...
...
transformer_engine/pytorch/attention/rope.py
View file @
27ddce40
...
...
@@ -5,14 +5,14 @@
"""
Rotary Position Embedding implementation of different types along with helper functions
"""
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Tuple
,
Union
,
List
import
torch
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.cpp_extensions.fused_attn
import
QKVFormat
__all__
=
[
"RotaryPositionEmbedding"
,
"apply_rotary_pos_emb"
]
__all__
=
[
"RotaryPositionEmbedding"
,
"apply_rotary_pos_emb"
,
"apply_fused_qkv_rotary_pos_emb"
]
class
RotaryPositionEmbedding
(
torch
.
nn
.
Module
):
...
...
@@ -170,6 +170,86 @@ class FusedRoPEFunc(torch.autograd.Function):
return
grad_input
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FusedQKVRoPEFunc
(
torch
.
autograd
.
Function
):
"""
Function for FusedQKVRoPE
This implementation accepts combined QKV tensor in `bshd` or `sbhd` format. Q and K RoPE tensors are the additional required inputs.
The RoPE tensors should be of shape (s, 1, 1, d). It produces 3 outputs: Q, K after RoPE, V is the same as input.
"""
@
staticmethod
def
forward
(
ctx
,
qkv
:
torch
.
Tensor
,
q_freqs
:
torch
.
Tensor
,
k_freqs
:
torch
.
Tensor
,
qkv_split_arg_list
:
List
[
int
],
start_positions
:
Union
[
torch
.
Tensor
,
None
]
=
None
,
tensor_format
:
str
=
"sbhd"
,
interleaved
:
bool
=
False
,
cp_size
:
int
=
1
,
cp_rank
:
int
=
0
,
)
->
torch
.
Tensor
:
"""Fused RoPE forward."""
if
q_freqs
.
dtype
!=
torch
.
float32
:
q_freqs
=
q_freqs
.
float
()
if
k_freqs
.
dtype
!=
torch
.
float32
:
k_freqs
=
k_freqs
.
float
()
assert
tensor_format
in
(
"sbhd"
,
"bshd"
,
),
f
"Unsupported tensor_format:
{
tensor_format
}
."
assert
qkv
.
is_contiguous
(),
"QKV Tensor should be contiguous."
assert
q_freqs
.
is_contiguous
(),
"q_freqs Tensor should be contiguous."
assert
k_freqs
.
is_contiguous
(),
"k_freqs Tensor should be contiguous."
output
=
tex
.
fused_qkv_rope_forward
(
qkv
,
q_freqs
,
k_freqs
,
start_positions
,
qkv_split_arg_list
,
QKVFormat
[
tensor_format
],
interleaved
,
cp_size
,
cp_rank
,
)
ctx
.
save_for_backward
(
q_freqs
,
k_freqs
)
ctx
.
tensor_format
=
tensor_format
ctx
.
qkv_split_arg_list
=
qkv_split_arg_list
ctx
.
cp_size
=
cp_size
ctx
.
cp_rank
=
cp_rank
ctx
.
interleaved
=
interleaved
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output_q
:
torch
.
Tensor
,
grad_output_k
:
torch
.
Tensor
,
grad_output_v
:
torch
.
Tensor
)
->
Tuple
[
Union
[
torch
.
Tensor
,
None
],
...]:
"""Fused RoPE backward."""
q_freqs
,
k_freqs
=
ctx
.
saved_tensors
grad_output_q
=
grad_output_q
.
contiguous
()
grad_output_k
=
grad_output_k
.
contiguous
()
grad_output_v
=
grad_output_v
.
contiguous
()
grad_input
=
tex
.
fused_qkv_rope_backward
(
grad_output_q
,
grad_output_k
,
grad_output_v
,
q_freqs
,
k_freqs
,
ctx
.
qkv_split_arg_list
,
QKVFormat
[
ctx
.
tensor_format
],
ctx
.
interleaved
,
ctx
.
cp_size
,
ctx
.
cp_rank
,
)
return
grad_input
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
def
_rotate_half
(
x
:
torch
.
Tensor
,
interleaved
:
bool
)
->
torch
.
Tensor
:
"""Change sign so the last dimension becomes [-odd, +even]
...
...
@@ -393,3 +473,82 @@ def apply_rotary_pos_emb(
tensor_format
,
interleaved
=
interleaved
,
)
def
apply_fused_qkv_rotary_pos_emb
(
qkv
:
torch
.
Tensor
,
q_freqs
:
torch
.
Tensor
,
k_freqs
:
torch
.
Tensor
,
qkv_split_arg_list
:
List
[
int
],
tensor_format
:
str
=
"sbhd"
,
start_positions
:
Union
[
torch
.
Tensor
,
None
]
=
None
,
interleaved
:
bool
=
False
,
cu_seqlens
:
Union
[
torch
.
Tensor
,
None
]
=
None
,
# pylint: disable=unused-argument
cp_size
:
int
=
1
,
cp_rank
:
int
=
0
,
)
->
torch
.
Tensor
:
"""
Apply rotary positional embedding tensor to the input qkv tensor.
Support matrix:
Fused:
Training:
qkv_formats: "bshd", "sbhd"
context parallel: yes
start_positions: no
interleaving: yes
Inference:
qkv_formats: "bshd", "sbhd"
context parallelism: no
start_positions: yes
interleaving: yes
Parameters
----------
qkv: torch.Tensor
Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which
rotary positional embedding will be applied. This tensor has q, k, v concatenated
along the last dimension.
q_freqs: torch.Tensor
Rotary positional embedding Q tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
k_freqs: torch.Tensor
Rotary positional embedding K tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
qkv_split_arg_list: List[int]
List of integers that specify the split of the qkv tensor. The list should have 3 elements,
the first element is the number of elements in the q tensor, the second element is the number
of elements in the k tensor, and the third element is the number of elements in the v tensor.
The sum of the elements in the list should be equal to the last dimension of the qkv tensor.
start_positions: torch.Tensor, default = None.
Tokens in a sequence `i` should be applied with position encoding offset by
`start_positions[i]`. If `start_positions=None`, there's no offset.
tensor_format: {'sbhd', 'bshd'}, default = 'sbhd'
is `bshd` if `qkv` is of shape `[bs, seq, ...]`, or `sbhd` if `qkv` is
of shape `[seq, bs, ...]`.
interleaved: bool, default = False
Whether to use interleaved rotary position embedding.
cp_size: int, default = 1.
Context parallel world size.
cp_rank: int, default = 0.
Context parallel rank.
"""
# `start_positions` is only supported for `cp_size=1` and inference.
assert
not
(
cp_size
>
1
and
start_positions
is
not
None
),
"""start_positions != None with CP SIZE > 1 is not supported!"""
assert
tensor_format
!=
"thd"
,
"'thd' tensor_format not supported currently."
return
FusedQKVRoPEFunc
.
apply
(
qkv
,
q_freqs
,
k_freqs
,
qkv_split_arg_list
,
start_positions
,
tensor_format
,
interleaved
,
cp_size
,
cp_rank
,
)
transformer_engine/pytorch/cpu_offload.py
View file @
27ddce40
...
...
@@ -559,17 +559,23 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
buffer_idx
=
0
double_buffer_idx
=
group_to_reload
%
2
main_stream
=
torch
.
cuda
.
current_stream
()
with
torch
.
cuda
.
stream
(
self
.
h2d_stream
):
# move back tensors
for
tensor_label
,
state
in
self
.
tensor_tag_to_state
.
items
():
group_id
,
_
=
tensor_label
if
group_id
==
group_to_reload
:
if
self
.
double_buffering
:
reload_buffer
=
self
.
reload_double_buffer
[
double_buffer_idx
][
buffer_idx
]
else
:
reload_buffer
=
None
if
isinstance
(
state
,
tuple
):
if
self
.
double_buffering
:
reload_buffer
=
self
.
reload_double_buffer
[
double_buffer_idx
][
buffer_idx
]
else
:
with
torch
.
cuda
.
stream
(
main_stream
):
reload_buffer
=
torch
.
empty_like
(
state
[
1
],
device
=
torch
.
cuda
.
current_device
()
)
recovered_tensor
=
SynchronizedGroupOffloadHandler
.
reload
(
state
,
True
,
reload_buffer
)
...
...
@@ -578,14 +584,18 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
elif
isinstance
(
state
,
list
):
tensor_list
=
[]
for
state_tuple
in
state
:
if
self
.
double_buffering
:
reload_buffer
=
self
.
reload_double_buffer
[
double_buffer_idx
][
buffer_idx
]
else
:
reload_buffer
=
None
if
isinstance
(
state_tuple
,
tuple
):
if
self
.
double_buffering
:
reload_buffer
=
self
.
reload_double_buffer
[
double_buffer_idx
][
buffer_idx
]
else
:
with
torch
.
cuda
.
stream
(
main_stream
):
reload_buffer
=
torch
.
empty_like
(
state_tuple
[
1
],
device
=
torch
.
cuda
.
current_device
()
)
tensor_list
.
append
(
SynchronizedGroupOffloadHandler
.
reload
(
state_tuple
,
...
...
transformer_engine/pytorch/csrc/extensions.h
View file @
27ddce40
...
...
@@ -190,38 +190,49 @@ at::Tensor swap_first_dims(at::Tensor tensor, std::optional<at::Tensor> out = st
* Activations
**************************************************************************************************/
/* GELU and variants*/
py
::
object
gelu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
py
::
object
r
elu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
py
::
object
dg
elu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
py
::
object
geglu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
py
::
object
q
geglu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
py
::
object
d
geglu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
py
::
object
reg
lu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
py
::
object
qge
lu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
py
::
object
swiglu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
py
::
object
dqgelu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
py
::
object
qgelu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
py
::
object
qge
g
lu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
py
::
object
srelu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
py
::
object
dqgeglu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
py
::
object
dgelu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
/* ReLU and variants*/
py
::
object
relu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
py
::
object
drelu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
py
::
object
dgeglu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
py
::
object
dqgeglu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
py
::
object
reglu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
py
::
object
dreglu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
py
::
object
dswiglu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
py
::
object
dqgelu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
py
::
object
srelu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
py
::
object
dsrelu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
py
::
object
sreglu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
py
::
object
dsreglu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
/* Silu and variants*/
py
::
object
silu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
py
::
object
dsilu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
py
::
object
swiglu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
py
::
object
dswiglu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
);
/***************************************************************************************************
* LayerNorm
**************************************************************************************************/
...
...
@@ -244,6 +255,11 @@ std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
const
at
::
Tensor
&
rsigma
,
const
at
::
Tensor
&
gamma
,
const
int
sm_margin
,
const
bool
zero_centered_gamma
);
std
::
vector
<
py
::
object
>
rmsnorm_bwd_add
(
const
at
::
Tensor
&
dz
,
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
add
,
const
at
::
Tensor
&
rsigma
,
const
at
::
Tensor
&
gamma
,
const
int
sm_margin
,
const
bool
zero_centered_gamma
);
std
::
vector
<
py
::
object
>
rmsnorm_fwd
(
const
py
::
handle
&
input
,
const
py
::
handle
&
weight
,
float
eps
,
py
::
object
ln_out
,
py
::
handle
quantizer
,
DType
otype
,
const
int
sm_margin
,
const
bool
zero_centered_gamma
);
...
...
@@ -285,6 +301,17 @@ std::vector<py::object> dbias_dqgelu(const at::Tensor &grad_output, const at::Te
std
::
vector
<
py
::
object
>
dbias_dsrelu
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
py
::
handle
quantizer
);
/***************************************************************************************************
* Dropout
**************************************************************************************************/
std
::
vector
<
py
::
object
>
dropout_fwd
(
const
py
::
handle
&
input
,
const
float
dropout_probability
,
std
::
optional
<
at
::
Tensor
>
out
=
std
::
nullopt
);
py
::
object
dropout_bwd
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
mask
,
const
float
dropout_probability
,
std
::
optional
<
at
::
Tensor
>
grad_input
=
std
::
nullopt
);
/***************************************************************************************************
* Softmax
**************************************************************************************************/
...
...
@@ -349,6 +376,18 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens
,
const
int
cp_size
,
const
int
cp_rank
);
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
fused_qkv_rope_forward
(
const
at
::
Tensor
&
qkv_input
,
const
at
::
Tensor
&
q_freqs
,
const
at
::
Tensor
&
k_freqs
,
const
std
::
optional
<
at
::
Tensor
>
start_positions
,
const
std
::
vector
<
int
>
&
qkv_split_arg_list
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
int
cp_size
,
const
int
cp_rank
);
at
::
Tensor
fused_qkv_rope_backward
(
const
at
::
Tensor
&
q_grad_out
,
const
at
::
Tensor
&
k_grad_out
,
const
at
::
Tensor
&
v_grad_out
,
const
at
::
Tensor
&
q_freqs
,
const
at
::
Tensor
&
k_freqs
,
const
std
::
vector
<
int
>
&
qkv_split_arg_list
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
int
cp_size
,
const
int
cp_rank
);
/***************************************************************************************************
* Miscellaneous
**************************************************************************************************/
...
...
transformer_engine/pytorch/csrc/extensions/activation.cpp
View file @
27ddce40
...
...
@@ -101,6 +101,7 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i
return
grad_input_py
;
}
/* GELU and variants*/
py
::
object
gelu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
activation_helper
<
nvte_gelu
>
(
input
,
quantizer
);
}
...
...
@@ -109,30 +110,39 @@ py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle qua
return
dactivation_helper
<
nvte_dgelu
>
(
grad
,
input
,
quantizer
);
}
py
::
object
re
lu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
activation_helper
<
nvte_
re
lu
>
(
input
,
quantizer
);
py
::
object
geg
lu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
activation_helper
<
nvte_
geg
lu
>
(
input
,
quantizer
,
2
);
}
py
::
object
d
re
lu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
dactivation_helper
<
nvte_d
re
lu
>
(
grad
,
input
,
quantizer
);
py
::
object
d
geg
lu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
dactivation_helper
<
nvte_d
geg
lu
>
(
grad
,
input
,
quantizer
);
}
py
::
object
ge
g
lu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
activation_helper
<
nvte_ge
g
lu
>
(
input
,
quantizer
,
2
);
py
::
object
q
gelu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
activation_helper
<
nvte_
q
gelu
>
(
input
,
quantizer
);
}
py
::
object
qge
g
lu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
activation_helper
<
nvte_qge
g
lu
>
(
input
,
quantizer
,
2
);
py
::
object
d
qgelu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
d
activation_helper
<
nvte_
d
qgelu
>
(
grad
,
input
,
quantizer
);
}
py
::
object
d
geglu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
d
activation_helper
<
nvte_
d
geglu
>
(
grad
,
input
,
quantizer
);
py
::
object
q
geglu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
activation_helper
<
nvte_
q
geglu
>
(
input
,
quantizer
,
2
);
}
py
::
object
dqgeglu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
dactivation_helper
<
nvte_dqgeglu
>
(
grad
,
input
,
quantizer
);
}
/* ReLU and variants*/
py
::
object
relu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
activation_helper
<
nvte_relu
>
(
input
,
quantizer
);
}
py
::
object
drelu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
dactivation_helper
<
nvte_drelu
>
(
grad
,
input
,
quantizer
);
}
py
::
object
reglu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
activation_helper
<
nvte_reglu
>
(
input
,
quantizer
,
2
);
}
...
...
@@ -141,28 +151,36 @@ py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle qu
return
dactivation_helper
<
nvte_dreglu
>
(
grad
,
input
,
quantizer
);
}
py
::
object
s
wig
lu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
activation_helper
<
nvte_s
wig
lu
>
(
input
,
quantizer
,
2
);
py
::
object
s
re
lu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
activation_helper
<
nvte_s
re
lu
>
(
input
,
quantizer
);
}
py
::
object
ds
wig
lu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
dactivation_helper
<
nvte_ds
wig
lu
>
(
grad
,
input
,
quantizer
);
py
::
object
ds
re
lu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
dactivation_helper
<
nvte_ds
re
lu
>
(
grad
,
input
,
quantizer
);
}
py
::
object
qge
lu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
activation_helper
<
nvte_
qge
lu
>
(
input
,
quantizer
);
py
::
object
sreg
lu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
activation_helper
<
nvte_
sreg
lu
>
(
input
,
quantizer
,
2
);
}
py
::
object
d
qge
lu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
dactivation_helper
<
nvte_d
qge
lu
>
(
grad
,
input
,
quantizer
);
py
::
object
d
sreg
lu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
dactivation_helper
<
nvte_d
sreg
lu
>
(
grad
,
input
,
quantizer
);
}
py
::
object
srelu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
activation_helper
<
nvte_srelu
>
(
input
,
quantizer
);
/* Silu and variants*/
py
::
object
silu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
activation_helper
<
nvte_silu
>
(
input
,
quantizer
);
}
py
::
object
dsrelu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
dactivation_helper
<
nvte_dsrelu
>
(
grad
,
input
,
quantizer
);
py
::
object
dsilu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
dactivation_helper
<
nvte_dsilu
>
(
grad
,
input
,
quantizer
);
}
py
::
object
swiglu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
activation_helper
<
nvte_swiglu
>
(
input
,
quantizer
,
2
);
}
py
::
object
dswiglu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
dactivation_helper
<
nvte_dswiglu
>
(
grad
,
input
,
quantizer
);
}
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/extensions/apply_rope.cpp
View file @
27ddce40
...
...
@@ -28,9 +28,10 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
auto
freqs_cu
=
makeTransformerEngineTensor
(
freqs
);
auto
output_cu
=
makeTransformerEngineTensor
(
output
);
auto
start_positions_cu
=
TensorWrapper
();
// empty
cu_seqle
ns tensor
auto
start_positions_cu
=
TensorWrapper
();
// empty
start_positio
ns tensor
if
(
start_positions
)
{
start_positions_cu
=
makeTransformerEngineTensor
(
start_positions
.
value
());
TORCH_CHECK
(
start_positions_cu
.
ndim
()
==
1
,
"expected 1D tensor"
);
}
if
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_THD
)
{
...
...
@@ -102,6 +103,65 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
return
output
;
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
fused_qkv_rope_forward
(
const
at
::
Tensor
&
qkv_input
,
const
at
::
Tensor
&
q_freqs
,
const
at
::
Tensor
&
k_freqs
,
const
std
::
optional
<
at
::
Tensor
>
start_positions
,
const
std
::
vector
<
int
>
&
qkv_split_arg_list
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
int
cp_size
,
const
int
cp_rank
)
{
TORCH_CHECK
(
q_freqs
.
dim
()
==
4
,
"expected 4D tensor"
);
TORCH_CHECK
(
q_freqs
.
size
(
1
)
==
1
&&
q_freqs
.
size
(
2
)
==
1
,
"expected the second and third dims of the freqs tensor equal 1"
);
TORCH_CHECK
(
q_freqs
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"Dtype of the freqs tensor must be float"
);
TORCH_CHECK
(
k_freqs
.
dim
()
==
4
,
"expected 4D tensor"
);
TORCH_CHECK
(
k_freqs
.
size
(
1
)
==
1
&&
k_freqs
.
size
(
2
)
==
1
,
"expected the second and third dims of the freqs tensor equal 1"
);
TORCH_CHECK
(
k_freqs
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"Dtype of the freqs tensor must be float"
);
// output
auto
act_options
=
at
::
TensorOptions
().
dtype
(
qkv_input
.
scalar_type
()).
device
(
qkv_input
.
device
());
auto
q_out_size
=
qkv_input
.
sizes
().
vec
();
q_out_size
[
2
]
=
q_out_size
[
2
]
*
qkv_split_arg_list
[
0
]
/
qkv_split_arg_list
[
1
];
q_out_size
[
3
]
=
qkv_split_arg_list
[
1
];
auto
q_out
=
at
::
empty
(
q_out_size
,
act_options
);
auto
k_out_size
=
qkv_input
.
sizes
().
vec
();
k_out_size
[
3
]
=
qkv_split_arg_list
[
1
];
auto
k_out
=
at
::
empty
(
k_out_size
,
act_options
);
auto
v_out_size
=
qkv_input
.
sizes
().
vec
();
v_out_size
[
3
]
=
qkv_split_arg_list
[
2
];
auto
v_out
=
at
::
empty
(
v_out_size
,
act_options
);
auto
qkv_cu
=
makeTransformerEngineTensor
(
qkv_input
);
auto
q_freqs_cu
=
makeTransformerEngineTensor
(
q_freqs
);
auto
k_freqs_cu
=
makeTransformerEngineTensor
(
k_freqs
);
auto
q_out_cu
=
makeTransformerEngineTensor
(
q_out
);
auto
k_out_cu
=
makeTransformerEngineTensor
(
k_out
);
auto
v_out_cu
=
makeTransformerEngineTensor
(
v_out
);
auto
start_positions_cu
=
TensorWrapper
();
// empty cu_seqlens tensor
if
(
start_positions
)
{
start_positions_cu
=
makeTransformerEngineTensor
(
start_positions
.
value
());
}
TORCH_CHECK
(
qkv_input
.
dim
()
==
4
,
"expected 4D input tensor"
);
TORCH_CHECK
(
qkv_input
.
is_contiguous
(),
"input tensor must be contiguous"
);
const
bool
is_sbhd
=
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
;
const
int
s
=
is_sbhd
?
qkv_input
.
size
(
0
)
:
qkv_input
.
size
(
1
);
const
int
b
=
is_sbhd
?
qkv_input
.
size
(
1
)
:
qkv_input
.
size
(
0
);
const
int
h
=
qkv_input
.
size
(
2
);
const
int
d
=
qkv_split_arg_list
[
2
];
const
int
d2
=
q_freqs
.
size
(
3
);
nvte_fused_qkv_rope_forward
(
qkv_cu
.
data
(),
q_freqs_cu
.
data
(),
k_freqs_cu
.
data
(),
start_positions_cu
.
data
(),
q_out_cu
.
data
(),
k_out_cu
.
data
(),
v_out_cu
.
data
(),
qkv_format
,
interleaved
,
cp_size
,
cp_rank
,
s
,
b
,
h
,
d
,
d2
,
qkv_split_arg_list
[
0
],
qkv_split_arg_list
[
1
],
qkv_split_arg_list
[
2
],
at
::
cuda
::
getCurrentCUDAStream
());
return
std
::
make_tuple
(
q_out
,
k_out
,
v_out
);
}
at
::
Tensor
fused_rope_backward
(
const
at
::
Tensor
&
output_grads
,
const
at
::
Tensor
&
freqs
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens
,
const
int
cp_size
,
...
...
@@ -193,4 +253,42 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
return
input_grads
;
}
at
::
Tensor
fused_qkv_rope_backward
(
const
at
::
Tensor
&
q_grad_out
,
const
at
::
Tensor
&
k_grad_out
,
const
at
::
Tensor
&
v_grad_out
,
const
at
::
Tensor
&
q_freqs
,
const
at
::
Tensor
&
k_freqs
,
const
std
::
vector
<
int
>
&
qkv_split_arg_list
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
int
cp_size
,
const
int
cp_rank
)
{
auto
act_options
=
at
::
TensorOptions
().
dtype
(
q_grad_out
.
scalar_type
()).
device
(
q_grad_out
.
device
());
auto
qkv_grad_size
=
q_grad_out
.
sizes
().
vec
();
auto
total_hd
=
(
q_grad_out
.
size
(
2
)
+
k_grad_out
.
size
(
2
)
+
v_grad_out
.
size
(
2
))
*
q_grad_out
.
size
(
3
);
auto
total_d
=
qkv_split_arg_list
[
0
]
+
qkv_split_arg_list
[
1
]
+
qkv_split_arg_list
[
2
];
qkv_grad_size
[
2
]
=
total_hd
/
total_d
;
qkv_grad_size
[
3
]
=
total_d
;
auto
qkv_grad_input
=
at
::
empty
(
qkv_grad_size
,
act_options
);
const
bool
is_sbhd
=
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
;
const
int
s
=
is_sbhd
?
q_grad_out
.
size
(
0
)
:
q_grad_out
.
size
(
1
);
const
int
b
=
is_sbhd
?
q_grad_out
.
size
(
1
)
:
q_grad_out
.
size
(
0
);
const
int
h
=
qkv_grad_input
.
size
(
2
);
const
int
d
=
qkv_split_arg_list
[
2
];
const
int
d2
=
q_freqs
.
size
(
3
);
auto
q_grad_out_cu
=
makeTransformerEngineTensor
(
q_grad_out
);
auto
k_grad_out_cu
=
makeTransformerEngineTensor
(
k_grad_out
);
auto
v_grad_out_cu
=
makeTransformerEngineTensor
(
v_grad_out
);
auto
q_freqs_cu
=
makeTransformerEngineTensor
(
q_freqs
);
auto
k_freqs_cu
=
makeTransformerEngineTensor
(
k_freqs
);
auto
qkv_grad_cu
=
makeTransformerEngineTensor
(
qkv_grad_input
);
nvte_fused_qkv_rope_backward
(
q_grad_out_cu
.
data
(),
k_grad_out_cu
.
data
(),
v_grad_out_cu
.
data
(),
q_freqs_cu
.
data
(),
k_freqs_cu
.
data
(),
qkv_grad_cu
.
data
(),
qkv_format
,
interleaved
,
cp_size
,
cp_rank
,
s
,
b
,
h
,
d
,
d2
,
qkv_split_arg_list
[
0
],
qkv_split_arg_list
[
1
],
qkv_split_arg_list
[
2
],
at
::
cuda
::
getCurrentCUDAStream
());
return
qkv_grad_input
;
}
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/extensions/cast.cpp
View file @
27ddce40
...
...
@@ -205,11 +205,8 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_fp
auto
make_torch_view
=
[](
std
::
shared_ptr
<
at
::
Tensor
>
&
buffer
,
const
std
::
vector
<
size_t
>
&
shape
,
size_t
offset
,
at
::
ScalarType
dtype
)
->
at
::
Tensor
{
std
::
vector
<
int64_t
>
shape_int64
(
shape
.
begin
(),
shape
.
end
());
// in the case where full buffer is empty because local rank receives no tokens for all the experts
// then the data_ptr is nullptr, we need to return an empty tensor instead of calling from_blob
// but in the case where some experts receive tokens, some not, we want to leverage from_blob
// as much as possible to avoid CPU overhead
if
(
buffer
->
data_ptr
<
uint8_t
>
()
==
nullptr
)
{
bool
is_empty_shape
=
product
(
shape
)
==
0
;
if
(
buffer
->
data_ptr
<
uint8_t
>
()
==
nullptr
||
is_empty_shape
)
{
return
at
::
empty
(
shape_int64
,
at
::
device
(
at
::
kCUDA
).
dtype
(
dtype
));
}
return
at
::
from_blob
(
...
...
@@ -359,11 +356,8 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
auto
make_torch_view
=
[](
std
::
shared_ptr
<
at
::
Tensor
>
&
buffer
,
const
std
::
vector
<
size_t
>
&
shape
,
size_t
offset
,
at
::
ScalarType
dtype
)
->
at
::
Tensor
{
std
::
vector
<
int64_t
>
shape_int64
(
shape
.
begin
(),
shape
.
end
());
// in the case where full buffer is empty because local rank receives no tokens for all the experts
// then the data_ptr is nullptr, we need to return an empty tensor instead of calling from_blob
// but in the case where some experts receive tokens, some not, we want to leverage from_blob
// as much as possible to avoid CPU overhead
if
(
buffer
->
data_ptr
<
uint8_t
>
()
==
nullptr
)
{
bool
is_empty_shape
=
product
(
shape
)
==
0
;
if
(
buffer
->
data_ptr
<
uint8_t
>
()
==
nullptr
||
is_empty_shape
)
{
return
at
::
empty
(
shape_int64
,
at
::
device
(
at
::
kCUDA
).
dtype
(
dtype
));
}
return
at
::
from_blob
(
...
...
transformer_engine/pytorch/csrc/extensions/dropout.cpp
0 → 100644
View file @
27ddce40
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "transformer_engine/dropout.h"
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <pybind.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include "../common.h"
#include "../extensions.h"
#include "../pybind.h"
#include "transformer_engine/transformer_engine.h"
namespace
transformer_engine
{
namespace
pytorch
{
std
::
vector
<
py
::
object
>
dropout_fwd
(
const
py
::
handle
&
input
,
float
dropout_probability
,
std
::
optional
<
at
::
Tensor
>
out
)
{
using
namespace
transformer_engine
::
pytorch
::
detail
;
// Input tensor
const
TensorWrapper
input_nvte
=
makeTransformerEngineTensor
(
input
,
py
::
none
());
// Allocate output tensor if needed
if
(
!
out
)
{
at
::
ScalarType
dtype
=
GetATenDType
(
input_nvte
.
dtype
());
if
(
dtype
==
at
::
kFloat8_e4m3fn
||
dtype
==
at
::
kFloat8_e5m2
)
{
dtype
=
input
.
attr
(
"dtype"
).
cast
<
at
::
ScalarType
>
();
}
const
auto
shape_uint64
=
convertShape
(
input_nvte
.
shape
());
const
std
::
vector
<
int64_t
>
shape_int64
(
shape_uint64
.
begin
(),
shape_uint64
.
end
());
const
auto
opts
=
at
::
TensorOptions
().
dtype
(
dtype
).
device
(
torch
::
kCUDA
);
out
=
at
::
empty
(
shape_int64
,
opts
);
}
TensorWrapper
out_nvte
=
makeTransformerEngineTensor
(
*
out
);
// Mask tensor
auto
mask_pyt
=
allocateTorchTensor
(
input_nvte
.
numel
()
/
8
,
DType
::
kByte
);
auto
mask_nvte
=
makeTransformerEngineTensor
(
mask_pyt
);
// RNG state tensor
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
std
::
nullopt
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
at
::
PhiloxCudaState
philox_args
;
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
gen
->
mutex_
);
constexpr
int64_t
rng_elts_per_thread
=
4
;
philox_args
=
gen
->
philox_cuda_state
(
rng_elts_per_thread
);
}
auto
rng_state_pyt
=
allocateTorchTensor
(
2
,
DType
::
kInt64
);
NVTE_SCOPED_GIL_RELEASE
({
nvte_extract_seed_and_offset
(
reinterpret_cast
<
int64_t
*>
(
rng_state_pyt
.
data_ptr
()),
philox_args
.
captured_
,
philox_args
.
seed_
.
ptr
,
philox_args
.
seed_
.
val
,
philox_args
.
offset_
.
ptr
,
philox_args
.
offset_
.
val
,
philox_args
.
offset_intragraph_
,
at
::
cuda
::
getCurrentCUDAStream
());
});
auto
rng_state_nvte
=
makeTransformerEngineTensor
(
rng_state_pyt
);
// Launch kernel
NVTE_SCOPED_GIL_RELEASE
({
nvte_dropout_fwd
(
input_nvte
.
data
(),
out_nvte
.
data
(),
mask_nvte
.
data
(),
rng_state_nvte
.
data
(),
dropout_probability
,
at
::
cuda
::
getCurrentCUDAStream
());
});
return
{
py
::
cast
(
std
::
move
(
*
out
)),
py
::
cast
(
mask_pyt
)};
}
py
::
object
dropout_bwd
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
mask
,
const
float
dropout_probability
,
std
::
optional
<
at
::
Tensor
>
grad_input
)
{
const
auto
grad_output_nvte
=
makeTransformerEngineTensor
(
grad_output
);
const
auto
mask_nvte
=
makeTransformerEngineTensor
(
mask
);
if
(
!
grad_input
)
{
grad_input
=
at
::
empty_like
(
grad_output
);
}
auto
grad_input_nvte
=
makeTransformerEngineTensor
(
*
grad_input
);
NVTE_SCOPED_GIL_RELEASE
({
nvte_dropout_bwd
(
grad_output_nvte
.
data
(),
mask_nvte
.
data
(),
grad_input_nvte
.
data
(),
dropout_probability
,
at
::
cuda
::
getCurrentCUDAStream
());
});
return
py
::
cast
(
std
::
move
(
*
grad_input
));
}
}
// namespace pytorch
}
// namespace transformer_engine
transformer_engine/pytorch/csrc/extensions/gemm.cpp
View file @
27ddce40
...
...
@@ -95,6 +95,8 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
bool
use_split_accumulator
,
CommOverlapCore
*
comm_overlap
,
std
::
optional
<
CommOverlapType
>
comm_type
,
MaybeTensor
extra_output
,
bool
bulk_overlap
,
float
alpha
,
std
::
optional
<
float
>
beta
)
{
using
namespace
transformer_engine
::
pytorch
::
detail
;
// Input tensors
NVTE_CHECK
(
!
A
.
is_none
(),
"Tensor A has not been provided"
);
NVTE_CHECK
(
!
B
.
is_none
(),
"Tensor B has not been provided"
);
...
...
@@ -125,10 +127,10 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
"into D tensor. Beta has nothing to be applied to."
);
}
DType
output_dtype
=
out_dtype
?
*
out_dtype
:
A_tensor
.
dtype
();
// Output tensor
TensorWrapper
D_tensor
;
if
(
D
.
is_none
())
{
DType
output_dtype
=
out_dtype
?
*
out_dtype
:
A_tensor
.
dtype
();
std
::
tie
(
D_tensor
,
D
)
=
createOutputTensor
(
D_shape
,
output_dtype
,
quantizer
);
}
else
{
D_tensor
=
makeTransformerEngineTensor
(
D
,
quantizer
);
...
...
@@ -141,12 +143,35 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
}
}
// maintain unquantized tensor in case we need unfused quantization support.
TensorWrapper
unquantized_D_tensor
;
py
::
object
unquantized_out
;
// Unfused quantization is needed in the following cases
// 1. Inputs: BF16, Output: FP8 (GEMM output has to be BF16, so FP8 quantization needed after that)
// 2. Inputs: FP8, Output: FP8 (For any quantization apart from delayed scaling,
// GEMM Output needs to be in BF16, to allow for unfused quantization)
bool
unfused_quantization_needed
=
!
quantizer
.
is_none
();
if
(
low_precision
)
{
// At the moment, only use-case for fused GEMM:
// Delayed scaling quantizer with per-tensor scaling inputs
bool
is_per_tensor_scaling_input
=
IsFloat8Tensor
(
A
.
ptr
())
||
IsFloat8Tensor
(
B
.
ptr
());
if
(
IsFloat8Quantizers
(
quantizer
.
ptr
())
&&
is_per_tensor_scaling_input
)
unfused_quantization_needed
=
false
;
}
if
(
unfused_quantization_needed
)
{
NoneQuantizer
q
{
none
};
std
::
tie
(
unquantized_D_tensor
,
unquantized_out
)
=
q
.
create_tensor
(
D_shape
,
output_dtype
);
}
TensorWrapper
&
out_tensor
=
unfused_quantization_needed
?
unquantized_D_tensor
:
D_tensor
;
// Bias tensor
TensorWrapper
bias_tensor
;
MaybeTensor
bias_grad
=
std
::
nullopt
;
if
(
bias
.
has_value
())
{
if
(
grad
)
{
auto
opts
=
torch
::
TensorOptions
().
dtype
(
GetATenDType
(
D_tensor
.
dtype
())).
device
(
torch
::
kCUDA
);
auto
opts
=
torch
::
TensorOptions
().
dtype
(
GetATenDType
(
out_tensor
.
dtype
())).
device
(
torch
::
kCUDA
);
bias_grad
=
at
::
empty
({
static_cast
<
int64_t
>
(
B_shape
.
data
[
B_shape
.
ndim
-
1
])},
opts
);
bias_tensor
=
makeTransformerEngineTensor
(
*
bias_grad
);
}
else
{
...
...
@@ -159,7 +184,7 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
// Activation input tensor
MaybeTensor
pre_gelu_out
=
std
::
nullopt
;
DType
gelu_type
=
low_precision
?
bias_type
:
D
_tensor
.
dtype
();
DType
gelu_type
=
low_precision
?
bias_type
:
out
_tensor
.
dtype
();
if
(
gelu
)
{
if
(
!
grad
)
{
auto
dtype
=
GetATenDType
(
gelu_type
);
...
...
@@ -212,7 +237,7 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
// Direct GEMM call to the correct overlap
if
(
bulk_overlap
)
{
NVTE_SCOPED_GIL_RELEASE
({
comm_overlap
->
bulk_overlap
(
A_tensor
,
transa
,
B_tensor
,
transb
,
D
_tensor
,
bias_tensor
,
comm_overlap
->
bulk_overlap
(
A_tensor
,
transa
,
B_tensor
,
transb
,
out
_tensor
,
bias_tensor
,
te_pre_gelu_out
,
te_workspace
,
grad
,
accumulate
,
use_split_accumulator
,
comm_type
.
value
(),
extra_output_tensor
,
main_stream
);
...
...
@@ -220,14 +245,14 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
}
else
if
(
comm_type
.
value
()
==
CommOverlapType
::
AG
)
{
if
(
comm_overlap
->
is_atomic_gemm
())
{
NVTE_SCOPED_GIL_RELEASE
({
comm_overlap
->
atomic_gemm_overlap_ag
(
A_tensor
,
transa
,
B_tensor
,
transb
,
D
_tensor
,
comm_overlap
->
atomic_gemm_overlap_ag
(
A_tensor
,
transa
,
B_tensor
,
transb
,
out
_tensor
,
bias_tensor
,
te_pre_gelu_out
,
te_workspace
,
grad
,
accumulate
,
use_split_accumulator
,
extra_output_tensor
,
main_stream
);
});
}
else
{
NVTE_SCOPED_GIL_RELEASE
({
comm_overlap
->
split_overlap_ag
(
A_tensor
,
transa
,
B_tensor
,
transb
,
D
_tensor
,
comm_overlap
->
split_overlap_ag
(
A_tensor
,
transa
,
B_tensor
,
transb
,
out
_tensor
,
bias_tensor
,
te_pre_gelu_out
,
te_workspace
,
grad
,
accumulate
,
use_split_accumulator
,
extra_output_tensor
,
main_stream
);
...
...
@@ -236,14 +261,14 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
}
else
{
if
(
comm_overlap
->
is_atomic_gemm
())
{
NVTE_SCOPED_GIL_RELEASE
({
comm_overlap
->
atomic_gemm_overlap_rs
(
A_tensor
,
transa
,
B_tensor
,
transb
,
D
_tensor
,
comm_overlap
->
atomic_gemm_overlap_rs
(
A_tensor
,
transa
,
B_tensor
,
transb
,
out
_tensor
,
bias_tensor
,
te_pre_gelu_out
,
te_workspace
,
grad
,
accumulate
,
use_split_accumulator
,
extra_output_tensor
,
main_stream
);
});
}
else
{
NVTE_SCOPED_GIL_RELEASE
({
comm_overlap
->
split_overlap_rs
(
A_tensor
,
transa
,
B_tensor
,
transb
,
D
_tensor
,
comm_overlap
->
split_overlap_rs
(
A_tensor
,
transa
,
B_tensor
,
transb
,
out
_tensor
,
bias_tensor
,
te_pre_gelu_out
,
te_workspace
,
grad
,
accumulate
,
use_split_accumulator
,
extra_output_tensor
,
main_stream
);
...
...
@@ -253,15 +278,15 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
}
else
{
// Launch GEMM
NVTE_SCOPED_GIL_RELEASE
({
nvte_cublas_gemm_scaled
(
A_tensor
.
data
(),
B_tensor
.
data
(),
D
_tensor
.
data
(),
nvte_cublas_gemm_scaled
(
A_tensor
.
data
(),
B_tensor
.
data
(),
out
_tensor
.
data
(),
bias_tensor
.
data
(),
te_pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
te_workspace
.
data
(),
alpha
,
*
beta
,
use_split_accumulator
,
num_math_sms
,
main_stream
);
});
}
}
else
{
if
(
D
_tensor
.
numel
()
!=
0
&&
!
accumulate
)
{
D
_tensor
.
zero_
(
main_stream
);
if
(
out
_tensor
.
numel
()
!=
0
&&
!
accumulate
)
{
out
_tensor
.
zero_
(
main_stream
);
}
if
(
bias
.
has_value
())
{
if
(
bias
->
numel
()
!=
0
&&
grad
)
{
...
...
@@ -269,7 +294,11 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
}
}
}
if
(
unfused_quantization_needed
)
{
// Quantize the output
std
::
unique_ptr
<
Quantizer
>
my_quantizer
=
convert_quantizer
(
quantizer
);
my_quantizer
->
quantize
(
unquantized_D_tensor
,
D_tensor
);
}
// Pack outputs
std
::
vector
<
py
::
object
>
out
;
out
.
emplace_back
(
std
::
move
(
D
));
...
...
@@ -449,24 +478,12 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
}
// For now, we only have multi-stream cublas backend.
const
char
*
NVTE_USE_HIPBLASLT_GROUPEDGEMM
=
std
::
getenv
(
"NVTE_USE_HIPBLASLT_GROUPEDGEMM"
);
if
(
NVTE_USE_HIPBLASLT_GROUPEDGEMM
!=
nullptr
&&
NVTE_USE_HIPBLASLT_GROUPEDGEMM
[
0
]
==
'1'
){
NVTE_SCOPED_GIL_RELEASE
({
nvte_grouped_gemm
(
te_A_vector
.
data
(),
te_B_vector
.
data
(),
te_D_vector
.
data
(),
te_bias_vector
.
data
(),
te_pre_gelu_out_vector
.
data
(),
te_A_vector
.
size
(),
transa
,
transb
,
grad
,
te_workspace_vector
.
data
(),
accumulate
,
use_split_accumulator
,
math_sm_count
,
at
::
cuda
::
getCurrentCUDAStream
());
});
}
else
{
NVTE_SCOPED_GIL_RELEASE
({
nvte_multi_stream_cublas_gemm
(
te_A_vector
.
data
(),
te_B_vector
.
data
(),
te_D_vector
.
data
(),
te_bias_vector
.
data
(),
te_pre_gelu_out_vector
.
data
(),
te_A_vector
.
size
(),
transa
,
transb
,
grad
,
te_workspace_vector
.
data
(),
accumulate
,
use_split_accumulator
,
math_sm_count
,
at
::
cuda
::
getCurrentCUDAStream
());
});
}
NVTE_SCOPED_GIL_RELEASE
({
nvte_multi_tensor_gemm
(
te_A_vector
.
data
(),
te_B_vector
.
data
(),
te_D_vector
.
data
(),
te_bias_vector
.
data
(),
te_pre_gelu_out_vector
.
data
(),
te_A_vector
.
size
(),
transa
,
transb
,
grad
,
te_workspace_vector
.
data
(),
accumulate
,
use_split_accumulator
,
math_sm_count
,
at
::
cuda
::
getCurrentCUDAStream
());
});
return
bias
;
}
...
...
transformer_engine/pytorch/csrc/extensions/normalization.cpp
View file @
27ddce40
...
...
@@ -110,7 +110,8 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
TensorWrapper
unquantized_out_cu
;
py
::
object
unquantized_out
;
if
(
force_unfused_kernel
)
{
if
(
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
if
(
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
())
&&
!
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
))
{
auto
my_quantizer_cs
=
dynamic_cast
<
Float8CurrentScalingQuantizer
*>
(
my_quantizer
.
get
());
std
::
tie
(
unquantized_out_cu
,
unquantized_out
)
=
my_quantizer_cs
->
create_hp_tensor_with_amax
(
size
,
out_dtype
);
...
...
@@ -145,7 +146,8 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
// Quantize output if using unfused kernel
if
(
force_unfused_kernel
)
{
if
(
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
if
(
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
())
&&
!
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
))
{
auto
my_quantizer_cs
=
dynamic_cast
<
Float8CurrentScalingQuantizer
*>
(
my_quantizer
.
get
());
my_quantizer_cs
->
quantize_with_amax
(
unquantized_out_cu
,
out_cu
);
}
else
{
...
...
@@ -199,6 +201,52 @@ std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
return
{
py
::
cast
(
dx
),
py
::
cast
(
dgamma
)};
}
std
::
vector
<
py
::
object
>
rmsnorm_bwd_add
(
const
at
::
Tensor
&
dz
,
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
add
,
const
at
::
Tensor
&
rsigma
,
const
at
::
Tensor
&
gamma
,
const
int
sm_margin
,
const
bool
zero_centered_gamma
)
{
const
auto
&
dz_
=
dz
.
contiguous
();
const
auto
&
x_
=
x
.
contiguous
();
const
auto
&
add_
=
add
.
contiguous
();
const
auto
&
rsigma_
=
rsigma
.
contiguous
();
const
auto
&
gamma_
=
gamma
.
contiguous
();
auto
dx
=
at
::
empty_like
(
x_
);
auto
dgamma
=
at
::
empty_like
(
gamma_
);
TensorWrapper
workspace
;
auto
dz_cu
=
makeTransformerEngineTensor
(
dz_
);
auto
x_cu
=
makeTransformerEngineTensor
(
x_
);
auto
add_cu
=
makeTransformerEngineTensor
(
add_
);
auto
rsigma_cu
=
makeTransformerEngineTensor
(
rsigma_
);
auto
gamma_cu
=
makeTransformerEngineTensor
(
gamma_
);
auto
dx_cu
=
makeTransformerEngineTensor
(
dx
);
auto
dgamma_cu
=
makeTransformerEngineTensor
(
dgamma
);
// This call populates tensors with the required config.
NVTE_SCOPED_GIL_RELEASE
({
nvte_rmsnorm_bwd_add
(
dz_cu
.
data
(),
x_cu
.
data
(),
add_cu
.
data
(),
rsigma_cu
.
data
(),
gamma_cu
.
data
(),
dx_cu
.
data
(),
dgamma_cu
.
data
(),
workspace
.
data
(),
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
-
sm_margin
,
zero_centered_gamma
,
at
::
cuda
::
getCurrentCUDAStream
());
});
// Alloc space for Tensors.
auto
workspace_data
=
allocateSpace
(
workspace
.
shape
(),
workspace
.
dtype
());
workspace
=
makeTransformerEngineTensor
(
workspace_data
.
data_ptr
(),
workspace
.
shape
(),
workspace
.
dtype
());
// Actual call to bwd kernel.
NVTE_SCOPED_GIL_RELEASE
({
nvte_rmsnorm_bwd_add
(
dz_cu
.
data
(),
x_cu
.
data
(),
add_cu
.
data
(),
rsigma_cu
.
data
(),
gamma_cu
.
data
(),
dx_cu
.
data
(),
dgamma_cu
.
data
(),
workspace
.
data
(),
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
-
sm_margin
,
zero_centered_gamma
,
at
::
cuda
::
getCurrentCUDAStream
());
});
return
{
py
::
cast
(
dx
),
py
::
cast
(
dgamma
)};
}
std
::
vector
<
py
::
object
>
rmsnorm_fwd
(
const
py
::
handle
&
input
,
const
py
::
handle
&
weight
,
float
eps
,
py
::
object
out
,
py
::
handle
quantizer
,
DType
out_dtype
,
const
int
sm_margin
,
const
bool
zero_centered_gamma
)
{
...
...
@@ -244,7 +292,8 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
TensorWrapper
unquantized_out_cu
;
py
::
object
unquantized_out
;
if
(
force_unfused_kernel
)
{
if
(
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
if
(
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
())
&&
!
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
))
{
auto
my_quantizer_cs
=
dynamic_cast
<
Float8CurrentScalingQuantizer
*>
(
my_quantizer
.
get
());
std
::
tie
(
unquantized_out_cu
,
unquantized_out
)
=
my_quantizer_cs
->
create_hp_tensor_with_amax
(
size
,
out_dtype
);
...
...
@@ -279,7 +328,8 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
// Quantize output if using unfused kernel
if
(
force_unfused_kernel
)
{
if
(
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
if
(
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
())
&&
!
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
))
{
auto
my_quantizer_cs
=
dynamic_cast
<
Float8CurrentScalingQuantizer
*>
(
my_quantizer
.
get
());
my_quantizer_cs
->
quantize_with_amax
(
unquantized_out_cu
,
out_cu
);
}
else
{
...
...
transformer_engine/pytorch/csrc/extensions/pybind.cpp
View file @
27ddce40
...
...
@@ -113,38 +113,53 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"comm_overlap"
)
=
nullptr
,
py
::
arg
(
"comm_type"
)
=
std
::
nullopt
,
py
::
arg
(
"extra_output"
)
=
std
::
nullopt
,
py
::
arg
(
"bulk_overlap"
)
=
false
,
py
::
arg
(
"alpha"
)
=
1.0
f
,
py
::
arg
(
"beta"
)
=
std
::
nullopt
);
/* GELU and variants*/
m
.
def
(
"gelu"
,
transformer_engine
::
pytorch
::
gelu
,
"GeLU activation"
,
py
::
arg
(
"input"
),
py
::
arg
(
"quantizer"
));
m
.
def
(
"relu"
,
transformer_engine
::
pytorch
::
relu
,
"ReLU activation"
,
py
::
arg
(
"input"
),
py
::
arg
(
"quantizer"
));
m
.
def
(
"geglu"
,
transformer_engine
::
pytorch
::
geglu
,
"GeGLU activation"
,
py
::
arg
(
"input"
),
py
::
arg
(
"quantizer"
));
m
.
def
(
"qgelu"
,
transformer_engine
::
pytorch
::
qgelu
,
"QuickGELU activation"
,
py
::
arg
(
"input"
),
py
::
arg
(
"quantizer"
));
m
.
def
(
"qgeglu"
,
transformer_engine
::
pytorch
::
qgeglu
,
"QuickGeGLU activation"
,
py
::
arg
(
"input"
),
py
::
arg
(
"quantizer"
));
/* ReLU and variants */
m
.
def
(
"relu"
,
transformer_engine
::
pytorch
::
relu
,
"ReLU activation"
,
py
::
arg
(
"input"
),
py
::
arg
(
"quantizer"
));
m
.
def
(
"reglu"
,
transformer_engine
::
pytorch
::
reglu
,
"ReGLU activation"
,
py
::
arg
(
"input"
),
py
::
arg
(
"quantizer"
));
m
.
def
(
"s
wig
lu"
,
transformer_engine
::
pytorch
::
s
wig
lu
,
"S
wiG
LU activation"
,
py
::
arg
(
"input"
),
m
.
def
(
"s
re
lu"
,
transformer_engine
::
pytorch
::
s
re
lu
,
"S
quared Re
LU activation"
,
py
::
arg
(
"input"
),
py
::
arg
(
"quantizer"
));
m
.
def
(
"
qge
lu"
,
transformer_engine
::
pytorch
::
qge
lu
,
"
QuickGE
LU activation"
,
py
::
arg
(
"input"
),
m
.
def
(
"
sreg
lu"
,
transformer_engine
::
pytorch
::
sreg
lu
,
"
Squared ReG
LU activation"
,
py
::
arg
(
"input"
),
py
::
arg
(
"quantizer"
));
m
.
def
(
"srelu"
,
transformer_engine
::
pytorch
::
srelu
,
"Squared ReLU activation"
,
py
::
arg
(
"input"
),
/* SwiGLU and variants */
m
.
def
(
"silu"
,
transformer_engine
::
pytorch
::
silu
,
"SiLU activation"
,
py
::
arg
(
"input"
),
py
::
arg
(
"quantizer"
));
m
.
def
(
"swiglu"
,
transformer_engine
::
pytorch
::
swiglu
,
"SwiGLU activation"
,
py
::
arg
(
"input"
),
py
::
arg
(
"quantizer"
));
/* Backward of GELU and variants */
m
.
def
(
"dgelu"
,
transformer_engine
::
pytorch
::
dgelu
,
"Backward of GeLU"
,
py
::
arg
(
"grad"
),
py
::
arg
(
"fwd_input"
),
py
::
arg
(
"quantizer"
));
m
.
def
(
"drelu"
,
transformer_engine
::
pytorch
::
drelu
,
"Backward of ReLU"
,
py
::
arg
(
"grad"
),
py
::
arg
(
"fwd_input"
),
py
::
arg
(
"quantizer"
));
m
.
def
(
"dgeglu"
,
transformer_engine
::
pytorch
::
dgeglu
,
"Backward of GeGLU"
,
py
::
arg
(
"grad"
),
py
::
arg
(
"fwd_input"
),
py
::
arg
(
"quantizer"
));
m
.
def
(
"dqgelu"
,
transformer_engine
::
pytorch
::
dqgelu
,
"Backward of QuickGELU"
,
py
::
arg
(
"grad"
),
py
::
arg
(
"fwd_input"
),
py
::
arg
(
"quantizer"
));
m
.
def
(
"dqgeglu"
,
transformer_engine
::
pytorch
::
dqgeglu
,
"Backward of QuickGeGLU"
,
py
::
arg
(
"grad"
),
py
::
arg
(
"fwd_input"
),
py
::
arg
(
"quantizer"
));
/* Backward of ReLU and variants */
m
.
def
(
"drelu"
,
transformer_engine
::
pytorch
::
drelu
,
"Backward of ReLU"
,
py
::
arg
(
"grad"
),
py
::
arg
(
"fwd_input"
),
py
::
arg
(
"quantizer"
));
m
.
def
(
"dreglu"
,
transformer_engine
::
pytorch
::
dreglu
,
"Backward of ReGLU"
,
py
::
arg
(
"grad"
),
py
::
arg
(
"fwd_input"
),
py
::
arg
(
"quantizer"
));
m
.
def
(
"ds
wig
lu"
,
transformer_engine
::
pytorch
::
ds
wig
lu
,
"Backward of S
wiG
LU"
,
py
::
arg
(
"grad"
),
m
.
def
(
"ds
re
lu"
,
transformer_engine
::
pytorch
::
ds
re
lu
,
"Backward of S
quared Re
LU"
,
py
::
arg
(
"grad"
),
py
::
arg
(
"fwd_input"
),
py
::
arg
(
"quantizer"
));
m
.
def
(
"dqgelu"
,
transformer_engine
::
pytorch
::
dqgelu
,
"Backward of QuickGELU"
,
py
::
arg
(
"grad"
),
m
.
def
(
"dsreglu"
,
transformer_engine
::
pytorch
::
dsreglu
,
"Backward of Squared ReGLU"
,
py
::
arg
(
"grad"
),
py
::
arg
(
"fwd_input"
),
py
::
arg
(
"quantizer"
));
/* Backward of SiLU and variants */
m
.
def
(
"dsilu"
,
transformer_engine
::
pytorch
::
dsilu
,
"Backward of SiLU"
,
py
::
arg
(
"grad"
),
py
::
arg
(
"fwd_input"
),
py
::
arg
(
"quantizer"
));
m
.
def
(
"ds
re
lu"
,
transformer_engine
::
pytorch
::
ds
re
lu
,
"Backward of S
quared Re
LU"
,
py
::
arg
(
"grad"
),
m
.
def
(
"ds
wig
lu"
,
transformer_engine
::
pytorch
::
ds
wig
lu
,
"Backward of S
wiG
LU"
,
py
::
arg
(
"grad"
),
py
::
arg
(
"fwd_input"
),
py
::
arg
(
"quantizer"
));
/* DBias + DAct fusions*/
m
.
def
(
"dbias_dgelu"
,
transformer_engine
::
pytorch
::
dbias_dgelu
,
"DGeLU + DBias + Quantize"
,
py
::
arg
(
"grad"
),
py
::
arg
(
"fwd_input"
),
py
::
arg
(
"quantizer"
));
m
.
def
(
"dbias_dsilu"
,
transformer_engine
::
pytorch
::
dbias_dsilu
,
"DSiLU + DBias + Quantize"
,
...
...
@@ -202,6 +217,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"weight"
),
py
::
arg
(
"eps"
),
py
::
arg
(
"ln_out"
),
py
::
arg
(
"quantizer"
),
py
::
arg
(
"otype"
),
py
::
arg
(
"sm_margin"
),
py
::
arg
(
"zero_centered_gamma"
));
m
.
def
(
"rmsnorm_bwd"
,
&
transformer_engine
::
pytorch
::
rmsnorm_bwd
,
"Backward of RMSNorm"
);
m
.
def
(
"rmsnorm_bwd_add"
,
&
transformer_engine
::
pytorch
::
rmsnorm_bwd_add
,
"Fused backward of RMSNorm + add"
);
m
.
def
(
"multi_tensor_quantize"
,
&
transformer_engine
::
pytorch
::
multi_tensor_quantize
,
"Multi-tensor quantize"
,
py
::
arg
(
"tensor_list"
),
py
::
arg
(
"quantizer_list"
));
m
.
def
(
"split_quantize"
,
&
transformer_engine
::
pytorch
::
split_quantize
,
...
...
@@ -281,6 +298,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Fused Apply RoPE FWD"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fused_rope_backward"
,
&
transformer_engine
::
pytorch
::
fused_rope_backward
,
"Fused Apply RoPE BWD"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fused_qkv_rope_forward"
,
&
transformer_engine
::
pytorch
::
fused_qkv_rope_forward
,
"Fused Apply QKV RoPE FWD"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fused_qkv_rope_backward"
,
&
transformer_engine
::
pytorch
::
fused_qkv_rope_backward
,
"Fused Apply QKV RoPE BWD"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
// fused router
m
.
def
(
"fused_topk_with_score_function_fwd"
,
...
...
@@ -308,6 +329,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"Const_buf"
),
py
::
arg
(
"tokens_per_expert"
),
py
::
arg
(
"num_rows"
),
py
::
arg
(
"num_cols"
),
py
::
arg
(
"grad_aux_loss"
),
"Fused aux loss bwd"
);
// Dropout
m
.
def
(
"dropout_fwd"
,
transformer_engine
::
pytorch
::
dropout_fwd
,
"Dropout forward with 8-bit RNG"
,
py
::
arg
(
"input"
),
py
::
arg
(
"dropout_probability"
),
py
::
arg
(
"out"
)
=
std
::
nullopt
);
m
.
def
(
"dropout_bwd"
,
transformer_engine
::
pytorch
::
dropout_bwd
,
"Dropout backward with 8-bit RNG"
,
py
::
arg
(
"grad_output"
),
py
::
arg
(
"mask"
),
py
::
arg
(
"dropout_probability"
),
py
::
arg
(
"grad_input"
)
=
std
::
nullopt
);
// Misc
m
.
def
(
"get_cublasLt_version"
,
&
transformer_engine
::
pytorch
::
get_cublasLt_version
,
"Get cublasLt version"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
...
...
transformer_engine/pytorch/csrc/quantizer.cpp
View file @
27ddce40
...
...
@@ -96,16 +96,6 @@ void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const {
at
::
TensorOptions
opts
=
opts
.
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
);
tensor
->
set_amax
(
amax
.
data_ptr
(),
GetTransformerEngineDType
(
amax
.
scalar_type
()),
getTensorShape
(
amax
));
auto
rowwise_data
=
tensor
->
get_rowwise_data
();
rowwise_data
.
dtype
=
static_cast
<
NVTEDType
>
(
dtype
);
auto
columnwise_data
=
tensor
->
get_columnwise_data
();
columnwise_data
.
dtype
=
static_cast
<
NVTEDType
>
(
dtype
);
tensor
->
set_rowwise_data
(
rowwise_data
.
data_ptr
,
static_cast
<
DType
>
(
rowwise_data
.
dtype
),
rowwise_data
.
shape
);
tensor
->
set_columnwise_data
(
columnwise_data
.
data_ptr
,
static_cast
<
DType
>
(
columnwise_data
.
dtype
),
columnwise_data
.
shape
);
}
std
::
pair
<
TensorWrapper
,
py
::
object
>
Float8Quantizer
::
create_tensor
(
...
...
@@ -318,17 +308,6 @@ void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tenso
at
::
TensorOptions
opts
=
opts
.
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
);
tensor
->
set_amax
(
amax
.
data_ptr
(),
GetTransformerEngineDType
(
amax
.
scalar_type
()),
getTensorShape
(
amax
));
// quantize output and its transpose
auto
rowwise_data
=
tensor
->
get_rowwise_data
();
rowwise_data
.
dtype
=
static_cast
<
NVTEDType
>
(
dtype
);
auto
columnwise_data
=
tensor
->
get_columnwise_data
();
columnwise_data
.
dtype
=
static_cast
<
NVTEDType
>
(
dtype
);
tensor
->
set_rowwise_data
(
rowwise_data
.
data_ptr
,
static_cast
<
DType
>
(
rowwise_data
.
dtype
),
rowwise_data
.
shape
);
tensor
->
set_columnwise_data
(
columnwise_data
.
data_ptr
,
static_cast
<
DType
>
(
columnwise_data
.
dtype
),
columnwise_data
.
shape
);
}
std
::
pair
<
TensorWrapper
,
py
::
object
>
Float8CurrentScalingQuantizer
::
create_tensor
(
...
...
@@ -518,7 +497,8 @@ void Float8CurrentScalingQuantizer::quantize_impl(const TensorWrapper& input, Te
// Compute amax
if
(
compute_amax
)
{
NVTE_SCOPED_GIL_RELEASE
({
nvte_compute_amax
(
input
.
data
(),
out
.
data
(),
stream
);
});
NVTE_SCOPED_GIL_RELEASE
(
{
nvte_compute_amax_with_config
(
input
.
data
(),
out
.
data
(),
quant_config
,
stream
);
});
}
// Perform amax reduction if needed
...
...
@@ -561,20 +541,7 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti
this
->
all_gather_usage
=
quantizer
.
attr
(
"all_gather_usage"
).
cast
<
bool
>
();
}
void
Float8BlockQuantizer
::
set_quantization_params
(
TensorWrapper
*
tensor
)
const
{
// Change the rowwise and columnwise_data to the configured dtype.
// May be a switch between E5M2 and E4M3.
auto
rowwise_data
=
tensor
->
get_rowwise_data
();
rowwise_data
.
dtype
=
static_cast
<
NVTEDType
>
(
dtype
);
auto
columnwise_data
=
tensor
->
get_columnwise_data
();
columnwise_data
.
dtype
=
static_cast
<
NVTEDType
>
(
dtype
);
tensor
->
set_rowwise_data
(
rowwise_data
.
data_ptr
,
static_cast
<
DType
>
(
rowwise_data
.
dtype
),
rowwise_data
.
shape
);
tensor
->
set_columnwise_data
(
columnwise_data
.
data_ptr
,
static_cast
<
DType
>
(
columnwise_data
.
dtype
),
columnwise_data
.
shape
);
}
void
Float8BlockQuantizer
::
set_quantization_params
(
TensorWrapper
*
tensor
)
const
{}
std
::
pair
<
TensorWrapper
,
py
::
object
>
Float8BlockQuantizer
::
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
)
const
{
...
...
@@ -916,18 +883,7 @@ MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantize
this
->
dtype
=
quantizer
.
attr
(
"dtype"
).
cast
<
DType
>
();
}
void
MXFP8Quantizer
::
set_quantization_params
(
TensorWrapper
*
tensor
)
const
{
auto
rowwise_data
=
tensor
->
get_rowwise_data
();
rowwise_data
.
dtype
=
static_cast
<
NVTEDType
>
(
dtype
);
auto
columnwise_data
=
tensor
->
get_columnwise_data
();
columnwise_data
.
dtype
=
static_cast
<
NVTEDType
>
(
dtype
);
tensor
->
set_rowwise_data
(
rowwise_data
.
data_ptr
,
static_cast
<
DType
>
(
rowwise_data
.
dtype
),
rowwise_data
.
shape
);
tensor
->
set_columnwise_data
(
columnwise_data
.
data_ptr
,
static_cast
<
DType
>
(
columnwise_data
.
dtype
),
columnwise_data
.
shape
);
}
void
MXFP8Quantizer
::
set_quantization_params
(
TensorWrapper
*
tensor
)
const
{}
std
::
pair
<
TensorWrapper
,
py
::
object
>
MXFP8Quantizer
::
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
)
const
{
...
...
transformer_engine/pytorch/graph.py
View file @
27ddce40
...
...
@@ -4,6 +4,8 @@
"""Functions for CUDA Graphs support in FP8"""
from
collections.abc
import
Iterable
import
contextlib
import
gc
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
TypeVar
,
Union
import
torch
...
...
@@ -58,6 +60,25 @@ def graph_pool_handle():
return
_graph_pool_handle
()
@
contextlib
.
contextmanager
def
_graph_context_wrapper
(
*
args
,
**
kwargs
):
"""Wrapper around `torch.cuda.graph`.
This wrapper is a temporary workaround for a PyTorch bug:
automatic garbage collection can destroy a graph while another
graph is being captured, resulting in a CUDA error. See
https://github.com/pytorch/pytorch/pull/161037.
"""
gc_is_enabled
=
gc
.
isenabled
()
if
gc_is_enabled
:
gc
.
disable
()
with
torch
.
cuda
.
graph
(
*
args
,
**
kwargs
):
yield
if
gc_is_enabled
:
gc
.
enable
()
def
_make_graphed_callables
(
callables
:
SingleOrTuple
[
Callable
],
sample_args
:
SingleOrTuple
[
Tuple
[
torch
.
Tensor
,
...]],
...
...
@@ -445,7 +466,7 @@ def _make_graphed_callables(
args
=
sample_args
[
per_callable_fwd_idx
]
kwargs
=
sample_kwargs
[
per_callable_fwd_idx
]
fwd_graph
=
fwd_graphs
[
per_callable_fwd_idx
]
with
torch
.
cuda
.
graph
(
fwd_graph
,
pool
=
mempool
):
with
_graph_context_wrapper
(
fwd_graph
,
pool
=
mempool
):
outputs
=
func
(
*
args
,
**
kwargs
)
flatten_outputs
,
spec
=
_tree_flatten
(
outputs
)
per_callable_static_outputs
[
per_callable_fwd_idx
]
=
tuple
(
flatten_outputs
)
...
...
@@ -483,7 +504,7 @@ def _make_graphed_callables(
torch
.
empty_like
(
o
)
if
o
.
requires_grad
else
None
for
o
in
static_outputs
)
if
is_training
:
with
torch
.
cuda
.
graph
(
bwd_graph
,
pool
=
mempool
):
with
_graph_context_wrapper
(
bwd_graph
,
pool
=
mempool
):
grad_inputs
=
torch
.
autograd
.
grad
(
outputs
=
tuple
(
o
for
o
in
static_outputs
if
o
.
requires_grad
),
inputs
=
tuple
(
i
for
i
in
static_input_surface
if
i
.
requires_grad
),
...
...
@@ -548,7 +569,7 @@ def _make_graphed_callables(
per_callable_output_unflatten_spec
=
[]
graph_id
=
0
for
func
,
args
,
kwargs
,
fwd_graph
in
zip
(
callables
,
sample_args
,
sample_kwargs
,
fwd_graphs
):
with
torch
.
cuda
.
graph
(
fwd_graph
,
pool
=
mempool
):
with
_graph_context_wrapper
(
fwd_graph
,
pool
=
mempool
):
outputs
=
func
(
*
args
,
**
kwargs
)
graph_callables
[
graph_id
]
=
func
graph_id
+=
1
...
...
@@ -570,7 +591,7 @@ def _make_graphed_callables(
torch
.
empty_like
(
o
)
if
o
.
requires_grad
else
None
for
o
in
static_outputs
)
if
is_training
:
with
torch
.
cuda
.
graph
(
bwd_graph
,
pool
=
mempool
):
with
_graph_context_wrapper
(
bwd_graph
,
pool
=
mempool
):
grad_inputs
=
torch
.
autograd
.
grad
(
outputs
=
tuple
(
o
for
o
in
static_outputs
if
o
.
requires_grad
),
inputs
=
tuple
(
i
for
i
in
static_input_surface
if
i
.
requires_grad
),
...
...
@@ -829,7 +850,7 @@ def make_graphed_callables(
num_warmup_iters
:
int
=
3
,
allow_unused_input
:
bool
=
False
,
sample_kwargs
:
Optional
[
SingleOrTuple
[
Dict
[
str
,
Any
]]]
=
None
,
fp8_enabled
:
bool
=
False
,
fp8_enabled
:
SingleOrTuple
[
bool
]
=
False
,
fp8_calibrating
:
bool
=
False
,
fp8_recipe
:
Optional
[
Recipe
]
=
None
,
fp8_group
:
Optional
[
dist_group_type
]
=
None
,
...
...
@@ -875,8 +896,9 @@ def make_graphed_callables(
FP8-related parameters
----------------------
fp8_enabled: bool, default = `True`
whether or not to enable fp8
fp8_enabled: (tuple of) bool, default = `False`
whether or not to enable fp8.
If tuple, the length must match the number of modules.
fp8_calibrating: bool, default = `False`
calibration mode allows collecting statistics such as amax and scale
data of fp8 tensors even when executing without fp8 enabled. This is
...
...
@@ -898,17 +920,25 @@ def make_graphed_callables(
"""
set_capture_start
()
if
fp8_enabled
and
fp8_recipe
is
None
:
fp8_recipe
=
get_default_fp8_recipe
()
elif
not
fp8_enabled
:
fp8_recipe
=
None
# Handle single module.
just_one_callable
=
False
if
not
isinstance
(
modules
,
tuple
):
just_one_callable
=
True
modules
=
(
modules
,)
if
not
isinstance
(
fp8_enabled
,
tuple
):
assert
isinstance
(
fp8_enabled
,
bool
),
"fp8_enabled must be a bool or a tuple of bools"
fp8_enabled
=
(
fp8_enabled
,)
*
len
(
modules
)
else
:
assert
len
(
fp8_enabled
)
==
len
(
modules
),
f
"fp8_enabled length (
{
len
(
fp8_enabled
)
}
) must match modules length (
{
len
(
modules
)
}
)"
if
any
(
fp8_enabled
)
and
fp8_recipe
is
None
:
fp8_recipe
=
get_default_fp8_recipe
()
elif
not
any
(
fp8_enabled
):
fp8_recipe
=
None
module_uses_fp8
=
dict
(
zip
((
id
(
m
)
for
m
in
modules
),
fp8_enabled
))
# Store FP8 tensors to reset later.
saved_fp8_tensors
=
save_fp8_tensors
(
modules
,
fp8_recipe
=
fp8_recipe
)
...
...
@@ -923,15 +953,15 @@ def make_graphed_callables(
old_call_funcs
[
block_cls
]
=
block_cls
.
__call__
# Wrap the original call function of the module class.
def
call_func
(
*
args
,
**
kwargs
):
def
call_func
(
self
,
*
args
,
**
kwargs
):
with
fp8_autocast
(
enabled
=
fp8_enabled
,
enabled
=
module_uses_fp8
.
get
(
id
(
self
),
False
)
,
calibrating
=
fp8_calibrating
,
fp8_recipe
=
fp8_recipe
,
fp8_group
=
fp8_group
,
_graph
=
True
,
):
outputs
=
old_call_funcs
[
block_cls
](
*
args
,
**
kwargs
)
outputs
=
old_call_funcs
[
block_cls
](
self
,
*
args
,
**
kwargs
)
return
outputs
block_cls
.
__call__
=
call_func
...
...
transformer_engine/pytorch/module/__init__.py
View file @
27ddce40
...
...
@@ -12,4 +12,4 @@ from .layernorm import LayerNorm
from
.rmsnorm
import
RMSNorm
from
.fp8_padding
import
Fp8Padding
from
.fp8_unpadding
import
Fp8Unpadding
from
.base
import
initialize_ub
,
destroy_ub
from
.base
import
initialize_ub
,
destroy_ub
,
UserBufferQuantizationMode
transformer_engine/pytorch/module/base.py
View file @
27ddce40
...
...
@@ -8,6 +8,7 @@ import math
import
os
import
pickle
import
warnings
from
enum
import
Enum
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
contextlib
import
contextmanager
...
...
@@ -50,7 +51,7 @@ from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTe
from
...debug.pytorch.utils
import
next_iter_when_debug_should_be_run
,
any_feature_enabled
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
__all__
=
[
"initialize_ub"
,
"destroy_ub"
]
__all__
=
[
"initialize_ub"
,
"destroy_ub"
,
"UserBufferQuantizationMode"
]
_2X_ACC_FPROP
=
False
_2X_ACC_DGRAD
=
True
...
...
@@ -66,6 +67,15 @@ _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None
layers_atomic_ring_exchange
=
[]
class
UserBufferQuantizationMode
(
Enum
):
"""
UserBufferQuantizationMode is an enum that represents the quantization mode of the UserBuffer.
"""
NONE
=
"none"
FP8
=
"fp8"
def
get_cublas_workspace_size_bytes
()
->
None
:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures."""
# Add env for control the padding for blaslt
...
...
@@ -134,8 +144,9 @@ def initialize_ub(
shape
:
list
,
tp_size
:
int
,
use_fp8
:
bool
=
False
,
quantization_modes
:
List
[
UserBufferQuantizationMode
]
=
None
,
dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
ub_cfgs
:
Optional
[
dict
]
=
None
,
ub_cfgs
:
Optional
[
Union
[
dict
,
List
[
dict
]]
]
=
None
,
bootstrap_backend
:
Union
[
str
,
torch
.
distributed
.
Backend
]
=
None
,
)
->
None
:
r
"""
...
...
@@ -151,7 +162,11 @@ def initialize_ub(
tp_size : int
number of GPUs in the tensor-parallel process group
use_fp8 : bool = False
allocate the communication buffer for FP8 GEMM inputs/outputs
allocate the communication buffer for FP8 GEMM inputs/outputs.
DEPRECATED: Please use `quantization_modes` instead.
quantization_modes : List[UserBufferQuantizationMode] = None
if a list of UserBufferQuantizationMode is provided, a UB communicator is created for each quantization setting in the list.
falls back to the legacy `use_fp8` parameter if `None` is provided.
dtype : torch.dtype = torch.bfloat16
non-FP8 data type of the communication buffer when `use_fp8 = False`
ub_cfgs: dict = None
...
...
@@ -175,6 +190,7 @@ def initialize_ub(
for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad",
"proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad",
"fc2_fprop", "fc2_wgrad"]`.
a list may be provided to specify different overlap configurations for different the quantization settings in `quantization_modes`
bootstrap_backend : str = None
`torch.distributed` communication backend for the all-gather, broadcast and
barrier collectives during Userbuffers initialization. Not all backends are
...
...
@@ -191,6 +207,28 @@ def initialize_ub(
+
"CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead."
)
if
not
quantization_modes
:
warnings
.
warn
(
"Initializing Userbuffers with use_fp8 is deprecated. Please use quantization_modes"
" instead."
,
DeprecationWarning
,
)
quantization_modes
=
[
UserBufferQuantizationMode
.
FP8
if
use_fp8
else
UserBufferQuantizationMode
.
NONE
]
else
:
assert
isinstance
(
quantization_modes
,
list
),
"quantization_modes must be a list"
assert
all
(
isinstance
(
mode
,
UserBufferQuantizationMode
)
for
mode
in
quantization_modes
),
"quantization_modes must be a list of UserBufferQuantizationMode"
if
isinstance
(
ub_cfgs
,
dict
)
or
ub_cfgs
is
None
:
ub_cfgs
=
[
ub_cfgs
]
*
len
(
quantization_modes
)
else
:
assert
len
(
ub_cfgs
)
==
len
(
quantization_modes
),
"Number of ub_cfgs settings must match number of quantization configurations"
global
_ub_communicators
assert
_ub_communicators
is
None
,
"UB communicators are already initialized."
_ub_communicators
=
{}
...
...
@@ -349,6 +387,7 @@ def initialize_ub(
def
add_ub
(
name
:
str
,
quantization_mode
:
UserBufferQuantizationMode
,
method
:
str
,
is_reduce_scatter
:
bool
,
num_sm
:
int
=
16
,
...
...
@@ -367,7 +406,9 @@ def initialize_ub(
warnings
.
warn
(
"Atomic GEMM uses a beta API from cublas and is not tested for all use cases."
)
assert
use_fp8
,
"Atomic GEMM overlap supported only for FP8 GEMM."
assert
(
quantization_mode
==
UserBufferQuantizationMode
.
FP8
),
"Atomic GEMM overlap supported only for FP8 GEMM."
if
method
in
(
"bulk"
,
"external"
):
warnings
.
warn
(
f
"At
{
name
}
, atoimic GEMM not is supported for a bulk overlap."
...
...
@@ -407,7 +448,11 @@ def initialize_ub(
f
"
{
external_gemm_to_overlap
[
name
]
}
is not using `ring_exchange` overlap method"
)
buffer_dtype
=
torch
.
uint8
if
(
use_fp8
and
fp8_buf
)
else
dtype
buffer_dtype
=
(
torch
.
uint8
if
(
quantization_mode
==
UserBufferQuantizationMode
.
FP8
and
fp8_buf
)
else
dtype
)
if
method
==
"ring_exchange"
:
ub_obj
=
tex
.
CommOverlapP2P
(
shape
,
# Communication buffer shape
...
...
@@ -441,42 +486,52 @@ def initialize_ub(
comm_priority
=
comm_priority
,
rs_overlap_first_gemm
=
pipeline_rs_overlap_first_gemm
,
)
_ub_communicators
[
name
]
=
ub_obj
if
ub_cfgs
is
not
None
:
for
name
in
dgrad_reduce_scatter_overlap
:
if
name
in
ub_cfgs
and
"method"
in
ub_cfgs
[
name
]
and
ub_cfgs
[
name
][
"method"
]
!=
"bulk"
:
wgrad_name
=
name
.
replace
(
"dgrad"
,
"wgrad"
)
assert
wgrad_name
not
in
ub_cfgs
layers_reduce_scatter_overlap
.
remove
(
wgrad_name
)
layers_all_gather_overlap
.
remove
(
name
)
layers_reduce_scatter_overlap
.
append
(
name
)
methods
[
"bulk"
].
remove
(
name
)
new_method
=
ub_cfgs
[
name
][
"method"
]
methods
[
new_method
].
append
(
name
)
for
name
in
(
methods
[
"ring_exchange"
]
+
methods
[
"pipeline"
]
+
methods
[
"bulk"
]
+
methods
[
"external"
]
):
if
name
in
remove_ag_gemm_dgrad
:
continue
ub_cfg
=
get_default_config
(
name
)
if
ub_cfgs
is
not
None
and
name
in
ub_cfgs
:
fp8_buf
=
(
name
in
layers_all_gather_overlap
)
or
(
ub_cfgs
[
name
].
get
(
"fp8_buf"
,
False
)
and
name
in
methods
[
"pipeline"
]
)
ub_cfg
.
update
(
ub_cfgs
[
name
])
ub_cfg
[
"fp8_buf"
]
=
fp8_buf
add_ub
(
name
,
**
ub_cfg
)
_ub_communicators
[(
name
,
quantization_mode
)]
=
ub_obj
for
quantization_mode
,
user_ub_cfg
in
zip
(
quantization_modes
,
ub_cfgs
):
if
user_ub_cfg
is
not
None
:
for
name
in
dgrad_reduce_scatter_overlap
:
if
(
name
in
user_ub_cfg
and
"method"
in
user_ub_cfg
[
name
]
and
user_ub_cfg
[
name
][
"method"
]
!=
"bulk"
):
wgrad_name
=
name
.
replace
(
"dgrad"
,
"wgrad"
)
assert
wgrad_name
not
in
user_ub_cfg
layers_reduce_scatter_overlap
.
remove
(
wgrad_name
)
layers_all_gather_overlap
.
remove
(
name
)
layers_reduce_scatter_overlap
.
append
(
name
)
methods
[
"bulk"
].
remove
(
name
)
new_method
=
user_ub_cfg
[
name
][
"method"
]
methods
[
new_method
].
append
(
name
)
for
name
in
(
methods
[
"ring_exchange"
]
+
methods
[
"pipeline"
]
+
methods
[
"bulk"
]
+
methods
[
"external"
]
):
if
name
in
remove_ag_gemm_dgrad
:
continue
ub_cfg
=
get_default_config
(
name
)
if
user_ub_cfg
is
not
None
and
name
in
user_ub_cfg
:
fp8_buf
=
(
name
in
layers_all_gather_overlap
)
or
(
user_ub_cfg
[
name
].
get
(
"fp8_buf"
,
False
)
and
name
in
methods
[
"pipeline"
]
)
ub_cfg
.
update
(
user_ub_cfg
[
name
])
ub_cfg
[
"fp8_buf"
]
=
fp8_buf
add_ub
(
name
,
quantization_mode
,
**
ub_cfg
)
def
get_ub
(
name
:
str
):
def
get_ub
(
name
:
str
,
use_fp8
:
bool
):
"""Get userbuffer communicator corresponding to give key."""
# For now use `use_fp8` boolean input as it matches the current design in the modules
# So favour simplicity until the correct design becomes clear.
# This is mainly an internal API so we don't need to worry about future changes
key
=
(
name
,
UserBufferQuantizationMode
.
FP8
if
use_fp8
else
UserBufferQuantizationMode
.
NONE
)
assert
_ub_communicators
is
not
None
,
"UB manager is not initialized."
assert
key
in
_ub_communicators
,
f
"UB for
{
name
}
with use_fp8=
{
use_fp8
}
is not registered."
# assert name in _ub_communicators, f"UB for {name} is not registered."
if
name
in
remove_ag_gemm_dgrad
:
return
None
return
_ub_communicators
[
name
]
return
_ub_communicators
[
key
]
def
destroy_ub
():
...
...
@@ -1472,8 +1527,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
(
wgrad
,
bgrad
),
_
=
self
.
wgrad_store
.
pop
()
if
not
self
.
fuse_wgrad_accumulation
:
weight_tensor
=
noop_cat
(
self
.
_get_weight_tensors
())
if
weight_tensor
.
grad
is
None
:
weight_tensor
.
grad
=
wgrad
.
to
(
weight_tensor
.
dtype
)
weight_tensor
.
grad
=
wgrad
.
to
(
weight_tensor
.
dtype
)
if
self
.
use_bias
:
bias_tensor
=
noop_cat
([
getattr
(
self
,
name
)
for
name
in
self
.
bias_names
])
if
bias_tensor
.
grad
is
None
:
...
...
transformer_engine/pytorch/module/grouped_linear.py
View file @
27ddce40
...
...
@@ -859,8 +859,7 @@ class GroupedLinear(TransformerEngineBaseModule):
bias_params
=
[
getattr
(
self
,
f
"bias
{
i
}
"
)
for
i
in
range
(
self
.
num_gemms
)]
if
not
self
.
fuse_wgrad_accumulation
:
for
i
in
range
(
self
.
num_gemms
):
if
weight_params
[
i
].
grad
is
None
:
weight_params
[
i
].
grad
=
wgrad_list
[
i
].
to
(
weight_params
[
i
].
dtype
)
weight_params
[
i
].
grad
=
wgrad_list
[
i
].
to
(
weight_params
[
i
].
dtype
)
if
self
.
use_bias
:
for
i
in
range
(
self
.
num_gemms
):
if
bias_params
[
i
].
grad
is
None
:
...
...
@@ -917,7 +916,7 @@ class GroupedLinear(TransformerEngineBaseModule):
def
_get_weight_quantizers
(
self
)
->
List
[
Quantizer
]:
"""Get the weight quantizers of the module."""
if
not
self
.
fp8
:
if
not
self
.
fp8
and
not
self
.
fp8_calibration
:
return
[
None
]
*
self
.
num_gemms
weight_quantizers
=
[
self
.
quantizers
[
"scaling_fwd"
][
...
...
transformer_engine/pytorch/module/layernorm_linear.py
View file @
27ddce40
...
...
@@ -181,10 +181,10 @@ class _LayerNormLinear(torch.autograd.Function):
ub_overlap_ag_fprop
and
is_grad_enabled
and
not
return_layernorm_output
)
if
ub_overlap_rs_fprop
:
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
)
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
,
fp8
)
ub_type
=
tex
.
CommOverlapType
.
RS
elif
ub_overlap_ag_fprop
:
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
)
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
,
fp8
)
ub_type
=
tex
.
CommOverlapType
.
AG
# Configure quantizer for norm output
...
...
@@ -361,8 +361,11 @@ class _LayerNormLinear(torch.autograd.Function):
# Deallocate GEMM input tensor if no longer needed
if
not
weight
.
requires_grad
and
not
return_layernorm_output
:
ln_out
=
ln_out_total
=
None
clear_tensor_data
(
ln_out
,
ln_out_total
)
ln_out
=
ln_out_total
=
None
elif
with_input_all_gather
and
not
return_layernorm_output_gathered
:
clear_tensor_data
(
ln_out_total
)
ln_out_total
=
None
# ------------------------------------------------------
# Prepare output tensor
...
...
@@ -608,23 +611,23 @@ class _LayerNormLinear(torch.autograd.Function):
dgrad_shape
=
[
reduce
(
multiply_op
,
ctx
.
inp_shape
[:
-
1
]),
ctx
.
inp_shape
[
-
1
]]
if
ctx
.
ub_overlap_ag
:
# Overlap grad_output all-gather with dgrad compute
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
,
ctx
.
fp8
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
AG
elif
ctx
.
ub_overlap_rs_dgrad
:
# Overlap dgrad reduce-scatter with dgrad compute
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
,
ctx
.
fp8
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
RS
else
:
if
ctx
.
ub_bulk_dgrad
:
# Overlap inputmat all-gather with dgrad compute
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
,
ctx
.
fp8
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
AG
if
ctx
.
ub_bulk_wgrad
:
# Overlap dgrad reduce-scatter with wgrad compute
ub_obj_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
)
ub_obj_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
,
ctx
.
fp8
)
ub_type_wgrad
=
tex
.
CommOverlapType
.
RS
# --------------------------------------------------
...
...
@@ -802,7 +805,7 @@ class _LayerNormLinear(torch.autograd.Function):
dgrad_send_stream
,
dgrad_recv_stream
=
ub_obj_dgrad
.
get_communication_stream
()
# This object is separate from the ub_obj_wgrad object which is passed to the GEMM
ub_obj_overlap_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
)
ub_obj_overlap_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
,
ctx
.
fp8
)
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
...
...
@@ -927,9 +930,19 @@ class _LayerNormLinear(torch.autograd.Function):
grad_bias
=
grad_bias_
del
grad_bias_
# Deallocate input tensor if permitted
if
not
ctx
.
return_layernorm_output
:
# Deallocate input tensors if permitted
if
not
ctx
.
return_layernorm_output
and
not
ctx
.
return_layernorm_output_gathered
:
# Input tensors have not been exposed externally
clear_tensor_data
(
ln_out
)
elif
ctx
.
ln_out_needs_gather
and
ctx
.
return_layernorm_output_gathered
:
# Non-gathered input has not been exposed externally
clear_tensor_data
(
ln_out
)
if
ctx
.
ln_out_needs_gather
:
# Gathered input is internal
clear_tensor_data
(
ln_out_total
)
if
ctx
.
parallel_mode
==
"row"
and
ctx
.
sequence_parallel
:
# Gathered grad output tensor is internal
clear_tensor_data
(
grad_output
)
# Update grad input if overlapping reduce-scatter with wgrad GEMM
if
ctx
.
ub_bulk_wgrad
:
...
...
@@ -1209,7 +1222,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
self
.
return_bias
=
return_bias
self
.
apply_bias
=
self
.
use_bias
and
not
return_bias
self
.
return_layernorm_output
=
return_layernorm_output
self
.
return_layernorm_output_gathered
=
return_layernorm_output_gathered
self
.
return_layernorm_output_gathered
=
(
return_layernorm_output_gathered
if
return_layernorm_output
else
False
)
self
.
zero_centered_gamma
=
zero_centered_gamma
self
.
symmetric_ar_type
=
symmetric_ar_type
...
...
@@ -1532,10 +1547,14 @@ class LayerNormLinear(TransformerEngineBaseModule):
is_first_microbatch
=
False
if
self
.
ub_overlap_rs_fprop
:
if
get_ub
(
self
.
ub_name
+
"_fprop"
).
is_fp8_ubuf
():
if
get_ub
(
self
.
ub_name
+
"_fprop"
,
FP8GlobalStateManager
.
is_fp8_enabled
()
).
is_fp8_ubuf
():
fp8_output
=
True
if
self
.
ub_overlap_rs_dgrad
:
if
get_ub
(
self
.
ub_name
+
"_dgrad"
).
is_fp8_ubuf
():
if
get_ub
(
self
.
ub_name
+
"_dgrad"
,
FP8GlobalStateManager
.
is_fp8_enabled
()
).
is_fp8_ubuf
():
fp8_grad
=
True
with
torch
.
cuda
.
device
(
...
...
@@ -1803,7 +1822,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
def
_get_weight_quantizers
(
self
)
->
List
[
Quantizer
]:
"""Get the weight quantizers of the module."""
if
not
self
.
fp8
:
if
not
self
.
fp8
and
not
self
.
fp8_calibration
:
return
[
None
]
weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
]
weight_quantizer
.
internal
=
True
...
...
Prev
1
…
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment