"vllm/vscode:/vscode.git/clone" did not exist on "e97f802b2d74861af77997691a7d1c36498f6dca"
Unverified Commit eebad39f authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[torch.compile] support all attention backends (#10558)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent db100c5c
...@@ -18,8 +18,10 @@ from vllm.attention import (Attention, AttentionBackend, AttentionMetadata, ...@@ -18,8 +18,10 @@ from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from vllm.attention.selector import (_Backend, _cached_get_attn_backend, from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
global_force_attn_backend_context_manager) global_force_attn_backend_context_manager)
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.plugins import set_current_vllm_config
# List of support backends for encoder/decoder models # List of support backends for encoder/decoder models
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN] LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
...@@ -594,6 +596,7 @@ def _run_encoder_attention_test( ...@@ -594,6 +596,7 @@ def _run_encoder_attention_test(
encoder_test_params: PhaseTestParameters, encoder_test_params: PhaseTestParameters,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
test_pt: TestPoint, test_pt: TestPoint,
vllm_config: VllmConfig,
) -> torch.Tensor: ) -> torch.Tensor:
''' '''
Run encoder attention. Run encoder attention.
...@@ -623,7 +626,7 @@ def _run_encoder_attention_test( ...@@ -623,7 +626,7 @@ def _run_encoder_attention_test(
attn_type = AttentionType.ENCODER attn_type = AttentionType.ENCODER
packed_qkv = encoder_test_params.packed_qkvo.packed_qkv packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
assert packed_qkv is not None assert packed_qkv is not None
with set_forward_context(attn_metadata): with set_forward_context(attn_metadata, vllm_config):
# In the test setup the shape of the query is # In the test setup the shape of the query is
# [batch_size, seq_len, num_heads, head_size]. However # [batch_size, seq_len, num_heads, head_size]. However
# the attention backend expect the shape to be # the attention backend expect the shape to be
...@@ -648,6 +651,7 @@ def _run_decoder_self_attention_test( ...@@ -648,6 +651,7 @@ def _run_decoder_self_attention_test(
decoder_test_params: PhaseTestParameters, decoder_test_params: PhaseTestParameters,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
test_pt: TestPoint, test_pt: TestPoint,
vllm_config: VllmConfig,
) -> torch.Tensor: ) -> torch.Tensor:
''' '''
Run decoder self-attention test. Run decoder self-attention test.
...@@ -677,7 +681,7 @@ def _run_decoder_self_attention_test( ...@@ -677,7 +681,7 @@ def _run_decoder_self_attention_test(
kv_cache = test_rsrcs.kv_cache kv_cache = test_rsrcs.kv_cache
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
assert packed_qkv is not None assert packed_qkv is not None
with set_forward_context(attn_metadata): with set_forward_context(attn_metadata, vllm_config):
# In the test setup the shape of the query is # In the test setup the shape of the query is
# [batch_size, seq_len, num_heads, head_size]. However # [batch_size, seq_len, num_heads, head_size]. However
# the attention backend expect the shape to be # the attention backend expect the shape to be
...@@ -701,6 +705,7 @@ def _run_encoder_decoder_cross_attention_test( ...@@ -701,6 +705,7 @@ def _run_encoder_decoder_cross_attention_test(
cross_test_params: Optional[PhaseTestParameters], cross_test_params: Optional[PhaseTestParameters],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
test_pt: TestPoint, test_pt: TestPoint,
vllm_config: VllmConfig,
) -> torch.Tensor: ) -> torch.Tensor:
''' '''
Run encoder/decoder cross-attention test. Run encoder/decoder cross-attention test.
...@@ -748,7 +753,7 @@ def _run_encoder_decoder_cross_attention_test( ...@@ -748,7 +753,7 @@ def _run_encoder_decoder_cross_attention_test(
cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv
key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key) key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key)
value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value) value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value)
with set_forward_context(attn_metadata): with set_forward_context(attn_metadata, vllm_config):
# In the test setup the shape of the query is # In the test setup the shape of the query is
# [batch_size, seq_len, num_heads, head_size]. However # [batch_size, seq_len, num_heads, head_size]. However
# the attention backend expect the shape to be # the attention backend expect the shape to be
...@@ -839,6 +844,8 @@ def test_encoder_only( ...@@ -839,6 +844,8 @@ def test_encoder_only(
# Attention scale factor, attention backend instance, attention wrapper # Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init # instance, KV cache init
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
test_rsrcs = _make_test_resources(test_pt) test_rsrcs = _make_test_resources(test_pt)
# Construct encoder attention test params (only used # Construct encoder attention test params (only used
...@@ -863,7 +870,8 @@ def test_encoder_only( ...@@ -863,7 +870,8 @@ def test_encoder_only(
test_rsrcs.attn, test_rsrcs.attn,
enc_test_params, enc_test_params,
prephase_attn_metadata, prephase_attn_metadata,
test_pt=test_pt)) test_pt=test_pt,
vllm_config=vllm_config))
# - Is encoder attention result correct? # - Is encoder attention result correct?
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out, assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
...@@ -960,6 +968,8 @@ def test_e2e_enc_dec_attn( ...@@ -960,6 +968,8 @@ def test_e2e_enc_dec_attn(
# Attention scale factor, attention backend instance, attention wrapper # Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init # instance, KV cache init
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
test_rsrcs = _make_test_resources(test_pt) test_rsrcs = _make_test_resources(test_pt)
# Construct encoder attention test params (only used # Construct encoder attention test params (only used
...@@ -1011,7 +1021,8 @@ def test_e2e_enc_dec_attn( ...@@ -1011,7 +1021,8 @@ def test_e2e_enc_dec_attn(
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn, enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
enc_test_params, enc_test_params,
prephase_attn_metadata, prephase_attn_metadata,
test_pt=test_pt) test_pt=test_pt,
vllm_config=vllm_config)
# - Is encoder attention result correct? # - Is encoder attention result correct?
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out, assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
...@@ -1023,7 +1034,8 @@ def test_e2e_enc_dec_attn( ...@@ -1023,7 +1034,8 @@ def test_e2e_enc_dec_attn(
test_rsrcs, test_rsrcs,
prephase_dec_test_params, prephase_dec_test_params,
prephase_attn_metadata, prephase_attn_metadata,
test_pt=test_pt) test_pt=test_pt,
vllm_config=vllm_config)
# - Is prefill decoder self-attention correct? # - Is prefill decoder self-attention correct?
assert_actual_matches_ideal(prephase_dec_test_params, assert_actual_matches_ideal(prephase_dec_test_params,
...@@ -1037,7 +1049,8 @@ def test_e2e_enc_dec_attn( ...@@ -1037,7 +1049,8 @@ def test_e2e_enc_dec_attn(
prephase_dec_test_params, prephase_dec_test_params,
prephase_cross_test_params, prephase_cross_test_params,
prephase_attn_metadata, prephase_attn_metadata,
test_pt=test_pt) test_pt=test_pt,
vllm_config=vllm_config)
# - Is prefill encoder/decoder cross-attention correct? # - Is prefill encoder/decoder cross-attention correct?
assert_actual_matches_ideal(prephase_cross_test_params, assert_actual_matches_ideal(prephase_cross_test_params,
...@@ -1061,7 +1074,8 @@ def test_e2e_enc_dec_attn( ...@@ -1061,7 +1074,8 @@ def test_e2e_enc_dec_attn(
test_rsrcs, test_rsrcs,
decphase_dec_test_params, decphase_dec_test_params,
decphase_attn_metadata, decphase_attn_metadata,
test_pt=test_pt) test_pt=test_pt,
vllm_config=vllm_config)
# - Is decode-phase decoder self-attention correct? # - Is decode-phase decoder self-attention correct?
assert_actual_matches_ideal(decphase_dec_test_params, assert_actual_matches_ideal(decphase_dec_test_params,
...@@ -1075,7 +1089,8 @@ def test_e2e_enc_dec_attn( ...@@ -1075,7 +1089,8 @@ def test_e2e_enc_dec_attn(
decphase_dec_test_params, decphase_dec_test_params,
None, None,
decphase_attn_metadata, decphase_attn_metadata,
test_pt=test_pt) test_pt=test_pt,
vllm_config=vllm_config)
# - Is decode-phase encoder/decoder cross-attention correct? # - Is decode-phase encoder/decoder cross-attention correct?
assert_actual_matches_ideal(decphase_cross_test_params, assert_actual_matches_ideal(decphase_cross_test_params,
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from enum import Enum, auto
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set, from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
Tuple, Type, TypeVar) Tuple, Type, TypeVar)
...@@ -15,13 +14,19 @@ if TYPE_CHECKING: ...@@ -15,13 +14,19 @@ if TYPE_CHECKING:
ModelRunnerInputBuilderBase) ModelRunnerInputBuilderBase)
class AttentionType(Enum): class AttentionType:
DECODER = auto() # Decoder attention between previous layer Q/K/V """
ENCODER = auto( Attention type.
) # Encoder attention between previous layer Q/K/V for encoder-decoder Use string to be compatible with `torch.compile`.
ENCODER_ONLY = auto() # Encoder attention between previous layer Q/K/V """
ENCODER_DECODER = auto( # Decoder attention between previous layer Q/K/V
) # Attention between dec. Q and enc. K/V for encoder-decoder DECODER = "decoder"
# Encoder attention between previous layer Q/K/V for encoder-decoder
ENCODER = "encoder"
# Encoder attention between previous layer Q/K/V
ENCODER_ONLY = "encoder_only"
# Attention between dec. Q and enc. K/V for encoder-decoder
ENCODER_DECODER = "encoder_decoder"
class AttentionBackend(ABC): class AttentionBackend(ABC):
...@@ -241,6 +246,6 @@ class AttentionImpl(ABC, Generic[T]): ...@@ -241,6 +246,6 @@ class AttentionImpl(ABC, Generic[T]):
attn_metadata: T, attn_metadata: T,
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
...@@ -354,7 +354,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl): ...@@ -354,7 +354,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
attn_metadata: BlocksparseFlashAttentionMetadata, attn_metadata: BlocksparseFlashAttentionMetadata,
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention. """Forward pass with FlashAttention and PagedAttention.
......
...@@ -16,10 +16,8 @@ from vllm.attention.backends.utils import ( ...@@ -16,10 +16,8 @@ from vllm.attention.backends.utils import (
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
get_seq_len_block_table_args, is_all_cross_attn_metadata_set, get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
is_all_encoder_attn_metadata_set, is_block_tables_empty) is_all_encoder_attn_metadata_set, is_block_tables_empty)
from vllm.forward_context import get_forward_context
from vllm.multimodal import MultiModalPlaceholderMap from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import (async_tensor_h2d, direct_register_custom_op, from vllm.utils import async_tensor_h2d, make_tensor_with_pad
make_tensor_with_pad)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder, from vllm.worker.model_runner import (ModelInputForGPUBuilder,
...@@ -639,7 +637,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -639,7 +637,7 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata: FlashAttentionMetadata, attn_metadata: FlashAttentionMetadata,
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention. """Forward pass with FlashAttention.
...@@ -668,139 +666,14 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -668,139 +666,14 @@ class FlashAttentionImpl(AttentionImpl):
"requires setting cross-attention " "requires setting cross-attention "
"metadata attributes.") "metadata attributes.")
output = torch.ops.vllm.unified_flash_attention( num_heads: int = self.num_heads
query, head_size: int = self.head_size
key, num_kv_heads: int = self.num_kv_heads
value, kv_cache_dtype: str = self.kv_cache_dtype
self.num_heads, softmax_scale: float = self.scale
self.head_size, window_size = self.sliding_window
self.num_kv_heads, alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes
kv_cache, logits_soft_cap: Optional[float] = self.logits_soft_cap
self.kv_cache_dtype,
k_scale,
v_scale,
self.scale,
attn_type.value,
self.sliding_window,
self.alibi_slopes,
self.logits_soft_cap,
)
return output
def _get_query_key_seq_metadata(
attn_metadata,
is_prompt: bool,
attn_type: AttentionType,
) -> tuple:
"""
Returns sequence metadata for key and query based on the specified
attention type and whether input is a prompt.
This function computes the starting locations and maximum sequence lengths
for key and query sequences for different attention types.
Args:
attn_metadata: The attention metadata object
is_prompt (bool): A flag indicating if the input is a prompt
attn_type (AttentionType): The type of attention being used.
Returns:
tuple: A tuple containing four integers:
- Starting location for the query sequence.
- Maximum sequence length for the query sequence.
- Starting location for the key sequence.
- Maximum sequence length for the key sequence.
Raises:
AttributeError: If an invalid attention type is provided.
"""
if attn_type == AttentionType.DECODER:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
if is_prompt:
max_seq_len = attn_metadata.max_prefill_seq_len
else:
max_seq_len = attn_metadata.max_decode_seq_len
return (attn_metadata.seq_start_loc, max_seq_len,
attn_metadata.seq_start_loc, max_seq_len)
elif attn_type == AttentionType.ENCODER_DECODER:
# This is cross attention between the where the key
# is the precomputed encoder attention and query
# is the input sequence.
# Choose query max length based on whether it is prompt
# or not.
if is_prompt:
max_seq_len = attn_metadata.max_prefill_seq_len
else:
max_seq_len = attn_metadata.max_decode_seq_len
return (attn_metadata.seq_start_loc, max_seq_len,
attn_metadata.encoder_seq_start_loc,
attn_metadata.max_encoder_seq_len)
elif attn_type == AttentionType.ENCODER:
# For encoder attention both the query and the key are same i.e the
# encoder sequence.
return (attn_metadata.encoder_seq_start_loc,
attn_metadata.max_encoder_seq_len,
attn_metadata.encoder_seq_start_loc,
attn_metadata.max_encoder_seq_len)
elif attn_type == AttentionType.ENCODER_ONLY:
assert is_prompt, "Should not have decode for encoder only model."
return (attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len,
attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def _get_causal_option(attn_type: AttentionType) -> bool:
"""
Determine whether the given attention type is suitable for causal
attention mechanisms.
Args:
attn_type (AttentionType): The type of attention being evaluated
Returns:
bool: Returns `True` if the attention type is suitable for causal
attention (i.e., not encoder, encoder-only, or encoder-decoder),
otherwise returns `False`.
"""
return not (attn_type == AttentionType.ENCODER
or attn_type == AttentionType.ENCODER_ONLY
or attn_type == AttentionType.ENCODER_DECODER)
def unified_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
attn_type_int_val: int,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
# Convert integer attn_type to enum
try:
attn_type = AttentionType(attn_type_int_val)
except ValueError as err:
raise AttributeError(
f"Invalid attention type {str(attn_type_int_val)}") from err
current_metadata = get_forward_context()
assert current_metadata is not None
assert isinstance(current_metadata, FlashAttentionMetadata)
attn_metadata: FlashAttentionMetadata = current_metadata
num_tokens, hidden_size = query.shape num_tokens, hidden_size = query.shape
...@@ -817,9 +690,9 @@ def unified_flash_attention( ...@@ -817,9 +690,9 @@ def unified_flash_attention(
# a. When the Attention Type is ENCODER. In this phase, we compute # a. When the Attention Type is ENCODER. In this phase, we compute
# only the encoder attention without updating the cache. # only the encoder attention without updating the cache.
# b. When both Key and Value are None. This occurs during # b. When both Key and Value are None. This occurs during
# cross-attention computation in the decoding phase, where the KV # cross-attention computation in the decoding phase, where the
# cache is already populated with the cross-attention tensor. # KV cache is already populated with the cross-attention
# Thus, we skip cache updates during this time. # tensor. Thus, we skip cache updates during this time.
if (attn_type != AttentionType.ENCODER) and (key is not None) and ( if (attn_type != AttentionType.ENCODER) and (key is not None) and (
value is not None): value is not None):
if attn_type == AttentionType.ENCODER_DECODER: if attn_type == AttentionType.ENCODER_DECODER:
...@@ -831,7 +704,8 @@ def unified_flash_attention( ...@@ -831,7 +704,8 @@ def unified_flash_attention(
# Reshape the input keys and values and store them in the cache. # Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are # If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run. # not cached. This happens during the initial memory
# profiling run.
torch.ops._C_cache_ops.reshape_and_cache_flash( torch.ops._C_cache_ops.reshape_and_cache_flash(
key, key,
value, value,
...@@ -912,7 +786,8 @@ def unified_flash_attention( ...@@ -912,7 +786,8 @@ def unified_flash_attention(
# use only for actual varlen decoding # use only for actual varlen decoding
if decode_meta.max_decode_query_len > 1: if decode_meta.max_decode_query_len > 1:
assert attn_type == AttentionType.DECODER, ( assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support max_decode_query_len > 1") "Only decoder-only models support max_decode_query_len > 1"
)
decode_output = flash_attn_varlen_func( decode_output = flash_attn_varlen_func(
q=decode_query, q=decode_query,
k=key_cache, k=key_cache,
...@@ -960,30 +835,87 @@ def unified_flash_attention( ...@@ -960,30 +835,87 @@ def unified_flash_attention(
output = torch.cat([prefill_output, decode_output], dim=0) output = torch.cat([prefill_output, decode_output], dim=0)
return output.view(num_tokens, hidden_size) return output.view(num_tokens, hidden_size)
return output
def unified_flash_attention_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
attn_type_int_val: int,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query)
def _get_query_key_seq_metadata(
attn_metadata,
is_prompt: bool,
attn_type: str,
) -> tuple:
"""
Returns sequence metadata for key and query based on the specified
attention type and whether input is a prompt.
This function computes the starting locations and maximum sequence lengths
for key and query sequences for different attention types.
Args:
attn_metadata: The attention metadata object
is_prompt (bool): A flag indicating if the input is a prompt
attn_type (AttentionType): The type of attention being used.
Returns:
tuple: A tuple containing four integers:
- Starting location for the query sequence.
- Maximum sequence length for the query sequence.
- Starting location for the key sequence.
- Maximum sequence length for the key sequence.
direct_register_custom_op( Raises:
op_name="unified_flash_attention", AttributeError: If an invalid attention type is provided.
op_func=unified_flash_attention, """
mutates_args=["kv_cache"], if attn_type == AttentionType.DECODER:
fake_impl=unified_flash_attention_fake, # Decoder self-attention
) # Choose max_seq_len based on whether we are in prompt_run
if is_prompt:
max_seq_len = attn_metadata.max_prefill_seq_len
else:
max_seq_len = attn_metadata.max_decode_seq_len
return (attn_metadata.seq_start_loc, max_seq_len,
attn_metadata.seq_start_loc, max_seq_len)
elif attn_type == AttentionType.ENCODER_DECODER:
# This is cross attention between the where the key
# is the precomputed encoder attention and query
# is the input sequence.
# Choose query max length based on whether it is prompt
# or not.
if is_prompt:
max_seq_len = attn_metadata.max_prefill_seq_len
else:
max_seq_len = attn_metadata.max_decode_seq_len
return (attn_metadata.seq_start_loc, max_seq_len,
attn_metadata.encoder_seq_start_loc,
attn_metadata.max_encoder_seq_len)
elif attn_type == AttentionType.ENCODER:
# For encoder attention both the query and the key are same i.e the
# encoder sequence.
return (attn_metadata.encoder_seq_start_loc,
attn_metadata.max_encoder_seq_len,
attn_metadata.encoder_seq_start_loc,
attn_metadata.max_encoder_seq_len)
elif attn_type == AttentionType.ENCODER_ONLY:
assert is_prompt, "Should not have decode for encoder only model."
return (attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len,
attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def _get_causal_option(attn_type: str) -> bool:
"""
Determine whether the given attention type is suitable for causal
attention mechanisms.
Args:
attn_type (AttentionType): The type of attention being evaluated
Returns:
bool: Returns `True` if the attention type is suitable for causal
attention (i.e., not encoder, encoder-only, or encoder-decoder),
otherwise returns `False`.
"""
return not (attn_type == AttentionType.ENCODER
or attn_type == AttentionType.ENCODER_ONLY
or attn_type == AttentionType.ENCODER_DECODER)
...@@ -30,9 +30,8 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, ...@@ -30,9 +30,8 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx, compute_slot_mapping_start_idx,
is_block_tables_empty) is_block_tables_empty)
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
from vllm.forward_context import get_forward_context from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
from vllm.utils import (async_tensor_h2d, direct_register_custom_op, make_tensor_with_pad)
get_kv_cache_torch_dtype, make_tensor_with_pad)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder, from vllm.worker.model_runner import (ModelInputForGPUBuilder,
...@@ -774,7 +773,7 @@ class FlashInferImpl(AttentionImpl): ...@@ -774,7 +773,7 @@ class FlashInferImpl(AttentionImpl):
attn_metadata: FlashInferMetadata, attn_metadata: FlashInferMetadata,
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
if attn_type != AttentionType.DECODER: if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and " raise NotImplementedError("Encoder self-attention and "
...@@ -782,45 +781,14 @@ class FlashInferImpl(AttentionImpl): ...@@ -782,45 +781,14 @@ class FlashInferImpl(AttentionImpl):
"are not implemented for " "are not implemented for "
"FlashInferImpl") "FlashInferImpl")
return torch.ops.vllm.unified_flash_infer( num_heads: int = self.num_heads
query, head_size: int = self.head_size
key, num_kv_heads: int = self.num_kv_heads
value, kv_cache_dtype: str = self.kv_cache_dtype
self.num_heads, softmax_scale: float = self.scale
self.head_size, window_size = self.sliding_window
self.num_kv_heads, alibi_slopes = self.alibi_slopes
kv_cache, logits_soft_cap = self.logits_soft_cap
self.kv_cache_dtype,
k_scale,
v_scale,
self.scale,
self.sliding_window,
self.alibi_slopes,
self.logits_soft_cap,
)
def unified_flash_infer(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
current_metadata = get_forward_context()
assert current_metadata is not None
assert isinstance(current_metadata, FlashInferMetadata)
attn_metadata: FlashInferMetadata = current_metadata
num_tokens, hidden_size = query.shape num_tokens, hidden_size = query.shape
query = query.view(-1, num_heads, head_size) query = query.view(-1, num_heads, head_size)
...@@ -852,7 +820,8 @@ def unified_flash_infer( ...@@ -852,7 +820,8 @@ def unified_flash_infer(
f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
query = query.contiguous() # Flashinfer requires query to be contiguous query = query.contiguous(
) # Flashinfer requires query to be contiguous
# Query for decode. KV is not needed because it is already cached. # Query for decode. KV is not needed because it is already cached.
# QKV for prefill. # QKV for prefill.
decode_query = query[num_prefill_tokens:] decode_query = query[num_prefill_tokens:]
...@@ -899,9 +868,9 @@ def unified_flash_infer( ...@@ -899,9 +868,9 @@ def unified_flash_infer(
v_scale=v_scale, v_scale=v_scale,
window_left=window_left) window_left=window_left)
if decode_meta := attn_metadata.decode_metadata: if decode_meta := attn_metadata.decode_metadata:
assert attn_metadata.decode_metadata is not None assert decode_meta is not None
assert attn_metadata.decode_metadata.decode_wrapper is not None assert decode_meta.decode_wrapper is not None
decode_output = attn_metadata.decode_metadata.decode_wrapper.forward( decode_output = decode_meta.decode_wrapper.forward(
decode_query, decode_query,
kv_cache, kv_cache,
sm_scale=softmax_scale, sm_scale=softmax_scale,
...@@ -926,30 +895,3 @@ def unified_flash_infer( ...@@ -926,30 +895,3 @@ def unified_flash_infer(
decode_output = decode_output.squeeze(1) decode_output = decode_output.squeeze(1)
output = torch.cat([prefill_output, decode_output], dim=0) output = torch.cat([prefill_output, decode_output], dim=0)
return output.view(num_tokens, hidden_size) return output.view(num_tokens, hidden_size)
def unified_flash_infer_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query).contiguous()
direct_register_custom_op(
op_name="unified_flash_infer",
op_func=unified_flash_infer,
mutates_args=["kv_cache"],
fake_impl=unified_flash_infer_fake,
)
...@@ -140,7 +140,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): ...@@ -140,7 +140,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
attn_metadata: HPUAttentionMetadata, attn_metadata: HPUAttentionMetadata,
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention. """Forward pass with xFormers and PagedAttention.
......
...@@ -172,7 +172,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): ...@@ -172,7 +172,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
attn_metadata: IpexAttnMetadata, # type: ignore attn_metadata: IpexAttnMetadata, # type: ignore
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention. """Forward pass with IPEX varlen_attention and PagedAttention.
......
...@@ -150,7 +150,7 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -150,7 +150,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
attn_metadata: PallasMetadata, attn_metadata: PallasMetadata,
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with Pallas attention. """Forward pass with Pallas attention.
......
...@@ -414,7 +414,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -414,7 +414,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_metadata: ROCmFlashAttentionMetadata, attn_metadata: ROCmFlashAttentionMetadata,
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention. """Forward pass with FlashAttention and PagedAttention.
......
...@@ -141,7 +141,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -141,7 +141,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
def get_seq_lens( def get_seq_lens(
self, self,
attn_type: AttentionType, attn_type: str,
): ):
''' '''
Extract appropriate sequence lengths from attention metadata Extract appropriate sequence lengths from attention metadata
...@@ -174,7 +174,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -174,7 +174,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
def get_attn_bias( def get_attn_bias(
self, self,
attn_type: AttentionType, attn_type: str,
) -> Optional[List[torch.Tensor]]: ) -> Optional[List[torch.Tensor]]:
''' '''
Extract appropriate attention bias from attention metadata Extract appropriate attention bias from attention metadata
...@@ -203,7 +203,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -203,7 +203,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
def set_attn_bias( def set_attn_bias(
self, self,
attn_bias: List[torch.Tensor], attn_bias: List[torch.Tensor],
attn_type: AttentionType, attn_type: str,
) -> None: ) -> None:
''' '''
Update appropriate attention bias field of attention metadata, Update appropriate attention bias field of attention metadata,
...@@ -229,7 +229,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -229,7 +229,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
def get_seq_len_block_table_args( def get_seq_len_block_table_args(
self, self,
attn_type: AttentionType, attn_type: str,
) -> tuple: ) -> tuple:
''' '''
The particular choice of sequence-length- and block-table-related The particular choice of sequence-length- and block-table-related
...@@ -426,7 +426,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -426,7 +426,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
attn_metadata: TorchSDPAMetadata, # type: ignore attn_metadata: TorchSDPAMetadata, # type: ignore
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention. """Forward pass with torch SDPA and PagedAttention.
...@@ -574,7 +574,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -574,7 +574,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
attn_metadata: TorchSDPAMetadata, attn_metadata: TorchSDPAMetadata,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> None: ) -> None:
if self.num_kv_heads != self.num_heads: if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1) key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
......
...@@ -478,7 +478,7 @@ def is_all_cross_attn_metadata_set(attn_metadata): ...@@ -478,7 +478,7 @@ def is_all_cross_attn_metadata_set(attn_metadata):
def get_seq_len_block_table_args( def get_seq_len_block_table_args(
attn_metadata, attn_metadata,
is_prompt: bool, is_prompt: bool,
attn_type: AttentionType, attn_type: str,
) -> tuple: ) -> tuple:
''' '''
The particular choice of sequence-length- and block-table-related The particular choice of sequence-length- and block-table-related
...@@ -529,7 +529,7 @@ def get_seq_len_block_table_args( ...@@ -529,7 +529,7 @@ def get_seq_len_block_table_args(
def get_num_prefill_decode_query_kv_tokens( def get_num_prefill_decode_query_kv_tokens(
attn_metadata, attn_metadata,
attn_type: AttentionType, attn_type: str,
) -> Tuple[int, int, int]: ) -> Tuple[int, int, int]:
""" """
Calculate the number of prefill and decode tokens for query, key/value Calculate the number of prefill and decode tokens for query, key/value
......
...@@ -284,7 +284,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -284,7 +284,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
def _get_attn_bias( def _get_attn_bias(
attn_metadata: XFormersMetadata, attn_metadata: XFormersMetadata,
attn_type: AttentionType, attn_type: str,
) -> Optional[AttentionBias]: ) -> Optional[AttentionBias]:
''' '''
Extract appropriate attention bias from attention metadata Extract appropriate attention bias from attention metadata
...@@ -314,7 +314,7 @@ def _get_attn_bias( ...@@ -314,7 +314,7 @@ def _get_attn_bias(
def _set_attn_bias( def _set_attn_bias(
attn_metadata: XFormersMetadata, attn_metadata: XFormersMetadata,
attn_bias: List[Optional[AttentionBias]], attn_bias: List[Optional[AttentionBias]],
attn_type: AttentionType, attn_type: str,
) -> None: ) -> None:
''' '''
Update appropriate attention bias field of attention metadata, Update appropriate attention bias field of attention metadata,
...@@ -416,7 +416,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -416,7 +416,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
attn_metadata: "XFormersMetadata", attn_metadata: "XFormersMetadata",
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention. """Forward pass with xFormers and PagedAttention.
...@@ -617,7 +617,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -617,7 +617,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
attn_metadata: XFormersMetadata, attn_metadata: XFormersMetadata,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Attention for 1D query of multiple prompts. Multiple prompt """Attention for 1D query of multiple prompts. Multiple prompt
tokens are flattened in to `query` input. tokens are flattened in to `query` input.
......
...@@ -4,12 +4,17 @@ from typing import Any, Dict, List, Optional ...@@ -4,12 +4,17 @@ from typing import Any, Dict, List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import vllm.envs as envs
from vllm.attention import AttentionMetadata, AttentionType from vllm.attention import AttentionMetadata, AttentionType
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import current_platform
from vllm.plugins import get_current_vllm_config
from vllm.utils import direct_register_custom_op
class Attention(nn.Module): class Attention(nn.Module):
...@@ -86,6 +91,18 @@ class Attention(nn.Module): ...@@ -86,6 +91,18 @@ class Attention(nn.Module):
alibi_slopes, sliding_window, kv_cache_dtype, alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap) blocksparse_params, logits_soft_cap)
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
# torch.compile works by registering the attention as one giant
# opaque custom op. For other platforms, we directly call them
# and let torch.compile handle them.
self.use_direct_call = envs.VLLM_USE_V1 or not (
current_platform.is_cuda_alike() or current_platform.is_cpu())
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
def forward( def forward(
self, self,
query: torch.Tensor, query: torch.Tensor,
...@@ -93,9 +110,10 @@ class Attention(nn.Module): ...@@ -93,9 +110,10 @@ class Attention(nn.Module):
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
if self.use_direct_call:
return self.impl.forward(query, return self.impl.forward(query,
key, key,
value, value,
...@@ -104,6 +122,10 @@ class Attention(nn.Module): ...@@ -104,6 +122,10 @@ class Attention(nn.Module):
self._k_scale, self._k_scale,
self._v_scale, self._v_scale,
attn_type=attn_type) attn_type=attn_type)
else:
return torch.ops.vllm.unified_attention(query, key, value,
kv_cache, attn_type,
self.layer_name)
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore s = f"head_size={self.impl.head_size}" # type: ignore
...@@ -112,3 +134,44 @@ class Attention(nn.Module): ...@@ -112,3 +134,44 @@ class Attention(nn.Module):
s += f", scale={self.impl.scale}" # type: ignore s += f", scale={self.impl.scale}" # type: ignore
s += f", backend={self.impl.__class__.__name__}" s += f", backend={self.impl.__class__.__name__}"
return s return s
def unified_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_type: str,
layer_name: str,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.dynamic_forward_context
self = forward_context.static_forward_context[layer_name]
return self.impl.forward(query,
key,
value,
kv_cache,
attn_metadata,
self._k_scale,
self._v_scale,
attn_type=attn_type)
def unified_attention_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_type: str,
layer_name: str,
) -> torch.Tensor:
return torch.empty_like(query).contiguous()
direct_register_custom_op(
op_name="unified_attention",
op_func=unified_attention,
mutates_args=["kv_cache"],
fake_impl=unified_attention_fake,
dispatch_key=current_platform.dispatch_key,
)
...@@ -2135,8 +2135,7 @@ class CompilationConfig(BaseModel): ...@@ -2135,8 +2135,7 @@ class CompilationConfig(BaseModel):
backend: str = "" backend: str = ""
custom_ops: List[str] = Field(default_factory=list) custom_ops: List[str] = Field(default_factory=list)
splitting_ops: List[str] = Field(default_factory=lambda: [ splitting_ops: List[str] = Field(default_factory=lambda: [
"vllm.unified_flash_attention", "vllm.unified_attention",
"vllm.unified_flash_infer",
"vllm.unified_v1_flash_attention", "vllm.unified_v1_flash_attention",
]) ])
...@@ -2197,6 +2196,11 @@ class CompilationConfig(BaseModel): ...@@ -2197,6 +2196,11 @@ class CompilationConfig(BaseModel):
enabled_custom_ops: Counter[str] = PrivateAttr enabled_custom_ops: Counter[str] = PrivateAttr
disabled_custom_ops: Counter[str] = PrivateAttr disabled_custom_ops: Counter[str] = PrivateAttr
# Per-model forward context
# Mainly used to store attention cls
# Map from layer name to the attention cls
static_forward_context: Dict[str, Any] = PrivateAttr
@classmethod @classmethod
def from_cli(cls, cli_value: str) -> "CompilationConfig": def from_cli(cls, cli_value: str) -> "CompilationConfig":
"""Parse the CLI value for the compilation config.""" """Parse the CLI value for the compilation config."""
...@@ -2228,6 +2232,7 @@ class CompilationConfig(BaseModel): ...@@ -2228,6 +2232,7 @@ class CompilationConfig(BaseModel):
self.enabled_custom_ops = Counter() self.enabled_custom_ops = Counter()
self.disabled_custom_ops = Counter() self.disabled_custom_ops = Counter()
self.static_forward_context = {}
def init_backend(self) -> Union[str, Callable]: def init_backend(self) -> Union[str, Callable]:
if self.level == CompilationLevel.NO_COMPILATION: if self.level == CompilationLevel.NO_COMPILATION:
......
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any from dataclasses import dataclass
from typing import Any, Dict, Optional
_forward_context: Any = None from vllm.config import VllmConfig
def get_forward_context() -> Any: @dataclass
class ForwardContext:
static_forward_context: Dict[str, Any]
# TODO: extend to support per-layer dynamic forward context
dynamic_forward_context: Any
_forward_context: Optional[ForwardContext] = None
def get_forward_context() -> ForwardContext:
"""Get the current forward context.""" """Get the current forward context."""
assert _forward_context is not None, (
"Forward context is not set. "
"Please use `set_forward_context` to set the forward context.")
return _forward_context return _forward_context
@contextmanager @contextmanager
def set_forward_context(context: Any): def set_forward_context(context: Any, vllm_config: VllmConfig):
"""A context manager that stores the current forward context, """A context manager that stores the current forward context,
can be attention metadata, etc.""" can be attention metadata, etc."""
global _forward_context global _forward_context
prev_context = _forward_context prev_context = _forward_context
_forward_context = context _forward_context = ForwardContext(
static_forward_context=vllm_config.compilation_config.
static_forward_context,
dynamic_forward_context=context)
try: try:
yield yield
finally: finally:
......
...@@ -223,6 +223,7 @@ class ArcticAttention(nn.Module): ...@@ -223,6 +223,7 @@ class ArcticAttention(nn.Module):
layer_idx: Optional[int] = None, layer_idx: Optional[int] = None,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -274,7 +275,8 @@ class ArcticAttention(nn.Module): ...@@ -274,7 +275,8 @@ class ArcticAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
...@@ -299,6 +301,7 @@ class ArcticDecoderLayer(nn.Module): ...@@ -299,6 +301,7 @@ class ArcticDecoderLayer(nn.Module):
layer_idx: int, layer_idx: int,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.layer_idx = layer_idx self.layer_idx = layer_idx
...@@ -308,7 +311,8 @@ class ArcticDecoderLayer(nn.Module): ...@@ -308,7 +311,8 @@ class ArcticDecoderLayer(nn.Module):
self.self_attn = ArcticAttention(config, self.self_attn = ArcticAttention(config,
layer_idx, layer_idx,
cache_config, cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.block_sparse_moe = ArcticMoE( self.block_sparse_moe = ArcticMoE(
config, config,
layer_id=layer_idx, layer_id=layer_idx,
...@@ -380,8 +384,11 @@ class ArcticModel(nn.Module): ...@@ -380,8 +384,11 @@ class ArcticModel(nn.Module):
org_num_embeddings=self.vocab_size) org_num_embeddings=self.vocab_size)
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: ArcticDecoderLayer(config, int( lambda prefix: ArcticDecoderLayer(config,
prefix.split(".")[-1]), cache_config, quant_config), int(prefix.split(".")[-1]),
cache_config,
quant_config,
prefix=prefix),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self._attn_implementation = config._attn_implementation self._attn_implementation = config._attn_implementation
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......
...@@ -116,6 +116,7 @@ class BaiChuanAttention(nn.Module): ...@@ -116,6 +116,7 @@ class BaiChuanAttention(nn.Module):
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -158,7 +159,8 @@ class BaiChuanAttention(nn.Module): ...@@ -158,7 +159,8 @@ class BaiChuanAttention(nn.Module):
self.head_dim, self.head_dim,
scaling, scaling,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
else: else:
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
...@@ -171,7 +173,8 @@ class BaiChuanAttention(nn.Module): ...@@ -171,7 +173,8 @@ class BaiChuanAttention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
...@@ -195,7 +198,8 @@ class BaiChuanDecoderLayer(nn.Module): ...@@ -195,7 +198,8 @@ class BaiChuanDecoderLayer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
position_embedding: str, position_embedding: str,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
...@@ -209,6 +213,7 @@ class BaiChuanDecoderLayer(nn.Module): ...@@ -209,6 +213,7 @@ class BaiChuanDecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn",
) )
self.mlp = BaiChuanMLP( self.mlp = BaiChuanMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
...@@ -275,8 +280,11 @@ class BaiChuanModel(nn.Module): ...@@ -275,8 +280,11 @@ class BaiChuanModel(nn.Module):
) )
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: BaiChuanDecoderLayer(config, position_embedding, lambda prefix: BaiChuanDecoderLayer(config,
cache_config, quant_config), position_embedding,
cache_config,
quant_config,
prefix=prefix),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......
...@@ -126,6 +126,7 @@ class BartEncoderAttention(nn.Module): ...@@ -126,6 +126,7 @@ class BartEncoderAttention(nn.Module):
config: Optional[BartConfig] = None, config: Optional[BartConfig] = None,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.d_model = config.d_model self.d_model = config.d_model
...@@ -178,7 +179,8 @@ class BartEncoderAttention(nn.Module): ...@@ -178,7 +179,8 @@ class BartEncoderAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata) -> torch.Tensor: attn_metadata: AttentionMetadata) -> torch.Tensor:
...@@ -208,6 +210,7 @@ class BartDecoderSelfAttention(nn.Module): ...@@ -208,6 +210,7 @@ class BartDecoderSelfAttention(nn.Module):
config: Optional[BartConfig] = None, config: Optional[BartConfig] = None,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.d_model = config.d_model self.d_model = config.d_model
...@@ -260,7 +263,8 @@ class BartDecoderSelfAttention(nn.Module): ...@@ -260,7 +263,8 @@ class BartDecoderSelfAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata) -> torch.Tensor: attn_metadata: AttentionMetadata) -> torch.Tensor:
...@@ -290,6 +294,7 @@ class BartCrossAttention(nn.Module): ...@@ -290,6 +294,7 @@ class BartCrossAttention(nn.Module):
config: Optional[BartConfig] = None, config: Optional[BartConfig] = None,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.d_model = config.d_model self.d_model = config.d_model
...@@ -342,7 +347,8 @@ class BartCrossAttention(nn.Module): ...@@ -342,7 +347,8 @@ class BartCrossAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
...@@ -384,6 +390,7 @@ class BartEncoderLayer(nn.Module): ...@@ -384,6 +390,7 @@ class BartEncoderLayer(nn.Module):
config: BartConfig, config: BartConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
...@@ -393,7 +400,9 @@ class BartEncoderLayer(nn.Module): ...@@ -393,7 +400,9 @@ class BartEncoderLayer(nn.Module):
num_heads=config.encoder_attention_heads, num_heads=config.encoder_attention_heads,
config=config, config=config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.activation_fn = get_act_fn(config.activation_function) self.activation_fn = get_act_fn(config.activation_function)
...@@ -464,6 +473,7 @@ class BartDecoderLayer(nn.Module): ...@@ -464,6 +473,7 @@ class BartDecoderLayer(nn.Module):
config: BartConfig, config: BartConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
...@@ -473,7 +483,9 @@ class BartDecoderLayer(nn.Module): ...@@ -473,7 +483,9 @@ class BartDecoderLayer(nn.Module):
num_heads=config.decoder_attention_heads, num_heads=config.decoder_attention_heads,
config=config, config=config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.activation_fn = get_act_fn(config.activation_function) self.activation_fn = get_act_fn(config.activation_function)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
...@@ -486,6 +498,7 @@ class BartDecoderLayer(nn.Module): ...@@ -486,6 +498,7 @@ class BartDecoderLayer(nn.Module):
self.embed_dim, self.embed_dim,
config.decoder_attention_heads, config.decoder_attention_heads,
config=config, config=config,
prefix=f"{prefix}.encoder_attn",
) )
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
...@@ -578,7 +591,8 @@ class BartEncoder(nn.Module): ...@@ -578,7 +591,8 @@ class BartEncoder(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
embed_tokens: Optional[nn.Embedding] = None): embed_tokens: Optional[nn.Embedding] = None,
prefix: str = ""):
super().__init__() super().__init__()
self.cache_config = cache_config self.cache_config = cache_config
...@@ -599,9 +613,13 @@ class BartEncoder(nn.Module): ...@@ -599,9 +613,13 @@ class BartEncoder(nn.Module):
config.max_position_embeddings, config.max_position_embeddings,
embed_dim, embed_dim,
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList([
[BartEncoderLayer(config,cache_config,quant_config) \ BartEncoderLayer(config,
for _ in range(config.encoder_layers)]) cache_config,
quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(config.encoder_layers)
])
self.layernorm_embedding = nn.LayerNorm(embed_dim) self.layernorm_embedding = nn.LayerNorm(embed_dim)
...@@ -661,6 +679,7 @@ class BartDecoder(nn.Module): ...@@ -661,6 +679,7 @@ class BartDecoder(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
embed_tokens: Optional[nn.Embedding] = None, embed_tokens: Optional[nn.Embedding] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.cache_config = cache_config self.cache_config = cache_config
...@@ -683,8 +702,9 @@ class BartDecoder(nn.Module): ...@@ -683,8 +702,9 @@ class BartDecoder(nn.Module):
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[BartDecoderLayer(config,cache_config,quant_config) \ [BartDecoderLayer(config,cache_config,quant_config,
for _ in range(config.decoder_layers)]) prefix=f"{prefix}.layers.{layer_idx}") \
for layer_idx in range(config.decoder_layers)])
self.layernorm_embedding = nn.LayerNorm(config.d_model) self.layernorm_embedding = nn.LayerNorm(config.d_model)
...@@ -759,10 +779,12 @@ class BartModel(nn.Module): ...@@ -759,10 +779,12 @@ class BartModel(nn.Module):
self.encoder = BartEncoder(config, self.encoder = BartEncoder(config,
cache_config, cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.encoder")
self.decoder = BartDecoder(config, self.decoder = BartDecoder(config,
cache_config, cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.decoder")
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
encoder_input_ids: torch.Tensor, encoder_input_ids: torch.Tensor,
......
...@@ -78,6 +78,7 @@ class BloomAttention(nn.Module): ...@@ -78,6 +78,7 @@ class BloomAttention(nn.Module):
config: BloomConfig, config: BloomConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -116,7 +117,8 @@ class BloomAttention(nn.Module): ...@@ -116,7 +117,8 @@ class BloomAttention(nn.Module):
scaling, scaling,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
...@@ -168,14 +170,17 @@ class BloomBlock(nn.Module): ...@@ -168,14 +170,17 @@ class BloomBlock(nn.Module):
config: BloomConfig, config: BloomConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
self.input_layernorm = nn.LayerNorm(hidden_size, self.input_layernorm = nn.LayerNorm(hidden_size,
eps=config.layer_norm_epsilon) eps=config.layer_norm_epsilon)
self.self_attention = BloomAttention(config, cache_config, self.self_attention = BloomAttention(config,
quant_config) cache_config,
quant_config,
prefix=f"{prefix}.self_attention")
self.post_attention_layernorm = nn.LayerNorm( self.post_attention_layernorm = nn.LayerNorm(
hidden_size, eps=config.layer_norm_epsilon) hidden_size, eps=config.layer_norm_epsilon)
self.mlp = BloomMLP(config, quant_config) self.mlp = BloomMLP(config, quant_config)
...@@ -242,7 +247,8 @@ class BloomModel(nn.Module): ...@@ -242,7 +247,8 @@ class BloomModel(nn.Module):
# Transformer blocks # Transformer blocks
self.start_layer, self.end_layer, self.h = make_layers( self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: BloomBlock(config, cache_config, quant_config), lambda prefix: BloomBlock(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.h") prefix=f"{prefix}.h")
# Final Layer Norm # Final Layer Norm
......
...@@ -223,6 +223,7 @@ class ChameleonAttention(nn.Module): ...@@ -223,6 +223,7 @@ class ChameleonAttention(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = False, bias: bool = False,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -276,7 +277,8 @@ class ChameleonAttention(nn.Module): ...@@ -276,7 +277,8 @@ class ChameleonAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def _apply_qk_norm(self, q: torch.Tensor, def _apply_qk_norm(self, q: torch.Tensor,
k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
...@@ -313,6 +315,7 @@ class ChameleonDecoderLayer(nn.Module): ...@@ -313,6 +315,7 @@ class ChameleonDecoderLayer(nn.Module):
config: ChameleonConfig, config: ChameleonConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -336,6 +339,7 @@ class ChameleonDecoderLayer(nn.Module): ...@@ -336,6 +339,7 @@ class ChameleonDecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
bias=False, bias=False,
cache_config=cache_config, cache_config=cache_config,
prefix=f"{prefix}.self_attn",
) )
self.mlp = ChameleonMLP( self.mlp = ChameleonMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
...@@ -386,6 +390,7 @@ class ChameleonSwinDecoderLayer(nn.Module): ...@@ -386,6 +390,7 @@ class ChameleonSwinDecoderLayer(nn.Module):
config: ChameleonConfig, config: ChameleonConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -409,6 +414,7 @@ class ChameleonSwinDecoderLayer(nn.Module): ...@@ -409,6 +414,7 @@ class ChameleonSwinDecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
bias=False, bias=False,
cache_config=cache_config, cache_config=cache_config,
prefix=f"{prefix}.self_attn",
) )
self.mlp = ChameleonMLP( self.mlp = ChameleonMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
...@@ -855,7 +861,8 @@ class ChameleonModel(nn.Module): ...@@ -855,7 +861,8 @@ class ChameleonModel(nn.Module):
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: decoder_layer(config=config, lambda prefix: decoder_layer(config=config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config), quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment