Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
711aa9d5
Commit
711aa9d5
authored
Jul 30, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.10.0' into v0.10.0-dev
parents
751c492c
6d8d0a24
Changes
519
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
720 additions
and
1316 deletions
+720
-1316
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+8
-9
vllm/attention/backends/triton_mla.py
vllm/attention/backends/triton_mla.py
+4
-9
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+3
-6
vllm/attention/layer.py
vllm/attention/layer.py
+45
-9
vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py
...ops/blocksparse_attention/blocksparse_attention_kernel.py
+0
-433
vllm/attention/ops/blocksparse_attention/interface.py
vllm/attention/ops/blocksparse_attention/interface.py
+0
-239
vllm/attention/ops/blocksparse_attention/utils.py
vllm/attention/ops/blocksparse_attention/utils.py
+0
-246
vllm/attention/ops/hpu_paged_attn.py
vllm/attention/ops/hpu_paged_attn.py
+0
-88
vllm/attention/ops/ipex_attn.py
vllm/attention/ops/ipex_attn.py
+0
-195
vllm/attention/ops/rocm_aiter_mla.py
vllm/attention/ops/rocm_aiter_mla.py
+6
-2
vllm/attention/ops/triton_unified_attention.py
vllm/attention/ops/triton_unified_attention.py
+14
-3
vllm/attention/selector.py
vllm/attention/selector.py
+40
-18
vllm/attention/utils/__init__.py
vllm/attention/utils/__init__.py
+0
-0
vllm/attention/utils/kv_sharing_utils.py
vllm/attention/utils/kv_sharing_utils.py
+33
-0
vllm/benchmarks/datasets.py
vllm/benchmarks/datasets.py
+91
-1
vllm/benchmarks/serve.py
vllm/benchmarks/serve.py
+40
-17
vllm/collect_env.py
vllm/collect_env.py
+25
-20
vllm/compilation/backends.py
vllm/compilation/backends.py
+38
-16
vllm/compilation/collective_fusion.py
vllm/compilation/collective_fusion.py
+364
-4
vllm/compilation/compiler_interface.py
vllm/compilation/compiler_interface.py
+9
-1
No files found.
Too many changes to show.
To preserve performance only
519 of 519+
files are displayed.
Plain diff
Email patch
vllm/attention/backends/rocm_flash_attn.py
View file @
711aa9d5
...
...
@@ -4,7 +4,7 @@
import
itertools
from
dataclasses
import
dataclass
from
functools
import
cache
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Type
import
torch
import
triton
...
...
@@ -21,7 +21,9 @@ from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata
)
from
vllm.config
import
get_current_vllm_config
from
vllm.logger
import
init_logger
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
)
from
vllm.platforms
import
current_platform
from
vllm.platforms.rocm
import
use_rocm_custom_paged_attention
from
vllm.utils
import
SUPPORT_TC
,
gpuname
...
...
@@ -502,21 +504,18 @@ class ROCmFlashAttentionImpl(AttentionImpl):
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
use_irope
:
bool
=
False
,
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0."
)
raise
NotImplementedError
(
"KV sharing is not supported in V0 "
"ROCM_FLASH backend."
)
if
use_irope
:
logger
.
warning_once
(
"Using irope in ROCm Flash Attention is not supported yet, it "
"will fail back to global attention for long context."
)
if
blocksparse_params
is
not
None
:
raise
ValueError
(
"ROCmFlashAttention does not support blocksparse attention."
)
if
use_irope
:
logger
.
warning
(
"Using irope in V0 is not supported yet, it will fall back "
...
...
@@ -616,10 +615,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
head_dim
))
def
fused_output_quant_supported
(
self
,
dtype
:
torch
.
dtype
,
static
:
bool
,
group_shape
:
tuple
[
int
,
int
]
):
group_shape
:
GroupShape
):
if
self
.
use_triton_flash_attn
:
return
dtype
==
current_platform
.
fp8_dtype
(
)
and
static
and
group_shape
==
(
-
1
,
-
1
)
# per-tensor
)
and
static
and
group_shape
==
GroupShape
.
PER_TENSOR
# Only supported in the Triton backend
return
False
...
...
vllm/attention/backends/triton_mla.py
View file @
711aa9d5
...
...
@@ -2,8 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Type
from
typing
import
List
,
Optional
,
Type
,
Any
,
Dict
from
.triton_config
import
get_nearest_config
,
get_attention_mla_configs
,
get_config
,
get_attention_mla_configs_json
import
torch
...
...
@@ -42,7 +41,6 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]],
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
kv_sharing_target_layer_name
:
Optional
[
str
],
...
...
@@ -50,17 +48,14 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
**
mla_args
)
->
None
:
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
logits_soft_cap
,
attn_type
,
kv_sharing_target_layer_name
,
**
mla_args
)
unsupported_features
=
[
alibi_slopes
,
sliding_window
,
blocksparse_params
,
logits_soft_cap
]
unsupported_features
=
[
alibi_slopes
,
sliding_window
,
logits_soft_cap
]
if
any
(
unsupported_features
):
raise
NotImplementedError
(
"TritonMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap"
)
"alibi_slopes, sliding_window, logits_soft_cap"
)
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
...
...
vllm/attention/backends/xformers.py
View file @
711aa9d5
...
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with xFormers and PagedAttention."""
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
from
xformers
import
ops
as
xops
...
...
@@ -394,17 +394,14 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
use_irope
:
bool
=
False
,
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0."
)
if
blocksparse_params
is
not
None
:
raise
ValueError
(
"XFormers does not support block-sparse attention."
)
raise
NotImplementedError
(
"KV sharing is not supported in V0 "
"XFORMERS backend."
)
if
logits_soft_cap
is
not
None
:
logger
.
warning_once
(
"XFormers does not support logits soft cap. "
"Outputs may be slightly off."
)
...
...
vllm/attention/layer.py
View file @
711aa9d5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer."""
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
List
,
Optional
import
torch
import
torch.nn
as
nn
...
...
@@ -10,18 +10,47 @@ import torch.nn.functional as F
import
vllm.envs
as
envs
from
vllm.attention
import
AttentionType
from
vllm.attention.selector
import
backend_name_to_enum
,
get_attn_backend
from
vllm.attention.utils.kv_sharing_utils
import
validate_kv_sharing_target
from
vllm.config
import
CacheConfig
,
get_current_vllm_config
from
vllm.distributed.kv_transfer
import
(
get_kv_transfer_group
,
has_kv_transfer_group
,
is_v1_kv_transfer_group
)
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
UnquantizedLinearMethod
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm.v1.attention.backends.utils
import
validate_kv_sharing_target
logger
=
init_logger
(
__name__
)
USE_XFORMERS_OPS
=
None
def
check_xformers_availability
():
global
USE_XFORMERS_OPS
if
USE_XFORMERS_OPS
is
not
None
:
return
USE_XFORMERS_OPS
if
current_platform
.
is_cuda
()
and
current_platform
.
has_device_capability
(
100
):
# Xformers FA is not compatible with B200
USE_XFORMERS_OPS
=
False
else
:
try
:
from
importlib.util
import
find_spec
find_spec
(
"xformers.ops"
)
USE_XFORMERS_OPS
=
True
except
ImportError
:
USE_XFORMERS_OPS
=
False
# the warning only needs to be shown once
if
not
USE_XFORMERS_OPS
:
logger
.
warning
(
"Xformers is not available, falling back."
)
return
USE_XFORMERS_OPS
class
Attention
(
nn
.
Module
):
...
...
@@ -45,7 +74,6 @@ class Attention(nn.Module):
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
per_layer_sliding_window
:
Optional
[
int
]
=
None
,
use_mla
:
bool
=
False
,
...
...
@@ -109,6 +137,15 @@ class Attention(nn.Module):
self
.
num_kv_heads
=
num_kv_heads
self
.
sliding_window
=
sliding_window
# For v1 we have backend agnostic iRoPE (local chunked attention)
# we have to store the flag on the layer so gpu model runner can
# set KVSpec appropriately (and pop it so it doesnt get passed to
# the backends)
if
envs
.
VLLM_USE_V1
:
self
.
use_irope
=
extra_impl_args
.
pop
(
"use_irope"
,
False
)
else
:
self
.
use_irope
=
extra_impl_args
.
get
(
"use_irope"
,
False
)
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
=
prefix
)
if
quant_config
else
None
if
quant_method
is
not
None
and
not
isinstance
(
...
...
@@ -134,12 +171,11 @@ class Attention(nn.Module):
kv_cache_dtype
,
block_size
,
is_attention_free
,
blocksparse_params
is
not
None
,
use_mla
=
use_mla
)
impl_cls
=
attn_backend
.
get_impl_cls
()
self
.
impl
=
impl_cls
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
logits_soft_cap
,
attn_type
,
kv_sharing_target_layer_name
,
**
extra_impl_args
)
self
.
backend
=
backend_name_to_enum
(
attn_backend
.
get_name
())
self
.
dtype
=
dtype
...
...
@@ -160,10 +196,6 @@ class Attention(nn.Module):
self
.
attn_type
=
attn_type
if
kv_sharing_target_layer_name
is
not
None
:
if
not
envs
.
VLLM_USE_V1
:
raise
NotImplementedError
(
"Cross-layer KV sharing is not supported in V0."
)
validate_kv_sharing_target
(
prefix
,
kv_sharing_target_layer_name
,
...
...
@@ -318,6 +350,10 @@ class MultiHeadAttention(nn.Module):
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
,
_Backend
.
PALLAS_VLLM_V1
}
else
_Backend
.
TORCH_SDPA
if
(
self
.
attn_backend
==
_Backend
.
XFORMERS
and
not
check_xformers_availability
()):
self
.
attn_backend
=
_Backend
.
TORCH_SDPA
def
forward
(
self
,
query
:
torch
.
Tensor
,
...
...
vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py
deleted
100644 → 0
View file @
751c492c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm.triton_utils
import
tl
,
triton
def
blocksparse_flash_attn_varlen_fwd
(
q
,
k
,
v
,
# (#tokens, n_heads, head_size)
cu_seqlens_k
,
cu_seqlens_q
,
sm_scale
,
sparse_layout
,
*
,
block_size
=
64
,
q_block_size
=
None
,
max_seqlen
=
None
):
# split q to blocks
assert
isinstance
(
sparse_layout
,
(
list
,
tuple
))
_
,
n_heads
,
head_size
=
q
.
shape
batch_size
=
cu_seqlens_k
.
size
(
0
)
-
1
q_block_size
=
q_block_size
or
block_size
assert
q
.
dim
()
==
k
.
dim
()
==
v
.
dim
()
==
3
assert
q
.
size
(
1
)
%
k
.
size
(
1
)
==
0
assert
q
.
size
(
2
)
==
k
.
size
(
2
)
# TODO(linxihui): allow k, v to have different head_size
assert
k
.
shape
==
v
.
shape
assert
cu_seqlens_k
.
dim
()
==
1
q_k_ratio
=
q
.
size
(
1
)
//
k
.
size
(
1
)
if
cu_seqlens_q
is
None
:
if
q
.
size
(
0
)
==
batch_size
:
# decoding only
cu_seqlens_q
=
torch
.
arange
(
0
,
batch_size
+
1
,
dtype
=
cu_seqlens_k
.
dtype
,
device
=
cu_seqlens_k
.
device
,
)
elif
q
.
size
(
0
)
==
k
.
size
(
0
):
cu_seqlens_q
=
cu_seqlens_k
else
:
raise
ValueError
(
"cu_seqlens_q must be specified
\
if it mix of prefilling and decoding."
)
else
:
assert
cu_seqlens_k
.
size
(
0
)
==
cu_seqlens_q
.
size
(
0
)
# switch to use cpu to avoid too many kernel launches when iterated over
q_lens
=
(
cu_seqlens_q
[
1
:]
-
cu_seqlens_q
[:
-
1
]).
cpu
()
k_lens
=
(
cu_seqlens_k
[
1
:]
-
cu_seqlens_k
[:
-
1
]).
cpu
()
assert
torch
.
logical_or
(
q_lens
==
1
,
k_lens
==
q_lens
).
all
(),
(
"length of q should either be 1 (decoding) or same as k (prefilling)."
)
if
max_seqlen
:
assert
k_lens
.
max
()
<=
max_seqlen
n_blocks
=
(
q_lens
+
q_block_size
-
1
)
//
q_block_size
q_batch_ids
=
torch
.
tensor
(
[
i
for
i
,
n
in
enumerate
(
n_blocks
)
for
_
in
range
(
n
)],
dtype
=
cu_seqlens_q
.
dtype
,
device
=
cu_seqlens_q
.
device
,
)
q_start_sids
=
torch
.
tensor
(
[
i
*
q_block_size
for
n
in
n_blocks
for
i
in
range
(
n
)],
dtype
=
cu_seqlens_q
.
dtype
,
device
=
cu_seqlens_q
.
device
,
)
out
=
q
.
new_empty
(
q
.
shape
)
cu_seqlens_q
=
cu_seqlens_q
.
contiguous
()
cu_seqlens_k
=
cu_seqlens_k
.
contiguous
()
layout_crow_indices
,
layout_col_indices
=
sparse_layout
block_d
=
triton
.
next_power_of_2
(
head_size
)
decoding_only
=
(
q_lens
==
1
).
all
().
item
()
grid
=
(
len
(
q_start_sids
),
n_heads
,
1
)
_fwd_kernel_batch_inference
[
grid
](
q
,
k
,
v
,
out
,
sm_scale
,
cu_seqlens_q
[:
-
1
],
cu_seqlens_q
[
1
:],
cu_seqlens_k
[:
-
1
],
cu_seqlens_k
[
1
:],
q_batch_ids
,
q_start_sids
,
0
,
*
q
.
stride
(),
0
,
*
k
.
stride
(),
0
,
*
v
.
stride
(),
0
,
*
out
.
stride
(),
layout_crow_indices
,
layout_col_indices
,
*
layout_crow_indices
.
stride
(),
*
layout_col_indices
.
stride
(),
q_k_ratio
,
HAS_BATCH_DIM
=
False
,
D_HEAD
=
head_size
,
BLOCK_M
=
q_block_size
,
BLOCK_N
=
block_size
,
BLOCK_D
=
block_d
,
BLOCK_M_LOADING
=
(
16
if
decoding_only
else
q_block_size
),
# smaller for decoding
EVEN_D
=
block_d
==
head_size
,
num_warps
=
1
if
decoding_only
else
4
,
num_stages
=
3
)
return
out
@
triton
.
jit
def
_fwd_kernel_inner
(
acc
,
l_i
,
m_i
,
q
,
Q
,
k_block_col_idx
,
layout_col_ptr
,
layout_col_stride_h
,
layout_col_stride_m
,
k_ptrs
,
v_ptrs
,
off_h
,
offs_m
,
offs_n
,
offs_d
,
stride_kt
,
stride_vt
,
sm_scale
,
k_seqlen
,
past_len
,
LAST_K_BLOCK
:
tl
.
constexpr
,
BLOCK_M_LOADING
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
D_HEAD
:
tl
.
constexpr
,
EVEN_D
:
tl
.
constexpr
,
M_LT_N
:
tl
.
constexpr
,
):
k_block_id
=
tl
.
load
(
layout_col_ptr
+
off_h
*
layout_col_stride_h
+
k_block_col_idx
*
layout_col_stride_m
).
to
(
tl
.
int32
)
start_n
=
k_block_id
*
BLOCK_N
if
LAST_K_BLOCK
:
if
EVEN_D
:
k
=
tl
.
load
(
k_ptrs
+
start_n
*
stride_kt
,
mask
=
offs_n
[
None
,
:]
+
start_n
<
k_seqlen
,
other
=
0.0
,
)
else
:
k
=
tl
.
load
(
k_ptrs
+
start_n
*
stride_kt
,
mask
=
(
offs_n
[
None
,
:]
+
start_n
<
k_seqlen
)
&
(
offs_d
[:,
None
]
<
D_HEAD
),
other
=
0.0
,
)
else
:
if
EVEN_D
:
k
=
tl
.
load
(
k_ptrs
+
start_n
*
stride_kt
)
else
:
k
=
tl
.
load
(
k_ptrs
+
start_n
*
stride_kt
,
mask
=
offs_d
[:,
None
]
<
D_HEAD
,
other
=
0.0
)
qk
=
tl
.
zeros
([
BLOCK_M_LOADING
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
*=
sm_scale
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
if
LAST_K_BLOCK
|
M_LT_N
:
qk
+=
tl
.
where
(
offs_m
[:,
None
]
+
past_len
>=
(
start_n
+
offs_n
[
None
,
:]),
0
,
float
(
"-inf"
),
)
# flash-attn2
m_ij
=
tl
.
maximum
(
m_i
,
tl
.
max
(
qk
,
1
))
p
=
tl
.
math
.
exp2
(
qk
-
m_ij
[:,
None
])
l_ij
=
tl
.
sum
(
p
,
1
)
alpha
=
tl
.
math
.
exp2
(
m_i
-
m_ij
)
acc
=
acc
*
alpha
[:,
None
]
# update m_i
m_i
=
m_ij
l_i
=
l_i
*
alpha
+
l_ij
p
=
p
.
to
(
Q
.
dtype
.
element_ty
)
# update acc
if
LAST_K_BLOCK
:
if
EVEN_D
:
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vt
,
mask
=
offs_n
[:,
None
]
+
start_n
<
k_seqlen
,
other
=
0.0
,
)
else
:
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vt
,
mask
=
(
offs_n
[:,
None
]
+
start_n
<
k_seqlen
)
&
(
offs_d
[
None
,
:]
<
D_HEAD
),
other
=
0.0
,
)
else
:
if
EVEN_D
:
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vt
)
else
:
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vt
,
mask
=
offs_d
[
None
,
:]
<
D_HEAD
,
other
=
0.0
)
acc
+=
tl
.
dot
(
p
,
v
)
return
acc
,
l_i
,
m_i
@
triton
.
heuristics
({
"M_LT_N"
:
lambda
kwargs
:
kwargs
[
"BLOCK_M"
]
<
kwargs
[
"BLOCK_N"
],
})
@
triton
.
jit
def
_fwd_kernel_batch_inference
(
Q
,
K
,
V
,
Out
,
sm_scale
,
q_batch_starts
,
q_batch_ends
,
k_batch_starts
,
k_batch_ends
,
q_batch_ids
,
q_start_sids
,
stride_qb
,
stride_qt
,
stride_qh
,
stride_qd
,
stride_kb
,
stride_kt
,
stride_kh
,
stride_kd
,
stride_vb
,
stride_vt
,
stride_vh
,
stride_vd
,
stride_ob
,
stride_ot
,
stride_oh
,
stride_od
,
layout_crow_ptr
,
layout_col_ptr
,
layout_crow_stride_h
,
layout_crow_stride_m
,
layout_col_stride_h
,
layout_col_stride_m
,
q_k_ratio
,
HAS_BATCH_DIM
:
tl
.
constexpr
,
D_HEAD
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
BLOCK_M_LOADING
:
tl
.
constexpr
,
EVEN_D
:
tl
.
constexpr
,
M_LT_N
:
tl
.
constexpr
,
):
"""
NOTATION:
pid: position id
sid: storage id
sbid: storage block id
pbid: position block id
offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col)
TODO(linxihui):
Optimize grouped-attn
"""
off_zm
=
tl
.
program_id
(
0
)
off_h
=
tl
.
program_id
(
1
)
off_h_for_kv
=
off_h
//
q_k_ratio
if
HAS_BATCH_DIM
:
off_z
=
tl
.
program_id
(
2
)
Q
+=
off_z
*
stride_qb
K
+=
off_z
*
stride_kb
V
+=
off_z
*
stride_vb
Out
+=
off_z
*
stride_ob
start_m
=
off_zm
q_start_sid
=
start_m
*
BLOCK_M
# always 0 for decoding
else
:
off_z
=
tl
.
load
(
q_batch_ids
+
off_zm
).
to
(
tl
.
int32
)
# [0, 0, 0, 1]
q_start_sid
=
tl
.
load
(
q_start_sids
+
off_zm
)
start_m
=
q_start_sid
//
BLOCK_M
# q_sbid
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M_LOADING
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_D
)
q_cu_start
=
tl
.
load
(
q_batch_starts
+
off_z
).
to
(
tl
.
int32
)
q_seqlen
=
tl
.
load
(
q_batch_ends
+
off_z
).
to
(
tl
.
int32
)
-
q_cu_start
k_cu_start
=
tl
.
load
(
k_batch_starts
+
off_z
).
to
(
tl
.
int32
)
k_seqlen
=
tl
.
load
(
k_batch_ends
+
off_z
).
to
(
tl
.
int32
)
-
k_cu_start
past_len
=
k_seqlen
-
q_seqlen
Q
+=
q_cu_start
*
stride_qt
+
off_h
*
stride_qh
K
+=
k_cu_start
*
stride_kt
+
off_h_for_kv
*
stride_kh
V
+=
k_cu_start
*
stride_vt
+
off_h_for_kv
*
stride_vh
Out
+=
q_cu_start
*
stride_ot
+
off_h
*
stride_oh
q_pbid
=
(
past_len
+
q_start_sid
)
//
BLOCK_M
if
EVEN_D
:
q
=
tl
.
load
(
Q
+
offs_m
[:,
None
]
*
stride_qt
+
offs_d
[
None
,
:]
*
stride_qd
,
mask
=
offs_m
[:,
None
]
<
q_seqlen
,
other
=
0.0
,
)
else
:
q
=
tl
.
load
(
Q
+
offs_m
[:,
None
]
*
stride_qt
+
offs_d
[
None
,
:]
*
stride_qd
,
mask
=
(
offs_m
[:,
None
]
<
q_seqlen
)
&
(
offs_d
[
None
,
:]
<
D_HEAD
),
other
=
0.0
,
)
sparse_crow_ptr
=
(
layout_crow_ptr
+
off_h
*
layout_crow_stride_h
+
q_pbid
*
layout_crow_stride_m
)
# TODO(linxihui): load at once, with any Triton version
# that supports `tl.split`, e.g., Triton 3.0
k_block_start
=
tl
.
load
(
sparse_crow_ptr
).
to
(
tl
.
int32
)
k_block_end
=
tl
.
load
(
sparse_crow_ptr
+
1
).
to
(
tl
.
int32
)
m_i
=
tl
.
zeros
([
BLOCK_M_LOADING
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
l_i
=
tl
.
zeros
([
BLOCK_M_LOADING
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M_LOADING
,
BLOCK_D
],
dtype
=
tl
.
float32
)
k_ptrs
=
K
+
offs_n
[
None
,
:]
*
stride_kt
+
offs_d
[:,
None
]
*
stride_kd
v_ptrs
=
V
+
offs_n
[:,
None
]
*
stride_vt
+
offs_d
[
None
,
:]
*
stride_vd
sm_scale
*=
(
1.44269504
# 1/log2 as we use base2 for exponential and logarithm
)
for
k_block_col_idx
in
range
(
k_block_start
,
k_block_end
-
1
):
acc
,
l_i
,
m_i
=
_fwd_kernel_inner
(
acc
,
l_i
,
m_i
,
q
,
Q
,
k_block_col_idx
,
layout_col_ptr
,
layout_col_stride_h
,
layout_col_stride_m
,
k_ptrs
,
v_ptrs
,
off_h
,
offs_m
,
offs_n
,
offs_d
,
stride_kt
,
stride_vt
,
sm_scale
,
k_seqlen
,
past_len
,
False
,
BLOCK_M_LOADING
,
BLOCK_N
,
D_HEAD
,
EVEN_D
,
M_LT_N
,
)
acc
,
l_i
,
m_i
=
_fwd_kernel_inner
(
acc
,
l_i
,
m_i
,
q
,
Q
,
k_block_end
-
1
,
layout_col_ptr
,
layout_col_stride_h
,
layout_col_stride_m
,
k_ptrs
,
v_ptrs
,
off_h
,
offs_m
,
offs_n
,
offs_d
,
stride_kt
,
stride_vt
,
sm_scale
,
k_seqlen
,
past_len
,
True
,
BLOCK_M_LOADING
,
BLOCK_N
,
D_HEAD
,
EVEN_D
,
M_LT_N
,
)
# flash-attn 2
m_i
+=
tl
.
math
.
log2
(
l_i
)
acc
=
acc
/
l_i
[:,
None
]
# write output
if
EVEN_D
:
tl
.
store
(
Out
+
offs_m
[:,
None
]
*
stride_ot
+
offs_d
[
None
,
:]
*
stride_od
,
acc
,
mask
=
offs_m
[:,
None
]
<
q_seqlen
,
)
else
:
tl
.
store
(
Out
+
offs_m
[:,
None
]
*
stride_ot
+
offs_d
[
None
,
:]
*
stride_od
,
acc
,
mask
=
(
offs_m
[:,
None
]
<
q_seqlen
)
&
(
offs_d
[
None
,
:]
<
D_HEAD
),
)
vllm/attention/ops/blocksparse_attention/interface.py
deleted
100644 → 0
View file @
751c492c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
import
torch
from
vllm.platforms
import
current_platform
from
.utils
import
(
dense_to_crow_col
,
get_head_sliding_step
,
get_sparse_attn_mask
)
IS_COMPUTE_8_OR_ABOVE
=
current_platform
.
has_device_capability
(
80
)
if
IS_COMPUTE_8_OR_ABOVE
:
from
.blocksparse_attention_kernel
import
blocksparse_flash_attn_varlen_fwd
class
LocalStridedBlockSparseAttn
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
n_heads
,
max_seqlen
,
local_blocks
,
vert_stride
,
block_size
,
device
=
None
,
dtype
=
None
,
homo_head
=
False
,
active_head_range
=
None
,
q_block_size
=
None
,
use_spda
=
None
,
):
super
().
__init__
()
if
use_spda
is
None
:
use_spda
=
current_platform
.
is_rocm
()
or
\
current_platform
.
is_cpu
()
or
not
\
IS_COMPUTE_8_OR_ABOVE
device
=
device
or
(
torch
.
cuda
.
current_device
()
if
current_platform
.
is_cuda_alike
()
else
"cpu"
)
device
=
torch
.
device
(
device
)
# NOTE: vllm CPU backend support BF16 instead of FP16.
dtype
=
dtype
or
(
torch
.
bfloat16
if
IS_COMPUTE_8_OR_ABOVE
or
device
.
type
==
"cpu"
else
torch
.
half
)
self
.
n_heads
=
n_heads
self
.
max_seqlen
=
max_seqlen
self
.
local_blocks
=
local_blocks
self
.
vert_stride
=
vert_stride
self
.
use_spda
=
use_spda
self
.
dtype
=
dtype
self
.
device
=
device
self
.
block_size
=
block_size
self
.
q_block_size
=
q_block_size
self
.
homo_head
=
homo_head
self
.
active_head_range
=
active_head_range
self
.
head_sliding_step
=
get_head_sliding_step
(
n_heads
,
vert_stride
,
homo_head
)
sparse_layout
,
sparse_pattern
,
self
.
dense_attn_mask
=
(
self
.
get_attn_pattern
(
dtype
,
device
))
if
q_block_size
is
not
None
and
q_block_size
!=
block_size
:
if
q_block_size
>
block_size
:
assert
q_block_size
%
block_size
==
0
blocks_to_merge
=
q_block_size
//
block_size
shape
=
sparse_pattern
.
shape
sparse_pattern
=
sparse_pattern
.
view
(
shape
[
0
],
-
1
,
blocks_to_merge
,
shape
[
-
1
])
sparse_pattern
=
sparse_pattern
.
sum
(
2
)
sparse_layout
=
dense_to_crow_col
(
sparse_pattern
)
else
:
raise
ValueError
(
"Does not support smaller q_block_size. It will be slower."
)
self
.
sparse_layout
=
sparse_layout
def
get_attn_pattern
(
self
,
dtype
,
device
):
sparse_layout
,
sparse_pattern
,
dense_attn_mask
=
get_sparse_attn_mask
(
self
.
n_heads
,
self
.
max_seqlen
,
self
.
max_seqlen
,
dtype
,
device
,
block_size
=
self
.
block_size
,
local_blocks
=
self
.
local_blocks
,
vert_stride
=
self
.
vert_stride
,
homo_head
=
self
.
homo_head
,
return_dense
=
self
.
use_spda
,
dense_mask_type
=
"bias"
,
)
if
(
not
self
.
homo_head
)
and
(
self
.
active_head_range
is
not
None
):
assert
isinstance
(
self
.
active_head_range
,
tuple
)
assert
(
len
(
self
.
active_head_range
)
==
2
)
h_start
,
h_end
=
self
.
active_head_range
sparse_layout
=
tuple
(
x
[
h_start
:
h_end
]
for
x
in
sparse_layout
)
if
self
.
use_spda
:
dense_attn_mask
=
dense_attn_mask
[
h_start
:
h_end
]
return
sparse_layout
,
sparse_pattern
,
dense_attn_mask
def
varlen_attn
(
self
,
q
,
k
,
v
,
cu_seqlens_k
,
cu_seqlens_q
=
None
,
sm_scale
=
None
):
"""
q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
Support grouped attention, with `q[:, i*r:(i*r + r)]`
is correspondent to `k[:, i]`, where `r` is the q/k ratio.
cu_seqlens_k: shape=(batch_size + 1,),
indicating segment of samples,
e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
cu_seqlens_q: shape=(batch_size + 1, ).
Default None: same as cu_seqlens_k for prefilling or
[0, 1, .., batch_size] for decoding.
The only case you need to specify is when q is a mix of
prefilling and decoding.
sm_scale: softmax scale, default to 1/sqrt(head_size).
return: tensor of shape as q.
"""
assert
(
IS_COMPUTE_8_OR_ABOVE
),
"Requires compute capability of 8 or above (Ampere or newer) to use
\
Triton kernel."
sm_scale
=
sm_scale
or
1.0
/
math
.
sqrt
(
q
.
size
(
-
1
))
return
blocksparse_flash_attn_varlen_fwd
(
q
,
k
,
v
,
cu_seqlens_k
,
cu_seqlens_q
,
sm_scale
,
self
.
sparse_layout
,
block_size
=
self
.
block_size
,
q_block_size
=
self
.
q_block_size
,
max_seqlen
=
self
.
max_seqlen
,
)
@
staticmethod
def
transpose_and_pad
(
x
,
cu_seqlens
,
maxlen
,
head_repeats
=
1
):
"""
:param x: (total_tokens, n_heads, head_size)
:return: (batch, n_heads, length, head_size)
"""
x_padded
=
x
.
new_empty
(
len
(
cu_seqlens
)
-
1
,
x
.
size
(
1
),
head_repeats
,
maxlen
,
x
.
size
(
2
))
cu_seqlens
=
cu_seqlens
.
cpu
()
for
i
,
(
s
,
e
)
in
enumerate
(
zip
(
cu_seqlens
[:
-
1
],
cu_seqlens
[
1
:])):
x_padded
[
i
,
:,
:,
:
e
-
s
].
copy_
(
x
[
s
:
e
].
transpose
(
0
,
1
).
unsqueeze
(
1
))
return
x_padded
.
flatten
(
1
,
2
)
@
staticmethod
def
transpose_and_unpad
(
x_padded
,
cu_seqlens
):
"""
:param x_padded: (batch, n_heads, length, head_size)
:return: (total_tokens, n_heads, head_size)
"""
cu_seqlens
=
cu_seqlens
.
cpu
()
total_n_tokens
=
cu_seqlens
[
-
1
]
x
=
x_padded
.
new_empty
(
total_n_tokens
,
x_padded
.
size
(
1
),
x_padded
.
size
(
3
))
for
i
,
(
s
,
e
)
in
enumerate
(
zip
(
cu_seqlens
[:
-
1
],
cu_seqlens
[
1
:])):
x
[
s
:
e
].
copy_
(
x_padded
[
i
,
:,
:
e
-
s
].
transpose
(
0
,
1
))
return
x
def
spda
(
self
,
q
,
k
,
v
,
cu_seqlens_k
,
cu_seqlens_q
=
None
,
sm_scale
=
None
):
"""For CPU, V100 or other older GPUs.
NOTE: torch SPDA supports nested tensor,
but seems extremely slow. Choose to pad instead.
"""
assert
(
cu_seqlens_q
is
None
or
(
cu_seqlens_q
==
cu_seqlens_k
).
all
()),
"Can only handle prompt with SPDA."
assert
q
.
size
(
0
)
==
k
.
size
(
0
),
"can only handle prompt with SPDA."
assert
q
.
size
(
1
)
%
k
.
size
(
1
)
==
0
q_k_ratio
=
q
.
size
(
1
)
//
k
.
size
(
1
)
sm_scale
=
sm_scale
or
1.0
/
math
.
sqrt
(
q
.
size
(
-
1
))
cu_seqlens
=
cu_seqlens_k
.
cpu
()
maxlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
()
if
(
self
.
dense_attn_mask
.
dtype
!=
q
.
dtype
or
self
.
dense_attn_mask
.
device
!=
q
.
device
):
_
,
_
,
self
.
dense_attn_mask
=
self
.
get_attn_pattern
(
q
.
dtype
,
q
.
device
)
attn_mask
=
self
.
dense_attn_mask
[
None
,
:,
:
maxlen
,
:
maxlen
]
q2
=
self
.
transpose_and_pad
(
q
,
cu_seqlens
,
maxlen
,
1
)
k2
,
v2
=
(
self
.
transpose_and_pad
(
x
,
cu_seqlens
,
maxlen
,
q_k_ratio
)
for
x
in
[
k
,
v
])
spda_output
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
q2
,
k2
,
v2
,
attn_mask
=
attn_mask
,
scale
=
sm_scale
)
return
self
.
transpose_and_unpad
(
spda_output
,
cu_seqlens
)
def
forward
(
self
,
q
,
k
,
v
,
cu_seqlens_k
,
cu_seqlens_q
=
None
,
sm_scale
=
None
):
"""Dispatch to `varlen_attn` (Ampere or newer) or
`self.spda`(cpu, Volta, Turing or older)based on
the type of device used and cuda compute capability.
q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
Support grouped attention, with `q[:, i*r:(i*r + r)]`
is correspondent to `k[:, i]`, where `r` is the q/k ratio.
cu_seqlens_k: shape=(batch_size + 1,), indicating segment of samples,
e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
cu_seqlens_q: shape=(batch_size + 1, ).
Default None: same as cu_seqlens_k for prefilling or
[0, 1, .., batch_size] for decoding.
The only case you need to specify
is when q is a mix of prefilling
and decoding.
sm_scale: softmax scale, default to 1/sqrt(head_size).
return: tensor of shape as q.
"""
assert
k
.
dim
()
==
3
if
self
.
use_spda
:
return
self
.
spda
(
q
,
k
,
v
,
cu_seqlens_k
,
cu_seqlens_q
=
cu_seqlens_q
,
sm_scale
=
sm_scale
,
)
return
self
.
varlen_attn
(
q
,
k
,
v
,
cu_seqlens_k
,
cu_seqlens_q
=
cu_seqlens_q
,
sm_scale
=
sm_scale
)
vllm/attention/ops/blocksparse_attention/utils.py
deleted
100644 → 0
View file @
751c492c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Helper functions for 3D sparse pattern
# These function are not optimized and very inefficient.
# Avoid calling them too frequent or use a cache mechanism.
from
functools
import
lru_cache
import
numpy
as
np
import
torch
from
vllm.triton_utils
import
triton
class
csr_matrix
:
"""Simple implementation of CSR matrix conversion without scipy.
This replaced scipy.sparse.csr_matrix() previously used."""
def
__init__
(
self
,
input_array
):
if
not
isinstance
(
input_array
,
np
.
ndarray
):
raise
ValueError
(
"Input must be a NumPy array"
)
self
.
shape
=
input_array
.
shape
rows
,
cols
=
self
.
shape
data
=
[]
indices
=
[]
indptr
=
[
0
]
for
i
in
range
(
rows
):
for
j
in
range
(
cols
):
if
input_array
[
i
,
j
]:
data
.
append
(
input_array
[
i
,
j
])
indices
.
append
(
j
)
indptr
.
append
(
len
(
indices
))
self
.
data
=
np
.
array
(
data
)
self
.
indices
=
np
.
array
(
indices
)
self
.
indptr
=
np
.
array
(
indptr
)
def
dense_to_crow_col
(
x
:
torch
.
Tensor
):
"""Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing.
NOTE: col_indices padded -1
"""
device
=
x
.
device
pad
=
-
1
dim
=
x
.
dim
()
assert
x
.
dim
()
in
(
2
,
3
)
if
x
.
dim
()
==
2
:
x
=
x
[
None
]
x
=
[
csr_matrix
(
xi
.
bool
().
cpu
().
numpy
())
for
xi
in
x
]
crows
=
torch
.
vstack
([
torch
.
from_numpy
(
xi
.
indptr
)
for
xi
in
x
])
cols
=
[
torch
.
from_numpy
(
xi
.
indices
)
for
xi
in
x
]
max_cols
=
max
(
len
(
xi
)
for
xi
in
cols
)
cols
=
[
torch
.
cat
([
xi
,
pad
+
xi
.
new_zeros
(
max_cols
-
xi
.
shape
[
0
])])
for
xi
in
cols
]
cols
=
torch
.
vstack
(
cols
)
if
dim
==
2
:
crows
=
crows
[
0
]
cols
=
cols
[
0
]
return
crows
.
to
(
device
),
cols
.
to
(
device
)
def
crow_col_to_dense
(
crows
:
torch
.
Tensor
,
cols
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
torch
.
float16
):
dim
=
crows
.
dim
()
if
dim
==
1
:
crows
=
crows
[
None
]
cols
=
cols
[
None
]
device
=
crows
.
device
crows
,
cols
=
crows
.
cpu
(),
cols
.
cpu
()
# faster in cpu
shape
=
(
crows
.
shape
[
0
],
crows
.
shape
[
1
]
-
1
,
cols
.
max
()
+
1
)
x
=
torch
.
zeros
(
shape
,
dtype
=
dtype
)
for
i
in
range
(
shape
[
0
]):
for
j
in
range
(
shape
[
1
]):
x
[
i
,
j
,
cols
[
i
,
crows
[
i
,
j
]:
crows
[
i
,
j
+
1
]]]
=
1
if
dim
==
1
:
x
=
x
[
0
]
return
x
.
to
(
device
)
def
dense_to_ccol_row
(
x
:
torch
.
Tensor
):
"""Similar, but to CSC format"""
x
=
x
.
transpose
(
-
2
,
-
1
)
return
dense_to_crow_col
(
x
)
def
ccol_row_to_dense
(
ccol
:
torch
.
Tensor
,
rows
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
torch
.
float16
):
return
crow_col_to_dense
(
ccol
,
rows
,
dtype
).
permute
(
0
,
2
,
1
).
contiguous
()
def
_get_sparse_attn_mask_homo_head
(
q_len
:
int
,
max_seqlen
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
block_size
:
int
=
128
,
local_blocks
:
int
=
4
,
vert_stride
:
int
=
4
,
return_dense
:
bool
=
False
,
):
"""
:return: a tuple of 3:
- tuple of crow_indices, col_indices representation
of CSR format.
- block dense mask
- all token dense mask (be aware that it can be
OOM if it is too big) if `return_dense==True`,
otherwise, None
"""
with
torch
.
no_grad
():
num_blocks
=
triton
.
cdiv
(
max_seqlen
,
block_size
)
q_pos
=
torch
.
arange
(
num_blocks
)[:,
None
]
k_pos
=
torch
.
arange
(
num_blocks
)[
None
]
mask_vert_strided
=
(
torch
.
arange
(
num_blocks
)
+
1
)
%
vert_stride
==
0
block_mask_dense
=
(((
q_pos
>=
k_pos
)
&
((
q_pos
-
k_pos
<
local_blocks
)
|
mask_vert_strided
)).
to
(
device
).
to
(
dtype
))
num_blocks_q
=
triton
.
cdiv
(
q_len
,
block_size
)
block_mask_dense_output
=
(
dense_to_crow_col
(
block_mask_dense
[
-
num_blocks_q
:].
contiguous
()))
if
return_dense
:
mask_dense
=
torch
.
kron
(
block_mask_dense
,
block_mask_dense
.
new_ones
((
block_size
,
block_size
)),
)
causal_mask
=
torch
.
tril
(
torch
.
ones
(
max_seqlen
,
max_seqlen
)).
type_as
(
mask_dense
)[
-
q_len
:]
mask_dense
=
mask_dense
[
-
q_len
:,
:
max_seqlen
]
*
causal_mask
return
(
block_mask_dense_output
,
block_mask_dense
,
mask_dense
,
)
else
:
return
(
block_mask_dense_output
,
block_mask_dense
,
None
,
)
def
binary_mask_to_bias
(
mask_dense
:
torch
.
Tensor
):
mask_dense
=
1
-
mask_dense
mask_dense
.
masked_fill_
(
mask_dense
.
bool
(),
-
torch
.
inf
)
return
mask_dense
def
get_head_sliding_step
(
n_heads
:
int
,
vert_stride
:
int
,
homo_head
:
bool
=
False
):
if
homo_head
:
return
0
return
max
(
1
,
int
(
vert_stride
/
n_heads
))
@
lru_cache
def
get_sparse_attn_mask
(
n_heads
:
int
,
q_len
:
int
,
max_seqlen
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
block_size
:
int
=
64
,
local_blocks
:
int
=
4
,
vert_stride
:
int
=
4
,
homo_head
:
bool
=
True
,
return_dense
:
bool
=
False
,
dense_mask_type
:
str
=
"binary"
,
):
"""
:param dense_mask_type: "binary" (0 for skip token, 1 for others)
or "bias" (-inf for skip token, 0 or others)
:return: a tuple of 3:
- tuple of crow_indices, col_indices representation
of CSR format.
- block dense mask
- all token dense mask (be aware that it can be OOM if it
is too big) if `return_dense==True`, otherwise, None
"""
assert
dense_mask_type
in
(
"binary"
,
"bias"
)
if
homo_head
:
with
torch
.
no_grad
():
(
crow
,
col
),
block_mask_dense
,
mask_dense
=
(
_get_sparse_attn_mask_homo_head
(
q_len
,
max_seqlen
,
dtype
,
device
,
block_size
,
local_blocks
,
vert_stride
,
return_dense
,
))
crow
=
crow
[
None
].
expand
(
n_heads
,
crow
.
shape
[
0
])
col
=
col
[
None
].
expand
(
n_heads
,
col
.
shape
[
0
])
if
return_dense
:
mask_dense
=
mask_dense
[
None
].
expand
(
n_heads
,
*
mask_dense
.
shape
)
if
dense_mask_type
==
"bias"
:
mask_dense
=
binary_mask_to_bias
(
mask_dense
)
return
(
crow
,
col
),
block_mask_dense
,
mask_dense
with
torch
.
no_grad
():
num_blocks
=
triton
.
cdiv
(
max_seqlen
,
block_size
)
q_pos
=
torch
.
arange
(
num_blocks
)[
None
,
:,
None
]
k_pos
=
torch
.
arange
(
num_blocks
)[
None
,
None
]
head_sliding_step
=
get_head_sliding_step
(
n_heads
,
vert_stride
)
mask_vert_strided
=
[
(
torch
.
arange
(
num_blocks
)
+
h
*
head_sliding_step
+
1
)
%
vert_stride
==
0
for
h
in
range
(
n_heads
)
]
mask_vert_strided
=
torch
.
vstack
(
mask_vert_strided
).
unsqueeze
(
1
)
block_mask_dense
=
(((
q_pos
>=
k_pos
)
&
((
q_pos
-
k_pos
<
local_blocks
)
|
mask_vert_strided
)).
to
(
device
).
to
(
dtype
))
num_blocks_q
=
triton
.
cdiv
(
q_len
,
block_size
)
block_mask_dense_output
=
block_mask_dense
[:,
-
num_blocks_q
:]
if
return_dense
:
mask_dense
=
torch
.
kron
(
block_mask_dense
,
block_mask_dense
.
new_ones
((
block_size
,
block_size
)),
)
causal_mask
=
torch
.
tril
(
torch
.
ones
(
max_seqlen
,
max_seqlen
)).
type_as
(
mask_dense
)[
-
q_len
:]
mask_dense
=
mask_dense
[...,
-
q_len
:,
:
max_seqlen
]
*
causal_mask
[
None
]
if
dense_mask_type
==
"bias"
:
mask_dense
=
binary_mask_to_bias
(
mask_dense
)
return
(
dense_to_crow_col
(
block_mask_dense_output
),
block_mask_dense
,
mask_dense
,
)
else
:
return
(
dense_to_crow_col
(
block_mask_dense_output
),
block_mask_dense
,
None
,
)
vllm/attention/ops/hpu_paged_attn.py
deleted
100644 → 0
View file @
751c492c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
###############################################################################
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
###############################################################################
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
vllm_hpu_extension
import
cache_ops
,
ops
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE
=
512
@
dataclass
class
HPUPagedAttentionMetadata
:
"""Metadata for PagedAttention."""
block_list
:
Optional
[
torch
.
Tensor
]
block_mapping
:
Optional
[
torch
.
Tensor
]
block_usage
:
Optional
[
torch
.
Tensor
]
block_indices
:
Optional
[
torch
.
Tensor
]
block_offsets
:
Optional
[
torch
.
Tensor
]
block_groups
:
Optional
[
torch
.
Tensor
]
class
HPUPagedAttention
:
@
staticmethod
def
get_supported_head_sizes
()
->
List
[
int
]:
return
[
64
,
80
,
96
,
112
,
128
,
256
]
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
return
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
split_kv_cache
(
kv_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
key_cache
=
kv_cache
[
0
]
value_cache
=
kv_cache
[
1
]
return
key_cache
,
value_cache
@
staticmethod
def
write_to_paged_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
is_prompt
:
bool
)
->
None
:
cache_ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
is_prompt
)
@
staticmethod
def
forward_decode
(
**
kwargs
)
->
torch
.
Tensor
:
return
ops
.
flat_pa
(
**
kwargs
)
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
dst_kv_cache
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
src_to_dsts
:
torch
.
Tensor
,
)
->
None
:
src_key_cache
=
src_kv_cache
[
0
]
dst_key_cache
=
dst_kv_cache
[
0
]
cache_ops
.
swap_blocks
(
src_key_cache
,
dst_key_cache
,
src_to_dsts
)
src_value_cache
=
src_kv_cache
[
1
]
dst_value_cache
=
dst_kv_cache
[
1
]
cache_ops
.
swap_blocks
(
src_value_cache
,
dst_value_cache
,
src_to_dsts
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
src_to_dsts
:
torch
.
Tensor
,
)
->
None
:
key_caches
=
[
kv_cache
[
0
]
for
kv_cache
in
kv_caches
]
value_caches
=
[
kv_cache
[
1
]
for
kv_cache
in
kv_caches
]
cache_ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dsts
)
vllm/attention/ops/ipex_attn.py
deleted
100644 → 0
View file @
751c492c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
List
,
Optional
,
Tuple
try
:
import
intel_extension_for_pytorch.llm.modules
as
ipex_modules
_use_ipex
=
True
# AttributeError is to handle a bug in ipex https://github.com/intel/intel-extension-for-pytorch/pull/813
except
(
ImportError
,
AttributeError
):
_use_ipex
=
False
import
torch
from
vllm
import
_custom_ops
as
ops
class
_PagedAttention
:
@
staticmethod
def
get_supported_head_sizes
()
->
List
[
int
]:
return
[
32
,
64
,
80
,
96
,
112
,
128
,
192
,
256
]
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
*
args
,
)
->
Tuple
[
int
,
...]:
return
2
,
num_blocks
,
block_size
*
num_kv_heads
*
head_size
@
staticmethod
def
split_kv_cache
(
kv_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
head_size
:
int
,
*
args
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
x
=
16
//
kv_cache
.
element_size
()
num_blocks
=
kv_cache
.
shape
[
1
]
key_cache
=
kv_cache
[
0
]
key_cache
=
key_cache
.
view
(
num_blocks
,
num_kv_heads
,
head_size
//
x
,
-
1
,
x
)
value_cache
=
kv_cache
[
1
]
value_cache
=
value_cache
.
view
(
num_blocks
,
num_kv_heads
,
head_size
,
-
1
)
return
key_cache
,
value_cache
@
staticmethod
def
write_to_paged_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
k_scale
:
torch
.
Tensor
,
v_scale
:
torch
.
Tensor
,
*
args
,
)
->
None
:
ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
.
flatten
(),
kv_cache_dtype
,
k_scale
,
v_scale
,
)
@
staticmethod
def
forward_decode
(
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
max_context_len
:
int
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
scale
:
float
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
k_scale
:
torch
.
Tensor
,
v_scale
:
torch
.
Tensor
,
*
args
,
)
->
None
:
tp_rank
:
int
=
0
blocksparse_local_blocks
:
int
=
0
blocksparse_vert_stride
:
int
=
0
blocksparse_block_size
:
int
=
64
blocksparse_head_sliding_step
:
int
=
0
block_size
=
value_cache
.
shape
[
3
]
ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_context_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
*
args
,
)
->
None
:
key_caches
=
[
kv_cache
[
0
]
for
kv_cache
in
kv_caches
]
value_caches
=
[
kv_cache
[
1
]
for
kv_cache
in
kv_caches
]
ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dists
)
class
_IPEXPagedAttention
(
_PagedAttention
):
@
staticmethod
def
split_kv_cache
(
kv_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
head_size
:
int
,
*
args
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
num_blocks
=
kv_cache
.
shape
[
1
]
key_cache
=
kv_cache
[
0
]
key_cache
=
key_cache
.
view
(
num_blocks
,
num_kv_heads
,
-
1
,
head_size
)
value_cache
=
kv_cache
[
1
]
value_cache
=
value_cache
.
view
(
num_blocks
,
num_kv_heads
,
-
1
,
head_size
)
return
key_cache
,
value_cache
@
staticmethod
def
write_to_paged_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
k_scale
:
torch
.
Tensor
,
v_scale
:
torch
.
Tensor
,
*
args
,
)
->
None
:
ipex_modules
.
PagedAttention
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
.
flatten
().
int
())
@
staticmethod
def
forward_decode
(
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
max_context_len
:
int
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
scale
:
float
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
k_scale
:
torch
.
Tensor
,
v_scale
:
torch
.
Tensor
,
*
args
,
)
->
None
:
block_size
=
value_cache
.
shape
[
2
]
head_mapping
=
torch
.
arange
(
0
,
num_kv_heads
,
device
=
"cpu"
,
dtype
=
torch
.
int32
,
).
view
(
num_kv_heads
,
1
).
repeat_interleave
(
query
.
size
(
1
)
//
num_kv_heads
).
flatten
()
ipex_modules
.
PagedAttention
.
single_query_cached_kv_attention
(
output
,
query
.
contiguous
(),
key_cache
,
value_cache
,
head_mapping
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_context_len
,
alibi_slopes
)
PagedAttention
=
_IPEXPagedAttention
if
_use_ipex
else
_PagedAttention
vllm/attention/ops/rocm_aiter_mla.py
View file @
711aa9d5
...
...
@@ -6,7 +6,7 @@ from typing import Optional
import
torch
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
,
is_torch_equal_or_newer
def
get_aiter_mla_metadata
(
max_batch_size
:
int
,
block_size
:
int
,
...
...
@@ -93,8 +93,12 @@ def mla_decode_fwd_fake(
if
current_platform
.
is_rocm
():
if
is_torch_equal_or_newer
(
"2.7.0"
):
tags
=
()
else
:
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,
),
direct_register_custom_op
(
op_name
=
"rocm_aiter_mla_decode_fwd"
,
op_func
=
mla_decode_fwd_impl
,
mutates_args
=
[
"o"
],
fake_impl
=
mla_decode_fwd_fake
,
tags
=
[
torch
.
Tag
.
needs_fixed_stride_order
]
)
tags
=
tags
)
vllm/attention/ops/triton_unified_attention.py
View file @
711aa9d5
...
...
@@ -8,10 +8,9 @@
# - Thomas Parnell <tpa@zurich.ibm.com>
import
torch
import
triton
import
triton.language
as
tl
from
vllm.logger
import
init_logger
from
vllm.triton_utils
import
tl
,
triton
logger
=
init_logger
(
__name__
)
...
...
@@ -145,7 +144,19 @@ def kernel_unified_attention_2d(
mask
=
query_mask_1
,
other
=
0.0
)
num_blocks
=
cdiv_fn
(
seq_len
,
BLOCK_SIZE
)
# compute the length of the longest sequence prefix spanned by any
# query token in the current q_block (q_block_local_idx)
max_seq_prefix_len
=
context_len
+
q_block_local_idx
*
BLOCK_Q
+
(
BLOCK_M
-
1
)
//
num_queries_per_kv
+
1
# adjust for potential padding in the last q_block by considering the
# actual sequence length
max_seq_prefix_len
=
tl
.
minimum
(
max_seq_prefix_len
,
seq_len
)
# calculate the number of tiles (blocks) that need to be processed to
# cover the longest sequence prefix (due to causal masking, blocks beyond
# this prefix can be skipped)
num_blocks
=
cdiv_fn
(
max_seq_prefix_len
,
BLOCK_SIZE
)
# iterate through tiles
for
j
in
range
(
0
,
num_blocks
):
...
...
vllm/attention/selector.py
View file @
711aa9d5
...
...
@@ -3,6 +3,7 @@
import
os
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
functools
import
cache
from
typing
import
Generator
,
Optional
,
Union
...
...
@@ -79,31 +80,61 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
return
forced_attn_backend
def
supports_head_size
(
@
dataclass
(
frozen
=
True
)
class
_IsSupported
:
can_import
:
bool
head_size
:
bool
dtype
:
bool
def
__bool__
(
self
)
->
bool
:
return
self
.
can_import
and
self
.
head_size
and
self
.
dtype
def
is_attn_backend_supported
(
attn_backend
:
Union
[
str
,
type
[
AttentionBackend
]],
head_size
:
int
,
)
->
bool
:
dtype
:
torch
.
dtype
,
*
,
allow_import_error
:
bool
=
True
,
)
->
_IsSupported
:
if
isinstance
(
attn_backend
,
str
):
try
:
attn_backend
=
resolve_obj_by_qualname
(
attn_backend
)
except
ImportError
:
return
False
if
not
allow_import_error
:
raise
return
_IsSupported
(
can_import
=
False
,
head_size
=
False
,
dtype
=
False
)
assert
isinstance
(
attn_backend
,
type
)
# TODO: Update the interface once V0 is removed
if
get_supported_head_sizes
:
=
getattr
(
attn_backend
,
"get_supported_head_sizes"
,
None
):
return
head_size
in
get_supported_head_sizes
()
if
validate_head_size
:
=
getattr
(
attn_backend
,
"validate_head_size"
,
None
):
is_head_size_supported
=
head_size
in
get_supported_head_sizes
()
elif
validate_head_size
:
=
getattr
(
attn_backend
,
"validate_head_size"
,
None
):
try
:
validate_head_size
(
head_size
)
return
True
is_head_size_supported
=
True
except
Exception
:
return
False
is_head_size_supported
=
False
else
:
raise
NotImplementedError
(
f
"
{
attn_backend
.
__name__
}
does not support "
"head size validation"
)
raise
NotImplementedError
(
f
"
{
attn_backend
.
__name__
}
does not support "
"head size validation"
)
if
get_supported_dtypes
:
=
getattr
(
attn_backend
,
"get_supported_dtypes"
,
None
):
is_dtype_supported
=
dtype
in
get_supported_dtypes
()
else
:
raise
NotImplementedError
(
f
"
{
attn_backend
.
__name__
}
does not support "
"dtype validation"
)
return
_IsSupported
(
can_import
=
True
,
head_size
=
is_head_size_supported
,
dtype
=
is_dtype_supported
,
)
def
get_attn_backend
(
...
...
@@ -112,7 +143,6 @@ def get_attn_backend(
kv_cache_dtype
:
Optional
[
str
],
block_size
:
int
,
is_attention_free
:
bool
,
is_blocksparse
:
bool
=
False
,
use_mla
:
bool
=
False
,
)
->
type
[
AttentionBackend
]:
"""Selects which attention backend to use and lazily imports it."""
...
...
@@ -126,7 +156,6 @@ def get_attn_backend(
kv_cache_dtype
=
kv_cache_dtype
,
block_size
=
block_size
,
is_attention_free
=
is_attention_free
,
is_blocksparse
=
is_blocksparse
,
use_v1
=
envs
.
VLLM_USE_V1
,
use_mla
=
use_mla
,
)
...
...
@@ -139,16 +168,9 @@ def _cached_get_attn_backend(
kv_cache_dtype
:
Optional
[
str
],
block_size
:
int
,
is_attention_free
:
bool
,
is_blocksparse
:
bool
=
False
,
use_v1
:
bool
=
False
,
use_mla
:
bool
=
False
,
)
->
type
[
AttentionBackend
]:
if
is_blocksparse
:
logger
.
info
(
"Using BlocksparseFlashAttention backend."
)
from
vllm.attention.backends.blocksparse_attn
import
(
BlocksparseFlashAttentionBackend
)
return
BlocksparseFlashAttentionBackend
# If there are no attention layers (e.g. we are running Mamba),
# use the placeholder NO_ATTENTION
if
is_attention_free
:
...
...
vllm/
prompt_adapter
/__init__.py
→
vllm/
attention/utils
/__init__.py
View file @
711aa9d5
File moved
vllm/attention/utils/kv_sharing_utils.py
0 → 100644
View file @
711aa9d5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
def
validate_kv_sharing_target
(
current_layer_name
,
target_layer_name
,
static_forward_context
):
error_msg
=
(
f
"Specified KV sharing target layer for
{
current_layer_name
}
"
f
"is not valid: target layer
{
target_layer_name
}
"
)
if
current_layer_name
==
target_layer_name
:
raise
ValueError
(
error_msg
+
"cannot be the same as the current layer."
)
if
target_layer_name
not
in
static_forward_context
:
from
vllm.model_executor.models.utils
import
extract_layer_index
# If target layer name is not in the static fwd context, it means either
# a) the target layer does not come BEFORE the current layer, or
# b) the target layer is not an Attention layer that exists in the model
current_layer_idx
=
extract_layer_index
(
current_layer_name
)
target_layer_idx
=
extract_layer_index
(
target_layer_name
)
if
current_layer_idx
<=
target_layer_idx
:
raise
ValueError
(
error_msg
+
"must come before the current layer."
)
else
:
raise
ValueError
(
error_msg
+
"is not a valid Attention layer in the model."
)
# Currently KV sharing is only supported between layers of the same type
target_layer_attn_type
=
static_forward_context
[
target_layer_name
].
attn_type
expected
=
static_forward_context
[
current_layer_name
].
attn_type
if
target_layer_attn_type
!=
expected
:
raise
ValueError
(
error_msg
+
f
"must be the same type as the current layer (
{
expected
}
)."
)
vllm/benchmarks/datasets.py
View file @
711aa9d5
...
...
@@ -481,6 +481,11 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
choices
=
[
"sharegpt"
,
"burstgpt"
,
"sonnet"
,
"random"
,
"hf"
,
"custom"
],
help
=
"Name of the dataset to benchmark on."
,
)
parser
.
add_argument
(
"--no-stream"
,
action
=
"store_true"
,
help
=
"Do not load the dataset in streaming mode."
,
)
parser
.
add_argument
(
"--dataset-path"
,
type
=
str
,
...
...
@@ -649,6 +654,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
elif
args
.
dataset_path
in
ASRDataset
.
SUPPORTED_DATASET_PATHS
:
dataset_class
=
ASRDataset
args
.
hf_split
=
"train"
elif
args
.
dataset_path
in
MLPerfDataset
.
SUPPORTED_DATASET_PATHS
:
dataset_class
=
MLPerfDataset
args
.
hf_split
=
"train"
else
:
supported_datasets
=
set
([
dataset_name
for
cls
in
HuggingFaceDataset
.
__subclasses__
()
...
...
@@ -674,6 +682,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
dataset_subset
=
args
.
hf_subset
,
dataset_split
=
args
.
hf_split
,
random_seed
=
args
.
seed
,
no_stream
=
args
.
no_stream
,
).
sample
(
num_requests
=
args
.
num_prompts
,
tokenizer
=
tokenizer
,
...
...
@@ -971,6 +980,7 @@ class HuggingFaceDataset(BenchmarkDataset):
self
,
dataset_path
:
str
,
dataset_split
:
str
,
no_stream
:
bool
=
False
,
dataset_subset
:
Optional
[
str
]
=
None
,
**
kwargs
,
)
->
None
:
...
...
@@ -978,6 +988,7 @@ class HuggingFaceDataset(BenchmarkDataset):
self
.
dataset_split
=
dataset_split
self
.
dataset_subset
=
dataset_subset
self
.
load_stream
=
not
no_stream
self
.
load_data
()
def
load_data
(
self
)
->
None
:
...
...
@@ -986,7 +997,7 @@ class HuggingFaceDataset(BenchmarkDataset):
self
.
dataset_path
,
name
=
self
.
dataset_subset
,
split
=
self
.
dataset_split
,
streaming
=
True
,
streaming
=
self
.
load_stream
,
)
self
.
data
=
self
.
data
.
shuffle
(
seed
=
self
.
random_seed
)
...
...
@@ -1439,3 +1450,82 @@ class ASRDataset(HuggingFaceDataset):
)
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
)
return
sampled_requests
# -----------------------------------------------------------------------------
# MLPerf Dataset Implementation
# -----------------------------------------------------------------------------
class
MLPerfDataset
(
HuggingFaceDataset
):
"""
MLPerf Inference Dataset.
Dataset on HF:
https://huggingface.co/datasets/mgoin/mlperf-inference-llama2-data
https://huggingface.co/datasets/mgoin/mlperf-inference-llama3.1-data
Each record contains:
- "system_prompt": system role instruction.
- "question": user question.
- "output": reference answer.
We combine the system prompt and question into a chat-formatted prompt
(using the tokenizer's chat template) and set the expected output length to
the tokenized length of the provided reference answer.
"""
SUPPORTED_DATASET_PATHS
=
{
"mgoin/mlperf-inference-llama2-data"
,
"mgoin/mlperf-inference-llama3.1-data"
,
}
def
sample
(
self
,
tokenizer
:
PreTrainedTokenizerBase
,
num_requests
:
int
,
output_len
:
Optional
[
int
]
=
None
,
**
kwargs
,
)
->
list
[
SampleRequest
]:
# Force dynamic output length based on reference completion.
dynamic_output
=
output_len
is
None
sampled_requests
:
list
[
SampleRequest
]
=
[]
for
item
in
self
.
data
:
if
len
(
sampled_requests
)
>=
num_requests
:
break
system_prompt
=
item
[
"system_prompt"
]
question
=
item
[
"question"
]
reference_answer
=
item
[
"output"
]
# Build chat-style prompt using tokenizer template, if available.
messages
=
[
{
"role"
:
"system"
,
"content"
:
system_prompt
},
{
"role"
:
"user"
,
"content"
:
question
},
]
prompt_formatted
=
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
,
tokenize
=
False
)
prompt_len
=
len
(
tokenizer
(
prompt_formatted
).
input_ids
)
# Determine output length from reference answer tokens.
ref_out_len
=
len
(
tokenizer
(
reference_answer
,
add_special_tokens
=
False
).
input_ids
)
expected_output_len
=
ref_out_len
if
dynamic_output
else
output_len
# Validate sequence lengths.
if
not
is_valid_sequence
(
prompt_len
,
expected_output_len
):
continue
sampled_requests
.
append
(
SampleRequest
(
prompt
=
prompt_formatted
,
prompt_len
=
prompt_len
,
expected_output_len
=
expected_output_len
,
)
)
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
)
return
sampled_requests
vllm/benchmarks/serve.py
View file @
711aa9d5
...
...
@@ -138,31 +138,54 @@ async def get_request(
input_requests
=
list
(
input_requests
)
total_requests
=
len
(
input_requests
)
request_index
=
0
assert
total_requests
>
0
,
"No requests provided."
for
request
in
input_requests
:
# Precompute delays among requests to minimize request send laggings
request_rates
=
[]
delay_ts
=
[]
for
request_index
,
request
in
enumerate
(
input_requests
):
current_request_rate
=
_get_current_request_rate
(
ramp_up_strategy
,
ramp_up_start_rps
,
ramp_up_end_rps
,
request_index
,
total_requests
,
request_rate
)
yield
request
,
current_request_rate
request_index
+=
1
request_rates
.
append
(
current_request_rate
)
if
current_request_rate
==
float
(
"inf"
):
# If the request rate is infinity, then we don't need to wait.
continue
theta
=
1.0
/
(
current_request_rate
*
burstiness
)
# Sample the request interval from the gamma distribution.
# If burstiness is 1, it follows exponential distribution.
interval
=
np
.
random
.
gamma
(
shape
=
burstiness
,
scale
=
theta
)
# The next request will be sent after the interval.
await
asyncio
.
sleep
(
interval
)
delay_ts
.
append
(
0
)
else
:
theta
=
1.0
/
(
current_request_rate
*
burstiness
)
# Sample the request interval from the gamma distribution.
# If burstiness is 1, it follows exponential distribution.
delay_ts
.
append
(
np
.
random
.
gamma
(
shape
=
burstiness
,
scale
=
theta
))
# Calculate the cumulative delay time from the first sent out requests.
for
i
in
range
(
1
,
len
(
delay_ts
)):
delay_ts
[
i
]
+=
delay_ts
[
i
-
1
]
if
ramp_up_strategy
is
None
and
delay_ts
[
-
1
]
!=
0
:
# When ramp_up_strategy is not set, we assume the request rate is fixed
# and all requests should be sent in target_total_delay_s, the following
# logic would re-scale delay time to ensure the final delay_ts
# align with target_total_delay_s.
#
# NOTE: If we simply accumulate the random delta values
# from the gamma distribution, their sum would have 1-2% gap
# from target_total_delay_s. The purpose of the following logic is to
# close the gap for stablizing the throughput data
# from different random seeds.
target_total_delay_s
=
total_requests
/
request_rate
normalize_factor
=
target_total_delay_s
/
delay_ts
[
-
1
]
delay_ts
=
[
delay
*
normalize_factor
for
delay
in
delay_ts
]
start_ts
=
time
.
time
()
for
request_index
,
request
in
enumerate
(
input_requests
):
if
delay_ts
[
request_index
]
>
0
:
current_ts
=
time
.
time
()
sleep_interval_s
=
start_ts
+
delay_ts
[
request_index
]
-
current_ts
if
sleep_interval_s
>
0
:
await
asyncio
.
sleep
(
sleep_interval_s
)
yield
request
,
request_rates
[
request_index
]
def
calculate_metrics
(
...
...
vllm/collect_env.py
View file @
711aa9d5
...
...
@@ -96,25 +96,30 @@ DEFAULT_PIP_PATTERNS = {
def
run
(
command
):
"""Return (return-code, stdout, stderr)."""
shell
=
True
if
type
(
command
)
is
str
else
False
p
=
subprocess
.
Popen
(
command
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
shell
=
shell
)
raw_output
,
raw_err
=
p
.
communicate
()
rc
=
p
.
returncode
if
get_platform
()
==
'win32'
:
enc
=
'oem'
else
:
enc
=
locale
.
getpreferredencoding
()
output
=
raw_output
.
decode
(
enc
)
if
command
==
'nvidia-smi topo -m'
:
# don't remove the leading whitespace of `nvidia-smi topo -m`
# because they are meaningful
output
=
output
.
rstrip
()
else
:
output
=
output
.
strip
()
err
=
raw_err
.
decode
(
enc
)
return
rc
,
output
,
err
.
strip
()
try
:
p
=
subprocess
.
Popen
(
command
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
shell
=
shell
)
raw_output
,
raw_err
=
p
.
communicate
()
rc
=
p
.
returncode
if
get_platform
()
==
'win32'
:
enc
=
'oem'
else
:
enc
=
locale
.
getpreferredencoding
()
output
=
raw_output
.
decode
(
enc
)
if
command
==
'nvidia-smi topo -m'
:
# don't remove the leading whitespace of `nvidia-smi topo -m`
# because they are meaningful
output
=
output
.
rstrip
()
else
:
output
=
output
.
strip
()
err
=
raw_err
.
decode
(
enc
)
return
rc
,
output
,
err
.
strip
()
except
FileNotFoundError
:
cmd_str
=
command
if
isinstance
(
command
,
str
)
else
command
[
0
]
return
127
,
''
,
f
"Command not found:
{
cmd_str
}
"
def
run_and_read_all
(
run_lambda
,
command
):
...
...
@@ -148,7 +153,7 @@ def get_conda_packages(run_lambda, patterns=None):
if
patterns
is
None
:
patterns
=
DEFAULT_CONDA_PATTERNS
conda
=
os
.
environ
.
get
(
'CONDA_EXE'
,
'conda'
)
out
=
run_and_read_all
(
run_lambda
,
"{} list"
.
format
(
conda
)
)
out
=
run_and_read_all
(
run_lambda
,
[
conda
,
'list'
]
)
if
out
is
None
:
return
out
...
...
vllm/compilation/backends.py
View file @
711aa9d5
...
...
@@ -120,10 +120,15 @@ class CompilerManager:
handle
=
self
.
cache
[(
runtime_shape
,
graph_index
,
self
.
compiler
.
name
)]
compiled_graph
=
self
.
compiler
.
load
(
handle
,
graph
,
example_inputs
,
graph_index
,
runtime_shape
)
logger
.
debug
(
"Directly load the %s-th graph for shape %s from %s via "
"handle %s"
,
graph_index
,
str
(
runtime_shape
),
self
.
compiler
.
name
,
handle
)
if
runtime_shape
is
None
:
logger
.
debug
(
"Directly load the %s-th graph for dynamic shape from %s via "
"handle %s"
,
graph_index
,
self
.
compiler
.
name
,
handle
)
else
:
logger
.
debug
(
"Directly load the %s-th graph for shape %s from %s via "
"handle %s"
,
graph_index
,
str
(
runtime_shape
),
self
.
compiler
.
name
,
handle
)
return
compiled_graph
def
compile
(
self
,
...
...
@@ -152,9 +157,15 @@ class CompilerManager:
# there can be multiple graphs due to piecewise compilation.
now
=
time
.
time
()
elapsed
=
now
-
compilation_start_time
logger
.
info
(
"Directly load the compiled graph(s) for shape %s "
"from the cache, took %.3f s"
,
str
(
runtime_shape
),
elapsed
)
if
runtime_shape
is
None
:
logger
.
info
(
"Directly load the compiled graph(s) for dynamic shape "
"from the cache, took %.3f s"
,
elapsed
)
else
:
logger
.
info
(
"Directly load the compiled graph(s) for shape %s "
"from the cache, took %.3f s"
,
str
(
runtime_shape
),
elapsed
)
return
compiled_graph
# no compiler cached the graph, or the cache is disabled,
...
...
@@ -172,17 +183,28 @@ class CompilerManager:
assert
compiled_graph
is
not
None
,
"Failed to compile the graph"
# store the artifact in the cache
if
handle
is
not
None
:
if
not
envs
.
VLLM_DISABLE_COMPILE_CACHE
and
handle
is
not
None
:
self
.
cache
[(
runtime_shape
,
graph_index
,
self
.
compiler
.
name
)]
=
handle
compilation_counter
.
num_cache_entries_updated
+=
1
self
.
is_cache_updated
=
True
if
graph_index
==
0
:
# adds some info logging for the first graph
logger
.
info
(
"Cache the graph of shape %s for later use"
,
str
(
runtime_shape
))
logger
.
debug
(
"store the %s-th graph for shape %s from %s via handle %s"
,
graph_index
,
str
(
runtime_shape
),
self
.
compiler
.
name
,
handle
)
if
runtime_shape
is
None
:
logger
.
info
(
"Cache the graph for dynamic shape for later use"
)
else
:
logger
.
info
(
"Cache the graph of shape %s for later use"
,
str
(
runtime_shape
))
if
runtime_shape
is
None
:
logger
.
debug
(
"Store the %s-th graph for dynamic shape from %s via "
"handle %s"
,
graph_index
,
self
.
compiler
.
name
,
handle
)
else
:
logger
.
debug
(
"Store the %s-th graph for shape %s from %s via handle %s"
,
graph_index
,
str
(
runtime_shape
),
self
.
compiler
.
name
,
handle
)
# after compiling the last graph, record the end time
if
graph_index
==
num_graphs
-
1
:
...
...
@@ -190,7 +212,7 @@ class CompilerManager:
elapsed
=
now
-
compilation_start_time
compilation_config
.
compilation_time
+=
elapsed
if
runtime_shape
is
None
:
logger
.
info
(
"Compiling a graph for
general
shape takes %.2f s"
,
logger
.
info
(
"Compiling a graph for
dynamic
shape takes %.2f s"
,
elapsed
)
else
:
logger
.
info
(
"Compiling a graph for shape %s takes %.2f s"
,
...
...
@@ -308,7 +330,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
i
for
i
,
x
in
enumerate
(
args
)
if
isinstance
(
x
,
torch
.
SymInt
)
]
global
compilation_start_time
compiled_graph_for_
general
_shape
=
self
.
vllm_backend
.
\
compiled_graph_for_
dynamic
_shape
=
self
.
vllm_backend
.
\
compiler_manager
.
compile
(
submod
,
args
,
...
...
@@ -323,7 +345,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
self
.
module
.
__dict__
[
target
]
=
piecewise_backend
(
submod
,
self
.
vllm_config
,
self
.
graph_pool
,
index
,
len
(
self
.
compile_submod_names
),
sym_shape_indices
,
compiled_graph_for_
general
_shape
,
self
.
vllm_backend
)
compiled_graph_for_
dynamic
_shape
,
self
.
vllm_backend
)
compilation_counter
.
num_piecewise_capturable_graphs_seen
+=
1
...
...
vllm/compilation/collective_fusion.py
View file @
711aa9d5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
importlib.util
import
find_spec
from
typing
import
Optional
import
torch
import
torch._inductor.pattern_matcher
as
pm
import
torch.fx
as
fx
from
torch._higher_order_ops.auto_functionalize
import
auto_functionalized
from
torch._inductor.pattern_matcher
import
PatternMatcherPass
from
torch.distributed._symmetric_memory
import
enable_symm_mem_for_group
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_tp_group
from
vllm.distributed
import
get_tp_group
,
tensor_model_parallel_all_reduce
from
vllm.distributed.parallel_state
import
(
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.logger
import
init_logger
from
vllm.utils
import
direct_register_custom_op
from
.vllm_inductor_pass
import
VllmInductorPass
if
find_spec
(
"flashinfer"
):
try
:
import
flashinfer.comm
as
flashinfer_comm
flashinfer_comm
=
(
flashinfer_comm
if
hasattr
(
flashinfer_comm
,
"trtllm_allreduce_fusion"
)
else
None
)
except
ImportError
:
flashinfer_comm
=
None
else
:
flashinfer_comm
=
None
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
ALLREDUCE_OP
=
torch
.
ops
.
vllm
.
all_reduce
.
default
RMS_OP
=
torch
.
ops
.
_C
.
rms_norm
.
default
RMS_ADD_OP
=
torch
.
ops
.
_C
.
fused_add_rms_norm
.
default
class
BasePattern
:
...
...
@@ -43,7 +61,8 @@ class GEMMReduceScatterPattern(BasePattern):
mm
,
dim
=
0
,
world_size
=
self
.
tp_size
,
group_name
=
self
.
tp
.
unique_name
)
group_name
=
self
.
tp
.
unique_name
,
)
return
reduce_scatter
def
replacement
(
mul
:
torch
.
Tensor
,
mm_weight
:
torch
.
Tensor
):
...
...
@@ -79,7 +98,8 @@ class AllGatherGEMMPattern(BasePattern):
x
,
dim
=
0
,
world_size
=
self
.
tp_size
,
group_name
=
self
.
tp
.
unique_name
)
group_name
=
self
.
tp
.
unique_name
,
)
return
torch
.
ops
.
aten
.
mm
.
default
(
all_gather
,
weight
)
...
...
@@ -125,3 +145,343 @@ class AsyncTPPass(VllmInductorPass):
logger
.
debug
(
"Replaced %s patterns"
,
count
)
self
.
dump_graph
(
graph
,
"after_async_tp_pass"
)
self
.
end_and_log
()
if
flashinfer_comm
is
not
None
:
_FI_WORKSPACE_TENSOR
=
None
MiB
=
1024
*
1024
# Max size of the input tensor per world size
# to use flashinfer fused allreduce
_FI_MAX_SIZES
=
{
2
:
MiB
,
# 1MB
4
:
MiB
,
# 1MB
6
:
MiB
//
2
,
# 512KB
8
:
MiB
//
2
,
# 512KB
}
# opt for a more conservative default value
# when world size is not in _FI_MAX_SIZES
_DEFAULT_FI_MAX_SIZE
=
MiB
//
2
def
call_trtllm_fused_allreduce_norm
(
allreduce_in
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
rms_gamma
:
torch
.
Tensor
,
rms_eps
:
float
,
world_rank
:
int
,
world_size
:
int
,
launch_with_pdl
:
bool
,
trigger_completion_at_end
:
bool
,
fp32_acc
:
bool
,
max_token_num
:
int
,
norm_out
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
num_tokens
,
hidden_size
=
allreduce_in
.
shape
element_size
=
allreduce_in
.
element_size
()
current_tensor_size
=
num_tokens
*
hidden_size
*
element_size
max_fusion_size
=
max_token_num
*
hidden_size
*
element_size
use_flashinfer
=
current_tensor_size
<=
min
(
_FI_MAX_SIZES
.
get
(
world_size
,
_DEFAULT_FI_MAX_SIZE
),
max_fusion_size
,
)
if
use_flashinfer
:
assert
(
_FI_WORKSPACE_TENSOR
is
not
None
),
"Flashinfer must be enabled when using flashinfer"
if
norm_out
is
None
:
norm_out
=
allreduce_in
residual_out
=
residual
else
:
# return residual_out as allreduce_out with zeroed residual_in
# as flashinfer does not support rms_norm
# and allreduce_out together
residual_out
=
allreduce_in
# For the sizes that are smaller than the max size,
# we only use flashinfer one shot allreduce
flashinfer_comm
.
trtllm_allreduce_fusion
(
allreduce_in
=
allreduce_in
,
token_num
=
allreduce_in
.
shape
[
0
],
residual_in
=
residual
,
residual_out
=
residual_out
,
norm_out
=
norm_out
,
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
world_rank
=
world_rank
,
world_size
=
world_size
,
hidden_dim
=
allreduce_in
.
shape
[
-
1
],
workspace_ptrs
=
_FI_WORKSPACE_TENSOR
,
launch_with_pdl
=
launch_with_pdl
,
use_oneshot
=
True
,
trigger_completion_at_end
=
trigger_completion_at_end
,
fp32_acc
=
fp32_acc
,
pattern_code
=
flashinfer_comm
.
AllReduceFusionPattern
.
kARResidualRMSNorm
,
allreduce_out
=
None
,
quant_out
=
None
,
scale_out
=
None
,
layout_code
=
None
,
scale_factor
=
None
,
)
else
:
allreduce_out
=
tensor_model_parallel_all_reduce
(
allreduce_in
)
if
norm_out
is
None
:
torch
.
ops
.
_C
.
fused_add_rms_norm
(
allreduce_out
,
residual
,
rms_gamma
,
rms_eps
)
else
:
torch
.
ops
.
_C
.
rms_norm
(
norm_out
,
allreduce_out
,
rms_gamma
,
rms_eps
)
allreduce_in
.
copy_
(
allreduce_out
)
def
call_trtllm_fused_allreduce_norm_fake
(
allreduce_in
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
rms_gamma
:
torch
.
Tensor
,
rms_eps
:
float
,
world_rank
:
int
,
world_size
:
int
,
launch_with_pdl
:
bool
,
trigger_completion_at_end
:
bool
,
fp32_acc
:
bool
,
max_token_num
:
int
,
norm_out
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
pass
direct_register_custom_op
(
op_name
=
"flashinfer_trtllm_fused_allreduce_norm"
,
op_func
=
call_trtllm_fused_allreduce_norm
,
mutates_args
=
[
"allreduce_in"
,
"residual"
,
"norm_out"
,
],
fake_impl
=
call_trtllm_fused_allreduce_norm_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
flashinfer_trtllm_fused_allreduce_norm
=
(
torch
.
ops
.
vllm
.
flashinfer_trtllm_fused_allreduce_norm
.
default
)
class
FlashInferFusedAllReduceParams
:
"""Parameters for FlashInfer fused allreduce operations."""
def
__init__
(
self
,
rank
:
int
,
world_size
:
int
,
use_fp32_lamport
:
bool
=
False
,
max_token_num
:
int
=
1024
,
):
self
.
rank
=
rank
self
.
world_size
=
world_size
self
.
use_fp32_lamport
=
use_fp32_lamport
self
.
trigger_completion_at_end
=
True
self
.
launch_with_pdl
=
True
self
.
fp32_acc
=
True
self
.
use_oneshot
=
False
self
.
max_token_num
=
max_token_num
def
get_trtllm_fused_allreduce_kwargs
(
self
):
return
{
"world_rank"
:
self
.
rank
,
"world_size"
:
self
.
world_size
,
"launch_with_pdl"
:
self
.
launch_with_pdl
,
"trigger_completion_at_end"
:
self
.
trigger_completion_at_end
,
"fp32_acc"
:
self
.
fp32_acc
,
"max_token_num"
:
self
.
max_token_num
,
}
class
AllReduceRMSNORMPattern
(
BasePattern
):
def
__init__
(
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
):
super
().
__init__
(
dtype
,
device
)
self
.
epsilon
=
epsilon
self
.
allreduce_params
=
allreduce_params
def
get_inputs
(
self
):
input
=
torch
.
empty
([
1
,
8
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
rms_result
=
torch
.
empty
([
1
,
8
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
weight
=
torch
.
empty
([
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
return
[
input
,
rms_result
,
weight
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
pattern
(
input
:
torch
.
Tensor
,
rms_result
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
):
all_reduce_output
=
tensor_model_parallel_all_reduce
(
input
)
rms
=
auto_functionalized
(
RMS_OP
,
result
=
rms_result
,
input
=
all_reduce_output
,
weight
=
weight
,
epsilon
=
self
.
epsilon
,
)
return
rms
[
1
],
all_reduce_output
def
replacement
(
input
:
torch
.
Tensor
,
rms_result
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
):
residual
=
torch
.
zeros_like
(
input
)
allreduce
=
auto_functionalized
(
torch
.
ops
.
vllm
.
flashinfer_trtllm_fused_allreduce_norm
.
default
,
allreduce_in
=
input
,
residual
=
residual
,
norm_out
=
rms_result
,
rms_gamma
=
weight
,
rms_eps
=
self
.
epsilon
,
**
self
.
allreduce_params
.
get_trtllm_fused_allreduce_kwargs
(),
)
return
allreduce
[
3
],
allreduce
[
1
]
pm
.
register_replacement
(
pattern
,
replacement
,
self
.
get_inputs
(),
pm
.
fwd_only
,
pm_pass
)
class
AllReduceFusedAddRMSNormPattern
(
BasePattern
):
def
__init__
(
self
,
epsilon
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
):
super
().
__init__
(
dtype
,
device
)
self
.
epsilon
=
epsilon
self
.
allreduce_params
=
allreduce_params
def
get_inputs
(
self
):
input
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
residual
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
weight
=
torch
.
empty
([
4
,
4
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
return
[
residual
,
input
,
weight
,
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
pattern
(
residual
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
):
all_reduce_output
=
tensor_model_parallel_all_reduce
(
input
)
rms
=
auto_functionalized
(
RMS_ADD_OP
,
input
=
all_reduce_output
,
residual
=
residual
,
weight
=
weight
,
epsilon
=
self
.
epsilon
,
)
return
rms
[
1
],
rms
[
2
]
def
replacement
(
residual
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
):
allreduce
=
auto_functionalized
(
torch
.
ops
.
vllm
.
flashinfer_trtllm_fused_allreduce_norm
.
default
,
allreduce_in
=
input
,
residual
=
residual
,
rms_gamma
=
weight
,
rms_eps
=
self
.
epsilon
,
norm_out
=
None
,
**
self
.
allreduce_params
.
get_trtllm_fused_allreduce_kwargs
(),
)
return
allreduce
[
1
],
allreduce
[
2
]
pm
.
register_replacement
(
pattern
,
replacement
,
self
.
get_inputs
(),
pm
.
fwd_only
,
pm_pass
)
class
AllReduceFusionPass
(
VllmInductorPass
):
def
__init__
(
self
,
config
:
VllmConfig
):
super
().
__init__
(
config
)
self
.
disabled
=
True
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
if
self
.
tp_size
<=
1
:
return
self
.
patterns
:
PatternMatcherPass
=
PatternMatcherPass
(
pass_name
=
"all_reduce_fusion_pass"
)
if
config
.
model_config
is
None
:
return
self
.
hidden_dim
=
config
.
model_config
.
get_hidden_size
()
self
.
group
=
get_tp_group
().
device_group
rank
=
get_tensor_model_parallel_rank
()
use_fp32_lamport
=
self
.
model_dtype
==
torch
.
float32
if
flashinfer_comm
is
None
:
logger
.
warning
(
"Flashinfer is not installed or comm module not found, "
"skipping allreduce fusion pass"
)
return
# Check if the world size is supported
if
self
.
tp_size
not
in
_FI_MAX_SIZES
:
logger
.
warning
(
"Flashinfer allreduce fusion is not "
"supported for world size %s"
,
self
.
tp_size
,
)
return
self
.
ipc_handles
,
workspace_tensor
=
(
flashinfer_comm
.
trtllm_create_ipc_workspace_for_all_reduce_fusion
(
tp_rank
=
rank
,
tp_size
=
self
.
tp_size
,
max_token_num
=
config
.
compilation_config
.
pass_config
.
fi_allreduce_fusion_max_token_num
,
hidden_dim
=
self
.
hidden_dim
,
group
=
self
.
group
,
use_fp32_lamport
=
use_fp32_lamport
,
))
global
_FI_WORKSPACE_TENSOR
_FI_WORKSPACE_TENSOR
=
workspace_tensor
self
.
allreduce_params
=
FlashInferFusedAllReduceParams
(
rank
=
rank
,
world_size
=
self
.
tp_size
,
use_fp32_lamport
=
use_fp32_lamport
,
max_token_num
=
config
.
compilation_config
.
pass_config
.
fi_allreduce_fusion_max_token_num
,
)
for
epsilon
in
[
1e-5
,
1e-6
]:
AllReduceRMSNORMPattern
(
epsilon
,
self
.
model_dtype
,
self
.
device
,
self
.
allreduce_params
,
).
register
(
self
.
patterns
)
AllReduceFusedAddRMSNormPattern
(
epsilon
,
self
.
model_dtype
,
self
.
device
,
self
.
allreduce_params
,
).
register
(
self
.
patterns
)
self
.
disabled
=
False
def
__call__
(
self
,
graph
:
fx
.
Graph
):
if
self
.
disabled
:
return
self
.
begin
()
self
.
dump_graph
(
graph
,
"before_all_reduce_fusion_pass"
)
count
=
self
.
patterns
.
apply
(
graph
)
logger
.
debug
(
"Replaced %s patterns"
,
count
)
self
.
dump_graph
(
graph
,
"after_all_reduce_fusion_pass"
)
self
.
end_and_log
()
def
__del__
(
self
):
if
self
.
disabled
:
return
if
flashinfer_comm
is
not
None
:
flashinfer_comm
.
trtllm_destroy_ipc_workspace
(
self
.
ipc_handles
,
self
.
group
)
vllm/compilation/compiler_interface.py
View file @
711aa9d5
...
...
@@ -213,7 +213,9 @@ class InductorStandaloneAdaptor(CompilerInterface):
# Save the compiled artifact to disk in the specified path
assert
key
is
not
None
path
=
os
.
path
.
join
(
self
.
cache_dir
,
key
)
compiled_graph
.
save
(
path
=
path
,
format
=
"unpacked"
)
if
not
envs
.
VLLM_DISABLE_COMPILE_CACHE
:
compiled_graph
.
save
(
path
=
path
,
format
=
"unpacked"
)
compilation_counter
.
num_compiled_artifacts_saved
+=
1
return
compiled_graph
,
(
key
,
path
)
def
load
(
self
,
...
...
@@ -421,6 +423,12 @@ class InductorAdaptor(CompilerInterface):
if
is_torch_equal_or_newer
(
"2.6"
):
stack
.
enter_context
(
torch
.
_inductor
.
config
.
patch
(
fx_graph_remote_cache
=
False
))
# InductorAdaptor (unfortunately) requires AOTAutogradCache
# to be turned off to run. It will fail to acquire the hash_str
# and error if not.
# StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem.
stack
.
enter_context
(
torch
.
_functorch
.
config
.
patch
(
enable_autograd_cache
=
False
))
stack
.
enter_context
(
torch
.
_functorch
.
config
.
patch
(
enable_remote_autograd_cache
=
False
))
...
...
Prev
1
…
20
21
22
23
24
25
26
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment