Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
eebad39f
Unverified
Commit
eebad39f
authored
Nov 22, 2024
by
youkaichao
Committed by
GitHub
Nov 22, 2024
Browse files
[torch.compile] support all attention backends (#10558)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
db100c5c
Changes
77
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
519 additions
and
490 deletions
+519
-490
tests/kernels/test_encoder_decoder_attn.py
tests/kernels/test_encoder_decoder_attn.py
+26
-11
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+14
-9
vllm/attention/backends/blocksparse_attn.py
vllm/attention/backends/blocksparse_attn.py
+1
-1
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+172
-240
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+111
-169
vllm/attention/backends/hpu_attn.py
vllm/attention/backends/hpu_attn.py
+1
-1
vllm/attention/backends/ipex_attn.py
vllm/attention/backends/ipex_attn.py
+1
-1
vllm/attention/backends/pallas.py
vllm/attention/backends/pallas.py
+1
-1
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+1
-1
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+6
-6
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+2
-2
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+4
-4
vllm/attention/layer.py
vllm/attention/layer.py
+72
-9
vllm/config.py
vllm/config.py
+7
-2
vllm/forward_context.py
vllm/forward_context.py
+22
-5
vllm/model_executor/models/arctic.py
vllm/model_executor/models/arctic.py
+11
-4
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+13
-5
vllm/model_executor/models/bart.py
vllm/model_executor/models/bart.py
+35
-13
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+10
-4
vllm/model_executor/models/chameleon.py
vllm/model_executor/models/chameleon.py
+9
-2
No files found.
tests/kernels/test_encoder_decoder_attn.py
View file @
eebad39f
...
...
@@ -18,8 +18,10 @@ from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
from
vllm.attention.backends.utils
import
STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from
vllm.attention.selector
import
(
_Backend
,
_cached_get_attn_backend
,
global_force_attn_backend_context_manager
)
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
set_forward_context
from
vllm.platforms
import
current_platform
from
vllm.plugins
import
set_current_vllm_config
# List of support backends for encoder/decoder models
LIST_ENC_DEC_SUPPORTED_BACKENDS
=
[
_Backend
.
XFORMERS
,
_Backend
.
FLASH_ATTN
]
...
...
@@ -594,6 +596,7 @@ def _run_encoder_attention_test(
encoder_test_params
:
PhaseTestParameters
,
attn_metadata
:
AttentionMetadata
,
test_pt
:
TestPoint
,
vllm_config
:
VllmConfig
,
)
->
torch
.
Tensor
:
'''
Run encoder attention.
...
...
@@ -623,7 +626,7 @@ def _run_encoder_attention_test(
attn_type
=
AttentionType
.
ENCODER
packed_qkv
=
encoder_test_params
.
packed_qkvo
.
packed_qkv
assert
packed_qkv
is
not
None
with
set_forward_context
(
attn_metadata
):
with
set_forward_context
(
attn_metadata
,
vllm_config
):
# In the test setup the shape of the query is
# [batch_size, seq_len, num_heads, head_size]. However
# the attention backend expect the shape to be
...
...
@@ -648,6 +651,7 @@ def _run_decoder_self_attention_test(
decoder_test_params
:
PhaseTestParameters
,
attn_metadata
:
AttentionMetadata
,
test_pt
:
TestPoint
,
vllm_config
:
VllmConfig
,
)
->
torch
.
Tensor
:
'''
Run decoder self-attention test.
...
...
@@ -677,7 +681,7 @@ def _run_decoder_self_attention_test(
kv_cache
=
test_rsrcs
.
kv_cache
packed_qkv
=
decoder_test_params
.
packed_qkvo
.
packed_qkv
assert
packed_qkv
is
not
None
with
set_forward_context
(
attn_metadata
):
with
set_forward_context
(
attn_metadata
,
vllm_config
):
# In the test setup the shape of the query is
# [batch_size, seq_len, num_heads, head_size]. However
# the attention backend expect the shape to be
...
...
@@ -701,6 +705,7 @@ def _run_encoder_decoder_cross_attention_test(
cross_test_params
:
Optional
[
PhaseTestParameters
],
attn_metadata
:
AttentionMetadata
,
test_pt
:
TestPoint
,
vllm_config
:
VllmConfig
,
)
->
torch
.
Tensor
:
'''
Run encoder/decoder cross-attention test.
...
...
@@ -748,7 +753,7 @@ def _run_encoder_decoder_cross_attention_test(
cross_pckd_qkv
=
cross_test_params
.
packed_qkvo
.
packed_qkv
key
=
(
None
if
cross_pckd_qkv
is
None
else
cross_pckd_qkv
.
key
)
value
=
(
None
if
cross_pckd_qkv
is
None
else
cross_pckd_qkv
.
value
)
with
set_forward_context
(
attn_metadata
):
with
set_forward_context
(
attn_metadata
,
vllm_config
):
# In the test setup the shape of the query is
# [batch_size, seq_len, num_heads, head_size]. However
# the attention backend expect the shape to be
...
...
@@ -839,7 +844,9 @@ def test_encoder_only(
# Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init
test_rsrcs
=
_make_test_resources
(
test_pt
)
vllm_config
=
VllmConfig
()
with
set_current_vllm_config
(
vllm_config
):
test_rsrcs
=
_make_test_resources
(
test_pt
)
# Construct encoder attention test params (only used
# during prefill)
...
...
@@ -863,7 +870,8 @@ def test_encoder_only(
test_rsrcs
.
attn
,
enc_test_params
,
prephase_attn_metadata
,
test_pt
=
test_pt
))
test_pt
=
test_pt
,
vllm_config
=
vllm_config
))
# - Is encoder attention result correct?
assert_actual_matches_ideal
(
enc_test_params
,
enc_pckd_act_out
,
...
...
@@ -960,7 +968,9 @@ def test_e2e_enc_dec_attn(
# Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init
test_rsrcs
=
_make_test_resources
(
test_pt
)
vllm_config
=
VllmConfig
()
with
set_current_vllm_config
(
vllm_config
):
test_rsrcs
=
_make_test_resources
(
test_pt
)
# Construct encoder attention test params (only used
# during prefill)
...
...
@@ -1011,7 +1021,8 @@ def test_e2e_enc_dec_attn(
enc_pckd_act_out
=
_run_encoder_attention_test
(
test_rsrcs
.
attn
,
enc_test_params
,
prephase_attn_metadata
,
test_pt
=
test_pt
)
test_pt
=
test_pt
,
vllm_config
=
vllm_config
)
# - Is encoder attention result correct?
assert_actual_matches_ideal
(
enc_test_params
,
enc_pckd_act_out
,
...
...
@@ -1023,7 +1034,8 @@ def test_e2e_enc_dec_attn(
test_rsrcs
,
prephase_dec_test_params
,
prephase_attn_metadata
,
test_pt
=
test_pt
)
test_pt
=
test_pt
,
vllm_config
=
vllm_config
)
# - Is prefill decoder self-attention correct?
assert_actual_matches_ideal
(
prephase_dec_test_params
,
...
...
@@ -1037,7 +1049,8 @@ def test_e2e_enc_dec_attn(
prephase_dec_test_params
,
prephase_cross_test_params
,
prephase_attn_metadata
,
test_pt
=
test_pt
)
test_pt
=
test_pt
,
vllm_config
=
vllm_config
)
# - Is prefill encoder/decoder cross-attention correct?
assert_actual_matches_ideal
(
prephase_cross_test_params
,
...
...
@@ -1061,7 +1074,8 @@ def test_e2e_enc_dec_attn(
test_rsrcs
,
decphase_dec_test_params
,
decphase_attn_metadata
,
test_pt
=
test_pt
)
test_pt
=
test_pt
,
vllm_config
=
vllm_config
)
# - Is decode-phase decoder self-attention correct?
assert_actual_matches_ideal
(
decphase_dec_test_params
,
...
...
@@ -1075,7 +1089,8 @@ def test_e2e_enc_dec_attn(
decphase_dec_test_params
,
None
,
decphase_attn_metadata
,
test_pt
=
test_pt
)
test_pt
=
test_pt
,
vllm_config
=
vllm_config
)
# - Is decode-phase encoder/decoder cross-attention correct?
assert_actual_matches_ideal
(
decphase_cross_test_params
,
...
...
vllm/attention/backends/abstract.py
View file @
eebad39f
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
,
fields
from
enum
import
Enum
,
auto
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Generic
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
TypeVar
)
...
...
@@ -15,13 +14,19 @@ if TYPE_CHECKING:
ModelRunnerInputBuilderBase
)
class
AttentionType
(
Enum
):
DECODER
=
auto
()
# Decoder attention between previous layer Q/K/V
ENCODER
=
auto
(
)
# Encoder attention between previous layer Q/K/V for encoder-decoder
ENCODER_ONLY
=
auto
()
# Encoder attention between previous layer Q/K/V
ENCODER_DECODER
=
auto
(
)
# Attention between dec. Q and enc. K/V for encoder-decoder
class
AttentionType
:
"""
Attention type.
Use string to be compatible with `torch.compile`.
"""
# Decoder attention between previous layer Q/K/V
DECODER
=
"decoder"
# Encoder attention between previous layer Q/K/V for encoder-decoder
ENCODER
=
"encoder"
# Encoder attention between previous layer Q/K/V
ENCODER_ONLY
=
"encoder_only"
# Attention between dec. Q and enc. K/V for encoder-decoder
ENCODER_DECODER
=
"encoder_decoder"
class
AttentionBackend
(
ABC
):
...
...
@@ -241,6 +246,6 @@ class AttentionImpl(ABC, Generic[T]):
attn_metadata
:
T
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
vllm/attention/backends/blocksparse_attn.py
View file @
eebad39f
...
...
@@ -354,7 +354,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
attn_metadata
:
BlocksparseFlashAttentionMetadata
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention and PagedAttention.
...
...
vllm/attention/backends/flash_attn.py
View file @
eebad39f
...
...
@@ -16,10 +16,8 @@ from vllm.attention.backends.utils import (
compute_slot_mapping_start_idx
,
get_num_prefill_decode_query_kv_tokens
,
get_seq_len_block_table_args
,
is_all_cross_attn_metadata_set
,
is_all_encoder_attn_metadata_set
,
is_block_tables_empty
)
from
vllm.forward_context
import
get_forward_context
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.utils
import
(
async_tensor_h2d
,
direct_register_custom_op
,
make_tensor_with_pad
)
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
...
...
@@ -639,7 +637,7 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata
:
FlashAttentionMetadata
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention.
...
...
@@ -668,23 +666,174 @@ class FlashAttentionImpl(AttentionImpl):
"requires setting cross-attention "
"metadata attributes."
)
output
=
torch
.
ops
.
vllm
.
unified_flash_attention
(
query
,
key
,
value
,
self
.
num_heads
,
self
.
head_size
,
self
.
num_kv_heads
,
kv_cache
,
self
.
kv_cache_dtype
,
k_scale
,
v_scale
,
self
.
scale
,
attn_type
.
value
,
self
.
sliding_window
,
self
.
alibi_slopes
,
self
.
logits_soft_cap
,
)
num_heads
:
int
=
self
.
num_heads
head_size
:
int
=
self
.
head_size
num_kv_heads
:
int
=
self
.
num_kv_heads
kv_cache_dtype
:
str
=
self
.
kv_cache_dtype
softmax_scale
:
float
=
self
.
scale
window_size
=
self
.
sliding_window
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
self
.
alibi_slopes
logits_soft_cap
:
Optional
[
float
]
=
self
.
logits_soft_cap
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
num_heads
,
head_size
)
if
(
key
is
not
None
)
and
(
value
is
not
None
):
key
=
key
.
view
(
-
1
,
num_kv_heads
,
head_size
)
value
=
value
.
view
(
-
1
,
num_kv_heads
,
head_size
)
if
kv_cache
.
numel
()
>
0
:
key_cache
=
kv_cache
[
0
]
value_cache
=
kv_cache
[
1
]
# We skip updating the KV cache under two conditions:
# a. When the Attention Type is ENCODER. In this phase, we compute
# only the encoder attention without updating the cache.
# b. When both Key and Value are None. This occurs during
# cross-attention computation in the decoding phase, where the
# KV cache is already populated with the cross-attention
# tensor. Thus, we skip cache updates during this time.
if
(
attn_type
!=
AttentionType
.
ENCODER
)
and
(
key
is
not
None
)
and
(
value
is
not
None
):
if
attn_type
==
AttentionType
.
ENCODER_DECODER
:
# Update cross-attention KV cache (prefill-only)
updated_slot_mapping
=
attn_metadata
.
cross_slot_mapping
else
:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping
=
attn_metadata
.
slot_mapping
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory
# profiling run.
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
key
,
value
,
kv_cache
[
0
],
kv_cache
[
1
],
updated_slot_mapping
.
flatten
(),
# type: ignore[union-attr]
kv_cache_dtype
,
k_scale
,
v_scale
,
)
(
num_prefill_query_tokens
,
num_prefill_kv_tokens
,
num_decode_query_tokens
)
=
\
get_num_prefill_decode_query_kv_tokens
(
attn_metadata
,
attn_type
)
decode_query
=
query
[
num_prefill_query_tokens
:]
# QKV for prefill.
query
=
query
[:
num_prefill_query_tokens
]
assert
query
.
shape
[
0
]
==
num_prefill_query_tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_query_tokens
prefill_output
:
Optional
[
torch
.
Tensor
]
=
None
decode_output
:
Optional
[
torch
.
Tensor
]
=
None
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
if
(
kv_cache
.
numel
()
==
0
or
prefill_meta
.
block_tables
is
None
or
prefill_meta
.
block_tables
.
numel
()
==
0
):
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
q_seq_start_loc
,
q_seq_len
,
k_seq_start_loc
,
k_seq_len
=
\
_get_query_key_seq_metadata
(
prefill_meta
,
True
,
attn_type
)
key
=
key
[:
num_prefill_kv_tokens
]
value
=
value
[:
num_prefill_kv_tokens
]
prefill_output
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
q_seq_start_loc
,
cu_seqlens_k
=
k_seq_start_loc
,
max_seqlen_q
=
q_seq_len
,
max_seqlen_k
=
k_seq_len
,
softmax_scale
=
softmax_scale
,
causal
=
_get_causal_option
(
attn_type
),
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
)
else
:
# prefix-enabled attention
assert
attn_type
==
AttentionType
.
DECODER
,
(
"Only decoder-only models support prefix caching"
)
assert
prefill_meta
.
seq_lens
is
not
None
max_seq_len
=
max
(
prefill_meta
.
seq_lens
)
prefill_output
=
flash_attn_varlen_func
(
# noqa
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
prefill_meta
.
query_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_query_len
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_k
=
max_seq_len
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
block_table
=
prefill_meta
.
block_tables
,
softcap
=
logits_soft_cap
,
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
# Use flash_attn_varlen_func kernel for speculative decoding
# because different queries might have different lengths.
assert
decode_meta
.
max_decode_query_len
is
not
None
# use only for actual varlen decoding
if
decode_meta
.
max_decode_query_len
>
1
:
assert
attn_type
==
AttentionType
.
DECODER
,
(
"Only decoder-only models support max_decode_query_len > 1"
)
decode_output
=
flash_attn_varlen_func
(
q
=
decode_query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
decode_meta
.
query_start_loc
,
max_seqlen_q
=
decode_meta
.
max_decode_query_len
,
cu_seqlens_k
=
decode_meta
.
seq_start_loc
,
max_seqlen_k
=
decode_meta
.
max_decode_seq_len
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
block_table
=
decode_meta
.
block_tables
,
)
else
:
# Use flash_attn_with_kvcache for normal decoding.
(
seq_lens_arg
,
_
,
block_tables_arg
,
)
=
get_seq_len_block_table_args
(
decode_meta
,
False
,
attn_type
)
decode_output
=
flash_attn_with_kvcache
(
q
=
decode_query
.
unsqueeze
(
1
),
k_cache
=
key_cache
,
v_cache
=
value_cache
,
block_table
=
block_tables_arg
,
cache_seqlens
=
seq_lens_arg
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
).
squeeze
(
1
)
if
prefill_output
is
None
:
assert
decode_output
is
not
None
return
decode_output
.
view
(
num_decode_query_tokens
,
hidden_size
)
if
decode_output
is
None
:
assert
prefill_output
is
not
None
return
prefill_output
.
view
(
num_prefill_query_tokens
,
hidden_size
)
assert
decode_meta
is
not
None
decode_output
=
decode_output
.
squeeze
(
1
)
output
=
torch
.
cat
([
prefill_output
,
decode_output
],
dim
=
0
)
return
output
.
view
(
num_tokens
,
hidden_size
)
return
output
...
...
@@ -692,7 +841,7 @@ class FlashAttentionImpl(AttentionImpl):
def
_get_query_key_seq_metadata
(
attn_metadata
,
is_prompt
:
bool
,
attn_type
:
AttentionType
,
attn_type
:
str
,
)
->
tuple
:
"""
Returns sequence metadata for key and query based on the specified
...
...
@@ -754,7 +903,7 @@ def _get_query_key_seq_metadata(
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
def
_get_causal_option
(
attn_type
:
AttentionType
)
->
bool
:
def
_get_causal_option
(
attn_type
:
str
)
->
bool
:
"""
Determine whether the given attention type is suitable for causal
attention mechanisms.
...
...
@@ -770,220 +919,3 @@ def _get_causal_option(attn_type: AttentionType) -> bool:
return
not
(
attn_type
==
AttentionType
.
ENCODER
or
attn_type
==
AttentionType
.
ENCODER_ONLY
or
attn_type
==
AttentionType
.
ENCODER_DECODER
)
def
unified_flash_attention
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
num_heads
:
int
,
head_size
:
int
,
num_kv_heads
:
int
,
kv_cache
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
k_scale
:
float
,
v_scale
:
float
,
softmax_scale
:
float
,
attn_type_int_val
:
int
,
window_size
:
Optional
[
List
[
int
]]
=
None
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
# Convert integer attn_type to enum
try
:
attn_type
=
AttentionType
(
attn_type_int_val
)
except
ValueError
as
err
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type_int_val
)
}
"
)
from
err
current_metadata
=
get_forward_context
()
assert
current_metadata
is
not
None
assert
isinstance
(
current_metadata
,
FlashAttentionMetadata
)
attn_metadata
:
FlashAttentionMetadata
=
current_metadata
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
num_heads
,
head_size
)
if
(
key
is
not
None
)
and
(
value
is
not
None
):
key
=
key
.
view
(
-
1
,
num_kv_heads
,
head_size
)
value
=
value
.
view
(
-
1
,
num_kv_heads
,
head_size
)
if
kv_cache
.
numel
()
>
0
:
key_cache
=
kv_cache
[
0
]
value_cache
=
kv_cache
[
1
]
# We skip updating the KV cache under two conditions:
# a. When the Attention Type is ENCODER. In this phase, we compute
# only the encoder attention without updating the cache.
# b. When both Key and Value are None. This occurs during
# cross-attention computation in the decoding phase, where the KV
# cache is already populated with the cross-attention tensor.
# Thus, we skip cache updates during this time.
if
(
attn_type
!=
AttentionType
.
ENCODER
)
and
(
key
is
not
None
)
and
(
value
is
not
None
):
if
attn_type
==
AttentionType
.
ENCODER_DECODER
:
# Update cross-attention KV cache (prefill-only)
updated_slot_mapping
=
attn_metadata
.
cross_slot_mapping
else
:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping
=
attn_metadata
.
slot_mapping
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
key
,
value
,
kv_cache
[
0
],
kv_cache
[
1
],
updated_slot_mapping
.
flatten
(),
# type: ignore[union-attr]
kv_cache_dtype
,
k_scale
,
v_scale
,
)
(
num_prefill_query_tokens
,
num_prefill_kv_tokens
,
num_decode_query_tokens
)
=
\
get_num_prefill_decode_query_kv_tokens
(
attn_metadata
,
attn_type
)
decode_query
=
query
[
num_prefill_query_tokens
:]
# QKV for prefill.
query
=
query
[:
num_prefill_query_tokens
]
assert
query
.
shape
[
0
]
==
num_prefill_query_tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_query_tokens
prefill_output
:
Optional
[
torch
.
Tensor
]
=
None
decode_output
:
Optional
[
torch
.
Tensor
]
=
None
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
if
(
kv_cache
.
numel
()
==
0
or
prefill_meta
.
block_tables
is
None
or
prefill_meta
.
block_tables
.
numel
()
==
0
):
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
q_seq_start_loc
,
q_seq_len
,
k_seq_start_loc
,
k_seq_len
=
\
_get_query_key_seq_metadata
(
prefill_meta
,
True
,
attn_type
)
key
=
key
[:
num_prefill_kv_tokens
]
value
=
value
[:
num_prefill_kv_tokens
]
prefill_output
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
q_seq_start_loc
,
cu_seqlens_k
=
k_seq_start_loc
,
max_seqlen_q
=
q_seq_len
,
max_seqlen_k
=
k_seq_len
,
softmax_scale
=
softmax_scale
,
causal
=
_get_causal_option
(
attn_type
),
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
)
else
:
# prefix-enabled attention
assert
attn_type
==
AttentionType
.
DECODER
,
(
"Only decoder-only models support prefix caching"
)
assert
prefill_meta
.
seq_lens
is
not
None
max_seq_len
=
max
(
prefill_meta
.
seq_lens
)
prefill_output
=
flash_attn_varlen_func
(
# noqa
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
prefill_meta
.
query_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_query_len
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_k
=
max_seq_len
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
block_table
=
prefill_meta
.
block_tables
,
softcap
=
logits_soft_cap
,
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
# Use flash_attn_varlen_func kernel for speculative decoding
# because different queries might have different lengths.
assert
decode_meta
.
max_decode_query_len
is
not
None
# use only for actual varlen decoding
if
decode_meta
.
max_decode_query_len
>
1
:
assert
attn_type
==
AttentionType
.
DECODER
,
(
"Only decoder-only models support max_decode_query_len > 1"
)
decode_output
=
flash_attn_varlen_func
(
q
=
decode_query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
decode_meta
.
query_start_loc
,
max_seqlen_q
=
decode_meta
.
max_decode_query_len
,
cu_seqlens_k
=
decode_meta
.
seq_start_loc
,
max_seqlen_k
=
decode_meta
.
max_decode_seq_len
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
block_table
=
decode_meta
.
block_tables
,
)
else
:
# Use flash_attn_with_kvcache for normal decoding.
(
seq_lens_arg
,
_
,
block_tables_arg
,
)
=
get_seq_len_block_table_args
(
decode_meta
,
False
,
attn_type
)
decode_output
=
flash_attn_with_kvcache
(
q
=
decode_query
.
unsqueeze
(
1
),
k_cache
=
key_cache
,
v_cache
=
value_cache
,
block_table
=
block_tables_arg
,
cache_seqlens
=
seq_lens_arg
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
).
squeeze
(
1
)
if
prefill_output
is
None
:
assert
decode_output
is
not
None
return
decode_output
.
view
(
num_decode_query_tokens
,
hidden_size
)
if
decode_output
is
None
:
assert
prefill_output
is
not
None
return
prefill_output
.
view
(
num_prefill_query_tokens
,
hidden_size
)
assert
decode_meta
is
not
None
decode_output
=
decode_output
.
squeeze
(
1
)
output
=
torch
.
cat
([
prefill_output
,
decode_output
],
dim
=
0
)
return
output
.
view
(
num_tokens
,
hidden_size
)
def
unified_flash_attention_fake
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
num_heads
:
int
,
head_size
:
int
,
num_kv_heads
:
int
,
kv_cache
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
k_scale
:
float
,
v_scale
:
float
,
softmax_scale
:
float
,
attn_type_int_val
:
int
,
window_size
:
Optional
[
List
[
int
]]
=
None
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
query
)
direct_register_custom_op
(
op_name
=
"unified_flash_attention"
,
op_func
=
unified_flash_attention
,
mutates_args
=
[
"kv_cache"
],
fake_impl
=
unified_flash_attention_fake
,
)
vllm/attention/backends/flashinfer.py
View file @
eebad39f
...
...
@@ -30,9 +30,8 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.forward_context
import
get_forward_context
from
vllm.utils
import
(
async_tensor_h2d
,
direct_register_custom_op
,
get_kv_cache_torch_dtype
,
make_tensor_with_pad
)
from
vllm.utils
import
(
async_tensor_h2d
,
get_kv_cache_torch_dtype
,
make_tensor_with_pad
)
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
...
...
@@ -774,7 +773,7 @@ class FlashInferImpl(AttentionImpl):
attn_metadata
:
FlashInferMetadata
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
...
...
@@ -782,174 +781,117 @@ class FlashInferImpl(AttentionImpl):
"are not implemented for "
"FlashInferImpl"
)
return
torch
.
ops
.
vllm
.
unified_flash_infer
(
query
,
key
,
value
,
self
.
num_heads
,
self
.
head_size
,
self
.
num_kv_heads
,
kv_cache
,
self
.
kv_cache_dtype
,
k_scale
,
v_scale
,
self
.
scale
,
self
.
sliding_window
,
self
.
alibi_slopes
,
self
.
logits_soft_cap
,
)
def
unified_flash_infer
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
num_heads
:
int
,
head_size
:
int
,
num_kv_heads
:
int
,
kv_cache
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
k_scale
:
float
,
v_scale
:
float
,
softmax_scale
:
float
,
window_size
:
Optional
[
List
[
int
]]
=
None
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
current_metadata
=
get_forward_context
()
assert
current_metadata
is
not
None
assert
isinstance
(
current_metadata
,
FlashInferMetadata
)
attn_metadata
:
FlashInferMetadata
=
current_metadata
num_tokens
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
-
1
,
num_heads
,
head_size
)
key
=
key
.
view
(
-
1
,
num_kv_heads
,
head_size
)
value
=
value
.
view
(
-
1
,
num_kv_heads
,
head_size
)
if
kv_cache
.
numel
()
>
0
:
# Use the same reshape and cache kernel as flash attention.
ops
.
reshape_and_cache_flash
(
key
,
value
,
kv_cache
[:,
0
],
kv_cache
[:,
1
],
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache_dtype
,
k_scale
,
v_scale
,
)
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
if
kv_cache_dtype
.
startswith
(
"fp8"
):
torch_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
kv_cache_dtype
)
kv_cache
=
kv_cache
.
view
(
torch_dtype
)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
assert
key
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
,
\
f
"key :
{
key
.
shape
}
: #prefill tokens
{
num_prefill_tokens
}
: #decode tokens
{
num_decode_tokens
}
"
# noqa
assert
value
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
,
\
f
"value :
{
value
.
shape
}
: #prefill toks
{
num_prefill_tokens
}
: #decode toks
{
num_decode_tokens
}
"
# noqa
query
=
query
.
contiguous
()
# Flashinfer requires query to be contiguous
# Query for decode. KV is not needed because it is already cached.
# QKV for prefill.
decode_query
=
query
[
num_prefill_tokens
:]
query
=
query
[:
num_prefill_tokens
]
key
=
key
[:
num_prefill_tokens
]
value
=
value
[:
num_prefill_tokens
]
assert
query
.
shape
[
0
]
==
num_prefill_tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_tokens
window_left
=
window_size
[
0
]
if
window_size
is
not
None
else
-
1
prefill_output
:
Optional
[
torch
.
Tensor
]
=
None
decode_output
:
Optional
[
torch
.
Tensor
]
=
None
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# We will use flash attention for prefill
# when kv_cache is not provided.
# This happens when vllm runs the profiling to
# determine the number of blocks.
if
kv_cache
.
numel
()
==
0
:
prefill_output
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlen_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
num_heads
:
int
=
self
.
num_heads
head_size
:
int
=
self
.
head_size
num_kv_heads
:
int
=
self
.
num_kv_heads
kv_cache_dtype
:
str
=
self
.
kv_cache_dtype
softmax_scale
:
float
=
self
.
scale
window_size
=
self
.
sliding_window
alibi_slopes
=
self
.
alibi_slopes
logits_soft_cap
=
self
.
logits_soft_cap
num_tokens
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
-
1
,
num_heads
,
head_size
)
key
=
key
.
view
(
-
1
,
num_kv_heads
,
head_size
)
value
=
value
.
view
(
-
1
,
num_kv_heads
,
head_size
)
if
kv_cache
.
numel
()
>
0
:
# Use the same reshape and cache kernel as flash attention.
ops
.
reshape_and_cache_flash
(
key
,
value
,
kv_cache
[:,
0
],
kv_cache
[:,
1
],
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache_dtype
,
k_scale
,
v_scale
,
)
else
:
assert
prefill_meta
is
not
None
assert
prefill_meta
.
prefill_wrapper
is
not
None
prefill_output
=
prefill_meta
.
prefill_wrapper
.
forward
(
query
,
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
if
kv_cache_dtype
.
startswith
(
"fp8"
):
torch_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
kv_cache_dtype
)
kv_cache
=
kv_cache
.
view
(
torch_dtype
)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
assert
key
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
,
\
f
"key :
{
key
.
shape
}
: #prefill tokens
{
num_prefill_tokens
}
: #decode tokens
{
num_decode_tokens
}
"
# noqa
assert
value
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
,
\
f
"value :
{
value
.
shape
}
: #prefill toks
{
num_prefill_tokens
}
: #decode toks
{
num_decode_tokens
}
"
# noqa
query
=
query
.
contiguous
(
)
# Flashinfer requires query to be contiguous
# Query for decode. KV is not needed because it is already cached.
# QKV for prefill.
decode_query
=
query
[
num_prefill_tokens
:]
query
=
query
[:
num_prefill_tokens
]
key
=
key
[:
num_prefill_tokens
]
value
=
value
[:
num_prefill_tokens
]
assert
query
.
shape
[
0
]
==
num_prefill_tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_tokens
window_left
=
window_size
[
0
]
if
window_size
is
not
None
else
-
1
prefill_output
:
Optional
[
torch
.
Tensor
]
=
None
decode_output
:
Optional
[
torch
.
Tensor
]
=
None
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# We will use flash attention for prefill
# when kv_cache is not provided.
# This happens when vllm runs the profiling to
# determine the number of blocks.
if
kv_cache
.
numel
()
==
0
:
prefill_output
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlen_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
)
else
:
assert
prefill_meta
is
not
None
assert
prefill_meta
.
prefill_wrapper
is
not
None
prefill_output
=
prefill_meta
.
prefill_wrapper
.
forward
(
query
,
kv_cache
,
logits_soft_cap
=
logits_soft_cap
,
causal
=
True
,
k_scale
=
k_scale
,
v_scale
=
v_scale
,
window_left
=
window_left
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
assert
decode_meta
is
not
None
assert
decode_meta
.
decode_wrapper
is
not
None
decode_output
=
decode_meta
.
decode_wrapper
.
forward
(
decode_query
,
kv_cache
,
sm_scale
=
softmax_scale
,
logits_soft_cap
=
logits_soft_cap
,
causal
=
True
,
k_scale
=
k_scale
,
v_scale
=
v_scale
,
window_left
=
window_left
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
assert
attn_metadata
.
decode_metadata
is
not
None
assert
attn_metadata
.
decode_metadata
.
decode_wrapper
is
not
None
decode_output
=
attn_metadata
.
decode_metadata
.
decode_wrapper
.
forward
(
decode_query
,
kv_cache
,
sm_scale
=
softmax_scale
,
logits_soft_cap
=
logits_soft_cap
,
k_scale
=
k_scale
,
v_scale
=
v_scale
,
window_left
=
window_left
)
if
prefill_output
is
None
and
decode_output
is
not
None
:
# Decode only batch.
output
,
num_tokens
=
decode_output
,
num_decode_tokens
elif
decode_output
is
None
and
prefill_output
is
not
None
:
# Prefill only batch.
output
,
num_tokens
=
prefill_output
,
num_prefill_tokens
else
:
# Chunked prefill batch does not work with speculative decoding in
# FlashInfer backend, so the query length for decode should be 1.
assert
prefill_output
is
not
None
assert
decode_output
is
not
None
assert
decode_meta
is
not
None
assert
decode_meta
.
decode_query_len
==
1
decode_output
=
decode_output
.
squeeze
(
1
)
output
=
torch
.
cat
([
prefill_output
,
decode_output
],
dim
=
0
)
return
output
.
view
(
num_tokens
,
hidden_size
)
def
unified_flash_infer_fake
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
num_heads
:
int
,
head_size
:
int
,
num_kv_heads
:
int
,
kv_cache
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
k_scale
:
float
,
v_scale
:
float
,
softmax_scale
:
float
,
window_size
:
Optional
[
List
[
int
]]
=
None
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
query
).
contiguous
()
direct_register_custom_op
(
op_name
=
"unified_flash_infer"
,
op_func
=
unified_flash_infer
,
mutates_args
=
[
"kv_cache"
],
fake_impl
=
unified_flash_infer_fake
,
)
if
prefill_output
is
None
and
decode_output
is
not
None
:
# Decode only batch.
output
,
num_tokens
=
decode_output
,
num_decode_tokens
elif
decode_output
is
None
and
prefill_output
is
not
None
:
# Prefill only batch.
output
,
num_tokens
=
prefill_output
,
num_prefill_tokens
else
:
# Chunked prefill batch does not work with speculative decoding in
# FlashInfer backend, so the query length for decode should be 1.
assert
prefill_output
is
not
None
assert
decode_output
is
not
None
assert
decode_meta
is
not
None
assert
decode_meta
.
decode_query_len
==
1
decode_output
=
decode_output
.
squeeze
(
1
)
output
=
torch
.
cat
([
prefill_output
,
decode_output
],
dim
=
0
)
return
output
.
view
(
num_tokens
,
hidden_size
)
vllm/attention/backends/hpu_attn.py
View file @
eebad39f
...
...
@@ -140,7 +140,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
attn_metadata
:
HPUAttentionMetadata
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
"""Forward pass with xFormers and PagedAttention.
...
...
vllm/attention/backends/ipex_attn.py
View file @
eebad39f
...
...
@@ -172,7 +172,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
attn_metadata
:
IpexAttnMetadata
,
# type: ignore
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
"""Forward pass with IPEX varlen_attention and PagedAttention.
...
...
vllm/attention/backends/pallas.py
View file @
eebad39f
...
...
@@ -150,7 +150,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
attn_metadata
:
PallasMetadata
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
"""Forward pass with Pallas attention.
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
eebad39f
...
...
@@ -414,7 +414,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_metadata
:
ROCmFlashAttentionMetadata
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention and PagedAttention.
...
...
vllm/attention/backends/torch_sdpa.py
View file @
eebad39f
...
...
@@ -141,7 +141,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
def
get_seq_lens
(
self
,
attn_type
:
AttentionType
,
attn_type
:
str
,
):
'''
Extract appropriate sequence lengths from attention metadata
...
...
@@ -174,7 +174,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
def
get_attn_bias
(
self
,
attn_type
:
AttentionType
,
attn_type
:
str
,
)
->
Optional
[
List
[
torch
.
Tensor
]]:
'''
Extract appropriate attention bias from attention metadata
...
...
@@ -203,7 +203,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
def
set_attn_bias
(
self
,
attn_bias
:
List
[
torch
.
Tensor
],
attn_type
:
AttentionType
,
attn_type
:
str
,
)
->
None
:
'''
Update appropriate attention bias field of attention metadata,
...
...
@@ -229,7 +229,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
def
get_seq_len_block_table_args
(
self
,
attn_type
:
AttentionType
,
attn_type
:
str
,
)
->
tuple
:
'''
The particular choice of sequence-length- and block-table-related
...
...
@@ -426,7 +426,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
attn_metadata
:
TorchSDPAMetadata
,
# type: ignore
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
"""Forward pass with torch SDPA and PagedAttention.
...
...
@@ -574,7 +574,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
attn_metadata
:
TorchSDPAMetadata
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
None
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
key
=
key
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=
1
)
...
...
vllm/attention/backends/utils.py
View file @
eebad39f
...
...
@@ -478,7 +478,7 @@ def is_all_cross_attn_metadata_set(attn_metadata):
def
get_seq_len_block_table_args
(
attn_metadata
,
is_prompt
:
bool
,
attn_type
:
AttentionType
,
attn_type
:
str
,
)
->
tuple
:
'''
The particular choice of sequence-length- and block-table-related
...
...
@@ -529,7 +529,7 @@ def get_seq_len_block_table_args(
def
get_num_prefill_decode_query_kv_tokens
(
attn_metadata
,
attn_type
:
AttentionType
,
attn_type
:
str
,
)
->
Tuple
[
int
,
int
,
int
]:
"""
Calculate the number of prefill and decode tokens for query, key/value
...
...
vllm/attention/backends/xformers.py
View file @
eebad39f
...
...
@@ -284,7 +284,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
def
_get_attn_bias
(
attn_metadata
:
XFormersMetadata
,
attn_type
:
AttentionType
,
attn_type
:
str
,
)
->
Optional
[
AttentionBias
]:
'''
Extract appropriate attention bias from attention metadata
...
...
@@ -314,7 +314,7 @@ def _get_attn_bias(
def
_set_attn_bias
(
attn_metadata
:
XFormersMetadata
,
attn_bias
:
List
[
Optional
[
AttentionBias
]],
attn_type
:
AttentionType
,
attn_type
:
str
,
)
->
None
:
'''
Update appropriate attention bias field of attention metadata,
...
...
@@ -416,7 +416,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
attn_metadata
:
"XFormersMetadata"
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
"""Forward pass with xFormers and PagedAttention.
...
...
@@ -617,7 +617,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
attn_metadata
:
XFormersMetadata
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
"""Attention for 1D query of multiple prompts. Multiple prompt
tokens are flattened in to `query` input.
...
...
vllm/attention/layer.py
View file @
eebad39f
...
...
@@ -4,12 +4,17 @@ from typing import Any, Dict, List, Optional
import
torch
import
torch.nn
as
nn
import
vllm.envs
as
envs
from
vllm.attention
import
AttentionMetadata
,
AttentionType
from
vllm.attention.selector
import
get_attn_backend
from
vllm.config
import
CacheConfig
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.platforms
import
current_platform
from
vllm.plugins
import
get_current_vllm_config
from
vllm.utils
import
direct_register_custom_op
class
Attention
(
nn
.
Module
):
...
...
@@ -86,6 +91,18 @@ class Attention(nn.Module):
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
,
logits_soft_cap
)
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
# torch.compile works by registering the attention as one giant
# opaque custom op. For other platforms, we directly call them
# and let torch.compile handle them.
self
.
use_direct_call
=
envs
.
VLLM_USE_V1
or
not
(
current_platform
.
is_cuda_alike
()
or
current_platform
.
is_cpu
())
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
self
.
layer_name
=
prefix
def
forward
(
self
,
query
:
torch
.
Tensor
,
...
...
@@ -93,17 +110,22 @@ class Attention(nn.Module):
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
return
self
.
impl
.
forward
(
query
,
key
,
value
,
kv_cache
,
attn_metadata
,
self
.
_k_scale
,
self
.
_v_scale
,
attn_type
=
attn_type
)
if
self
.
use_direct_call
:
return
self
.
impl
.
forward
(
query
,
key
,
value
,
kv_cache
,
attn_metadata
,
self
.
_k_scale
,
self
.
_v_scale
,
attn_type
=
attn_type
)
else
:
return
torch
.
ops
.
vllm
.
unified_attention
(
query
,
key
,
value
,
kv_cache
,
attn_type
,
self
.
layer_name
)
def
extra_repr
(
self
)
->
str
:
s
=
f
"head_size=
{
self
.
impl
.
head_size
}
"
# type: ignore
...
...
@@ -112,3 +134,44 @@ class Attention(nn.Module):
s
+=
f
", scale=
{
self
.
impl
.
scale
}
"
# type: ignore
s
+=
f
", backend=
{
self
.
impl
.
__class__
.
__name__
}
"
return
s
def
unified_attention
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_type
:
str
,
layer_name
:
str
,
)
->
torch
.
Tensor
:
forward_context
:
ForwardContext
=
get_forward_context
()
attn_metadata
=
forward_context
.
dynamic_forward_context
self
=
forward_context
.
static_forward_context
[
layer_name
]
return
self
.
impl
.
forward
(
query
,
key
,
value
,
kv_cache
,
attn_metadata
,
self
.
_k_scale
,
self
.
_v_scale
,
attn_type
=
attn_type
)
def
unified_attention_fake
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_type
:
str
,
layer_name
:
str
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
query
).
contiguous
()
direct_register_custom_op
(
op_name
=
"unified_attention"
,
op_func
=
unified_attention
,
mutates_args
=
[
"kv_cache"
],
fake_impl
=
unified_attention_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
vllm/config.py
View file @
eebad39f
...
...
@@ -2135,8 +2135,7 @@ class CompilationConfig(BaseModel):
backend
:
str
=
""
custom_ops
:
List
[
str
]
=
Field
(
default_factory
=
list
)
splitting_ops
:
List
[
str
]
=
Field
(
default_factory
=
lambda
:
[
"vllm.unified_flash_attention"
,
"vllm.unified_flash_infer"
,
"vllm.unified_attention"
,
"vllm.unified_v1_flash_attention"
,
])
...
...
@@ -2197,6 +2196,11 @@ class CompilationConfig(BaseModel):
enabled_custom_ops
:
Counter
[
str
]
=
PrivateAttr
disabled_custom_ops
:
Counter
[
str
]
=
PrivateAttr
# Per-model forward context
# Mainly used to store attention cls
# Map from layer name to the attention cls
static_forward_context
:
Dict
[
str
,
Any
]
=
PrivateAttr
@
classmethod
def
from_cli
(
cls
,
cli_value
:
str
)
->
"CompilationConfig"
:
"""Parse the CLI value for the compilation config."""
...
...
@@ -2228,6 +2232,7 @@ class CompilationConfig(BaseModel):
self
.
enabled_custom_ops
=
Counter
()
self
.
disabled_custom_ops
=
Counter
()
self
.
static_forward_context
=
{}
def
init_backend
(
self
)
->
Union
[
str
,
Callable
]:
if
self
.
level
==
CompilationLevel
.
NO_COMPILATION
:
...
...
vllm/forward_context.py
View file @
eebad39f
from
contextlib
import
contextmanager
from
typing
import
Any
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
Optional
_forward_context
:
Any
=
None
from
vllm.config
import
VllmConfig
def
get_forward_context
()
->
Any
:
@
dataclass
class
ForwardContext
:
static_forward_context
:
Dict
[
str
,
Any
]
# TODO: extend to support per-layer dynamic forward context
dynamic_forward_context
:
Any
_forward_context
:
Optional
[
ForwardContext
]
=
None
def
get_forward_context
()
->
ForwardContext
:
"""Get the current forward context."""
assert
_forward_context
is
not
None
,
(
"Forward context is not set. "
"Please use `set_forward_context` to set the forward context."
)
return
_forward_context
@
contextmanager
def
set_forward_context
(
context
:
Any
):
def
set_forward_context
(
context
:
Any
,
vllm_config
:
VllmConfig
):
"""A context manager that stores the current forward context,
can be attention metadata, etc."""
global
_forward_context
prev_context
=
_forward_context
_forward_context
=
context
_forward_context
=
ForwardContext
(
static_forward_context
=
vllm_config
.
compilation_config
.
static_forward_context
,
dynamic_forward_context
=
context
)
try
:
yield
finally
:
...
...
vllm/model_executor/models/arctic.py
View file @
eebad39f
...
...
@@ -223,6 +223,7 @@ class ArcticAttention(nn.Module):
layer_idx
:
Optional
[
int
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -274,7 +275,8 @@ class ArcticAttention(nn.Module):
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
def
forward
(
self
,
...
...
@@ -299,6 +301,7 @@ class ArcticDecoderLayer(nn.Module):
layer_idx
:
int
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
layer_idx
=
layer_idx
...
...
@@ -308,7 +311,8 @@ class ArcticDecoderLayer(nn.Module):
self
.
self_attn
=
ArcticAttention
(
config
,
layer_idx
,
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
)
self
.
block_sparse_moe
=
ArcticMoE
(
config
,
layer_id
=
layer_idx
,
...
...
@@ -380,8 +384,11 @@ class ArcticModel(nn.Module):
org_num_embeddings
=
self
.
vocab_size
)
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
ArcticDecoderLayer
(
config
,
int
(
prefix
.
split
(
"."
)[
-
1
]),
cache_config
,
quant_config
),
lambda
prefix
:
ArcticDecoderLayer
(
config
,
int
(
prefix
.
split
(
"."
)[
-
1
]),
cache_config
,
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
)
self
.
_attn_implementation
=
config
.
_attn_implementation
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
vllm/model_executor/models/baichuan.py
View file @
eebad39f
...
...
@@ -116,6 +116,7 @@ class BaiChuanAttention(nn.Module):
max_position_embeddings
:
int
=
8192
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
...
...
@@ -158,7 +159,8 @@ class BaiChuanAttention(nn.Module):
self
.
head_dim
,
scaling
,
alibi_slopes
=
alibi_slopes
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
else
:
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
...
...
@@ -171,7 +173,8 @@ class BaiChuanAttention(nn.Module):
self
.
head_dim
,
self
.
scaling
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
def
forward
(
self
,
...
...
@@ -195,7 +198,8 @@ class BaiChuanDecoderLayer(nn.Module):
config
:
PretrainedConfig
,
position_embedding
:
str
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
...
...
@@ -209,6 +213,7 @@ class BaiChuanDecoderLayer(nn.Module):
max_position_embeddings
=
max_position_embeddings
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
self
.
mlp
=
BaiChuanMLP
(
hidden_size
=
self
.
hidden_size
,
...
...
@@ -275,8 +280,11 @@ class BaiChuanModel(nn.Module):
)
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
BaiChuanDecoderLayer
(
config
,
position_embedding
,
cache_config
,
quant_config
),
lambda
prefix
:
BaiChuanDecoderLayer
(
config
,
position_embedding
,
cache_config
,
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
,
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
vllm/model_executor/models/bart.py
View file @
eebad39f
...
...
@@ -126,6 +126,7 @@ class BartEncoderAttention(nn.Module):
config
:
Optional
[
BartConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
d_model
=
config
.
d_model
...
...
@@ -178,7 +179,8 @@ class BartEncoderAttention(nn.Module):
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
)
->
torch
.
Tensor
:
...
...
@@ -208,6 +210,7 @@ class BartDecoderSelfAttention(nn.Module):
config
:
Optional
[
BartConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
d_model
=
config
.
d_model
...
...
@@ -260,7 +263,8 @@ class BartDecoderSelfAttention(nn.Module):
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
)
->
torch
.
Tensor
:
...
...
@@ -290,6 +294,7 @@ class BartCrossAttention(nn.Module):
config
:
Optional
[
BartConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
d_model
=
config
.
d_model
...
...
@@ -342,7 +347,8 @@ class BartCrossAttention(nn.Module):
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
def
forward
(
self
,
...
...
@@ -384,6 +390,7 @@ class BartEncoderLayer(nn.Module):
config
:
BartConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
embed_dim
=
config
.
d_model
...
...
@@ -393,7 +400,9 @@ class BartEncoderLayer(nn.Module):
num_heads
=
config
.
encoder_attention_heads
,
config
=
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
)
self
.
activation_fn
=
get_act_fn
(
config
.
activation_function
)
...
...
@@ -464,6 +473,7 @@ class BartDecoderLayer(nn.Module):
config
:
BartConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
embed_dim
=
config
.
d_model
...
...
@@ -473,7 +483,9 @@ class BartDecoderLayer(nn.Module):
num_heads
=
config
.
decoder_attention_heads
,
config
=
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
self
.
activation_fn
=
get_act_fn
(
config
.
activation_function
)
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
)
...
...
@@ -486,6 +498,7 @@ class BartDecoderLayer(nn.Module):
self
.
embed_dim
,
config
.
decoder_attention_heads
,
config
=
config
,
prefix
=
f
"
{
prefix
}
.encoder_attn"
,
)
self
.
encoder_attn_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
)
...
...
@@ -578,7 +591,8 @@ class BartEncoder(nn.Module):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
embed_tokens
:
Optional
[
nn
.
Embedding
]
=
None
):
embed_tokens
:
Optional
[
nn
.
Embedding
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
cache_config
=
cache_config
...
...
@@ -599,9 +613,13 @@ class BartEncoder(nn.Module):
config
.
max_position_embeddings
,
embed_dim
,
)
self
.
layers
=
nn
.
ModuleList
(
[
BartEncoderLayer
(
config
,
cache_config
,
quant_config
)
\
for
_
in
range
(
config
.
encoder_layers
)])
self
.
layers
=
nn
.
ModuleList
([
BartEncoderLayer
(
config
,
cache_config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.layers.
{
layer_idx
}
"
)
for
layer_idx
in
range
(
config
.
encoder_layers
)
])
self
.
layernorm_embedding
=
nn
.
LayerNorm
(
embed_dim
)
...
...
@@ -661,6 +679,7 @@ class BartDecoder(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
embed_tokens
:
Optional
[
nn
.
Embedding
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
cache_config
=
cache_config
...
...
@@ -683,8 +702,9 @@ class BartDecoder(nn.Module):
)
self
.
layers
=
nn
.
ModuleList
(
[
BartDecoderLayer
(
config
,
cache_config
,
quant_config
)
\
for
_
in
range
(
config
.
decoder_layers
)])
[
BartDecoderLayer
(
config
,
cache_config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.layers.
{
layer_idx
}
"
)
\
for
layer_idx
in
range
(
config
.
decoder_layers
)])
self
.
layernorm_embedding
=
nn
.
LayerNorm
(
config
.
d_model
)
...
...
@@ -759,10 +779,12 @@ class BartModel(nn.Module):
self
.
encoder
=
BartEncoder
(
config
,
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.encoder"
)
self
.
decoder
=
BartDecoder
(
config
,
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.decoder"
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
encoder_input_ids
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/bloom.py
View file @
eebad39f
...
...
@@ -78,6 +78,7 @@ class BloomAttention(nn.Module):
config
:
BloomConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
...
...
@@ -116,7 +117,8 @@ class BloomAttention(nn.Module):
scaling
,
alibi_slopes
=
alibi_slopes
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
def
forward
(
self
,
...
...
@@ -168,14 +170,17 @@ class BloomBlock(nn.Module):
config
:
BloomConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
self
.
input_layernorm
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
self_attention
=
BloomAttention
(
config
,
cache_config
,
quant_config
)
self
.
self_attention
=
BloomAttention
(
config
,
cache_config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attention"
)
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
BloomMLP
(
config
,
quant_config
)
...
...
@@ -242,7 +247,8 @@ class BloomModel(nn.Module):
# Transformer blocks
self
.
start_layer
,
self
.
end_layer
,
self
.
h
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
BloomBlock
(
config
,
cache_config
,
quant_config
),
lambda
prefix
:
BloomBlock
(
config
,
cache_config
,
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.h"
)
# Final Layer Norm
...
...
vllm/model_executor/models/chameleon.py
View file @
eebad39f
...
...
@@ -223,6 +223,7 @@ class ChameleonAttention(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
bias
:
bool
=
False
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
...
...
@@ -276,7 +277,8 @@ class ChameleonAttention(nn.Module):
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
def
_apply_qk_norm
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
@@ -313,6 +315,7 @@ class ChameleonDecoderLayer(nn.Module):
config
:
ChameleonConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
...
...
@@ -336,6 +339,7 @@ class ChameleonDecoderLayer(nn.Module):
quant_config
=
quant_config
,
bias
=
False
,
cache_config
=
cache_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
self
.
mlp
=
ChameleonMLP
(
hidden_size
=
self
.
hidden_size
,
...
...
@@ -386,6 +390,7 @@ class ChameleonSwinDecoderLayer(nn.Module):
config
:
ChameleonConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
...
...
@@ -409,6 +414,7 @@ class ChameleonSwinDecoderLayer(nn.Module):
quant_config
=
quant_config
,
bias
=
False
,
cache_config
=
cache_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
self
.
mlp
=
ChameleonMLP
(
hidden_size
=
self
.
hidden_size
,
...
...
@@ -855,7 +861,8 @@ class ChameleonModel(nn.Module):
config
.
num_hidden_layers
,
lambda
prefix
:
decoder_layer
(
config
=
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
),
quant_config
=
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
,
)
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment