Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
4099aa8e
Commit
4099aa8e
authored
Mar 20, 2025
by
yuguo
Browse files
Merge branch 'main' of
https://github.com/NVIDIA/TransformerEngine
parents
c520cba3
96f9c6de
Changes
49
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1372 additions
and
291 deletions
+1372
-291
transformer_engine/pytorch/csrc/extensions/pybind.cpp
transformer_engine/pytorch/csrc/extensions/pybind.cpp
+14
-8
transformer_engine/pytorch/csrc/kv_cache.cuh
transformer_engine/pytorch/csrc/kv_cache.cuh
+145
-0
transformer_engine/pytorch/distributed.py
transformer_engine/pytorch/distributed.py
+15
-2
transformer_engine/pytorch/dot_product_attention/inference.py
...sformer_engine/pytorch/dot_product_attention/inference.py
+777
-32
transformer_engine/pytorch/dot_product_attention/utils.py
transformer_engine/pytorch/dot_product_attention/utils.py
+336
-190
transformer_engine/pytorch/graph.py
transformer_engine/pytorch/graph.py
+44
-28
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+18
-11
transformer_engine/pytorch/module/layernorm_mlp.py
transformer_engine/pytorch/module/layernorm_mlp.py
+7
-17
transformer_engine/pytorch/module/linear.py
transformer_engine/pytorch/module/linear.py
+16
-3
No files found.
transformer_engine/pytorch/csrc/extensions/pybind.cpp
View file @
4099aa8e
...
@@ -177,16 +177,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -177,16 +177,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifdef USE_ROCM
#ifdef USE_ROCM
m
.
def
(
"te_general_batched_gemm"
,
&
te_general_batched_gemm
,
"Batched GEMM"
);
/// rocblas
m
.
def
(
"te_general_batched_gemm"
,
&
te_general_batched_gemm
,
"Batched GEMM"
);
/// rocblas
#endif
#endif
m
.
def
(
"fused_attn_fwd"
,
&
fused_attn_fwd
,
"Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V"
);
m
.
def
(
"fused_attn_bwd"
,
&
fused_attn_bwd
,
"Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V"
);
m
.
def
(
"fp8_transpose"
,
&
fp8_transpose
,
"Transpose with FP8 I/O"
,
py
::
arg
(
"input"
),
m
.
def
(
"fp8_transpose"
,
&
fp8_transpose
,
"Transpose with FP8 I/O"
,
py
::
arg
(
"input"
),
py
::
arg
(
"dtype"
),
py
::
kw_only
(),
py
::
arg
(
"out"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
arg
(
"dtype"
),
py
::
kw_only
(),
py
::
arg
(
"out"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fa_prepare_fwd"
,
&
fa_prepare_fwd
,
"Prepare QKV for Flash Attention"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fa_prepare_bwd"
,
&
fa_prepare_bwd
,
"Backward of QKV preparation for Flash Attention"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"get_fused_attn_backend"
,
&
get_fused_attn_backend
,
"Get Fused Attention backend"
,
m
.
def
(
"get_fused_attn_backend"
,
&
get_fused_attn_backend
,
"Get Fused Attention backend"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fused_amax_and_scale_update_after_reduction"
,
&
fused_amax_and_scale_update_after_reduction
,
m
.
def
(
"fused_amax_and_scale_update_after_reduction"
,
&
fused_amax_and_scale_update_after_reduction
,
...
@@ -194,6 +186,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -194,6 +186,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fused_multi_row_padding"
,
&
fused_multi_row_padding
,
"Fused Multi-tensor padding"
,
m
.
def
(
"fused_multi_row_padding"
,
&
fused_multi_row_padding
,
"Fused Multi-tensor padding"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
call_guard
<
py
::
gil_scoped_release
>
());
// attention kernels
m
.
def
(
"fa_prepare_fwd"
,
&
fa_prepare_fwd
,
"Prepare QKV for Flash Attention"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fa_prepare_bwd"
,
&
fa_prepare_bwd
,
"Backward of QKV preparation for Flash Attention"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fused_attn_fwd"
,
&
fused_attn_fwd
,
"Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V"
);
m
.
def
(
"fused_attn_bwd"
,
&
fused_attn_bwd
,
"Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V"
);
m
.
def
(
"copy_to_kv_cache"
,
&
copy_to_kv_cache
,
"Copy new KV tokens to KV cache"
);
m
.
def
(
"convert_thd_to_bshd"
,
&
convert_thd_to_bshd
,
"Convert a tensor from THD to BSHD"
);
m
.
def
(
"convert_bshd_to_thd"
,
&
convert_bshd_to_thd
,
"Convert a tesnor from BSHD to THD"
);
// fused apply rope
// fused apply rope
m
.
def
(
"fused_rope_forward"
,
&
fused_rope_forward
,
"Fused Apply RoPE FWD"
,
m
.
def
(
"fused_rope_forward"
,
&
fused_rope_forward
,
"Fused Apply RoPE FWD"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
call_guard
<
py
::
gil_scoped_release
>
());
...
...
transformer_engine/pytorch/csrc/kv_cache.cuh
0 → 100644
View file @
4099aa8e
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_KV_CACHE_CUH_
#define TRANSFORMER_ENGINE_FUSED_ATTN_KV_CACHE_CUH_
namespace
transformer_engine
{
namespace
fused_attn
{
template
<
typename
scalar_t
>
__global__
void
convert_thd_to_bshd_kernel
(
scalar_t
*
tensor
,
scalar_t
*
new_tensor
,
int
*
cu_seqlens
,
int
b
,
int
max_seq_len
,
int
h
,
int
d
)
{
// tensor: thd; new_tensor: bshd
// cu_seqlens: [b + 1]
for
(
int
batch_idx
=
blockIdx
.
x
;
batch_idx
<
b
;
batch_idx
+=
gridDim
.
x
)
{
int
num_elts
=
(
cu_seqlens
[
batch_idx
+
1
]
-
cu_seqlens
[
batch_idx
])
*
h
*
d
;
int
thd_offset
=
cu_seqlens
[
batch_idx
]
*
h
*
d
;
int
bshd_offset
=
batch_idx
*
max_seq_len
*
h
*
d
;
scalar_t
*
thd_token
=
tensor
+
thd_offset
;
scalar_t
*
bshd_token
=
new_tensor
+
bshd_offset
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_elts
;
i
+=
blockDim
.
x
)
{
*
(
bshd_token
+
i
)
=
*
(
thd_token
+
i
);
}
}
}
template
<
typename
scalar_t
>
__global__
void
convert_bshd_to_thd_kernel
(
scalar_t
*
tensor
,
scalar_t
*
new_tensor
,
int
*
cu_seqlens
,
int
b
,
int
max_seq_len
,
int
h
,
int
d
)
{
// tensor: bshd; new_tensor: thd
// cu_seqlens: [b + 1]
for
(
int
batch_idx
=
blockIdx
.
x
;
batch_idx
<
b
;
batch_idx
+=
gridDim
.
x
)
{
int
seqlen
=
cu_seqlens
[
batch_idx
+
1
]
-
cu_seqlens
[
batch_idx
];
int
num_elts
=
seqlen
*
h
*
d
;
int
bshd_offset
=
batch_idx
*
max_seq_len
*
h
*
d
;
int
thd_offset
=
cu_seqlens
[
batch_idx
]
*
h
*
d
;
scalar_t
*
bshd_token
=
tensor
+
bshd_offset
;
scalar_t
*
thd_token
=
new_tensor
+
thd_offset
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_elts
;
i
+=
blockDim
.
x
)
{
*
(
thd_token
+
i
)
=
*
(
bshd_token
+
i
);
}
}
}
template
<
typename
scalar_t
>
__global__
void
reindex_kv_cache_kernel
(
scalar_t
*
k_cache
,
scalar_t
*
v_cache
,
int
*
batch_indices
,
int
*
cu_new_lens
,
int
*
cu_cached_lens
,
int
h_kv
,
int
d_k
,
int
d_v
,
int
b
,
int
max_seq_len
)
{
// k_cache, v_cache: bshd
// batch_indices: [b]; cu_new_lens, cu_cached_lens: [b + 1]
int
actual_b
=
b
;
for
(
int
i
=
0
;
i
<
b
-
1
;
i
++
)
{
if
(
batch_indices
[
i
+
1
]
<
batch_indices
[
i
])
{
actual_b
=
i
+
1
;
}
}
for
(
int
batch_idx
=
0
;
batch_idx
<
actual_b
;
batch_idx
++
)
{
int
cached_len
=
cu_cached_lens
[
batch_idx
+
1
]
-
cu_cached_lens
[
batch_idx
];
int
new_len
=
cu_new_lens
[
batch_idx
+
1
]
-
cu_new_lens
[
batch_idx
];
for
(
int
token_idx
=
blockIdx
.
x
;
token_idx
<
cached_len
-
new_len
;
token_idx
+=
gridDim
.
x
)
{
int
num_elts_k
=
h_kv
*
d_k
;
int
num_elts_v
=
h_kv
*
d_v
;
int
k_cache_src_offset
=
(
batch_indices
[
batch_idx
]
*
max_seq_len
+
token_idx
)
*
h_kv
*
d_k
;
int
k_cache_des_offset
=
(
batch_idx
*
max_seq_len
+
token_idx
)
*
h_kv
*
d_k
;
int
v_cache_src_offset
=
(
batch_indices
[
batch_idx
]
*
max_seq_len
+
token_idx
)
*
h_kv
*
d_v
;
int
v_cache_des_offset
=
(
batch_idx
*
max_seq_len
+
token_idx
)
*
h_kv
*
d_v
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_elts_k
;
i
+=
blockDim
.
x
)
{
*
(
k_cache
+
k_cache_des_offset
+
i
)
=
*
(
k_cache
+
k_cache_src_offset
+
i
);
}
for
(
int
i
=
threadIdx
.
x
;
i
<
num_elts_v
;
i
+=
blockDim
.
x
)
{
*
(
v_cache
+
v_cache_des_offset
+
i
)
=
*
(
v_cache
+
v_cache_src_offset
+
i
);
}
}
}
}
template
<
typename
scalar_t
>
__global__
void
copy_to_kv_cache_kernel
(
scalar_t
*
new_k
,
scalar_t
*
new_v
,
scalar_t
*
k_cache
,
scalar_t
*
v_cache
,
int
*
page_table
,
int
*
cu_new_lens
,
int
*
cu_cached_lens
,
NVTE_QKV_Format
qkv_format
,
int
h_kv
,
int
d_k
,
int
d_v
,
int
b
,
int
max_ctx_len
,
int
max_seq_len
,
int
max_pages_per_seq
,
bool
is_non_paged
)
{
// new_k, new_v: qkv_format; k_cache, v_cache: bshd
// cu_new_lens, cu_cached_lens: [b + 1]
// page_table: [b, max_pages_per_seq]
int
page_size
=
max_seq_len
/
max_pages_per_seq
;
if
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_BSHD
)
{
for
(
int
batch_idx
=
blockIdx
.
x
;
batch_idx
<
b
;
batch_idx
+=
gridDim
.
x
)
{
int
*
page_list
=
is_non_paged
?
nullptr
:
page_table
+
batch_idx
*
max_pages_per_seq
;
int
new_token_offset
=
batch_idx
*
max_ctx_len
;
int
cached_len
=
cu_cached_lens
[
batch_idx
+
1
]
-
cu_cached_lens
[
batch_idx
];
int
new_len
=
cu_new_lens
[
batch_idx
+
1
]
-
cu_new_lens
[
batch_idx
];
for
(
int
i
=
threadIdx
.
x
;
i
<
new_len
;
i
+=
blockDim
.
x
)
{
int
page_idx
=
is_non_paged
?
batch_idx
:
page_list
[(
cached_len
-
new_len
+
i
)
/
page_size
];
int
token_idx
=
page_idx
*
page_size
+
(
cached_len
-
new_len
+
i
)
%
page_size
;
for
(
int
j
=
0
;
j
<
h_kv
*
d_k
;
j
++
)
{
*
(
k_cache
+
token_idx
*
h_kv
*
d_k
+
j
)
=
*
(
new_k
+
(
new_token_offset
+
i
)
*
h_kv
*
d_k
+
j
);
}
for
(
int
j
=
0
;
j
<
h_kv
*
d_v
;
j
++
)
{
*
(
v_cache
+
token_idx
*
h_kv
*
d_v
+
j
)
=
*
(
new_v
+
(
new_token_offset
+
i
)
*
h_kv
*
d_v
+
j
);
}
}
}
}
else
if
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
)
{
for
(
int
batch_idx
=
blockIdx
.
x
;
batch_idx
<
b
;
batch_idx
+=
gridDim
.
x
)
{
int
*
page_list
=
is_non_paged
?
nullptr
:
page_table
+
batch_idx
*
max_pages_per_seq
;
int
cached_len
=
cu_cached_lens
[
batch_idx
+
1
]
-
cu_cached_lens
[
batch_idx
];
int
new_len
=
cu_new_lens
[
batch_idx
+
1
]
-
cu_new_lens
[
batch_idx
];
for
(
int
i
=
threadIdx
.
x
;
i
<
new_len
;
i
+=
blockDim
.
x
)
{
int
page_idx
=
is_non_paged
?
batch_idx
:
page_list
[(
cached_len
-
new_len
+
i
)
/
page_size
];
int
token_idx
=
page_idx
*
page_size
+
(
cached_len
-
new_len
+
i
)
%
page_size
;
for
(
int
j
=
0
;
j
<
h_kv
*
d_k
;
j
++
)
{
*
(
k_cache
+
token_idx
*
h_kv
*
d_k
+
j
)
=
*
(
new_k
+
(
i
*
b
+
batch_idx
)
*
h_kv
*
d_k
+
j
);
}
for
(
int
j
=
0
;
j
<
h_kv
*
d_v
;
j
++
)
{
*
(
v_cache
+
token_idx
*
h_kv
*
d_v
+
j
)
=
*
(
new_v
+
(
i
*
b
+
batch_idx
)
*
h_kv
*
d_v
+
j
);
}
}
}
}
else
if
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_THD
)
{
for
(
int
batch_idx
=
blockIdx
.
x
;
batch_idx
<
b
;
batch_idx
+=
gridDim
.
x
)
{
int
*
page_list
=
is_non_paged
?
nullptr
:
page_table
+
batch_idx
*
max_pages_per_seq
;
int
cached_len
=
cu_cached_lens
[
batch_idx
+
1
]
-
cu_cached_lens
[
batch_idx
];
int
new_len
=
cu_new_lens
[
batch_idx
+
1
]
-
cu_new_lens
[
batch_idx
];
for
(
int
i
=
threadIdx
.
x
;
i
<
new_len
;
i
+=
blockDim
.
x
)
{
int
page_idx
=
is_non_paged
?
batch_idx
:
page_list
[(
cached_len
-
new_len
+
i
)
/
page_size
];
int
token_idx
=
page_idx
*
page_size
+
(
cached_len
-
new_len
+
i
)
%
page_size
;
for
(
int
j
=
0
;
j
<
h_kv
*
d_k
;
j
++
)
{
*
(
k_cache
+
token_idx
*
h_kv
*
d_k
+
j
)
=
*
(
new_k
+
(
cu_new_lens
[
batch_idx
]
+
i
)
*
h_kv
*
d_k
+
j
);
}
for
(
int
j
=
0
;
j
<
h_kv
*
d_v
;
j
++
)
{
*
(
v_cache
+
token_idx
*
h_kv
*
d_v
+
j
)
=
*
(
new_v
+
(
cu_new_lens
[
batch_idx
]
+
i
)
*
h_kv
*
d_v
+
j
);
}
}
}
}
}
}
// namespace fused_attn
}
// namespace transformer_engine
#endif
transformer_engine/pytorch/distributed.py
View file @
4099aa8e
...
@@ -20,7 +20,7 @@ from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_module
...
@@ -20,7 +20,7 @@ from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_module
from
.utils
import
safely_set_viewless_tensor_data
from
.utils
import
safely_set_viewless_tensor_data
from
.constants
import
dist_group_type
from
.constants
import
dist_group_type
from
.fp8
import
FP8GlobalStateManager
from
.fp8
import
FP8GlobalStateManager
,
fp8_autocast
from
.tensor.float8_tensor
import
Float8Quantizer
,
Float8Tensor
,
Float8CurrentScalingQuantizer
from
.tensor.float8_tensor
import
Float8Quantizer
,
Float8Tensor
,
Float8CurrentScalingQuantizer
from
.tensor.mxfp8_tensor
import
MXFP8Quantizer
from
.tensor.mxfp8_tensor
import
MXFP8Quantizer
from
.tensor.quantized_tensor
import
QuantizedTensor
,
Quantizer
from
.tensor.quantized_tensor
import
QuantizedTensor
,
Quantizer
...
@@ -328,11 +328,14 @@ class _CheckpointFunction(torch.autograd.Function):
...
@@ -328,11 +328,14 @@ class _CheckpointFunction(torch.autograd.Function):
tensor_inputs
=
[
arg
if
torch
.
is_tensor
(
arg
)
else
None
for
arg
in
args
]
tensor_inputs
=
[
arg
if
torch
.
is_tensor
(
arg
)
else
None
for
arg
in
args
]
ctx
.
save_for_backward
(
*
tensor_inputs
)
ctx
.
save_for_backward
(
*
tensor_inputs
)
fp8
=
FP8GlobalStateManager
.
is_fp8_enabled
()
ctx
.
get_rng_state_tracker
=
get_rng_state_tracker
ctx
.
get_rng_state_tracker
=
get_rng_state_tracker
ctx
.
tp_group
=
tp_group
ctx
.
tp_group
=
tp_group
ctx
.
recompute_ctx
=
recompute_ctx
ctx
.
recompute_ctx
=
recompute_ctx
ctx
.
torch_gpu_amp_ctx
=
torch_gpu_amp_ctx
ctx
.
torch_gpu_amp_ctx
=
torch_gpu_amp_ctx
ctx
.
torch_cpu_amp_ctx
=
torch_cpu_amp_ctx
ctx
.
torch_cpu_amp_ctx
=
torch_cpu_amp_ctx
ctx
.
fp8
=
fp8
ctx
.
fp8_recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
fp8
else
None
ctx
.
kwargs
=
kwargs
ctx
.
kwargs
=
kwargs
return
outputs
return
outputs
...
@@ -375,6 +378,8 @@ class _CheckpointFunction(torch.autograd.Function):
...
@@ -375,6 +378,8 @@ class _CheckpointFunction(torch.autograd.Function):
detached_inputs
=
detach_variable
(
inputs
)
detached_inputs
=
detach_variable
(
inputs
)
with
torch
.
enable_grad
(),
ctx
.
recompute_ctx
,
ctx
.
torch_gpu_amp_ctx
,
ctx
.
torch_cpu_amp_ctx
,
activation_recompute_forward
(
with
torch
.
enable_grad
(),
ctx
.
recompute_ctx
,
ctx
.
torch_gpu_amp_ctx
,
ctx
.
torch_cpu_amp_ctx
,
activation_recompute_forward
(
activation_recompute
=
True
,
recompute_phase
=
True
activation_recompute
=
True
,
recompute_phase
=
True
),
fp8_autocast
(
enabled
=
ctx
.
fp8
,
fp8_recipe
=
ctx
.
fp8_recipe
):
):
outputs
=
ctx
.
run_function
(
*
detached_inputs
,
**
ctx
.
kwargs
)
outputs
=
ctx
.
run_function
(
*
detached_inputs
,
**
ctx
.
kwargs
)
...
@@ -398,6 +403,9 @@ class _CheckpointFunction(torch.autograd.Function):
...
@@ -398,6 +403,9 @@ class _CheckpointFunction(torch.autograd.Function):
"none of output has requires_grad=True, this checkpoint() is not necessary"
"none of output has requires_grad=True, this checkpoint() is not necessary"
)
)
# backward does not require entering autocast context because
# backward implementations already retrieve fp8 recipe and
# enablement from stored ctx.
torch
.
autograd
.
backward
(
outputs_with_grad
,
args_with_grad
)
torch
.
autograd
.
backward
(
outputs_with_grad
,
args_with_grad
)
grads
=
tuple
(
grads
=
tuple
(
inp
.
grad
if
isinstance
(
inp
,
torch
.
Tensor
)
else
None
for
inp
in
detached_inputs
inp
.
grad
if
isinstance
(
inp
,
torch
.
Tensor
)
else
None
for
inp
in
detached_inputs
...
@@ -694,10 +702,15 @@ def checkpoint(
...
@@ -694,10 +702,15 @@ def checkpoint(
# Preserve the torch autocast contexts from the forward pass during recompute phase.
# Preserve the torch autocast contexts from the forward pass during recompute phase.
torch_gpu_amp_forward_ctx
,
torch_cpu_amp_forward_ctx
=
_get_active_autocast_contexts
()
torch_gpu_amp_forward_ctx
,
torch_cpu_amp_forward_ctx
=
_get_active_autocast_contexts
()
fp8
=
FP8GlobalStateManager
.
is_fp8_enabled
()
fp8_recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
fp8
else
None
def
recompute_fn
(
*
args
,
**
kwargs
):
def
recompute_fn
(
*
args
,
**
kwargs
):
with
torch
.
autograd
.
enable_grad
(),
(
with
torch
.
autograd
.
enable_grad
(),
(
te_recompute_ctx
te_recompute_ctx
),
user_recompute_ctx
,
torch_gpu_amp_forward_ctx
,
torch_cpu_amp_forward_ctx
:
),
user_recompute_ctx
,
torch_gpu_amp_forward_ctx
,
torch_cpu_amp_forward_ctx
,
fp8_autocast
(
enabled
=
fp8
,
fp8_recipe
=
fp8_recipe
):
function
(
*
args
,
**
kwargs
)
function
(
*
args
,
**
kwargs
)
# Initialize a new checkpoint frame for each new forward pass.
# Initialize a new checkpoint frame for each new forward pass.
...
...
transformer_engine/pytorch/dot_product_attention/inference.py
View file @
4099aa8e
...
@@ -2,52 +2,797 @@
...
@@ -2,52 +2,797 @@
#
#
# See LICENSE for license information.
# See LICENSE for license information.
"""
"""Inference"""
Inference classes for attention
import
logging
"""
from
collections
import
OrderedDict
,
defaultdict
from
typing
import
Optional
,
List
from
einops
import
rearrange
import
torch
class
InferenceParams
:
# pylint: disable=too-few-public-methods
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.cpp_extensions.fused_attn
import
QKVFormat
__all__
=
[
"InferenceParams"
,
"KVCacheManager"
,
"NonPagedKVCacheManager"
,
"PagedKVCacheManager"
]
class
KVCacheManager
:
"""Base KV cache manager"""
def
__init__
(
self
):
"""Initialize cache manager"""
self
.
cache
=
{}
self
.
sequences
=
OrderedDict
()
def
reset
(
self
):
"""Reset cache manager state"""
self
.
sequences
=
OrderedDict
()
def
allocate_memory
(
self
,
layer_number
:
int
):
"""Allocate memory for the cache"""
self
.
cache
[
layer_number
]
=
(
None
,
None
)
def
pre_step
(
self
,
step_dict
:
OrderedDict
,
# pylint: disable=unused-argument
):
"""Update tracked sequences and prepare for step()"""
return
self
.
sequences
def
step
(
self
,
layer_number
:
int
,
new_k
:
torch
.
Tensor
,
# pylint: disable=unused-argument
new_v
:
torch
.
Tensor
,
# pylint: disable=unused-argument
cu_new_seqlens
:
torch
.
Tensor
,
# pylint: disable=unused-argument
cu_cached_seqlens
:
torch
.
Tensor
,
# pylint: disable=unused-argument
qkv_format
:
str
,
# pylint: disable=unused-argument
):
"""Copy the new tokens to KV cache"""
return
self
.
cache
[
layer_number
]
class
InferenceParams
:
"""
"""
Inference parameters that are passed to the main model in order
KV caching for inference. The memory allocation of the caches and the copying of new tokens
to efficiently calculate and store the context during inference.
to the cache take place at the following locations.::
class TransformerLayer:
class MultiHeadAttention:
if self.layer_number not in inference_params.cache_manager.cache:
inference_params.allocate_memory(self.layer_number)
class DotProductAttention:
if inference_params is not None:
k_cache, v_cache, new_qkv_format = inference_params.step(
new_k, new_v, qkv_format)
output = attention(new_q, k_cache, v_cache, new_qkv_format)
allocate_memory() can be called outside the model, independently. step() can take three formats,
qkv_format = {'bshd', 'sbhd', 'thd'}. It converts new_k and new_v to 'bshd' in both
NonPagedKVCacheManager and PagedKVCacheManager. The format of new_q may change depending on the
backend. If it is unchanged, we would have new_qkv_format = {'bshd', 'sbhd_2bshd', 'thd_2bshd'}.
A standard KV caching workflow for inference is as follows.::
model = [TransformerLayer() for _ in range(num_layers)]
# initialize InferenceParams, e.g. with PagedKVCacheManager
inference_params = InferenceParams(..., is_paged=True)
# inference loop
for i in range(num_iters):
# get info for iteration i, e.g. seq_ids = [0, 2, 3], step_lens = [10, 1, 1]
step_dict = OrderedDict(zip(seq_ids, step_lens))
# update inference_params' state
inference_params.pre_step(step_dict)
# run iteration
output = model(
...,
attn_mask_type="padding_causal",
cu_seqlens_q=cu_seqlens_new_q,
cu_seqlens_kv=cu_seqlens_new_kv,
inference_params=inference_params,
)
# get output tokens based on qkv_format
# 'bshd': output = output[:,step_dict.values()-1]
# 'sbhd': output = output[step_dict.values()-1,:]
# 'thd' : output = output[cu_seqlens_new_q[j+1]-1], j=0,...b-1
Parameters
Parameters
----------
----------
max_batch_size : int
max_batch_size: int
maximum batch size during inference.
Maximum batch size in inference
max_sequence_length : int
max_seqlen_kv: int
maximum sequence length during inference.
Maximum sequence length in inference
num_heads_kv: int
Number of attention heads in keys and values
head_dim_k: int
Head size for keys
dtype: torch.dtype
Data type of the KV cache
head_dim_v: int, default = None
Head size for values. If None, initialized as head_dim_k.
is_paged: bool, default = False
Whether the KV cache is paged (True) or non-paged (False)
total_num_pages: int, default = None
Total number of pages in the KV cache. Required for is_paged = True.
page_size: int, default = None
Page size of the KV cache. Required for is_paged = True.
max_ctx_len: int, default = None
Maximum context length in inference. 1 <= max_ctx_len <= max_seqlen_kv.
qkv_format: str, default = "bshd"
Format of the incoming query/key/value tensors in current iteration
custom_cache_manager: KVCacheManager, default = None
Custom cache manager, with KVCacheManager as the base class.
"""
"""
def
__init__
(
self
,
max_batch_size
,
max_sequence_length
):
def
__init__
(
self
.
max_sequence_length
=
max_sequence_length
self
,
max_batch_size
:
int
,
max_seqlen_kv
:
int
,
num_heads_kv
:
int
=
16
,
head_dim_k
:
int
=
64
,
dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
head_dim_v
:
int
=
None
,
is_paged
:
bool
=
False
,
total_num_pages
:
int
=
None
,
page_size
:
int
=
None
,
max_ctx_len
:
int
=
None
,
qkv_format
:
str
=
"bshd"
,
custom_cache_manager
:
KVCacheManager
=
None
,
):
self
.
max_batch_size
=
max_batch_size
self
.
max_batch_size
=
max_batch_size
self
.
sequence_len_offset
=
0
self
.
max_seqlen_kv
=
max_seqlen_kv
self
.
batch_size_offset
=
0
self
.
num_heads_kv
=
num_heads_kv
self
.
key_value_memory_dict
=
{}
self
.
head_dim_k
=
head_dim_k
self
.
dtype
=
dtype
self
.
head_dim_v
=
head_dim_v
if
head_dim_v
is
not
None
else
head_dim_k
self
.
is_paged
=
is_paged
if
not
self
.
is_paged
:
cache_manager
=
(
custom_cache_manager
if
custom_cache_manager
is
not
None
else
NonPagedKVCacheManager
)
self
.
cache_manager
=
cache_manager
(
max_batch_size
=
self
.
max_batch_size
,
max_seqlen
=
self
.
max_seqlen_kv
,
num_heads
=
self
.
num_heads_kv
,
head_dim_k
=
self
.
head_dim_k
,
dtype
=
self
.
dtype
,
head_dim_v
=
self
.
head_dim_v
,
)
else
:
assert
page_size
is
not
None
,
"Paged KV cache requires page_size is not None."
self
.
page_size
=
page_size
assert
(
max_seqlen_kv
%
page_size
==
0
),
"Paged KV cache requires max_seqlen_kv % page_size = 0."
max_pages_per_seq
=
max_seqlen_kv
//
page_size
assert
(
total_num_pages
==
self
.
max_batch_size
*
max_pages_per_seq
),
"Paged KV cache requires total_num_pages = max_batch_size * max_pages_per_seq."
self
.
total_num_pages
=
total_num_pages
cache_manager
=
(
custom_cache_manager
if
custom_cache_manager
is
not
None
else
PagedKVCacheManager
)
self
.
cache_manager
=
cache_manager
(
total_num_pages
=
self
.
total_num_pages
,
page_size
=
self
.
page_size
,
num_heads
=
self
.
num_heads_kv
,
head_dim_k
=
self
.
head_dim_k
,
dtype
=
self
.
dtype
,
max_batch_size
=
self
.
max_batch_size
,
max_seqlen
=
self
.
max_seqlen_kv
,
head_dim_v
=
self
.
head_dim_v
,
)
if
qkv_format
==
"thd"
:
assert
max_ctx_len
is
not
None
,
"max_ctx_len is required when qkv_format=thd!"
self
.
max_ctx_len
=
max_ctx_len
self
.
cache_qkv_format
=
"bshd"
self
.
input_qkv_format
=
qkv_format
if
self
.
input_qkv_format
==
self
.
cache_qkv_format
:
self
.
output_qkv_format
=
self
.
cache_qkv_format
else
:
self
.
output_qkv_format
=
self
.
input_qkv_format
+
"_2"
+
self
.
cache_qkv_format
self
.
sequences_pre_step
=
OrderedDict
()
self
.
sequences
=
OrderedDict
()
self
.
batch_size
=
0
self
.
cu_seqlens_q
=
torch
.
zeros
(
self
.
max_batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
torch
.
cuda
.
current_device
(),
)
self
.
cu_seqlens_kv
=
torch
.
zeros
(
self
.
max_batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
torch
.
cuda
.
current_device
(),
)
def
reset
(
self
):
"""Reset InferenceParams state"""
self
.
sequences
=
OrderedDict
()
self
.
cache_manager
.
reset
()
def
swap_key_value_dict
(
self
,
batch_indices
):
def
__repr__
(
self
)
->
str
:
if
self
.
is_paged
:
return
(
f
"dtype=
{
self
.
dtype
}
, "
f
"is_paged=
{
self
.
is_paged
}
, "
f
"total_pages=
{
self
.
total_num_pages
}
, "
f
"page_size=
{
self
.
page_size
}
, "
f
"num_heads=
{
self
.
num_heads_kv
}
, "
f
"head_dim_k=
{
self
.
head_dim_k
}
, "
f
"head_dim_v=
{
self
.
head_dim_v
}
"
)
return
(
f
"dtype=
{
self
.
dtype
}
, "
f
"is_paged=
{
self
.
is_paged
}
, "
f
"max_batch_size=
{
self
.
max_batch_size
}
, "
f
"max_seqlen=
{
self
.
max_seqlen_kv
}
, "
f
"num_heads=
{
self
.
num_heads_kv
}
, "
f
"head_dim_k=
{
self
.
head_dim_k
}
, "
f
"head_dim_v=
{
self
.
head_dim_v
}
"
)
def
allocate_memory
(
self
,
layer_number
:
int
):
"""
Allocate memory for the cache. For layer layer_number,
- NonPagedKVCacheManager:
- K cache: [max_batch_size, max_seqlen_kv, num_heads_kv, head_dim_k]
- V cache: [max_batch_size, max_seqlen_kv, num_heads_kv, head_dim_v]
- PagedKVCacheManager:
- K cache: [total_num_pages, page_size, num_heads_kv, head_dim_k]
- V cache: [total_num_pages, page_size, num_heads_kv, head_dim_v]
"""
self
.
cache_manager
.
allocate_memory
(
layer_number
)
def
pre_step
(
self
,
step_dict
:
OrderedDict
,
):
"""Update tracked sequences and prepare for step()"""
self
.
batch_size
=
len
(
step_dict
)
self
.
sequences
=
self
.
cache_manager
.
pre_step
(
step_dict
)
# track the pre-step seqlens for the next layer in the model
self
.
sequences_pre_step
=
OrderedDict
()
for
k
,
v
in
self
.
sequences
.
items
():
self
.
sequences_pre_step
[
k
]
=
v
-
step_dict
[
k
]
seqlens_q
=
list
(
step_dict
.
values
())
cu_seqlens_q
=
[
0
]
+
[
sum
(
seqlens_q
[:
i
])
for
i
in
range
(
1
,
self
.
batch_size
+
1
)]
cu_seqlens_q
=
cu_seqlens_q
+
[
cu_seqlens_q
[
-
1
]]
*
(
self
.
max_batch_size
-
self
.
batch_size
)
self
.
cu_seqlens_q
.
copy_
(
torch
.
Tensor
(
cu_seqlens_q
).
to
(
dtype
=
torch
.
int32
,
device
=
"cpu"
))
seqlens_kv
=
list
(
self
.
sequences
.
values
())
cu_seqlens_kv
=
[
0
]
+
[
sum
(
seqlens_kv
[:
i
])
for
i
in
range
(
1
,
self
.
batch_size
+
1
)]
cu_seqlens_kv
=
cu_seqlens_kv
+
[
cu_seqlens_kv
[
-
1
]]
*
(
self
.
max_batch_size
-
self
.
batch_size
)
self
.
cu_seqlens_kv
.
copy_
(
torch
.
Tensor
(
cu_seqlens_kv
).
to
(
dtype
=
torch
.
int32
,
device
=
"cpu"
))
def
get_seqlens_pre_step
(
self
):
"""Get cached sequence lengths before the stepping"""
return
torch
.
Tensor
(
list
(
self
.
sequences_pre_step
.
values
())).
to
(
dtype
=
torch
.
int32
,
device
=
"cpu"
)
def
convert_paged_to_nonpaged
(
self
,
layer_number
:
int
):
"""
"""
Reorders the KV cache using the specified batch indices
.
Convert k_cache and v_cache from paged to non-paged format
.
Parameters
Parameters
----------
----------
batch_indices : List[int]
layer_number: int
Sequence of indices to reorder along the batch dimensions of
Layer number of attention in the model
the KV cache. Must have a length equal to the batch size.
Returns
-------
k_cache: torch.Tensor
Non-paged key cache tensor
v_cache: torch.Tensor
Non-paged value cache tensor
"""
"""
if
len
(
self
.
key_value_memory_dict
)
==
0
:
k_cache
,
v_cache
=
self
.
cache_manager
.
cache
[
layer_number
]
raise
ValueError
(
"should not swap when dict in empty"
)
page_table
=
self
.
cache_manager
.
page_table
batch_size
=
page_table
.
shape
[
0
]
new_k_cache
=
rearrange
(
k_cache
[
page_table
.
flatten
()],
"(b npages) page_size ... -> b (npages page_size) ..."
,
b
=
batch_size
,
)
new_v_cache
=
rearrange
(
v_cache
[
page_table
.
flatten
()],
"(b npages) page_size ... -> b (npages page_size) ..."
,
b
=
batch_size
,
)
for
layer_number
,
inference_memory
in
self
.
key_value_memory_dict
.
items
():
new_k_cache
=
new_k_cache
[:
self
.
batch_size
].
contiguous
()
inference_key_memory
,
inference_value_memory
=
inference_memory
new_v_cache
=
new_v_cache
[:
self
.
batch_size
].
contiguous
()
assert
(
len
(
batch_indices
)
==
inference_key_memory
.
shape
[
1
]
return
new_k_cache
,
new_v_cache
)
# make sure batch size is the same
new_inference_key_memory
=
inference_key_memory
[:,
batch_indices
]
def
step
(
new_inference_value_memory
=
inference_value_memory
[:,
batch_indices
]
self
,
self
.
key_value_memory_dict
[
layer_number
]
=
(
layer_number
:
int
,
new_inference_key_memory
,
new_k
:
torch
.
Tensor
,
new_inference_value_memory
,
new_v
:
torch
.
Tensor
,
qkv_format
:
str
,
):
"""
Copy new KV tokens to the cache.
Parameters
----------
layer_number: int
Layer number of attention in the model
new_k: torch.Tensor
New key tokens for layer_number in current inference iteration
new_v: torch.Tensor
New value tokens for layer_number in current inference iteration
qkv_format: str
Format of new_q, new_k and new_v tensors, {'bshd', 'sbhd', 'thd'}
Returns
-------
k_cache: torch.Tensor
Full key tensor containing both previous and current key tokens
v_cache: torch.Tensor
Full value tensor containing both previous and current value tokens
cu_seqlens_q: torch.Tensor
Updated cumulative sequence lengths for query, [batch_size + 1]
cu_seqlens_kv: torch.Tensor
Updated cumulative sequence lengths for key and value, [batch_size + 1]
max_seqlen_q: int
Update maximum sequence length for query
max_seqlen_kv: int
Update maximum sequence length for key and value
qkv_format: str
Updated qkv_format, e.g. 'thd' format becomes 'thd_2bshd' after step()
"""
self
.
input_qkv_format
=
qkv_format
if
self
.
input_qkv_format
==
self
.
cache_qkv_format
:
self
.
output_qkv_format
=
self
.
cache_qkv_format
else
:
self
.
output_qkv_format
=
self
.
input_qkv_format
+
"_2"
+
self
.
cache_qkv_format
k_cache
,
v_cache
=
self
.
cache_manager
.
step
(
layer_number
,
new_k
,
new_v
,
self
.
cu_seqlens_q
,
self
.
cu_seqlens_kv
,
qkv_format
,
)
return
(
k_cache
,
v_cache
,
self
.
cu_seqlens_q
,
self
.
cu_seqlens_kv
,
self
.
max_seqlen_kv
,
self
.
output_qkv_format
,
)
class
NonPagedKVCacheManager
(
KVCacheManager
):
"""Non-paged KV cache manager"""
def
__init__
(
self
,
max_batch_size
:
int
,
max_seqlen
:
int
,
num_heads
:
int
,
head_dim_k
:
int
,
dtype
:
torch
.
dtype
,
head_dim_v
:
Optional
[
int
]
=
None
,
):
super
().
__init__
()
"""Initialize cache manager"""
self
.
max_batch_size
=
max_batch_size
self
.
max_seqlen
=
max_seqlen
self
.
num_heads
=
num_heads
self
.
head_dim_k
=
head_dim_k
self
.
dtype
=
dtype
self
.
head_dim_v
=
head_dim_v
if
head_dim_v
is
not
None
else
head_dim_k
# track sequences in the cache, {seq_id: seq_len}
self
.
sequences
=
OrderedDict
()
# cache tensors, cache[layer_number] = (k_cache, v_cache)
self
.
cache
=
{}
# track sequence indices in the batch in order to re-index k_cache and v_cache
self
.
batch_indices
=
torch
.
zeros
(
self
.
max_batch_size
,
dtype
=
torch
.
int32
,
device
=
torch
.
cuda
.
current_device
(),
)
# after re-indexing, batch indices are always [0, ..., b-1]
self
.
batch_indices_post_step
=
torch
.
range
(
0
,
self
.
max_batch_size
-
1
,
dtype
=
torch
.
int32
,
device
=
torch
.
cuda
.
current_device
(),
)
def
allocate_memory
(
self
,
layer_number
):
"""Allocate memory for the cache"""
k_cache
=
torch
.
zeros
(
self
.
max_batch_size
,
self
.
max_seqlen
,
self
.
num_heads
,
self
.
head_dim_k
,
dtype
=
self
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
)
v_cache
=
torch
.
zeros
(
self
.
max_batch_size
,
self
.
max_seqlen
,
self
.
num_heads
,
self
.
head_dim_v
,
dtype
=
self
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
)
self
.
cache
[
layer_number
]
=
(
k_cache
,
v_cache
)
def
pre_step
(
self
,
step_dict
:
OrderedDict
,
):
"""Update tracked sequences and prepare for step()"""
# Track unfinished sequences' indices in the batch, e.g.
# at t-1, seq_ids = [0, 1, 2, 3]; at t, seq_ids = [0, 2, 3] since seq_id 1 is finished
# step() re-indexes k_cache and v_cache using batch_indices = [0, 2, 3, 1] so that
# they are contiguous and match the indexing in q
prev_batch_size
=
len
(
self
.
sequences
)
unfinished_seqs
=
self
.
sequences
.
keys
()
&
step_dict
.
keys
()
finished_seqs
=
self
.
sequences
.
keys
()
-
unfinished_seqs
unfinished_indices
=
[
i
for
i
,
j
in
enumerate
(
self
.
sequences
)
if
j
in
unfinished_seqs
]
finished_indices
=
[
i
for
i
,
j
in
enumerate
(
self
.
sequences
)
if
j
in
finished_seqs
]
self
.
batch_indices
.
copy_
(
torch
.
Tensor
(
(
unfinished_indices
+
finished_indices
+
list
(
range
(
prev_batch_size
,
self
.
max_batch_size
))
)
).
to
(
dtype
=
torch
.
int32
,
device
=
"cpu"
)
)
# Advance unfinished sequences
for
i
in
unfinished_seqs
:
self
.
sequences
[
i
]
+=
1
# Remove finished sequences
for
i
in
finished_seqs
:
self
.
sequences
.
pop
(
i
)
# Add new sequences
new_seqs
=
step_dict
.
keys
()
-
self
.
sequences
.
keys
()
for
i
in
new_seqs
:
self
.
sequences
[
i
]
=
step_dict
[
i
]
return
self
.
sequences
def
step
(
self
,
layer_number
,
new_k
:
torch
.
Tensor
,
new_v
:
torch
.
Tensor
,
cu_new_seqlens
,
cu_cached_seqlens
,
qkv_format
:
str
,
):
"""
Copy the new tokens to the non-paged KV cache.
Parameters
----------
layer_number: int
Layer number of attention in the model
new_k: torch.Tensor
New key tokens for layer_number in current inference iteration
new_v: torch.Tensor
New value tokens for layer_number in current inference iteration
cu_new_seqlens: torch.Tensor
Cumulative sequence lengths for new_k and new_v, in shape [batch_size + 1]
cu_cached_seqlens: torch.Tensor
Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), in shape [batch_size + 1]
qkv_format: str
Format of new_k and new_v tensors, {'bshd', 'sbhd', 'thd'}
Returns
-------
k_cache: torch.Tensor
Full key tensor containing both previous and current key tokens
v_cache: torch.Tensor
Full value tensor containing both previous and current value tokens
"""
k_cache
,
v_cache
=
self
.
cache
[
layer_number
]
batch_size
=
self
.
max_batch_size
ctx_len
=
1
if
qkv_format
==
"bshd"
:
batch_size
=
new_k
.
shape
[
0
]
ctx_len
=
new_k
.
shape
[
1
]
if
qkv_format
==
"sbhd"
:
batch_size
=
new_k
.
shape
[
1
]
ctx_len
=
new_k
.
shape
[
0
]
tex
.
copy_to_kv_cache
(
new_k
,
new_v
,
k_cache
,
v_cache
,
self
.
batch_indices
,
cu_new_seqlens
,
cu_cached_seqlens
,
QKVFormat
[
qkv_format
],
batch_size
,
ctx_len
,
self
.
max_seqlen
,
1
,
True
,
)
k_cache
=
k_cache
[:
batch_size
]
v_cache
=
v_cache
[:
batch_size
]
return
k_cache
,
v_cache
class
Page
:
"""A single page"""
def
__init__
(
self
,
page_id
:
int
):
"""Initialize a page"""
self
.
page_id
=
page_id
self
.
allocated
=
0
def
allocate_page
(
self
):
"""Allocate a page"""
self
.
allocated
=
True
def
deallocate_page
(
self
):
"""Deallocate a page"""
self
.
allocated
=
False
class
PagedKVCacheManager
(
KVCacheManager
):
"""Paged KV cache manager"""
def
__init__
(
self
,
total_num_pages
:
int
,
page_size
:
int
,
num_heads
:
int
,
head_dim_k
:
int
,
dtype
:
torch
.
dtype
,
max_batch_size
:
int
,
max_seqlen
:
int
,
head_dim_v
:
Optional
[
int
]
=
None
,
):
super
().
__init__
()
"""Initialize cache manager"""
self
.
total_num_pages
=
total_num_pages
self
.
page_size
=
page_size
self
.
num_heads
=
num_heads
self
.
head_dim_k
=
head_dim_k
self
.
dtype
=
dtype
self
.
max_batch_size
=
max_batch_size
self
.
max_seqlen
=
max_seqlen
self
.
max_pages_per_seq
=
max_seqlen
//
self
.
page_size
self
.
head_dim_v
=
head_dim_v
if
head_dim_v
is
not
None
else
head_dim_k
# track sequences in the cache, {seq_id: seq_len}
self
.
sequences
=
OrderedDict
()
# cache tensors, cache[layer_number] = (k_cache, v_cache)
self
.
cache
=
{}
# available pages, [Page(),...]
self
.
free_pages
=
[]
for
i
in
range
(
self
.
total_num_pages
):
self
.
free_pages
.
append
(
Page
(
i
))
# allocated pages, {seq_id: [page_id,...]}
self
.
allocated_pages
=
defaultdict
(
list
)
# page table, [batch_size, max_pages_per_seq]
self
.
page_table
=
torch
.
zeros
(
self
.
max_batch_size
,
self
.
max_pages_per_seq
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
def
reset
(
self
):
"""Reset cache manager state"""
self
.
sequences
=
OrderedDict
()
self
.
free_pages
=
[]
for
i
in
range
(
self
.
total_num_pages
):
self
.
free_pages
.
append
(
Page
(
i
))
self
.
allocated_pages
=
defaultdict
(
list
)
self
.
page_table
.
fill_
(
0
)
def
allocate_memory
(
self
,
layer_number
):
"""Allocate memory for the cache"""
k_cache
=
torch
.
zeros
(
self
.
total_num_pages
,
self
.
page_size
,
self
.
num_heads
,
self
.
head_dim_k
,
dtype
=
self
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
)
v_cache
=
torch
.
zeros
(
self
.
total_num_pages
,
self
.
page_size
,
self
.
num_heads
,
self
.
head_dim_v
,
dtype
=
self
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
)
self
.
cache
[
layer_number
]
=
(
k_cache
,
v_cache
)
def
print_cache
(
self
):
"""Print KV cache status"""
used_pages
=
[
self
.
get_page_count
(
seq
)
for
seq
in
self
.
sequences
]
logger
=
logging
.
getLogger
(
"PagedKVCacheManager"
)
logger
.
debug
(
"Cache status:"
)
logger
.
debug
(
" total pages: %s (used %s, free %s)"
,
self
.
total_num_pages
,
sum
(
used_pages
),
len
(
self
.
free_pages
),
)
logger
.
debug
(
" total sequences: %s"
,
self
.
get_sequence_count
())
for
i
,
seq
in
enumerate
(
self
.
sequences
):
logger
.
debug
(
" >> batch index %s: seq_id %s, num_tokens %s, num_pages %s, page_list %s"
,
i
,
seq
,
self
.
get_sequence_lengths
()[
i
],
self
.
get_page_count
(
seq
),
self
.
get_page_list
(
seq
),
)
)
def
get_sequence_count
(
self
):
"""Get the total number of sequences in the KV cache"""
return
len
(
self
.
sequences
)
def
get_sequence_lengths
(
self
):
"""Get the list of sequence lengths in the KV cache"""
return
list
(
self
.
sequences
.
values
())
def
has_free_page
(
self
)
->
bool
:
"""Whether the page pool has any free pages left"""
return
len
(
self
.
free_pages
)
>
0
def
get_page_count
(
self
,
seq
:
int
):
"""Get the number of pages allocated to a sequence"""
return
len
(
self
.
allocated_pages
[
seq
])
def
get_page_list
(
self
,
seq
:
int
):
"""Get the list of pages allocated to a sequence"""
return
[
x
.
page_id
for
x
in
self
.
allocated_pages
[
seq
]]
def
get_page_table
(
self
,
sequences
:
List
[
int
]):
"""Get the page table, in shape [batch_size, max_pages_per_seq]"""
page_table
=
torch
.
Tensor
(
[
self
.
get_page_list
(
seq
)
+
[
0
]
*
(
self
.
max_pages_per_seq
-
self
.
get_page_count
(
seq
))
for
seq
in
sequences
]
).
to
(
dtype
=
torch
.
int32
,
device
=
"cpu"
)
self
.
page_table
[:
self
.
get_sequence_count
()].
copy_
(
page_table
)
return
self
.
page_table
def
allocate_page
(
self
,
seq
:
int
):
"""Allocate a new page to a sequence"""
if
not
self
.
has_free_page
():
raise
RuntimeError
(
"KV cache is full!"
)
page
=
self
.
free_pages
.
pop
(
0
)
page
.
allocate_page
()
self
.
allocated_pages
[
seq
].
append
(
page
)
def
allocate_sequence
(
self
,
seq
:
int
,
context_len
:
int
):
"""Add a new sequence to the cache"""
num_pages
=
context_len
//
self
.
page_size
if
context_len
%
self
.
page_size
>
0
:
num_pages
=
num_pages
+
1
for
_
in
range
(
num_pages
):
self
.
allocate_page
(
seq
)
def
deallocate_sequence
(
self
,
seq
:
int
):
"""Deallocate all the pages for a sequence"""
for
page
in
self
.
allocated_pages
[
seq
]:
page
.
deallocate_page
()
if
not
page
.
allocated
:
self
.
free_pages
.
append
(
page
)
self
.
allocated_pages
.
pop
(
seq
)
def
pre_step
(
self
,
step_dict
:
OrderedDict
,
):
"""Update tracked sequences and prepare for step()"""
# Remove finished sequences and advance unfinished sequences
unfinished_seqs
=
self
.
sequences
.
keys
()
&
step_dict
.
keys
()
finished_seqs
=
self
.
sequences
.
keys
()
-
unfinished_seqs
for
seq
in
finished_seqs
:
self
.
sequences
.
pop
(
seq
)
self
.
deallocate_sequence
(
seq
)
for
seq
in
unfinished_seqs
:
if
self
.
sequences
[
seq
]
%
self
.
page_size
==
0
and
self
.
sequences
[
seq
]
<
self
.
max_seqlen
:
self
.
allocate_page
(
seq
)
self
.
sequences
[
seq
]
+=
1
# Add new sequences
new_seqs
=
step_dict
.
keys
()
-
self
.
sequences
.
keys
()
for
seq
in
new_seqs
:
self
.
sequences
[
seq
]
=
step_dict
[
seq
]
self
.
allocate_sequence
(
seq
,
step_dict
[
seq
])
# Get page table
self
.
page_table
=
self
.
get_page_table
(
list
(
self
.
sequences
.
keys
()))
return
self
.
sequences
def
step
(
self
,
layer_number
:
int
,
new_k
:
torch
.
Tensor
,
new_v
:
torch
.
Tensor
,
cu_new_seqlens
,
cu_cached_seqlens
,
qkv_format
:
str
,
):
"""
Copy the new tokens to the paged KV cache.
Parameters
----------
layer_number: int
Layer number of attention in the model
new_k: torch.Tensor
New key tokens for layer_number in current inference iteration
new_v: torch.Tensor
New value tokens for layer_number in current inference iteration
cu_new_seqlens: torch.Tensor
Cumulative sequence lengths for new_k and new_v, in shape [batch_size + 1]
cu_cached_seqlens: torch.Tensor
Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), in shape [batch_size + 1]
qkv_format: str
Format of new_k and new_v tensors, {'bshd', 'sbhd', 'thd'}
Returns
-------
k_cache: torch.Tensor
Full key tensor containing both previous and current key tokens
v_cache: torch.Tensor
Full value tensor containing both previous and current value tokens
"""
k_cache
,
v_cache
=
self
.
cache
[
layer_number
]
batch_size
=
self
.
max_batch_size
ctx_len
=
1
if
qkv_format
==
"bshd"
:
batch_size
=
new_k
.
shape
[
0
]
ctx_len
=
new_k
.
shape
[
1
]
if
qkv_format
==
"sbhd"
:
batch_size
=
new_k
.
shape
[
1
]
ctx_len
=
new_k
.
shape
[
0
]
tex
.
copy_to_kv_cache
(
new_k
,
new_v
,
k_cache
,
v_cache
,
self
.
page_table
,
cu_new_seqlens
,
cu_cached_seqlens
,
QKVFormat
[
qkv_format
],
batch_size
,
ctx_len
,
self
.
max_seqlen
,
self
.
max_pages_per_seq
,
False
,
)
return
k_cache
,
v_cache
transformer_engine/pytorch/dot_product_attention/utils.py
View file @
4099aa8e
...
@@ -34,6 +34,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
...
@@ -34,6 +34,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
META_O_CP
,
META_O_CP
,
META_DQKV_CP
,
META_DQKV_CP
,
)
)
from
transformer_engine.pytorch.dot_product_attention.inference
import
InferenceParams
from
transformer_engine.pytorch.float8_tensor
import
Float8Tensor
from
transformer_engine.pytorch.float8_tensor
import
Float8Tensor
from
transformer_engine.pytorch.fp8
import
get_fp8_te_dtype
from
transformer_engine.pytorch.fp8
import
get_fp8_te_dtype
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.constants
import
TE_DType
...
@@ -91,7 +92,6 @@ class FlashAttentionUtils:
...
@@ -91,7 +92,6 @@ class FlashAttentionUtils:
Manage Flash Attention versioning information
Manage Flash Attention versioning information
"""
"""
# Detect flash-attn v2 in the environment
is_installed
=
False
is_installed
=
False
version
=
PkgVersion
(
"0"
)
version
=
PkgVersion
(
"0"
)
version_required
=
PkgVersion
(
"2.1.1"
)
version_required
=
PkgVersion
(
"2.1.1"
)
...
@@ -102,21 +102,25 @@ class FlashAttentionUtils:
...
@@ -102,21 +102,25 @@ class FlashAttentionUtils:
v2_3_plus
=
False
v2_3_plus
=
False
v2_4_plus
=
False
v2_4_plus
=
False
v2_4_1_plus
=
False
v2_4_1_plus
=
False
v2_5_plus
=
False
v2_5_7_plus
=
False
v2_5_7_plus
=
False
v2_6_0_plus
=
False
v2_6_0_plus
=
False
v2_7_0_plus
=
False
v2_7_0_plus
=
False
warning_printed
=
False
v3_is_installed
=
False
v3_is_installed
=
False
fa3_version
=
PkgVersion
(
"0"
)
fa3_version
=
PkgVersion
(
"0"
)
v3_0_0_beta
=
False
v3_0_0_beta
=
False
use_v3
=
False
use_v3
=
False
#
TODO(cyang): update
FA
to
2.7.3
when its FA3 compilation issue is resolved
#
FA3 from
FA 2.7.3
+/hopper has different APIs than FA3 from 2.7.2/hopper
#
https://github.com/Dao-AILab/flash-attention/issues/1452
#
Please follow these instructions to install FA3
v3_installation_steps
=
"""
\
v3_installation_steps
=
"""
\
(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper"
(1) git clone https://github.com/Dao-AILab/flash-attention.git
(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"`
(2) cd flash-attention/ && git checkout 27f501d && cd hopper/ && python setup.py install
(3) mkdir -p $python_path/flashattn_hopper
(3) python_path=`python -c "import site; print(site.getsitepackages()[0])"`
(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/v2.7.2/hopper/flash_attn_interface.py"""
(4) mkdir -p $python_path/flash_attn_3
(5) wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/27f501dbe011f4371bff938fe7e09311ab3002fa/hopper/flash_attn_interface.py"""
v3_warning_printed
=
False
@
staticmethod
@
staticmethod
def
set_flash_attention_version
():
def
set_flash_attention_version
():
...
@@ -129,13 +133,11 @@ class FlashAttentionUtils:
...
@@ -129,13 +133,11 @@ class FlashAttentionUtils:
FlashAttentionUtils
.
v2_3_plus
=
FlashAttentionUtils
.
version
>=
PkgVersion
(
"2.3"
)
FlashAttentionUtils
.
v2_3_plus
=
FlashAttentionUtils
.
version
>=
PkgVersion
(
"2.3"
)
FlashAttentionUtils
.
v2_4_plus
=
FlashAttentionUtils
.
version
>=
PkgVersion
(
"2.4"
)
FlashAttentionUtils
.
v2_4_plus
=
FlashAttentionUtils
.
version
>=
PkgVersion
(
"2.4"
)
FlashAttentionUtils
.
v2_4_1_plus
=
FlashAttentionUtils
.
version
>=
PkgVersion
(
"2.4.1"
)
FlashAttentionUtils
.
v2_4_1_plus
=
FlashAttentionUtils
.
version
>=
PkgVersion
(
"2.4.1"
)
FlashAttentionUtils
.
v2_5_plus
=
FlashAttentionUtils
.
version
>=
PkgVersion
(
"2.5.0"
)
FlashAttentionUtils
.
v2_5_7_plus
=
FlashAttentionUtils
.
version
>=
PkgVersion
(
"2.5.7"
)
FlashAttentionUtils
.
v2_5_7_plus
=
FlashAttentionUtils
.
version
>=
PkgVersion
(
"2.5.7"
)
FlashAttentionUtils
.
v2_6_0_plus
=
FlashAttentionUtils
.
version
>=
PkgVersion
(
"2.6.0"
)
FlashAttentionUtils
.
v2_6_0_plus
=
FlashAttentionUtils
.
version
>=
PkgVersion
(
"2.6.0"
)
FlashAttentionUtils
.
v2_7_0_plus
=
FlashAttentionUtils
.
version
>=
PkgVersion
(
"2.7.0"
)
FlashAttentionUtils
.
v2_7_0_plus
=
FlashAttentionUtils
.
version
>=
PkgVersion
(
"2.7.0"
)
# Detect flash-attn v3 in the environment
# This section will be removed when FA3 is released as a regular FA package,
# i.e. flashattn-hopper 3.0.0 as flash-attn 3.0.0
@
staticmethod
@
staticmethod
def
set_flash_attention_3_params
():
def
set_flash_attention_3_params
():
"""
"""
...
@@ -145,7 +147,6 @@ class FlashAttentionUtils:
...
@@ -145,7 +147,6 @@ class FlashAttentionUtils:
FlashAttentionUtils
.
v3_0_0_beta
=
(
FlashAttentionUtils
.
v3_0_0_beta
=
(
PkgVersion
(
"3.0.0b"
)
<
FlashAttentionUtils
.
fa3_version
<
PkgVersion
(
"3.0.0"
)
PkgVersion
(
"3.0.0b"
)
<
FlashAttentionUtils
.
fa3_version
<
PkgVersion
(
"3.0.0"
)
)
)
FlashAttentionUtils
.
use_v3
=
True
@
dataclass
(
eq
=
True
)
@
dataclass
(
eq
=
True
)
...
@@ -203,6 +204,8 @@ class AttentionParams:
...
@@ -203,6 +204,8 @@ class AttentionParams:
Whether `DotProductAttention` is in an `fp8_autocast` region.
Whether `DotProductAttention` is in an `fp8_autocast` region.
fp8_meta: Optional[Dict[str Any]], default = `None`
fp8_meta: Optional[Dict[str Any]], default = `None`
The FP8 metadata tensor of `DotProductAttention`.
The FP8 metadata tensor of `DotProductAttention`.
inference_params: Optional[InferenceParams], default = `None`
Inference-related parameters. See InferenceParams for details.
"""
"""
qkv_type
:
Union
[
torch
.
Tensor
,
Float8Tensor
]
=
torch
.
Tensor
qkv_type
:
Union
[
torch
.
Tensor
,
Float8Tensor
]
=
torch
.
Tensor
...
@@ -228,6 +231,7 @@ class AttentionParams:
...
@@ -228,6 +231,7 @@ class AttentionParams:
is_training
:
bool
=
True
is_training
:
bool
=
True
fp8
:
bool
=
False
fp8
:
bool
=
False
fp8_meta
:
Union
[
Dict
[
str
,
Any
],
None
]
=
None
fp8_meta
:
Union
[
Dict
[
str
,
Any
],
None
]
=
None
inference_params
:
Optional
[
InferenceParams
]
=
None
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
"""
"""
...
@@ -298,6 +302,7 @@ def get_attention_backend(
...
@@ -298,6 +302,7 @@ def get_attention_backend(
is_training
=
attention_params
.
is_training
is_training
=
attention_params
.
is_training
fp8
=
attention_params
.
fp8
fp8
=
attention_params
.
fp8
fp8_meta
=
attention_params
.
fp8_meta
fp8_meta
=
attention_params
.
fp8_meta
inference_params
=
attention_params
.
inference_params
# Run config
# Run config
logger
=
logging
.
getLogger
(
"DotProductAttention"
)
logger
=
logging
.
getLogger
(
"DotProductAttention"
)
...
@@ -334,13 +339,19 @@ def get_attention_backend(
...
@@ -334,13 +339,19 @@ def get_attention_backend(
# regardless of whether FA2 or FA3 is installed. If FA2 or FA3 is not installed but is
# regardless of whether FA2 or FA3 is installed. If FA2 or FA3 is not installed but is
# necessary for performance/functionality, a warning will be issued to prompt users to
# necessary for performance/functionality, a warning will be issued to prompt users to
# install an appropriate FA version.
# install an appropriate FA version.
qkv_format
,
q_format
,
_
=
get_qkv_format
(
qkv_layout
,
inference_params
)
# Filter: Environment variables
# Filter: Environment variables
use_flash_attention
=
int
(
os
.
getenv
(
"NVTE_FLASH_ATTN"
,
"1"
))
use_flash_attention
=
int
(
os
.
getenv
(
"NVTE_FLASH_ATTN"
,
"1"
))
use_flash_attention_2
=
use_flash_attention
use_flash_attention_3
=
use_flash_attention
flash_attention_backend
=
None
use_fused_attention
=
int
(
os
.
getenv
(
"NVTE_FUSED_ATTN"
,
"1"
))
use_fused_attention
=
int
(
os
.
getenv
(
"NVTE_FUSED_ATTN"
,
"1"
))
use_unfused_attention
=
int
(
os
.
getenv
(
"NVTE_UNFUSED_ATTN"
,
"1"
))
use_unfused_attention
=
int
(
os
.
getenv
(
"NVTE_UNFUSED_ATTN"
,
"1"
))
if
not
use_flash_attention
and
FlashAttentionUtils
.
is_installed
:
if
not
use_flash_attention_2
and
FlashAttentionUtils
.
is_installed
:
logger
.
debug
(
"Disabling FlashAttention due to NVTE_FLASH_ATTN=0"
)
logger
.
debug
(
"Disabling FlashAttention 2 due to NVTE_FLASH_ATTN=0"
)
if
not
use_flash_attention_3
and
FlashAttentionUtils
.
v3_is_installed
:
logger
.
debug
(
"Disabling FlashAttention 3 due to NVTE_FLASH_ATTN=0"
)
if
not
use_fused_attention
:
if
not
use_fused_attention
:
logger
.
debug
(
"Disabling FusedAttention due to NVTE_FUSED_ATTN=0"
)
logger
.
debug
(
"Disabling FusedAttention due to NVTE_FUSED_ATTN=0"
)
if
not
use_unfused_attention
:
if
not
use_unfused_attention
:
...
@@ -348,70 +359,134 @@ def get_attention_backend(
...
@@ -348,70 +359,134 @@ def get_attention_backend(
# Filter: Compute capability
# Filter: Compute capability
if
not
IS_HIP_EXTENSION
and
device_compute_capability
<
(
8
,
0
):
if
not
IS_HIP_EXTENSION
and
device_compute_capability
<
(
8
,
0
):
if
use_flash_attention
and
FlashAttentionUtils
.
is_installed
:
if
use_flash_attention
_2
and
FlashAttentionUtils
.
is_installed
:
logger
.
debug
(
"Disabling FlashAttention
as it requires
compute capability sm80
+
"
)
logger
.
debug
(
"Disabling FlashAttention
2 for
compute capability
<
sm80"
)
use_flash_attention
=
False
use_flash_attention
_2
=
False
if
use_fused_attention
:
if
use_fused_attention
:
logger
.
debug
(
"Disabling FusedAttention
as it requires
compute capability sm80
+
"
)
logger
.
debug
(
"Disabling FusedAttention
for
compute capability
<
sm80"
)
use_fused_attention
=
False
use_fused_attention
=
False
if
device_compute_capability
<
(
9
,
0
):
if
device_compute_capability
!=
(
9
,
0
):
if
use_flash_attention
and
FlashAttentionUtils
.
v3_is_installed
:
if
use_flash_attention
_3
and
FlashAttentionUtils
.
v3_is_installed
:
logger
.
debug
(
"Disabling FlashAttention 3
as it requires
compute capability sm90
+
"
)
logger
.
debug
(
"Disabling FlashAttention 3
for
compute capability
!=
sm90"
)
F
lash
A
ttention
Utils
.
use_v
3
=
False
use_f
lash
_a
ttention
_
3
=
False
# Filter: Data type
# Filter: Data type
if
qkv_dtype
not
in
[
torch
.
bfloat16
,
torch
.
float16
]
or
qkv_type
not
in
[
if
qkv_dtype
not
in
[
torch
.
bfloat16
,
torch
.
float16
]:
if
use_flash_attention_2
and
FlashAttentionUtils
.
is_installed
:
logger
.
debug
(
"Disabling FlashAttention 2 for unsupported qkv_dtype = %s. "
"Supported: qkv_dtype = {torch.bfloat16, torch.float16}. "
,
qkv_dtype
,
)
use_flash_attention_2
=
False
if
qkv_dtype
not
in
[
torch
.
bfloat16
,
torch
.
float16
,
torch
.
float8_e4m3fn
]
or
qkv_type
not
in
[
torch
.
Tensor
,
torch
.
Tensor
,
Float8Tensor
,
Float8Tensor
,
]:
]:
if
use_flash_attention
and
FlashAttentionUtils
.
is_installed
:
if
use_flash_attention
_3
and
FlashAttentionUtils
.
v3_
is_installed
:
logger
.
debug
(
logger
.
debug
(
"Disabling FlashAttention
due to
unsupported
QKV data type
. "
"Disabling FlashAttention
3 for
unsupported
qkv_dtype = %s, qkv_type = %s
. "
"Supported: qkv_dtype = {torch.bfloat16, torch.float16
}.
"
"Supported: qkv_dtype = {torch.bfloat16, torch.float16
, torch.float8_e4m3fn},
"
"
Found:
qkv_
d
type =
%s.
"
,
"qkv_type =
{torch.Tensor, Float8Tensor}.
"
,
qkv_dtype
,
qkv_dtype
,
qkv_type
,
)
)
use_flash_attention
=
False
use_flash_attention
_3
=
False
if
use_fused_attention
:
if
use_fused_attention
:
logger
.
debug
(
logger
.
debug
(
"Disabling FusedAttention
due to
unsupported
QKV data type
. "
"Disabling FusedAttention
for
unsupported
qkv_dtype = %s, qkv_type = %s
. "
"Supported: qkv_dtype = {torch.bfloat16, torch.float16
}.
"
"Supported: qkv_dtype = {torch.bfloat16, torch.float16
, torch.float8_e4m3fn},
"
"
Found:
qkv_
d
type =
%s.
"
,
"qkv_type =
{torch.Tensor, Float8Tensor}.
"
,
qkv_dtype
,
qkv_dtype
,
qkv_type
,
)
)
use_fused_attention
=
False
use_fused_attention
=
False
# Filter: Execution type
# Filter: Execution type
if
fp8
and
fp8_meta
[
"recipe"
].
fp8_dpa
:
if
fp8
and
fp8_meta
[
"recipe"
].
fp8_dpa
:
if
use_flash_attention
and
not
FlashAttentionUtils
.
use_v3
:
if
use_flash_attention_2
and
FlashAttentionUtils
.
is_installed
:
if
FlashAttentionUtils
.
is_installed
:
logger
.
debug
(
"Disabling FlashAttention 2 for FP8 attention"
)
logger
.
debug
(
"Disabling FlashAttention as FlashAttention 2 does not support FP8"
)
use_flash_attention_2
=
False
use_flash_attention
=
False
if
use_flash_attention_3
and
is_training
:
if
use_flash_attention
and
FlashAttentionUtils
.
use_v3
and
is_training
:
if
FlashAttentionUtils
.
v3_is_installed
:
logger
.
debug
(
logger
.
debug
(
"Disabling FlashAttention 3 for FP8 training"
)
"Disabling FlashAttention as FlashAttention 3 does not support FP8 training"
use_flash_attention_3
=
False
)
use_flash_attention
=
False
if
use_unfused_attention
:
if
use_unfused_attention
:
logger
.
debug
(
"Disabling UnfusedDotProductAttention
as it does not support FP8
"
)
logger
.
debug
(
"Disabling UnfusedDotProductAttention
for FP8 attention
"
)
use_unfused_attention
=
False
use_unfused_attention
=
False
# TODO: rocm fused attention backends does not support fp8 yet
# TODO: rocm fused attention backends does not support fp8 yet
if
IS_HIP_EXTENSION
and
use_fused_attention
:
if
IS_HIP_EXTENSION
and
use_fused_attention
:
logger
.
debug
(
"Disabling ROCm FusedAttention as it does not support FP8"
)
logger
.
debug
(
"Disabling ROCm FusedAttention as it does not support FP8"
)
use_fused_attention
=
False
use_fused_attention
=
False
# Filter: KV cache
# backend | precision | KV cache | architecture | qkv_format | page_size
# ---------------------------------------------------------------------------------------
# Fused | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd | >= 1
# Flash v2 | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd | >= 256
# Flash v3 | FP16/BF16 | non-paged/paged | sm90 | bshd,sbhd,thd | >= 1
# | FP8 | non-paged/paged | sm90 | thd | >= 1
# Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1
if
inference_params
is
not
None
:
if
context_parallel
:
logger
.
debug
(
"Disabling all backends for KV caching with context parallelism"
)
use_flash_attention
=
False
use_fused_attention
=
False
use_unfused_attention
=
False
if
fp8
and
fp8_meta
[
"recipe"
].
fp8_dpa
:
if
fp8_meta
[
"recipe"
].
fp8_mha
:
logger
.
debug
(
"Disabling all backends for KV caching with FP8 MHA"
)
use_flash_attention
=
False
use_fused_attention
=
False
use_unfused_attention
=
False
if
use_flash_attention_3
and
q_format
!=
"thd"
:
if
FlashAttentionUtils
.
v3_is_installed
:
logger
.
debug
(
"Disabling FlashAttention 3 for FP8 KV caching and non-THD"
)
use_flash_attention_3
=
False
if
use_fused_attention
:
logger
.
debug
(
"Disabling FusedAttention for FP8 KV caching"
)
use_fused_attention
=
False
else
:
if
q_format
==
"thd"
and
pad_between_seqs
:
logger
.
debug
(
"Disabling all backends for pad_between_seqs = True and KV caching"
)
use_flash_attention
=
False
use_fused_attention
=
False
use_unfused_attention
=
False
if
inference_params
.
is_paged
:
if
use_flash_attention_2
and
inference_params
.
page_size
<
256
:
if
FlashAttentionUtils
.
is_installed
:
logger
.
debug
(
"Disabling FlashAttention 2 for page size < 256"
)
use_flash_attention_2
=
False
if
use_flash_attention_2
:
if
not
FlashAttentionUtils
.
is_installed
:
FlashAttentionUtils
.
version_required
=
PkgVersion
(
"2.5"
)
elif
not
FlashAttentionUtils
.
v2_5_plus
:
logger
.
debug
(
"Disabling FlashAttention 2 as paged attention requires flash-attn 2.5+"
)
use_flash_attention_2
=
False
# Filter: Head dimension
# Filter: Head dimension
if
not
IS_HIP_EXTENSION
:
if
not
IS_HIP_EXTENSION
:
if
use_flash_attention
and
head_dim_qk
!=
head_dim_v
:
if
head_dim_qk
!=
head_dim_v
:
if
FlashAttentionUtils
.
is_installed
:
if
(
use_flash_attention_2
and
FlashAttentionUtils
.
is_installed
)
or
(
use_flash_attention_3
and
FlashAttentionUtils
.
v3_is_installed
):
logger
.
debug
(
"Disabling FlashAttention as it does not support MLA."
)
logger
.
debug
(
"Disabling FlashAttention as it does not support MLA."
)
use_flash_attention
=
False
use_flash_attention
=
False
qkv_layout_group
=
qkv_layout
.
replace
(
"b"
,
""
).
replace
(
"s"
,
""
).
replace
(
"t"
,
""
)
if
use_fused_attention
and
qkv_layout_group
!=
"hd_hd_hd"
:
logger
.
debug
(
"Disabling FusedAttention as MLA is not supported with qkv_layout = %s"
,
qkv_layout
,
)
use_fused_attention
=
False
else
:
else
:
if
use_fused_attention
and
head_dim_qk
!=
head_dim_v
:
if
use_fused_attention
and
head_dim_qk
!=
head_dim_v
:
logger
.
debug
(
"Disabling FusedAttention as it does not support MLA in rocm backend."
)
logger
.
debug
(
"Disabling FusedAttention as it does not support MLA in rocm backend."
)
use_fused_attention
=
False
use_fused_attention
=
False
if
use_flash_attention
and
(
if
use_flash_attention
_2
and
(
head_dim_qk
>
256
head_dim_qk
>
256
or
head_dim_qk
%
8
!=
0
or
head_dim_qk
%
8
!=
0
or
(
or
(
...
@@ -421,7 +496,7 @@ def get_attention_backend(
...
@@ -421,7 +496,7 @@ def get_attention_backend(
):
):
if
FlashAttentionUtils
.
is_installed
:
if
FlashAttentionUtils
.
is_installed
:
logger
.
debug
(
logger
.
debug
(
"Disabling FlashAttention due to unsupported head_dim_qk and head_dim_v. "
"Disabling FlashAttention
2
due to unsupported head_dim_qk and head_dim_v. "
"Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, "
"Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, "
"head_dim_qk <= 256 (>192 requires sm80/90/100+). "
"head_dim_qk <= 256 (>192 requires sm80/90/100+). "
"Found: head_dim_qk = %s, head_dim_v = %s, on sm%s."
,
"Found: head_dim_qk = %s, head_dim_v = %s, on sm%s."
,
...
@@ -429,23 +504,21 @@ def get_attention_backend(
...
@@ -429,23 +504,21 @@ def get_attention_backend(
head_dim_v
,
head_dim_v
,
"."
.
join
([
str
(
i
)
for
i
in
device_compute_capability
]),
"."
.
join
([
str
(
i
)
for
i
in
device_compute_capability
]),
)
)
use_flash_attention
=
False
use_flash_attention_2
=
False
qkv_layout_group
=
qkv_layout
.
replace
(
"b"
,
""
).
replace
(
"s"
,
""
).
replace
(
"t"
,
""
)
if
use_flash_attention_3
and
(
head_dim_qk
>
128
or
head_dim_v
>
128
):
if
use_fused_attention
and
head_dim_qk
!=
head_dim_v
and
qkv_layout_group
!=
"hd_hd_hd"
:
if
FlashAttentionUtils
.
v3_is_installed
:
logger
.
debug
(
logger
.
debug
(
"Disabling FlashAttention 3 for head_dim > 128"
)
"Disabling FusedAttention as MLA is not supported with qkv_layout = %s"
,
use_flash_attention_3
=
False
qkv_layout
,
)
use_fused_attention
=
False
# Filter: QKV layout
# Filter: QKV layout
qkv_format
=
""
.
join
([
i
for
i
in
qkv_layout
.
split
(
"_"
)[
0
]
if
i
.
isalpha
()])
if
qkv_format
==
"thd"
:
if
qkv_format
==
"thd"
:
if
use_unfused_attention
:
if
use_unfused_attention
:
logger
.
debug
(
"Disabling UnfusedDotProductAttention for qkv_format = thd"
)
logger
.
debug
(
"Disabling UnfusedDotProductAttention for qkv_format = thd"
)
use_unfused_attention
=
False
use_unfused_attention
=
False
if
use_flash_attention
and
pad_between_seqs
:
if
pad_between_seqs
:
if
FlashAttentionUtils
.
is_installed
:
if
(
use_flash_attention_2
and
FlashAttentionUtils
.
is_installed
)
or
(
use_flash_attention_3
and
FlashAttentionUtils
.
v3_is_installed
):
logger
.
debug
(
logger
.
debug
(
"Disabling FlashAttention for qkv_format = thd when there is "
"Disabling FlashAttention for qkv_format = thd when there is "
"padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]"
"padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]"
...
@@ -459,9 +532,9 @@ def get_attention_backend(
...
@@ -459,9 +532,9 @@ def get_attention_backend(
use_fused_attention
=
False
use_fused_attention
=
False
# Filter: Dropout
# Filter: Dropout
if
attention_dropout
!=
0.0
and
use_flash_attention
and
FlashAttentionUtils
.
use_v
3
:
if
attention_dropout
!=
0.0
and
use_flash_attention
_
3
:
logger
.
debug
(
"Disabling FlashAttention 3 for dropout"
)
logger
.
debug
(
"Disabling FlashAttention 3 for dropout"
)
F
lash
A
ttention
Utils
.
use_v
3
=
False
use_f
lash
_a
ttention
_
3
=
False
# Filter: Context parallelism
# Filter: Context parallelism
# qkv_format | attn_mask_type | attn_bias_type | supported backends
# qkv_format | attn_mask_type | attn_bias_type | supported backends
...
@@ -480,42 +553,38 @@ def get_attention_backend(
...
@@ -480,42 +553,38 @@ def get_attention_backend(
"Disabling UnfusedDotProductAttention as it does not support context parallelism"
"Disabling UnfusedDotProductAttention as it does not support context parallelism"
)
)
use_unfused_attention
=
False
use_unfused_attention
=
False
if
context_parallel
and
use_flash_attention
:
if
context_parallel
and
(
use_flash_attention
_2
or
use_flash_attention_3
)
:
if
fp8
and
fp8_meta
[
"recipe"
].
fp8_dpa
:
if
FlashAttentionUtils
.
is_installed
or
FlashAttentionUtils
.
v3_is_installed
:
if
FlashAttentionUtils
.
is_installed
:
if
fp8
and
fp8_meta
[
"recipe"
].
fp8_dpa
:
logger
.
debug
(
logger
.
debug
(
"Disabling FlashAttention as it does not support context parallelism with FP8"
"Disabling FlashAttention as it does not support context parallelism with FP8"
)
)
use_flash_attention
=
False
use_flash_attention
=
False
if
"bottom_right"
in
attn_mask_type
:
if
"bottom_right"
in
attn_mask_type
:
if
FlashAttentionUtils
.
is_installed
:
logger
.
debug
(
logger
.
debug
(
"Disabling FlashAttention as it does not support context parallelism with"
"Disabling FlashAttention as it does not support context parallelism with"
" causal_bottom_right masking"
" causal_bottom_right masking"
)
)
use_flash_attention
=
False
use_flash_attention
=
False
elif
"causal"
in
attn_mask_type
and
max_seqlen_q
!=
max_seqlen_kv
:
elif
"causal"
in
attn_mask_type
and
max_seqlen_q
!=
max_seqlen_kv
:
if
FlashAttentionUtils
.
is_installed
:
logger
.
debug
(
logger
.
debug
(
"Disabling FlashAttention as it does not support context parallelism with"
"Disabling FlashAttention as it does not support context parallelism with"
" causal masking for cross-attention"
" causal masking for cross-attention"
)
)
use_flash_attention
=
False
use_flash_attention
=
False
elif
core_attention_bias_type
not
in
[
"no_bias"
,
"post_scale_bias"
]:
elif
core_attention_bias_type
not
in
[
"no_bias"
,
"post_scale_bias"
]:
if
FlashAttentionUtils
.
is_installed
:
logger
.
debug
(
logger
.
debug
(
"Disabling FlashAttention as it does not support context parallelism with bias"
"Disabling FlashAttention as it does not support context parallelism with bias"
" type of %s"
,
" type of %s"
,
core_attention_bias_type
,
core_attention_bias_type
,
)
)
use_flash_attention
=
False
use_flash_attention
=
False
elif
qkv_format
==
"thd"
and
core_attention_bias_type
!=
"no_bias"
:
elif
qkv_format
==
"thd"
and
core_attention_bias_type
!=
"no_bias"
:
if
FlashAttentionUtils
.
is_installed
:
logger
.
debug
(
logger
.
debug
(
"Disabling FlashAttention as it does not support context parallelism with"
"Disabling FlashAttention as it does not support context parallelism with"
" attention bias for THD format"
" attention bias for THD format"
)
)
use_flash_attention
=
False
use_flash_attention
=
False
if
context_parallel
and
use_fused_attention
:
if
context_parallel
and
use_fused_attention
:
if
"bottom_right"
in
attn_mask_type
:
if
"bottom_right"
in
attn_mask_type
:
...
@@ -568,61 +637,25 @@ def get_attention_backend(
...
@@ -568,61 +637,25 @@ def get_attention_backend(
# arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention
# arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention
# | [b, h, sq, skv] |
# | [b, h, sq, skv] |
if
attn_mask_type
==
"arbitrary"
:
if
attn_mask_type
==
"arbitrary"
:
if
use_flash_attention
and
FlashAttentionUtils
.
is_installed
:
if
(
use_flash_attention_2
and
FlashAttentionUtils
.
is_installed
)
or
(
use_flash_attention_3
and
FlashAttentionUtils
.
v3_is_installed
):
logger
.
debug
(
"Disabling FlashAttention for arbitrary mask"
)
logger
.
debug
(
"Disabling FlashAttention for arbitrary mask"
)
use_flash_attention
=
False
use_flash_attention
=
False
if
use_fused_attention
:
if
use_fused_attention
:
logger
.
debug
(
"Disabling FusedAttention for arbitrary mask"
)
logger
.
debug
(
"Disabling FusedAttention for arbitrary mask"
)
use_fused_attention
=
False
use_fused_attention
=
False
if
(
if
(
use_flash_attention
(
use_flash_attention_2
or
use_flash_attention_3
)
and
FlashAttentionUtils
.
use_v3
and
attn_mask_type
in
[
"causal"
,
"padding_causal"
]
and
attn_mask_type
in
[
"causal"
,
"padding_causal"
]
and
max_seqlen_q
!=
max_seqlen_kv
and
max_seqlen_q
!=
max_seqlen_kv
):
):
logger
.
warning
(
logger
.
warning
(
"Disabling FlashAttention
3
as it only supports bottom-right-diagonal "
"Disabling FlashAttention as it only supports bottom-right-diagonal "
"causal mask since flash-attn 2.1. See "
"causal mask since flash-attn 2.1
(our minimum supported version)
. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
)
FlashAttentionUtils
.
use_v3
=
False
use_flash_attention
=
False
if
(
use_flash_attention
and
attn_mask_type
in
[
"causal"
,
"padding_causal"
]
and
max_seqlen_q
!=
max_seqlen_kv
):
if
FlashAttentionUtils
.
v2_1_plus
:
logger
.
warning
(
"Disabling FlashAttention as it only supports bottom-right-diagonal "
"causal mask since flash-attn 2.1. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention
=
False
if
not
FlashAttentionUtils
.
is_installed
:
FlashAttentionUtils
.
max_version
=
PkgVersion
(
"2.1"
)
if
(
use_flash_attention
and
attn_mask_type
in
[
"causal_bottom_right"
,
"padding_causal_bottom_right"
]
and
max_seqlen_q
!=
max_seqlen_kv
):
if
not
FlashAttentionUtils
.
is_installed
:
FlashAttentionUtils
.
version_required
=
PkgVersion
(
"2.1"
)
elif
not
FlashAttentionUtils
.
v2_1_plus
and
not
FlashAttentionUtils
.
use_v3
:
logger
.
warning
(
"Disabling FlashAttention as it only supports top-left-diagonal "
"causal mask before flash-attn 2.1. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention
=
False
if
(
use_flash_attention
and
FlashAttentionUtils
.
use_v3
and
fp8
and
fp8_meta
[
"recipe"
].
fp8_dpa
and
"padding"
in
attn_mask_type
):
logger
.
debug
(
"Disabling FlashAttention 3 for FP8 and padding masks"
)
FlashAttentionUtils
.
use_v3
=
False
# Filter: Sliding window attention
# Filter: Sliding window attention
# backend | window_size | diagonal alignment
# backend | window_size | diagonal alignment
...
@@ -653,19 +686,14 @@ def get_attention_backend(
...
@@ -653,19 +686,14 @@ def get_attention_backend(
"with s_q > s_kv for cross-attention"
"with s_q > s_kv for cross-attention"
)
)
use_fused_attention
=
False
use_fused_attention
=
False
if
use_flash_attention
and
(
window_size
[
0
]
!=
-
1
or
window_size
[
1
]
not
in
[
-
1
,
0
]):
if
use_flash_attention_2
and
(
window_size
[
0
]
!=
-
1
or
window_size
[
1
]
not
in
[
-
1
,
0
]):
if
FlashAttentionUtils
.
use_v3
:
logger
.
debug
(
"Disabling FlashAttention 3 as it does not support sliding window attention"
)
FlashAttentionUtils
.
use_v3
=
False
if
not
FlashAttentionUtils
.
is_installed
:
if
not
FlashAttentionUtils
.
is_installed
:
FlashAttentionUtils
.
version_required
=
PkgVersion
(
"2.3"
)
FlashAttentionUtils
.
version_required
=
PkgVersion
(
"2.3"
)
elif
not
FlashAttentionUtils
.
v2_3_plus
:
elif
not
FlashAttentionUtils
.
v2_3_plus
:
logger
.
debug
(
logger
.
debug
(
"Disabling FlashAttention as sliding window attention requires flash-attn 2.3+"
"Disabling FlashAttention as sliding window attention requires flash-attn 2.3+"
)
)
use_flash_attention
=
False
use_flash_attention
_2
=
False
# Filter: Attention bias
# Filter: Attention bias
# backend | bias types | ALiBi diagonal alignment
# backend | bias types | ALiBi diagonal alignment
...
@@ -676,21 +704,25 @@ def get_attention_backend(
...
@@ -676,21 +704,25 @@ def get_attention_backend(
# | | bottom_right (converts to a 'post_scale_bias' bias)
# | | bottom_right (converts to a 'post_scale_bias' bias)
# UnfusedDotProductAttention | no_bias, pre/post_scale_bias |
# UnfusedDotProductAttention | no_bias, pre/post_scale_bias |
# | alibi/alibi_slopes | both; converts to a 'post_scale_bias' bias
# | alibi/alibi_slopes | both; converts to a 'post_scale_bias' bias
if
use_flash_attention
and
core_attention_bias_type
==
"alibi"
:
if
core_attention_bias_type
==
"alibi"
:
if
FlashAttentionUtils
.
use_v3
:
if
use_flash_attention_3
:
logger
.
debug
(
"Disabling FlashAttention 3 for ALiBi"
)
if
FlashAttentionUtils
.
v3_is_installed
:
FlashAttentionUtils
.
use_v3
=
False
logger
.
debug
(
"Disabling FlashAttention 3 for ALiBi"
)
if
not
FlashAttentionUtils
.
is_installed
:
use_flash_attention_3
=
False
FlashAttentionUtils
.
version_required
=
PkgVersion
(
"2.4"
)
if
use_flash_attention_2
:
elif
not
FlashAttentionUtils
.
v2_4_plus
:
if
not
FlashAttentionUtils
.
is_installed
:
logger
.
debug
(
"Disabling FlashAttention as ALiBi requires flash-attn 2.4+"
)
FlashAttentionUtils
.
version_required
=
PkgVersion
(
"2.4"
)
use_flash_attention
=
False
elif
not
FlashAttentionUtils
.
v2_4_plus
:
logger
.
debug
(
"Disabling FlashAttention as ALiBi requires flash-attn 2.4+"
)
use_flash_attention_2
=
False
if
use_flash_attention
and
(
if
(
core_attention_bias_type
not
in
[
"no_bias"
,
"alibi"
]
core_attention_bias_type
not
in
[
"no_bias"
,
"alibi"
]
or
core_attention_bias_shape
is
not
None
or
core_attention_bias_shape
is
not
None
):
):
if
FlashAttentionUtils
.
is_installed
:
if
(
use_flash_attention_2
and
FlashAttentionUtils
.
is_installed
)
or
(
use_flash_attention_3
and
FlashAttentionUtils
.
v3_is_installed
):
logger
.
debug
(
"Disabling FlashAttention for pre/post_scale_bias"
)
logger
.
debug
(
"Disabling FlashAttention for pre/post_scale_bias"
)
use_flash_attention
=
False
use_flash_attention
=
False
...
@@ -795,16 +827,16 @@ def get_attention_backend(
...
@@ -795,16 +827,16 @@ def get_attention_backend(
# | otherwise: no
# | otherwise: no
# sub-backend 2 | no
# sub-backend 2 | no
# UnfusedDotProductAttention | yes
# UnfusedDotProductAttention | yes
if
use_flash_attention
and
deterministic
:
if
use_flash_attention
_2
and
deterministic
:
if
not
FlashAttentionUtils
.
is_installed
:
if
not
FlashAttentionUtils
.
is_installed
:
FlashAttentionUtils
.
version_required
=
PkgVersion
(
"2.4.1"
)
FlashAttentionUtils
.
version_required
=
PkgVersion
(
"2.4.1"
)
elif
not
FlashAttentionUtils
.
v2_4_1_plus
and
not
FlashAttentionUtils
.
use_v3
:
elif
not
FlashAttentionUtils
.
v2_4_1_plus
:
logger
.
warning
(
logger
.
warning
(
"Disabling FlashAttention as version <2.4.1 does not support deterministic "
"Disabling FlashAttention as version <2.4.1 does not support deterministic "
"execution. To use FlashAttention with deterministic behavior, "
"execution. To use FlashAttention with deterministic behavior, "
"please install flash-attn >= 2.4.1."
"please install flash-attn >= 2.4.1."
)
)
use_flash_attention
=
False
use_flash_attention
_2
=
False
if
use_fused_attention
and
deterministic
:
if
use_fused_attention
and
deterministic
:
if
fused_attention_backend
==
FusedAttnBackend
[
"FP8"
]
and
is_training
:
if
fused_attention_backend
==
FusedAttnBackend
[
"FP8"
]
and
is_training
:
logger
.
debug
(
"Disabling FusedAttention for determinism reasons"
)
logger
.
debug
(
"Disabling FusedAttention for determinism reasons"
)
...
@@ -821,29 +853,58 @@ def get_attention_backend(
...
@@ -821,29 +853,58 @@ def get_attention_backend(
logger
.
debug
(
"Disabling FusedAttention for determinism reasons"
)
logger
.
debug
(
"Disabling FusedAttention for determinism reasons"
)
use_fused_attention
=
False
use_fused_attention
=
False
# All available backends
# use_flash_attention may have been set above
available_backends
=
[
use_flash_attention
,
use_fused_attention
,
use_unfused_attention
]
use_flash_attention_2
=
use_flash_attention
and
use_flash_attention_2
use_flash_attention_3
=
use_flash_attention
and
use_flash_attention_3
# `FusedAttention` and `FlashAttention` are faster backends than `UnfusedDotProductAttention`.
# `FusedAttention` and `FlashAttention` are faster backends than `UnfusedDotProductAttention`.
# When `FusedAttention` does not support the provided attention params, and `FlashAttention`
# When `FusedAttention` does not support the provided attention params, and `FlashAttention`
# does, we recommend users to install flash-attn if not installed already.
# does, we recommend users to install flash-attn if not installed already.
if
not
use_fused_attention
and
use_flash_attention
and
not
FlashAttentionUtils
.
is_installed
:
if
not
use_fused_attention
and
_NVTE_FLASH_ATTN
:
logger
.
warning
(
if
(
"flash-attn may provide important feature support or performance improvement."
use_flash_attention_3
" Please install flash-attn %s."
,
and
not
FlashAttentionUtils
.
v3_is_installed
_get_supported_versions
(
and
not
FlashAttentionUtils
.
v3_warning_printed
FlashAttentionUtils
.
version_required
,
and
torch
.
cuda
.
current_device
()
==
0
FlashAttentionUtils
.
max_version
,
):
),
logger
.
warning
(
)
"flash-attn v3 may provide important feature support or performance improvement."
if
use_flash_attention
and
not
FlashAttentionUtils
.
is_installed
:
" Please install flash-attn v3 by
\n
%s"
,
use_flash_attention
=
False
FlashAttentionUtils
.
v3_installation_steps
,
available_backends
[
0
]
=
False
)
FlashAttentionUtils
.
v3_warning_printed
=
True
elif
(
use_flash_attention_2
and
not
FlashAttentionUtils
.
is_installed
and
not
FlashAttentionUtils
.
warning_printed
and
torch
.
cuda
.
current_device
()
==
0
):
logger
.
warning
(
"flash-attn may provide important feature support or performance improvement."
" Please install flash-attn %s by pip3 install flash-attn==<version>."
,
_get_supported_versions
(
FlashAttentionUtils
.
version_required
,
FlashAttentionUtils
.
max_version
,
),
)
FlashAttentionUtils
.
warning_printed
=
True
# All available backends
if
use_flash_attention_2
and
not
FlashAttentionUtils
.
is_installed
:
use_flash_attention_2
=
False
if
use_flash_attention_3
and
not
FlashAttentionUtils
.
v3_is_installed
:
use_flash_attention_3
=
False
use_flash_attention
=
use_flash_attention_2
or
use_flash_attention_3
available_backends
=
[
use_flash_attention
,
use_fused_attention
,
use_unfused_attention
]
if
use_flash_attention_2
:
flash_attention_backend
=
FlashAttentionUtils
.
version
if
use_flash_attention_3
:
flash_attention_backend
=
FlashAttentionUtils
.
fa3_version
logger
.
debug
(
logger
.
debug
(
"Available backends = {FlashAttention=%s, FusedAttention=%s%s,"
"Available backends = {FlashAttention=%s
%s
, FusedAttention=%s%s,"
" UnfusedDotProductAttention=%s}"
,
" UnfusedDotProductAttention=%s}"
,
bool
(
available_backends
[
0
]),
bool
(
available_backends
[
0
]),
(
f
" (
{
str
(
flash_attention_backend
)
}
)"
if
flash_attention_backend
is
not
None
else
""
),
bool
(
available_backends
[
1
]),
bool
(
available_backends
[
1
]),
(
(
f
" (sub-backend
{
int
(
fused_attention_backend
)
}
)"
f
" (sub-backend
{
int
(
fused_attention_backend
)
}
)"
...
@@ -854,27 +915,10 @@ def get_attention_backend(
...
@@ -854,27 +915,10 @@ def get_attention_backend(
)
)
# Select FusedAttention for performance
# Select FusedAttention for performance
if
(
if
use_flash_attention
and
use_fused_attention
and
(
not
IS_HIP_EXTENSION
)
and
device_compute_capability
>=
(
9
,
0
):
use_flash_attention
and
(
not
IS_HIP_EXTENSION
)
and
use_fused_attention
and
fused_attention_backend
==
FusedAttnBackend
[
"F16_arbitrary_seqlen"
]
):
if
device_compute_capability
>=
(
9
,
0
):
logger
.
debug
(
"Disabling FlashAttention to give FusedAttention preference on Hopper+ "
"for performance reasons"
)
use_flash_attention
=
False
if
(
use_flash_attention
and
use_fused_attention
and
not
IS_HIP_EXTENSION
and
fused_attention_backend
==
FusedAttnBackend
[
"FP8"
]
and
FlashAttentionUtils
.
use_v3
):
logger
.
debug
(
logger
.
debug
(
"Disabling FlashAttention
3
to give FusedAttention preference
for performance reasons
"
"Disabling FlashAttention to give FusedAttention preference
on Hopper+
"
"
in FP8 executi
on"
"
for performance reas
on
s
"
)
)
use_flash_attention
=
False
use_flash_attention
=
False
...
@@ -886,22 +930,16 @@ def get_attention_backend(
...
@@ -886,22 +930,16 @@ def get_attention_backend(
use_unfused_attention
=
False
use_unfused_attention
=
False
selected_backend
=
"NoBackend"
selected_backend
=
"NoBackend"
if
use_flash_attention
:
if
use_flash_attention
:
selected_backend
=
"FlashAttention"
selected_backend
=
f
"FlashAttention
(
{
str
(
flash_attention_backend
)
}
)
"
elif
use_fused_attention
:
elif
use_fused_attention
:
selected_backend
=
f
"FusedAttention (sub-backend
{
int
(
fused_attention_backend
)
}
)"
selected_backend
=
f
"FusedAttention (sub-backend
{
int
(
fused_attention_backend
)
}
)"
elif
use_unfused_attention
:
elif
use_unfused_attention
:
selected_backend
=
"UnfusedDotProductAttention"
selected_backend
=
"UnfusedDotProductAttention"
logger
.
debug
(
"Selected backend = %s"
,
selected_backend
)
logger
.
debug
(
"Selected backend = %s"
,
selected_backend
)
"""global _attention_backends
_attention_backends["use_flash_attention"] = use_flash_attention
_attention_backends["use_fused_attention"] = use_fused_attention
_attention_backends["fused_attention_backend"] = fused_attention_backend
_attention_backends["use_unfused_attention"] = use_unfused_attention
_attention_backends["backend_selection_requires_update"] = False"""
return
(
return
(
use_flash_attention
,
use_flash_attention
,
flash_attention_backend
,
use_fused_attention
,
use_fused_attention
,
fused_attention_backend
,
fused_attention_backend
,
use_unfused_attention
,
use_unfused_attention
,
...
@@ -909,6 +947,49 @@ def get_attention_backend(
...
@@ -909,6 +947,49 @@ def get_attention_backend(
)
)
@
torch
.
no_grad
()
def
get_padding_mask
(
batch_size
:
int
,
cu_seqlens_q
:
torch
.
Tensor
,
cu_seqlens_kv
:
torch
.
Tensor
,
max_seqlen_q
:
int
,
max_seqlen_kv
:
int
,
):
"""Convert cu_seqlens to attention_mask"""
seqlens_q
=
cu_seqlens_q
[
1
:]
-
cu_seqlens_q
[:
-
1
]
seqlens_kv
=
cu_seqlens_kv
[
1
:]
-
cu_seqlens_kv
[:
-
1
]
attention_mask_q
=
torch
.
Tensor
([]).
to
(
dtype
=
torch
.
bool
)
attention_mask_kv
=
torch
.
Tensor
([]).
to
(
dtype
=
torch
.
bool
)
for
i
in
range
(
batch_size
):
attention_mask_q
=
torch
.
cat
(
[
attention_mask_q
,
torch
.
Tensor
([
False
]
*
seqlens_q
[
i
]
+
[
True
]
*
(
max_seqlen_q
-
seqlens_q
[
i
]))
.
to
(
dtype
=
torch
.
bool
)
.
unsqueeze
(
0
)
.
unsqueeze
(
0
)
.
unsqueeze
(
0
),
],
dim
=
0
,
)
attention_mask_kv
=
torch
.
cat
(
[
attention_mask_kv
,
torch
.
Tensor
([
False
]
*
seqlens_kv
[
i
]
+
[
True
]
*
(
max_seqlen_kv
-
seqlens_kv
[
i
]))
.
to
(
dtype
=
torch
.
bool
)
.
unsqueeze
(
0
)
.
unsqueeze
(
0
)
.
unsqueeze
(
0
),
],
dim
=
0
,
)
attention_mask
=
(
attention_mask_q
.
to
(
device
=
"cuda"
),
attention_mask_kv
.
to
(
device
=
"cuda"
),
)
return
attention_mask
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
get_full_mask
(
def
get_full_mask
(
max_seqlen_q
:
int
,
max_seqlen_q
:
int
,
...
@@ -1417,11 +1498,46 @@ class UnpackTensor(torch.autograd.Function):
...
@@ -1417,11 +1498,46 @@ class UnpackTensor(torch.autograd.Function):
return
None
,
None
,
_pack_tensor
(
indices
,
grad_output
)
return
None
,
None
,
_pack_tensor
(
indices
,
grad_output
)
def
get_qkv_format
(
qkv_layout
:
str
=
"bshd_bshd_bshd"
,
inference_params
:
InferenceParams
=
None
,
)
->
str
:
"""Get qkv format.
Parameters
----------
qkv_layout: str
Memory layout of `q`, `k` and `v`. See get_qkv_layout() for more details.
inference_params: InferenceParams, default = `None`
InferenceParams related to KV caching.
Returns
----------
qkv_format: str, default = `sbhd`
Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}.
q_format: str
Format of the `q` tensor, {`bshd`, `sbhd`, `thd`}.
kv_format: str
Format of the `k` and `v` tensors, {`bshd`, `sbhd`, `thd`}.
"""
splited
=
qkv_layout
.
replace
(
"paged_kv_"
,
""
).
split
(
"_"
)
if
inference_params
is
not
None
:
q_format
=
""
.
join
([
i
for
i
in
splited
[
0
]
if
i
.
isalpha
()])
kv_format
=
""
.
join
([
i
for
i
in
splited
[
1
]
if
i
.
isalpha
()])
qkv_format
=
q_format
+
"_2"
+
kv_format
if
q_format
!=
kv_format
else
q_format
else
:
qkv_format
=
""
.
join
([
i
for
i
in
splited
[
0
]
if
i
.
isalpha
()])
q_format
=
qkv_format
kv_format
=
qkv_format
return
qkv_format
,
q_format
,
kv_format
def
get_qkv_layout
(
def
get_qkv_layout
(
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
qkv_format
:
str
=
"sbhd"
,
qkv_format
:
str
=
"sbhd"
,
inference_params
:
InferenceParams
=
None
,
)
->
str
:
)
->
str
:
"""Get qkv layout.
"""Get qkv layout.
...
@@ -1438,20 +1554,33 @@ def get_qkv_layout(
...
@@ -1438,20 +1554,33 @@ def get_qkv_layout(
the sequence length dimension, `b` batch size, `h` the number of attention heads,
the sequence length dimension, `b` batch size, `h` the number of attention heads,
`d` head size, and `t` the total number of tokens in a batch, i.e.
`d` head size, and `t` the total number of tokens in a batch, i.e.
`t = sum(s_i) for i = 0...b-1`.
`t = sum(s_i) for i = 0...b-1`.
inference_params: InferenceParams, default = `None`
InferenceParams related to KV caching.
Returns
Returns
----------
----------
qkv_layout: str
qkv_layout: str
Memory layout of `q`, `k` and `v`. Each `qkv_format` can be mapped to one of five
Memory layout of `q`, `k` and `v`. Each `qkv_layout` maps to a pair of `q_format` and
memory layouts. For example, `sb3hd` means `q`, `k`, `v` are created as one chunk
`kv_format` in {`bshd`, `sbhd`, `thd`}. The `paged_kv_` prefix is used to indicate that
of memory and that they are interleaved in the `2`nd dimension. `sbhd_sbh2d` means
paged KV caching is in play. A few examples of the layouts are as follows.
`q` and `kv` are created in two chunks and that `q` itself is contiguous and `k`, `v`
are interleaved with each other in the `3`rd dimension, `k = kv[:,:,:,0,:]` and
(1) `sb3hd` means `q`, `k`, `v` are created as one chunk of memory and that they are
`v = kv[:,:,:,1,:]`.
interleaved in the `2`nd dimension. (2) `sbhd_sbh2d` means `q` and `kv` are created in
two chunks and that `q` itself is contiguous and `k`, `v` are interleaved with each other
in the `3`rd dimension, `k = kv[:,:,:,0,:]` and `v = kv[:,:,:,1,:]`. `q_format` and
`kv_format` in this case are still both `sbhd`. (3) `paged_kv_thd_bshd_bshd` means `q` is
created in `thd` and `k`, `v` are in `sbhd`. This is likely due to the cache format in
paged KV caching.
Mapping:
Mapping:
`sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`}
`sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`
, `paged_kv_sbhd_sbhd_sbhd`
}
`bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`}
`bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`
, `paged_kv_bshd_bshd_bshd`
}
`thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`}
`thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`}
`sbhd_2bshd`: {`sbhd_bshd_bshd`, `paged_kv_sbhd_bshd_bshd`}
`bshd_2sbhd`: {`bshd_sbhd_sbhd`, `paged_kv_bshd_sbhd_sbhd`}
`thd_2bshd`: {`thd_bshd_bshd`, `paged_kv_thd_bshd_bshd`}
`thd_2sbhd`: {`thd_sbhd_sbhd`, `paged_kv_thd_sbhd_sbhd`}
q: torch.Tensor
q: torch.Tensor
Query tensor. It may be different from input `q` as we try to fit tensors to
Query tensor. It may be different from input `q` as we try to fit tensors to
a supported layout.
a supported layout.
...
@@ -1461,10 +1590,21 @@ def get_qkv_layout(
...
@@ -1461,10 +1590,21 @@ def get_qkv_layout(
v: torch.Tensor
v: torch.Tensor
Value tensor. It may be different from input `v` as we try to fit tensors to
Value tensor. It may be different from input `v` as we try to fit tensors to
a supported layout.
a supported layout.
q_format: str
Format of the query tensor, {`bshd`, `sbhd`, `thd`}.
kv_format: str
Format of the key and value tensors, {`bshd`, `sbhd`, `thd`}.
"""
"""
check_last_dim_contiguous
=
all
(
x
.
stride
(
-
1
)
==
1
for
x
in
[
q
,
k
,
v
])
check_last_dim_contiguous
=
all
(
x
.
stride
(
-
1
)
==
1
for
x
in
[
q
,
k
,
v
])
assert
check_last_dim_contiguous
,
"q, k and v must have stride 1 in their last dimension!"
assert
check_last_dim_contiguous
,
"q, k and v must have stride 1 in their last dimension!"
if
"_2"
in
qkv_format
:
q_format
,
kv_format
=
qkv_format
.
split
(
"_2"
)
is_same_q_kv_format
=
False
else
:
q_format
=
qkv_format
kv_format
=
qkv_format
is_same_q_kv_format
=
True
def
run_iteratively
(
q
,
k
,
v
):
def
run_iteratively
(
q
,
k
,
v
):
# check data pointers
# check data pointers
...
@@ -1551,7 +1691,10 @@ def get_qkv_layout(
...
@@ -1551,7 +1691,10 @@ def get_qkv_layout(
# three chunks of memory, q, k and v, which may be disjoint or consecutive, and
# three chunks of memory, q, k and v, which may be disjoint or consecutive, and
# when consecutive, they may have the same data pointer, i.e. check_ptrs_qkv=True or
# when consecutive, they may have the same data pointer, i.e. check_ptrs_qkv=True or
# check_ptrs_qk=True or check_ptrs_kv=True
# check_ptrs_qk=True or check_ptrs_kv=True
qkv_layout
=
"_"
.
join
(
list
([
qkv_format
])
*
3
)
if
is_same_q_kv_format
:
qkv_layout
=
"_"
.
join
(
list
([
qkv_format
])
*
3
)
else
:
qkv_layout
=
q_format
+
"_"
+
kv_format
+
"_"
+
kv_format
else
:
else
:
qkv_layout
=
"not_supported"
qkv_layout
=
"not_supported"
...
@@ -1565,7 +1708,10 @@ def get_qkv_layout(
...
@@ -1565,7 +1708,10 @@ def get_qkv_layout(
if
qkv_layout
==
"not_supported"
:
if
qkv_layout
==
"not_supported"
:
raise
RuntimeError
(
"The provided qkv memory layout is not supported!"
)
raise
RuntimeError
(
"The provided qkv memory layout is not supported!"
)
return
qkv_layout
,
q
,
k
,
v
if
inference_params
is
not
None
and
inference_params
.
is_paged
:
qkv_layout
=
"paged_kv_"
+
qkv_layout
return
qkv_layout
,
q
,
k
,
v
,
q_format
,
kv_format
def
check_set_window_size
(
def
check_set_window_size
(
...
...
transformer_engine/pytorch/graph.py
View file @
4099aa8e
...
@@ -91,6 +91,14 @@ def _make_graphed_callables(
...
@@ -91,6 +91,14 @@ def _make_graphed_callables(
sample_args
=
(
sample_args
,)
sample_args
=
(
sample_args
,)
sample_kwargs
=
(
sample_kwargs
,)
sample_kwargs
=
(
sample_kwargs
,)
# Check training/inference
is_training
=
all
(
c
.
training
for
c
in
callables
)
if
not
is_training
and
any
(
c
.
training
for
c
in
callables
):
assert
False
,
(
"make_graphed_callables only supports when modules are all in training or all in"
" inference mode."
)
# Check sizes of args
# Check sizes of args
if
_order
is
None
:
if
_order
is
None
:
assert
len
(
sample_args
)
==
len
(
callables
)
assert
len
(
sample_args
)
==
len
(
callables
)
...
@@ -255,13 +263,16 @@ def _make_graphed_callables(
...
@@ -255,13 +263,16 @@ def _make_graphed_callables(
outputs
,
_
=
_tree_flatten
(
func
(
*
args
,
**
kwargs
))
outputs
,
_
=
_tree_flatten
(
func
(
*
args
,
**
kwargs
))
for
hook
in
hooks
:
for
hook
in
hooks
:
hook
.
remove
()
hook
.
remove
()
grad_inputs
=
torch
.
autograd
.
grad
(
if
is_training
:
outputs
=
tuple
(
o
for
o
in
outputs
if
o
.
requires_grad
),
grad_inputs
=
torch
.
autograd
.
grad
(
inputs
=
tuple
(
i
for
i
in
static_input_surface
if
i
.
requires_grad
),
outputs
=
tuple
(
o
for
o
in
outputs
if
o
.
requires_grad
),
grad_outputs
=
tuple
(
torch
.
empty_like
(
o
)
for
o
in
outputs
if
o
.
requires_grad
),
inputs
=
tuple
(
i
for
i
in
static_input_surface
if
i
.
requires_grad
),
only_inputs
=
True
,
grad_outputs
=
tuple
(
torch
.
empty_like
(
o
)
for
o
in
outputs
if
o
.
requires_grad
),
allow_unused
=
allow_unused_input
,
only_inputs
=
True
,
)
allow_unused
=
allow_unused_input
,
)
else
:
grad_inputs
=
None
del
outputs
,
grad_inputs
del
outputs
,
grad_inputs
# The following code is added specifically for MCore's special requirements,
# The following code is added specifically for MCore's special requirements,
# aimed at preventing warmup from altering the control flow.
# aimed at preventing warmup from altering the control flow.
...
@@ -314,22 +325,23 @@ def _make_graphed_callables(
...
@@ -314,22 +325,23 @@ def _make_graphed_callables(
static_grad_outputs
=
tuple
(
static_grad_outputs
=
tuple
(
torch
.
empty_like
(
o
)
if
o
.
requires_grad
else
None
for
o
in
static_outputs
torch
.
empty_like
(
o
)
if
o
.
requires_grad
else
None
for
o
in
static_outputs
)
)
with
torch
.
cuda
.
graph
(
bwd_graph
,
pool
=
mempool
):
if
is_training
:
grad_inputs
=
torch
.
autograd
.
grad
(
with
torch
.
cuda
.
graph
(
bwd_graph
,
pool
=
mempool
):
outputs
=
tuple
(
o
for
o
in
static_outputs
if
o
.
requires_grad
),
grad_inputs
=
torch
.
autograd
.
grad
(
inputs
=
tuple
(
i
for
i
in
static_input_surface
if
i
.
requires_grad
),
outputs
=
tuple
(
o
for
o
in
static_outputs
if
o
.
requires_grad
),
grad_outputs
=
tuple
(
o
for
o
in
static_grad_outputs
if
o
is
not
None
),
inputs
=
tuple
(
i
for
i
in
static_input_surface
if
i
.
requires_grad
),
only_inputs
=
True
,
grad_outputs
=
tuple
(
o
for
o
in
static_grad_outputs
if
o
is
not
None
),
allow_unused
=
allow_unused_input
,
only_inputs
=
True
,
retain_graph
=
retain_graph_in_backward
,
allow_unused
=
allow_unused_input
,
)
retain_graph
=
retain_graph_in_backward
,
)
# Constructs a tuple suitable for returning from Graphed.backward:
# Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs
# Pads out the actually-needed grads with Nones in gradient slots for inputs
# that don't require grad. I couldn't think of a one-liner for this pattern.
# that don't require grad. I couldn't think of a one-liner for this pattern.
static_grad_inputs
=
[]
static_grad_inputs
=
[]
grad_idx
=
0
grad_idx
=
0
for
arg
in
static_input_surface
:
for
arg
in
static_input_surface
:
if
arg
.
requires_grad
:
if
is_training
and
isinstance
(
arg
,
torch
.
Tensor
)
and
arg
.
requires_grad
:
static_grad_inputs
.
append
(
grad_inputs
[
grad_idx
])
static_grad_inputs
.
append
(
grad_inputs
[
grad_idx
])
grad_idx
+=
1
grad_idx
+=
1
else
:
else
:
...
@@ -366,22 +378,23 @@ def _make_graphed_callables(
...
@@ -366,22 +378,23 @@ def _make_graphed_callables(
static_grad_outputs
=
tuple
(
static_grad_outputs
=
tuple
(
torch
.
empty_like
(
o
)
if
o
.
requires_grad
else
None
for
o
in
static_outputs
torch
.
empty_like
(
o
)
if
o
.
requires_grad
else
None
for
o
in
static_outputs
)
)
with
torch
.
cuda
.
graph
(
bwd_graph
,
pool
=
mempool
):
if
is_training
:
grad_inputs
=
torch
.
autograd
.
grad
(
with
torch
.
cuda
.
graph
(
bwd_graph
,
pool
=
mempool
):
outputs
=
tuple
(
o
for
o
in
static_outputs
if
o
.
requires_grad
),
grad_inputs
=
torch
.
autograd
.
grad
(
inputs
=
tuple
(
i
for
i
in
static_input_surface
if
i
.
requires_grad
),
outputs
=
tuple
(
o
for
o
in
static_outputs
if
o
.
requires_grad
),
grad_outputs
=
tuple
(
o
for
o
in
static_grad_outputs
if
o
is
not
None
),
inputs
=
tuple
(
i
for
i
in
static_input_surface
if
i
.
requires_grad
),
only_inputs
=
True
,
grad_outputs
=
tuple
(
o
for
o
in
static_grad_outputs
if
o
is
not
None
),
allow_unused
=
allow_unused_input
,
only_inputs
=
True
,
retain_graph
=
retain_graph_in_backward
,
allow_unused
=
allow_unused_input
,
)
retain_graph
=
retain_graph_in_backward
,
)
# Constructs a tuple suitable for returning from Graphed.backward:
# Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs that
# Pads out the actually-needed grads with Nones in gradient slots for inputs that
# don't require grad. I couldn't think of a slick one-liner for this pattern.
# don't require grad. I couldn't think of a slick one-liner for this pattern.
static_grad_inputs
=
[]
static_grad_inputs
=
[]
grad_idx
=
0
grad_idx
=
0
for
arg
in
static_input_surface
:
for
arg
in
static_input_surface
:
if
arg
.
requires_grad
:
if
is_training
and
isinstance
(
arg
,
torch
.
Tensor
)
and
arg
.
requires_grad
:
static_grad_inputs
.
append
(
grad_inputs
[
grad_idx
])
static_grad_inputs
.
append
(
grad_inputs
[
grad_idx
])
grad_idx
+=
1
grad_idx
+=
1
else
:
else
:
...
@@ -422,7 +435,10 @@ def _make_graphed_callables(
...
@@ -422,7 +435,10 @@ def _make_graphed_callables(
# Copy values from new tensors into static tensors
# Copy values from new tensors into static tensors
for
i
in
range
(
len_user_args
):
for
i
in
range
(
len_user_args
):
if
static_input_surface
[
i
].
data_ptr
()
!=
inputs
[
i
].
data_ptr
():
if
(
isinstance
(
static_input_surface
[
i
],
torch
.
Tensor
)
and
static_input_surface
[
i
].
data_ptr
()
!=
inputs
[
i
].
data_ptr
()
):
static_input_surface
[
i
].
copy_
(
inputs
[
i
])
static_input_surface
[
i
].
copy_
(
inputs
[
i
])
# Replay forward graph
# Replay forward graph
...
...
transformer_engine/pytorch/module/layernorm_linear.py
View file @
4099aa8e
...
@@ -79,7 +79,6 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -79,7 +79,6 @@ class _LayerNormLinear(torch.autograd.Function):
ln_bias
:
Union
[
torch
.
Tensor
,
None
],
ln_bias
:
Union
[
torch
.
Tensor
,
None
],
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
use_bias
:
bool
,
eps
:
float
,
eps
:
float
,
is_first_microbatch
:
Union
[
bool
,
None
],
is_first_microbatch
:
Union
[
bool
,
None
],
fp8
:
bool
,
fp8
:
bool
,
...
@@ -383,6 +382,17 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -383,6 +382,17 @@ class _LayerNormLinear(torch.autograd.Function):
)
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.fsdp_scatter"
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.fsdp_scatter"
)
if
cpu_offloading
:
ctx
.
grad_added_to_main_grad
=
hasattr
(
weight
,
"grad_added_to_main_grad"
)
if
ctx
.
grad_added_to_main_grad
:
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user
# sets for the weights. Because of this, it is not recommended to offload
# weights if weights are externally touched outside this module
ctx
.
weight_object
=
weight
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
inputmat
,
inputmat
,
weightmat
,
weightmat
,
...
@@ -411,7 +421,7 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -411,7 +421,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx
.
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
ctx
.
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
ctx
.
cpu_offloading
=
cpu_offloading
ctx
.
cpu_offloading
=
cpu_offloading
ctx
.
is_first_microbatch
=
is_first_microbatch
ctx
.
is_first_microbatch
=
is_first_microbatch
ctx
.
use_bias
=
use_
bias
ctx
.
use_bias
=
bias
is
not
None
ctx
.
sequence_parallel
=
sequence_parallel
ctx
.
sequence_parallel
=
sequence_parallel
ctx
.
tensor_parallel
=
tensor_parallel
ctx
.
tensor_parallel
=
tensor_parallel
ctx
.
inp_shape
=
inp_shape
ctx
.
inp_shape
=
inp_shape
...
@@ -526,8 +536,11 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -526,8 +536,11 @@ class _LayerNormLinear(torch.autograd.Function):
# For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# we need to connect them into one.
# we need to connect them into one.
if
ctx
.
cpu_offloading
and
ctx
.
fuse_wgrad_accumulation
:
if
ctx
.
cpu_offloading
:
weight
.
main_grad
=
main_grad
if
ctx
.
grad_added_to_main_grad
:
origin_weight
=
ctx
.
weight_object
if
ctx
.
requires_wgrad
and
ctx
.
fuse_wgrad_accumulation
:
origin_weight
.
main_grad
=
main_grad
ctx
.
ub_obj_gradout
=
None
ctx
.
ub_obj_gradout
=
None
ub_obj_dgrad
=
None
ub_obj_dgrad
=
None
...
@@ -742,10 +755,6 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -742,10 +755,6 @@ class _LayerNormLinear(torch.autograd.Function):
# TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme
# TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme
clear_tensor_data
(
ln_out_total
)
clear_tensor_data
(
ln_out_total
)
# Don't return grad bias if not needed
if
not
ctx
.
use_bias
:
grad_bias
=
None
# Synchronize tensor parallel communication
# Synchronize tensor parallel communication
if
ln_out_total_work
is
not
None
:
if
ln_out_total_work
is
not
None
:
ln_out_total_work
.
wait
()
ln_out_total_work
.
wait
()
...
@@ -827,7 +836,6 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -827,7 +836,6 @@ class _LayerNormLinear(torch.autograd.Function):
dbeta
,
dbeta
,
wgrad
,
wgrad
,
grad_bias
,
grad_bias
,
None
,
# use_bias
None
,
# eps
None
,
# eps
None
,
# is_first_microbatch
None
,
# is_first_microbatch
None
,
# fp8
None
,
# fp8
...
@@ -1330,8 +1338,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1330,8 +1338,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self
.
layer_norm_weight
,
self
.
layer_norm_weight
,
self
.
layer_norm_bias
,
self
.
layer_norm_bias
,
weight_tensor
,
weight_tensor
,
bias_tensor
,
bias_tensor
if
self
.
apply_bias
and
not
self
.
gemm_bias_unfused_add
else
None
,
self
.
apply_bias
and
not
self
.
gemm_bias_unfused_add
,
self
.
eps
,
self
.
eps
,
is_first_microbatch
,
is_first_microbatch
,
self
.
fp8
,
self
.
fp8
,
...
...
transformer_engine/pytorch/module/layernorm_mlp.py
View file @
4099aa8e
...
@@ -140,10 +140,8 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -140,10 +140,8 @@ class _LayerNormMLP(torch.autograd.Function):
ln_bias
:
torch
.
Tensor
,
ln_bias
:
torch
.
Tensor
,
fc1_weight
:
torch
.
Tensor
,
fc1_weight
:
torch
.
Tensor
,
fc1_bias
:
torch
.
Tensor
,
fc1_bias
:
torch
.
Tensor
,
use_fc1_bias
:
bool
,
fc2_weight
:
torch
.
Tensor
,
fc2_weight
:
torch
.
Tensor
,
fc2_bias
:
torch
.
Tensor
,
fc2_bias
:
torch
.
Tensor
,
use_fc2_bias
:
bool
,
eps
:
float
,
eps
:
float
,
is_first_microbatch
:
Union
[
bool
,
None
],
is_first_microbatch
:
Union
[
bool
,
None
],
fp8
:
bool
,
fp8
:
bool
,
...
@@ -368,7 +366,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -368,7 +366,7 @@ class _LayerNormMLP(torch.autograd.Function):
# FC1 GEMM
# FC1 GEMM
# There are 2 fus
s
ions possible:
# There are 2 fusions possible:
# - gemm_gelu_fusion - default for full precision, optional for fp8 - need to turn on gemm_gelu_fusion,
# - gemm_gelu_fusion - default for full precision, optional for fp8 - need to turn on gemm_gelu_fusion,
# - bias_gelu_fusion - only for full precision.
# - bias_gelu_fusion - only for full precision.
# If both gemm_gelu_fusion and bias_gelu_fusion are enabled, only bias_gelu_fusion will be performer
# If both gemm_gelu_fusion and bias_gelu_fusion are enabled, only bias_gelu_fusion will be performer
...
@@ -453,8 +451,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -453,8 +451,7 @@ class _LayerNormMLP(torch.autograd.Function):
)
)
if
not
is_grad_enabled
:
if
not
is_grad_enabled
:
clear_tensor_data
(
act_out
,
fc1_out_without_bias
,
fc1_out
)
clear_tensor_data
(
act_out
,
fc1_out_without_bias
,
fc1_out
)
else
:
if
is_grad_enabled
:
if
cpu_offloading
:
if
cpu_offloading
:
if
fp8
and
fc1_weight_final
is
not
None
:
if
fp8
and
fc1_weight_final
is
not
None
:
set_offloading_param
(
fc1_weight_final
,
"weight_offloading"
,
True
)
set_offloading_param
(
fc1_weight_final
,
"weight_offloading"
,
True
)
...
@@ -537,9 +534,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -537,9 +534,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx
.
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
ctx
.
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
ctx
.
cpu_offloading
=
cpu_offloading
ctx
.
cpu_offloading
=
cpu_offloading
ctx
.
is_first_microbatch
=
is_first_microbatch
ctx
.
is_first_microbatch
=
is_first_microbatch
ctx
.
use_fc1_bias
=
use_fc1_bias
ctx
.
use_bias
=
fc2_bias
is
not
None
ctx
.
use_fc2_bias
=
use_fc2_bias
ctx
.
use_bias
=
ctx
.
use_fc1_bias
ctx
.
sequence_parallel
=
sequence_parallel
ctx
.
sequence_parallel
=
sequence_parallel
ctx
.
tensor_parallel
=
tensor_parallel
ctx
.
tensor_parallel
=
tensor_parallel
ctx
.
inp_shape
=
inp_shape
ctx
.
inp_shape
=
inp_shape
...
@@ -774,14 +769,13 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -774,14 +769,13 @@ class _LayerNormMLP(torch.autograd.Function):
quantization_params
=
None
,
# wgrad in high precision
quantization_params
=
None
,
# wgrad in high precision
layout
=
"NT"
,
layout
=
"NT"
,
grad
=
True
,
grad
=
True
,
bias
=
fc2_bias
if
fc2_bias
is
not
None
and
fc2_bias_grad
is
None
else
None
,
bias
=
fc2_bias
if
fc2_bias_grad
is
None
else
None
,
accumulate
=
accumulate_wgrad_into_param_main_grad
,
accumulate
=
accumulate_wgrad_into_param_main_grad
,
use_split_accumulator
=
_2X_ACC_WGRAD
,
use_split_accumulator
=
_2X_ACC_WGRAD
,
out
=
fc2_weight
.
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
out
=
fc2_weight
.
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
)
)
if
fc2_bias_grad
is
None
:
if
fc2_bias_grad
is
None
:
fc2_bias_grad
=
fc2_bias_grad_
fc2_bias_grad
=
fc2_bias_grad_
del
fc2_bias_grad_
clear_tensor_data
(
act_out
)
clear_tensor_data
(
act_out
)
# bias computation
# bias computation
...
@@ -1046,11 +1040,9 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -1046,11 +1040,9 @@ class _LayerNormMLP(torch.autograd.Function):
dgamma
,
dgamma
,
dbeta
,
dbeta
,
fc1_wgrad
,
fc1_wgrad
,
fc1_bias_grad
if
ctx
.
use_fc1_bias
else
None
,
fc1_bias_grad
if
fc1_bias
is
not
None
else
None
,
None
,
# use_fc1_bias
fc2_wgrad
,
# pylint: disable=possibly-used-before-assignment
fc2_wgrad
,
# pylint: disable=possibly-used-before-assignment
fc2_bias_grad
if
ctx
.
use_fc2_bias
else
None
,
fc2_bias_grad
,
None
,
# use_fc2_bias
None
,
# eps
None
,
# eps
None
,
# is_first_microbatch
None
,
# is_first_microbatch
None
,
# fp8
None
,
# fp8
...
@@ -1471,10 +1463,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -1471,10 +1463,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
self
.
layer_norm_bias
,
self
.
layer_norm_bias
,
fc1_weight
,
fc1_weight
,
fc1_bias
,
fc1_bias
,
self
.
use_bias
,
fc2_weight
,
fc2_weight
,
fc2_bias
,
fc2_bias
if
self
.
apply_bias
and
not
self
.
gemm_bias_unfused_add
else
None
,
self
.
apply_bias
and
not
self
.
gemm_bias_unfused_add
,
self
.
eps
,
self
.
eps
,
is_first_microbatch
,
is_first_microbatch
,
self
.
fp8
,
self
.
fp8
,
...
...
transformer_engine/pytorch/module/linear.py
View file @
4099aa8e
...
@@ -291,6 +291,17 @@ class _Linear(torch.autograd.Function):
...
@@ -291,6 +291,17 @@ class _Linear(torch.autograd.Function):
)
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.fsdp_scatter"
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.fsdp_scatter"
)
if
cpu_offloading
:
ctx
.
grad_added_to_main_grad
=
hasattr
(
weight
,
"grad_added_to_main_grad"
)
if
ctx
.
grad_added_to_main_grad
:
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user
# sets for the weights. Because of this, it is not recommended to offload
# weights if weights are externally touched outside this module
ctx
.
weight_object
=
weight
# TODO(ksivamani): Check memory usage
# TODO(ksivamani): Check memory usage
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
saved_inputmat
,
saved_inputmat
,
...
@@ -392,9 +403,11 @@ class _Linear(torch.autograd.Function):
...
@@ -392,9 +403,11 @@ class _Linear(torch.autograd.Function):
else
None
else
None
)
)
if
ctx
.
cpu_offloading
and
ctx
.
fuse_wgrad_accumulation
:
if
ctx
.
cpu_offloading
:
weight
=
torch
.
nn
.
Parameter
(
weight
,
weight
.
requires_grad
)
if
ctx
.
grad_added_to_main_grad
:
weight
.
main_grad
=
main_grad
weight
=
ctx
.
weight_object
if
ctx
.
requires_wgrad
and
ctx
.
fuse_wgrad_accumulation
:
weight
.
main_grad
=
main_grad
# Gather intermediate/activation tensors if needed
# Gather intermediate/activation tensors if needed
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment