Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
4099aa8e
Commit
4099aa8e
authored
Mar 20, 2025
by
yuguo
Browse files
Merge branch 'main' of
https://github.com/NVIDIA/TransformerEngine
parents
c520cba3
96f9c6de
Changes
49
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1372 additions
and
291 deletions
+1372
-291
transformer_engine/pytorch/csrc/extensions/pybind.cpp
transformer_engine/pytorch/csrc/extensions/pybind.cpp
+14
-8
transformer_engine/pytorch/csrc/kv_cache.cuh
transformer_engine/pytorch/csrc/kv_cache.cuh
+145
-0
transformer_engine/pytorch/distributed.py
transformer_engine/pytorch/distributed.py
+15
-2
transformer_engine/pytorch/dot_product_attention/inference.py
...sformer_engine/pytorch/dot_product_attention/inference.py
+777
-32
transformer_engine/pytorch/dot_product_attention/utils.py
transformer_engine/pytorch/dot_product_attention/utils.py
+336
-190
transformer_engine/pytorch/graph.py
transformer_engine/pytorch/graph.py
+44
-28
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+18
-11
transformer_engine/pytorch/module/layernorm_mlp.py
transformer_engine/pytorch/module/layernorm_mlp.py
+7
-17
transformer_engine/pytorch/module/linear.py
transformer_engine/pytorch/module/linear.py
+16
-3
No files found.
transformer_engine/pytorch/csrc/extensions/pybind.cpp
View file @
4099aa8e
...
@@ -177,16 +177,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -177,16 +177,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifdef USE_ROCM
#ifdef USE_ROCM
m
.
def
(
"te_general_batched_gemm"
,
&
te_general_batched_gemm
,
"Batched GEMM"
);
/// rocblas
m
.
def
(
"te_general_batched_gemm"
,
&
te_general_batched_gemm
,
"Batched GEMM"
);
/// rocblas
#endif
#endif
m
.
def
(
"fused_attn_fwd"
,
&
fused_attn_fwd
,
"Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V"
);
m
.
def
(
"fused_attn_bwd"
,
&
fused_attn_bwd
,
"Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V"
);
m
.
def
(
"fp8_transpose"
,
&
fp8_transpose
,
"Transpose with FP8 I/O"
,
py
::
arg
(
"input"
),
m
.
def
(
"fp8_transpose"
,
&
fp8_transpose
,
"Transpose with FP8 I/O"
,
py
::
arg
(
"input"
),
py
::
arg
(
"dtype"
),
py
::
kw_only
(),
py
::
arg
(
"out"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
arg
(
"dtype"
),
py
::
kw_only
(),
py
::
arg
(
"out"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fa_prepare_fwd"
,
&
fa_prepare_fwd
,
"Prepare QKV for Flash Attention"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fa_prepare_bwd"
,
&
fa_prepare_bwd
,
"Backward of QKV preparation for Flash Attention"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"get_fused_attn_backend"
,
&
get_fused_attn_backend
,
"Get Fused Attention backend"
,
m
.
def
(
"get_fused_attn_backend"
,
&
get_fused_attn_backend
,
"Get Fused Attention backend"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fused_amax_and_scale_update_after_reduction"
,
&
fused_amax_and_scale_update_after_reduction
,
m
.
def
(
"fused_amax_and_scale_update_after_reduction"
,
&
fused_amax_and_scale_update_after_reduction
,
...
@@ -194,6 +186,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -194,6 +186,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fused_multi_row_padding"
,
&
fused_multi_row_padding
,
"Fused Multi-tensor padding"
,
m
.
def
(
"fused_multi_row_padding"
,
&
fused_multi_row_padding
,
"Fused Multi-tensor padding"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
call_guard
<
py
::
gil_scoped_release
>
());
// attention kernels
m
.
def
(
"fa_prepare_fwd"
,
&
fa_prepare_fwd
,
"Prepare QKV for Flash Attention"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fa_prepare_bwd"
,
&
fa_prepare_bwd
,
"Backward of QKV preparation for Flash Attention"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fused_attn_fwd"
,
&
fused_attn_fwd
,
"Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V"
);
m
.
def
(
"fused_attn_bwd"
,
&
fused_attn_bwd
,
"Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V"
);
m
.
def
(
"copy_to_kv_cache"
,
&
copy_to_kv_cache
,
"Copy new KV tokens to KV cache"
);
m
.
def
(
"convert_thd_to_bshd"
,
&
convert_thd_to_bshd
,
"Convert a tensor from THD to BSHD"
);
m
.
def
(
"convert_bshd_to_thd"
,
&
convert_bshd_to_thd
,
"Convert a tesnor from BSHD to THD"
);
// fused apply rope
// fused apply rope
m
.
def
(
"fused_rope_forward"
,
&
fused_rope_forward
,
"Fused Apply RoPE FWD"
,
m
.
def
(
"fused_rope_forward"
,
&
fused_rope_forward
,
"Fused Apply RoPE FWD"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
call_guard
<
py
::
gil_scoped_release
>
());
...
...
transformer_engine/pytorch/csrc/kv_cache.cuh
0 → 100644
View file @
4099aa8e
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_KV_CACHE_CUH_
#define TRANSFORMER_ENGINE_FUSED_ATTN_KV_CACHE_CUH_
namespace
transformer_engine
{
namespace
fused_attn
{
template
<
typename
scalar_t
>
__global__
void
convert_thd_to_bshd_kernel
(
scalar_t
*
tensor
,
scalar_t
*
new_tensor
,
int
*
cu_seqlens
,
int
b
,
int
max_seq_len
,
int
h
,
int
d
)
{
// tensor: thd; new_tensor: bshd
// cu_seqlens: [b + 1]
for
(
int
batch_idx
=
blockIdx
.
x
;
batch_idx
<
b
;
batch_idx
+=
gridDim
.
x
)
{
int
num_elts
=
(
cu_seqlens
[
batch_idx
+
1
]
-
cu_seqlens
[
batch_idx
])
*
h
*
d
;
int
thd_offset
=
cu_seqlens
[
batch_idx
]
*
h
*
d
;
int
bshd_offset
=
batch_idx
*
max_seq_len
*
h
*
d
;
scalar_t
*
thd_token
=
tensor
+
thd_offset
;
scalar_t
*
bshd_token
=
new_tensor
+
bshd_offset
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_elts
;
i
+=
blockDim
.
x
)
{
*
(
bshd_token
+
i
)
=
*
(
thd_token
+
i
);
}
}
}
template
<
typename
scalar_t
>
__global__
void
convert_bshd_to_thd_kernel
(
scalar_t
*
tensor
,
scalar_t
*
new_tensor
,
int
*
cu_seqlens
,
int
b
,
int
max_seq_len
,
int
h
,
int
d
)
{
// tensor: bshd; new_tensor: thd
// cu_seqlens: [b + 1]
for
(
int
batch_idx
=
blockIdx
.
x
;
batch_idx
<
b
;
batch_idx
+=
gridDim
.
x
)
{
int
seqlen
=
cu_seqlens
[
batch_idx
+
1
]
-
cu_seqlens
[
batch_idx
];
int
num_elts
=
seqlen
*
h
*
d
;
int
bshd_offset
=
batch_idx
*
max_seq_len
*
h
*
d
;
int
thd_offset
=
cu_seqlens
[
batch_idx
]
*
h
*
d
;
scalar_t
*
bshd_token
=
tensor
+
bshd_offset
;
scalar_t
*
thd_token
=
new_tensor
+
thd_offset
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_elts
;
i
+=
blockDim
.
x
)
{
*
(
thd_token
+
i
)
=
*
(
bshd_token
+
i
);
}
}
}
template
<
typename
scalar_t
>
__global__
void
reindex_kv_cache_kernel
(
scalar_t
*
k_cache
,
scalar_t
*
v_cache
,
int
*
batch_indices
,
int
*
cu_new_lens
,
int
*
cu_cached_lens
,
int
h_kv
,
int
d_k
,
int
d_v
,
int
b
,
int
max_seq_len
)
{
// k_cache, v_cache: bshd
// batch_indices: [b]; cu_new_lens, cu_cached_lens: [b + 1]
int
actual_b
=
b
;
for
(
int
i
=
0
;
i
<
b
-
1
;
i
++
)
{
if
(
batch_indices
[
i
+
1
]
<
batch_indices
[
i
])
{
actual_b
=
i
+
1
;
}
}
for
(
int
batch_idx
=
0
;
batch_idx
<
actual_b
;
batch_idx
++
)
{
int
cached_len
=
cu_cached_lens
[
batch_idx
+
1
]
-
cu_cached_lens
[
batch_idx
];
int
new_len
=
cu_new_lens
[
batch_idx
+
1
]
-
cu_new_lens
[
batch_idx
];
for
(
int
token_idx
=
blockIdx
.
x
;
token_idx
<
cached_len
-
new_len
;
token_idx
+=
gridDim
.
x
)
{
int
num_elts_k
=
h_kv
*
d_k
;
int
num_elts_v
=
h_kv
*
d_v
;
int
k_cache_src_offset
=
(
batch_indices
[
batch_idx
]
*
max_seq_len
+
token_idx
)
*
h_kv
*
d_k
;
int
k_cache_des_offset
=
(
batch_idx
*
max_seq_len
+
token_idx
)
*
h_kv
*
d_k
;
int
v_cache_src_offset
=
(
batch_indices
[
batch_idx
]
*
max_seq_len
+
token_idx
)
*
h_kv
*
d_v
;
int
v_cache_des_offset
=
(
batch_idx
*
max_seq_len
+
token_idx
)
*
h_kv
*
d_v
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_elts_k
;
i
+=
blockDim
.
x
)
{
*
(
k_cache
+
k_cache_des_offset
+
i
)
=
*
(
k_cache
+
k_cache_src_offset
+
i
);
}
for
(
int
i
=
threadIdx
.
x
;
i
<
num_elts_v
;
i
+=
blockDim
.
x
)
{
*
(
v_cache
+
v_cache_des_offset
+
i
)
=
*
(
v_cache
+
v_cache_src_offset
+
i
);
}
}
}
}
template
<
typename
scalar_t
>
__global__
void
copy_to_kv_cache_kernel
(
scalar_t
*
new_k
,
scalar_t
*
new_v
,
scalar_t
*
k_cache
,
scalar_t
*
v_cache
,
int
*
page_table
,
int
*
cu_new_lens
,
int
*
cu_cached_lens
,
NVTE_QKV_Format
qkv_format
,
int
h_kv
,
int
d_k
,
int
d_v
,
int
b
,
int
max_ctx_len
,
int
max_seq_len
,
int
max_pages_per_seq
,
bool
is_non_paged
)
{
// new_k, new_v: qkv_format; k_cache, v_cache: bshd
// cu_new_lens, cu_cached_lens: [b + 1]
// page_table: [b, max_pages_per_seq]
int
page_size
=
max_seq_len
/
max_pages_per_seq
;
if
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_BSHD
)
{
for
(
int
batch_idx
=
blockIdx
.
x
;
batch_idx
<
b
;
batch_idx
+=
gridDim
.
x
)
{
int
*
page_list
=
is_non_paged
?
nullptr
:
page_table
+
batch_idx
*
max_pages_per_seq
;
int
new_token_offset
=
batch_idx
*
max_ctx_len
;
int
cached_len
=
cu_cached_lens
[
batch_idx
+
1
]
-
cu_cached_lens
[
batch_idx
];
int
new_len
=
cu_new_lens
[
batch_idx
+
1
]
-
cu_new_lens
[
batch_idx
];
for
(
int
i
=
threadIdx
.
x
;
i
<
new_len
;
i
+=
blockDim
.
x
)
{
int
page_idx
=
is_non_paged
?
batch_idx
:
page_list
[(
cached_len
-
new_len
+
i
)
/
page_size
];
int
token_idx
=
page_idx
*
page_size
+
(
cached_len
-
new_len
+
i
)
%
page_size
;
for
(
int
j
=
0
;
j
<
h_kv
*
d_k
;
j
++
)
{
*
(
k_cache
+
token_idx
*
h_kv
*
d_k
+
j
)
=
*
(
new_k
+
(
new_token_offset
+
i
)
*
h_kv
*
d_k
+
j
);
}
for
(
int
j
=
0
;
j
<
h_kv
*
d_v
;
j
++
)
{
*
(
v_cache
+
token_idx
*
h_kv
*
d_v
+
j
)
=
*
(
new_v
+
(
new_token_offset
+
i
)
*
h_kv
*
d_v
+
j
);
}
}
}
}
else
if
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
)
{
for
(
int
batch_idx
=
blockIdx
.
x
;
batch_idx
<
b
;
batch_idx
+=
gridDim
.
x
)
{
int
*
page_list
=
is_non_paged
?
nullptr
:
page_table
+
batch_idx
*
max_pages_per_seq
;
int
cached_len
=
cu_cached_lens
[
batch_idx
+
1
]
-
cu_cached_lens
[
batch_idx
];
int
new_len
=
cu_new_lens
[
batch_idx
+
1
]
-
cu_new_lens
[
batch_idx
];
for
(
int
i
=
threadIdx
.
x
;
i
<
new_len
;
i
+=
blockDim
.
x
)
{
int
page_idx
=
is_non_paged
?
batch_idx
:
page_list
[(
cached_len
-
new_len
+
i
)
/
page_size
];
int
token_idx
=
page_idx
*
page_size
+
(
cached_len
-
new_len
+
i
)
%
page_size
;
for
(
int
j
=
0
;
j
<
h_kv
*
d_k
;
j
++
)
{
*
(
k_cache
+
token_idx
*
h_kv
*
d_k
+
j
)
=
*
(
new_k
+
(
i
*
b
+
batch_idx
)
*
h_kv
*
d_k
+
j
);
}
for
(
int
j
=
0
;
j
<
h_kv
*
d_v
;
j
++
)
{
*
(
v_cache
+
token_idx
*
h_kv
*
d_v
+
j
)
=
*
(
new_v
+
(
i
*
b
+
batch_idx
)
*
h_kv
*
d_v
+
j
);
}
}
}
}
else
if
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_THD
)
{
for
(
int
batch_idx
=
blockIdx
.
x
;
batch_idx
<
b
;
batch_idx
+=
gridDim
.
x
)
{
int
*
page_list
=
is_non_paged
?
nullptr
:
page_table
+
batch_idx
*
max_pages_per_seq
;
int
cached_len
=
cu_cached_lens
[
batch_idx
+
1
]
-
cu_cached_lens
[
batch_idx
];
int
new_len
=
cu_new_lens
[
batch_idx
+
1
]
-
cu_new_lens
[
batch_idx
];
for
(
int
i
=
threadIdx
.
x
;
i
<
new_len
;
i
+=
blockDim
.
x
)
{
int
page_idx
=
is_non_paged
?
batch_idx
:
page_list
[(
cached_len
-
new_len
+
i
)
/
page_size
];
int
token_idx
=
page_idx
*
page_size
+
(
cached_len
-
new_len
+
i
)
%
page_size
;
for
(
int
j
=
0
;
j
<
h_kv
*
d_k
;
j
++
)
{
*
(
k_cache
+
token_idx
*
h_kv
*
d_k
+
j
)
=
*
(
new_k
+
(
cu_new_lens
[
batch_idx
]
+
i
)
*
h_kv
*
d_k
+
j
);
}
for
(
int
j
=
0
;
j
<
h_kv
*
d_v
;
j
++
)
{
*
(
v_cache
+
token_idx
*
h_kv
*
d_v
+
j
)
=
*
(
new_v
+
(
cu_new_lens
[
batch_idx
]
+
i
)
*
h_kv
*
d_v
+
j
);
}
}
}
}
}
}
// namespace fused_attn
}
// namespace transformer_engine
#endif
transformer_engine/pytorch/distributed.py
View file @
4099aa8e
...
@@ -20,7 +20,7 @@ from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_module
...
@@ -20,7 +20,7 @@ from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_module
from
.utils
import
safely_set_viewless_tensor_data
from
.utils
import
safely_set_viewless_tensor_data
from
.constants
import
dist_group_type
from
.constants
import
dist_group_type
from
.fp8
import
FP8GlobalStateManager
from
.fp8
import
FP8GlobalStateManager
,
fp8_autocast
from
.tensor.float8_tensor
import
Float8Quantizer
,
Float8Tensor
,
Float8CurrentScalingQuantizer
from
.tensor.float8_tensor
import
Float8Quantizer
,
Float8Tensor
,
Float8CurrentScalingQuantizer
from
.tensor.mxfp8_tensor
import
MXFP8Quantizer
from
.tensor.mxfp8_tensor
import
MXFP8Quantizer
from
.tensor.quantized_tensor
import
QuantizedTensor
,
Quantizer
from
.tensor.quantized_tensor
import
QuantizedTensor
,
Quantizer
...
@@ -328,11 +328,14 @@ class _CheckpointFunction(torch.autograd.Function):
...
@@ -328,11 +328,14 @@ class _CheckpointFunction(torch.autograd.Function):
tensor_inputs
=
[
arg
if
torch
.
is_tensor
(
arg
)
else
None
for
arg
in
args
]
tensor_inputs
=
[
arg
if
torch
.
is_tensor
(
arg
)
else
None
for
arg
in
args
]
ctx
.
save_for_backward
(
*
tensor_inputs
)
ctx
.
save_for_backward
(
*
tensor_inputs
)
fp8
=
FP8GlobalStateManager
.
is_fp8_enabled
()
ctx
.
get_rng_state_tracker
=
get_rng_state_tracker
ctx
.
get_rng_state_tracker
=
get_rng_state_tracker
ctx
.
tp_group
=
tp_group
ctx
.
tp_group
=
tp_group
ctx
.
recompute_ctx
=
recompute_ctx
ctx
.
recompute_ctx
=
recompute_ctx
ctx
.
torch_gpu_amp_ctx
=
torch_gpu_amp_ctx
ctx
.
torch_gpu_amp_ctx
=
torch_gpu_amp_ctx
ctx
.
torch_cpu_amp_ctx
=
torch_cpu_amp_ctx
ctx
.
torch_cpu_amp_ctx
=
torch_cpu_amp_ctx
ctx
.
fp8
=
fp8
ctx
.
fp8_recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
fp8
else
None
ctx
.
kwargs
=
kwargs
ctx
.
kwargs
=
kwargs
return
outputs
return
outputs
...
@@ -375,6 +378,8 @@ class _CheckpointFunction(torch.autograd.Function):
...
@@ -375,6 +378,8 @@ class _CheckpointFunction(torch.autograd.Function):
detached_inputs
=
detach_variable
(
inputs
)
detached_inputs
=
detach_variable
(
inputs
)
with
torch
.
enable_grad
(),
ctx
.
recompute_ctx
,
ctx
.
torch_gpu_amp_ctx
,
ctx
.
torch_cpu_amp_ctx
,
activation_recompute_forward
(
with
torch
.
enable_grad
(),
ctx
.
recompute_ctx
,
ctx
.
torch_gpu_amp_ctx
,
ctx
.
torch_cpu_amp_ctx
,
activation_recompute_forward
(
activation_recompute
=
True
,
recompute_phase
=
True
activation_recompute
=
True
,
recompute_phase
=
True
),
fp8_autocast
(
enabled
=
ctx
.
fp8
,
fp8_recipe
=
ctx
.
fp8_recipe
):
):
outputs
=
ctx
.
run_function
(
*
detached_inputs
,
**
ctx
.
kwargs
)
outputs
=
ctx
.
run_function
(
*
detached_inputs
,
**
ctx
.
kwargs
)
...
@@ -398,6 +403,9 @@ class _CheckpointFunction(torch.autograd.Function):
...
@@ -398,6 +403,9 @@ class _CheckpointFunction(torch.autograd.Function):
"none of output has requires_grad=True, this checkpoint() is not necessary"
"none of output has requires_grad=True, this checkpoint() is not necessary"
)
)
# backward does not require entering autocast context because
# backward implementations already retrieve fp8 recipe and
# enablement from stored ctx.
torch
.
autograd
.
backward
(
outputs_with_grad
,
args_with_grad
)
torch
.
autograd
.
backward
(
outputs_with_grad
,
args_with_grad
)
grads
=
tuple
(
grads
=
tuple
(
inp
.
grad
if
isinstance
(
inp
,
torch
.
Tensor
)
else
None
for
inp
in
detached_inputs
inp
.
grad
if
isinstance
(
inp
,
torch
.
Tensor
)
else
None
for
inp
in
detached_inputs
...
@@ -694,10 +702,15 @@ def checkpoint(
...
@@ -694,10 +702,15 @@ def checkpoint(
# Preserve the torch autocast contexts from the forward pass during recompute phase.
# Preserve the torch autocast contexts from the forward pass during recompute phase.
torch_gpu_amp_forward_ctx
,
torch_cpu_amp_forward_ctx
=
_get_active_autocast_contexts
()
torch_gpu_amp_forward_ctx
,
torch_cpu_amp_forward_ctx
=
_get_active_autocast_contexts
()
fp8
=
FP8GlobalStateManager
.
is_fp8_enabled
()
fp8_recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
fp8
else
None
def
recompute_fn
(
*
args
,
**
kwargs
):
def
recompute_fn
(
*
args
,
**
kwargs
):
with
torch
.
autograd
.
enable_grad
(),
(
with
torch
.
autograd
.
enable_grad
(),
(
te_recompute_ctx
te_recompute_ctx
),
user_recompute_ctx
,
torch_gpu_amp_forward_ctx
,
torch_cpu_amp_forward_ctx
:
),
user_recompute_ctx
,
torch_gpu_amp_forward_ctx
,
torch_cpu_amp_forward_ctx
,
fp8_autocast
(
enabled
=
fp8
,
fp8_recipe
=
fp8_recipe
):
function
(
*
args
,
**
kwargs
)
function
(
*
args
,
**
kwargs
)
# Initialize a new checkpoint frame for each new forward pass.
# Initialize a new checkpoint frame for each new forward pass.
...
...
transformer_engine/pytorch/dot_product_attention/inference.py
View file @
4099aa8e
This diff is collapsed.
Click to expand it.
transformer_engine/pytorch/dot_product_attention/utils.py
View file @
4099aa8e
This diff is collapsed.
Click to expand it.
transformer_engine/pytorch/graph.py
View file @
4099aa8e
...
@@ -91,6 +91,14 @@ def _make_graphed_callables(
...
@@ -91,6 +91,14 @@ def _make_graphed_callables(
sample_args
=
(
sample_args
,)
sample_args
=
(
sample_args
,)
sample_kwargs
=
(
sample_kwargs
,)
sample_kwargs
=
(
sample_kwargs
,)
# Check training/inference
is_training
=
all
(
c
.
training
for
c
in
callables
)
if
not
is_training
and
any
(
c
.
training
for
c
in
callables
):
assert
False
,
(
"make_graphed_callables only supports when modules are all in training or all in"
" inference mode."
)
# Check sizes of args
# Check sizes of args
if
_order
is
None
:
if
_order
is
None
:
assert
len
(
sample_args
)
==
len
(
callables
)
assert
len
(
sample_args
)
==
len
(
callables
)
...
@@ -255,13 +263,16 @@ def _make_graphed_callables(
...
@@ -255,13 +263,16 @@ def _make_graphed_callables(
outputs
,
_
=
_tree_flatten
(
func
(
*
args
,
**
kwargs
))
outputs
,
_
=
_tree_flatten
(
func
(
*
args
,
**
kwargs
))
for
hook
in
hooks
:
for
hook
in
hooks
:
hook
.
remove
()
hook
.
remove
()
grad_inputs
=
torch
.
autograd
.
grad
(
if
is_training
:
outputs
=
tuple
(
o
for
o
in
outputs
if
o
.
requires_grad
),
grad_inputs
=
torch
.
autograd
.
grad
(
inputs
=
tuple
(
i
for
i
in
static_input_surface
if
i
.
requires_grad
),
outputs
=
tuple
(
o
for
o
in
outputs
if
o
.
requires_grad
),
grad_outputs
=
tuple
(
torch
.
empty_like
(
o
)
for
o
in
outputs
if
o
.
requires_grad
),
inputs
=
tuple
(
i
for
i
in
static_input_surface
if
i
.
requires_grad
),
only_inputs
=
True
,
grad_outputs
=
tuple
(
torch
.
empty_like
(
o
)
for
o
in
outputs
if
o
.
requires_grad
),
allow_unused
=
allow_unused_input
,
only_inputs
=
True
,
)
allow_unused
=
allow_unused_input
,
)
else
:
grad_inputs
=
None
del
outputs
,
grad_inputs
del
outputs
,
grad_inputs
# The following code is added specifically for MCore's special requirements,
# The following code is added specifically for MCore's special requirements,
# aimed at preventing warmup from altering the control flow.
# aimed at preventing warmup from altering the control flow.
...
@@ -314,22 +325,23 @@ def _make_graphed_callables(
...
@@ -314,22 +325,23 @@ def _make_graphed_callables(
static_grad_outputs
=
tuple
(
static_grad_outputs
=
tuple
(
torch
.
empty_like
(
o
)
if
o
.
requires_grad
else
None
for
o
in
static_outputs
torch
.
empty_like
(
o
)
if
o
.
requires_grad
else
None
for
o
in
static_outputs
)
)
with
torch
.
cuda
.
graph
(
bwd_graph
,
pool
=
mempool
):
if
is_training
:
grad_inputs
=
torch
.
autograd
.
grad
(
with
torch
.
cuda
.
graph
(
bwd_graph
,
pool
=
mempool
):
outputs
=
tuple
(
o
for
o
in
static_outputs
if
o
.
requires_grad
),
grad_inputs
=
torch
.
autograd
.
grad
(
inputs
=
tuple
(
i
for
i
in
static_input_surface
if
i
.
requires_grad
),
outputs
=
tuple
(
o
for
o
in
static_outputs
if
o
.
requires_grad
),
grad_outputs
=
tuple
(
o
for
o
in
static_grad_outputs
if
o
is
not
None
),
inputs
=
tuple
(
i
for
i
in
static_input_surface
if
i
.
requires_grad
),
only_inputs
=
True
,
grad_outputs
=
tuple
(
o
for
o
in
static_grad_outputs
if
o
is
not
None
),
allow_unused
=
allow_unused_input
,
only_inputs
=
True
,
retain_graph
=
retain_graph_in_backward
,
allow_unused
=
allow_unused_input
,
)
retain_graph
=
retain_graph_in_backward
,
)
# Constructs a tuple suitable for returning from Graphed.backward:
# Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs
# Pads out the actually-needed grads with Nones in gradient slots for inputs
# that don't require grad. I couldn't think of a one-liner for this pattern.
# that don't require grad. I couldn't think of a one-liner for this pattern.
static_grad_inputs
=
[]
static_grad_inputs
=
[]
grad_idx
=
0
grad_idx
=
0
for
arg
in
static_input_surface
:
for
arg
in
static_input_surface
:
if
arg
.
requires_grad
:
if
is_training
and
isinstance
(
arg
,
torch
.
Tensor
)
and
arg
.
requires_grad
:
static_grad_inputs
.
append
(
grad_inputs
[
grad_idx
])
static_grad_inputs
.
append
(
grad_inputs
[
grad_idx
])
grad_idx
+=
1
grad_idx
+=
1
else
:
else
:
...
@@ -366,22 +378,23 @@ def _make_graphed_callables(
...
@@ -366,22 +378,23 @@ def _make_graphed_callables(
static_grad_outputs
=
tuple
(
static_grad_outputs
=
tuple
(
torch
.
empty_like
(
o
)
if
o
.
requires_grad
else
None
for
o
in
static_outputs
torch
.
empty_like
(
o
)
if
o
.
requires_grad
else
None
for
o
in
static_outputs
)
)
with
torch
.
cuda
.
graph
(
bwd_graph
,
pool
=
mempool
):
if
is_training
:
grad_inputs
=
torch
.
autograd
.
grad
(
with
torch
.
cuda
.
graph
(
bwd_graph
,
pool
=
mempool
):
outputs
=
tuple
(
o
for
o
in
static_outputs
if
o
.
requires_grad
),
grad_inputs
=
torch
.
autograd
.
grad
(
inputs
=
tuple
(
i
for
i
in
static_input_surface
if
i
.
requires_grad
),
outputs
=
tuple
(
o
for
o
in
static_outputs
if
o
.
requires_grad
),
grad_outputs
=
tuple
(
o
for
o
in
static_grad_outputs
if
o
is
not
None
),
inputs
=
tuple
(
i
for
i
in
static_input_surface
if
i
.
requires_grad
),
only_inputs
=
True
,
grad_outputs
=
tuple
(
o
for
o
in
static_grad_outputs
if
o
is
not
None
),
allow_unused
=
allow_unused_input
,
only_inputs
=
True
,
retain_graph
=
retain_graph_in_backward
,
allow_unused
=
allow_unused_input
,
)
retain_graph
=
retain_graph_in_backward
,
)
# Constructs a tuple suitable for returning from Graphed.backward:
# Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs that
# Pads out the actually-needed grads with Nones in gradient slots for inputs that
# don't require grad. I couldn't think of a slick one-liner for this pattern.
# don't require grad. I couldn't think of a slick one-liner for this pattern.
static_grad_inputs
=
[]
static_grad_inputs
=
[]
grad_idx
=
0
grad_idx
=
0
for
arg
in
static_input_surface
:
for
arg
in
static_input_surface
:
if
arg
.
requires_grad
:
if
is_training
and
isinstance
(
arg
,
torch
.
Tensor
)
and
arg
.
requires_grad
:
static_grad_inputs
.
append
(
grad_inputs
[
grad_idx
])
static_grad_inputs
.
append
(
grad_inputs
[
grad_idx
])
grad_idx
+=
1
grad_idx
+=
1
else
:
else
:
...
@@ -422,7 +435,10 @@ def _make_graphed_callables(
...
@@ -422,7 +435,10 @@ def _make_graphed_callables(
# Copy values from new tensors into static tensors
# Copy values from new tensors into static tensors
for
i
in
range
(
len_user_args
):
for
i
in
range
(
len_user_args
):
if
static_input_surface
[
i
].
data_ptr
()
!=
inputs
[
i
].
data_ptr
():
if
(
isinstance
(
static_input_surface
[
i
],
torch
.
Tensor
)
and
static_input_surface
[
i
].
data_ptr
()
!=
inputs
[
i
].
data_ptr
()
):
static_input_surface
[
i
].
copy_
(
inputs
[
i
])
static_input_surface
[
i
].
copy_
(
inputs
[
i
])
# Replay forward graph
# Replay forward graph
...
...
transformer_engine/pytorch/module/layernorm_linear.py
View file @
4099aa8e
...
@@ -79,7 +79,6 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -79,7 +79,6 @@ class _LayerNormLinear(torch.autograd.Function):
ln_bias
:
Union
[
torch
.
Tensor
,
None
],
ln_bias
:
Union
[
torch
.
Tensor
,
None
],
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
use_bias
:
bool
,
eps
:
float
,
eps
:
float
,
is_first_microbatch
:
Union
[
bool
,
None
],
is_first_microbatch
:
Union
[
bool
,
None
],
fp8
:
bool
,
fp8
:
bool
,
...
@@ -383,6 +382,17 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -383,6 +382,17 @@ class _LayerNormLinear(torch.autograd.Function):
)
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.fsdp_scatter"
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.fsdp_scatter"
)
if
cpu_offloading
:
ctx
.
grad_added_to_main_grad
=
hasattr
(
weight
,
"grad_added_to_main_grad"
)
if
ctx
.
grad_added_to_main_grad
:
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user
# sets for the weights. Because of this, it is not recommended to offload
# weights if weights are externally touched outside this module
ctx
.
weight_object
=
weight
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
inputmat
,
inputmat
,
weightmat
,
weightmat
,
...
@@ -411,7 +421,7 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -411,7 +421,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx
.
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
ctx
.
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
ctx
.
cpu_offloading
=
cpu_offloading
ctx
.
cpu_offloading
=
cpu_offloading
ctx
.
is_first_microbatch
=
is_first_microbatch
ctx
.
is_first_microbatch
=
is_first_microbatch
ctx
.
use_bias
=
use_
bias
ctx
.
use_bias
=
bias
is
not
None
ctx
.
sequence_parallel
=
sequence_parallel
ctx
.
sequence_parallel
=
sequence_parallel
ctx
.
tensor_parallel
=
tensor_parallel
ctx
.
tensor_parallel
=
tensor_parallel
ctx
.
inp_shape
=
inp_shape
ctx
.
inp_shape
=
inp_shape
...
@@ -526,8 +536,11 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -526,8 +536,11 @@ class _LayerNormLinear(torch.autograd.Function):
# For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# we need to connect them into one.
# we need to connect them into one.
if
ctx
.
cpu_offloading
and
ctx
.
fuse_wgrad_accumulation
:
if
ctx
.
cpu_offloading
:
weight
.
main_grad
=
main_grad
if
ctx
.
grad_added_to_main_grad
:
origin_weight
=
ctx
.
weight_object
if
ctx
.
requires_wgrad
and
ctx
.
fuse_wgrad_accumulation
:
origin_weight
.
main_grad
=
main_grad
ctx
.
ub_obj_gradout
=
None
ctx
.
ub_obj_gradout
=
None
ub_obj_dgrad
=
None
ub_obj_dgrad
=
None
...
@@ -742,10 +755,6 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -742,10 +755,6 @@ class _LayerNormLinear(torch.autograd.Function):
# TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme
# TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme
clear_tensor_data
(
ln_out_total
)
clear_tensor_data
(
ln_out_total
)
# Don't return grad bias if not needed
if
not
ctx
.
use_bias
:
grad_bias
=
None
# Synchronize tensor parallel communication
# Synchronize tensor parallel communication
if
ln_out_total_work
is
not
None
:
if
ln_out_total_work
is
not
None
:
ln_out_total_work
.
wait
()
ln_out_total_work
.
wait
()
...
@@ -827,7 +836,6 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -827,7 +836,6 @@ class _LayerNormLinear(torch.autograd.Function):
dbeta
,
dbeta
,
wgrad
,
wgrad
,
grad_bias
,
grad_bias
,
None
,
# use_bias
None
,
# eps
None
,
# eps
None
,
# is_first_microbatch
None
,
# is_first_microbatch
None
,
# fp8
None
,
# fp8
...
@@ -1330,8 +1338,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1330,8 +1338,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self
.
layer_norm_weight
,
self
.
layer_norm_weight
,
self
.
layer_norm_bias
,
self
.
layer_norm_bias
,
weight_tensor
,
weight_tensor
,
bias_tensor
,
bias_tensor
if
self
.
apply_bias
and
not
self
.
gemm_bias_unfused_add
else
None
,
self
.
apply_bias
and
not
self
.
gemm_bias_unfused_add
,
self
.
eps
,
self
.
eps
,
is_first_microbatch
,
is_first_microbatch
,
self
.
fp8
,
self
.
fp8
,
...
...
transformer_engine/pytorch/module/layernorm_mlp.py
View file @
4099aa8e
...
@@ -140,10 +140,8 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -140,10 +140,8 @@ class _LayerNormMLP(torch.autograd.Function):
ln_bias
:
torch
.
Tensor
,
ln_bias
:
torch
.
Tensor
,
fc1_weight
:
torch
.
Tensor
,
fc1_weight
:
torch
.
Tensor
,
fc1_bias
:
torch
.
Tensor
,
fc1_bias
:
torch
.
Tensor
,
use_fc1_bias
:
bool
,
fc2_weight
:
torch
.
Tensor
,
fc2_weight
:
torch
.
Tensor
,
fc2_bias
:
torch
.
Tensor
,
fc2_bias
:
torch
.
Tensor
,
use_fc2_bias
:
bool
,
eps
:
float
,
eps
:
float
,
is_first_microbatch
:
Union
[
bool
,
None
],
is_first_microbatch
:
Union
[
bool
,
None
],
fp8
:
bool
,
fp8
:
bool
,
...
@@ -368,7 +366,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -368,7 +366,7 @@ class _LayerNormMLP(torch.autograd.Function):
# FC1 GEMM
# FC1 GEMM
# There are 2 fus
s
ions possible:
# There are 2 fusions possible:
# - gemm_gelu_fusion - default for full precision, optional for fp8 - need to turn on gemm_gelu_fusion,
# - gemm_gelu_fusion - default for full precision, optional for fp8 - need to turn on gemm_gelu_fusion,
# - bias_gelu_fusion - only for full precision.
# - bias_gelu_fusion - only for full precision.
# If both gemm_gelu_fusion and bias_gelu_fusion are enabled, only bias_gelu_fusion will be performer
# If both gemm_gelu_fusion and bias_gelu_fusion are enabled, only bias_gelu_fusion will be performer
...
@@ -453,8 +451,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -453,8 +451,7 @@ class _LayerNormMLP(torch.autograd.Function):
)
)
if
not
is_grad_enabled
:
if
not
is_grad_enabled
:
clear_tensor_data
(
act_out
,
fc1_out_without_bias
,
fc1_out
)
clear_tensor_data
(
act_out
,
fc1_out_without_bias
,
fc1_out
)
else
:
if
is_grad_enabled
:
if
cpu_offloading
:
if
cpu_offloading
:
if
fp8
and
fc1_weight_final
is
not
None
:
if
fp8
and
fc1_weight_final
is
not
None
:
set_offloading_param
(
fc1_weight_final
,
"weight_offloading"
,
True
)
set_offloading_param
(
fc1_weight_final
,
"weight_offloading"
,
True
)
...
@@ -537,9 +534,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -537,9 +534,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx
.
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
ctx
.
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
ctx
.
cpu_offloading
=
cpu_offloading
ctx
.
cpu_offloading
=
cpu_offloading
ctx
.
is_first_microbatch
=
is_first_microbatch
ctx
.
is_first_microbatch
=
is_first_microbatch
ctx
.
use_fc1_bias
=
use_fc1_bias
ctx
.
use_bias
=
fc2_bias
is
not
None
ctx
.
use_fc2_bias
=
use_fc2_bias
ctx
.
use_bias
=
ctx
.
use_fc1_bias
ctx
.
sequence_parallel
=
sequence_parallel
ctx
.
sequence_parallel
=
sequence_parallel
ctx
.
tensor_parallel
=
tensor_parallel
ctx
.
tensor_parallel
=
tensor_parallel
ctx
.
inp_shape
=
inp_shape
ctx
.
inp_shape
=
inp_shape
...
@@ -774,14 +769,13 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -774,14 +769,13 @@ class _LayerNormMLP(torch.autograd.Function):
quantization_params
=
None
,
# wgrad in high precision
quantization_params
=
None
,
# wgrad in high precision
layout
=
"NT"
,
layout
=
"NT"
,
grad
=
True
,
grad
=
True
,
bias
=
fc2_bias
if
fc2_bias
is
not
None
and
fc2_bias_grad
is
None
else
None
,
bias
=
fc2_bias
if
fc2_bias_grad
is
None
else
None
,
accumulate
=
accumulate_wgrad_into_param_main_grad
,
accumulate
=
accumulate_wgrad_into_param_main_grad
,
use_split_accumulator
=
_2X_ACC_WGRAD
,
use_split_accumulator
=
_2X_ACC_WGRAD
,
out
=
fc2_weight
.
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
out
=
fc2_weight
.
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
)
)
if
fc2_bias_grad
is
None
:
if
fc2_bias_grad
is
None
:
fc2_bias_grad
=
fc2_bias_grad_
fc2_bias_grad
=
fc2_bias_grad_
del
fc2_bias_grad_
clear_tensor_data
(
act_out
)
clear_tensor_data
(
act_out
)
# bias computation
# bias computation
...
@@ -1046,11 +1040,9 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -1046,11 +1040,9 @@ class _LayerNormMLP(torch.autograd.Function):
dgamma
,
dgamma
,
dbeta
,
dbeta
,
fc1_wgrad
,
fc1_wgrad
,
fc1_bias_grad
if
ctx
.
use_fc1_bias
else
None
,
fc1_bias_grad
if
fc1_bias
is
not
None
else
None
,
None
,
# use_fc1_bias
fc2_wgrad
,
# pylint: disable=possibly-used-before-assignment
fc2_wgrad
,
# pylint: disable=possibly-used-before-assignment
fc2_bias_grad
if
ctx
.
use_fc2_bias
else
None
,
fc2_bias_grad
,
None
,
# use_fc2_bias
None
,
# eps
None
,
# eps
None
,
# is_first_microbatch
None
,
# is_first_microbatch
None
,
# fp8
None
,
# fp8
...
@@ -1471,10 +1463,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -1471,10 +1463,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
self
.
layer_norm_bias
,
self
.
layer_norm_bias
,
fc1_weight
,
fc1_weight
,
fc1_bias
,
fc1_bias
,
self
.
use_bias
,
fc2_weight
,
fc2_weight
,
fc2_bias
,
fc2_bias
if
self
.
apply_bias
and
not
self
.
gemm_bias_unfused_add
else
None
,
self
.
apply_bias
and
not
self
.
gemm_bias_unfused_add
,
self
.
eps
,
self
.
eps
,
is_first_microbatch
,
is_first_microbatch
,
self
.
fp8
,
self
.
fp8
,
...
...
transformer_engine/pytorch/module/linear.py
View file @
4099aa8e
...
@@ -291,6 +291,17 @@ class _Linear(torch.autograd.Function):
...
@@ -291,6 +291,17 @@ class _Linear(torch.autograd.Function):
)
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.fsdp_scatter"
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.fsdp_scatter"
)
if
cpu_offloading
:
ctx
.
grad_added_to_main_grad
=
hasattr
(
weight
,
"grad_added_to_main_grad"
)
if
ctx
.
grad_added_to_main_grad
:
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user
# sets for the weights. Because of this, it is not recommended to offload
# weights if weights are externally touched outside this module
ctx
.
weight_object
=
weight
# TODO(ksivamani): Check memory usage
# TODO(ksivamani): Check memory usage
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
saved_inputmat
,
saved_inputmat
,
...
@@ -392,9 +403,11 @@ class _Linear(torch.autograd.Function):
...
@@ -392,9 +403,11 @@ class _Linear(torch.autograd.Function):
else
None
else
None
)
)
if
ctx
.
cpu_offloading
and
ctx
.
fuse_wgrad_accumulation
:
if
ctx
.
cpu_offloading
:
weight
=
torch
.
nn
.
Parameter
(
weight
,
weight
.
requires_grad
)
if
ctx
.
grad_added_to_main_grad
:
weight
.
main_grad
=
main_grad
weight
=
ctx
.
weight_object
if
ctx
.
requires_wgrad
and
ctx
.
fuse_wgrad_accumulation
:
weight
.
main_grad
=
main_grad
# Gather intermediate/activation tensors if needed
# Gather intermediate/activation tensors if needed
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment