"docker/vscode:/vscode.git/clone" did not exist on "21d47d0eb8fca8b4cdfea37f99a1c8f40b3c4782"
Unverified Commit eebad39f authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[torch.compile] support all attention backends (#10558)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent db100c5c
...@@ -18,8 +18,10 @@ from vllm.attention import (Attention, AttentionBackend, AttentionMetadata, ...@@ -18,8 +18,10 @@ from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from vllm.attention.selector import (_Backend, _cached_get_attn_backend, from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
global_force_attn_backend_context_manager) global_force_attn_backend_context_manager)
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.plugins import set_current_vllm_config
# List of support backends for encoder/decoder models # List of support backends for encoder/decoder models
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN] LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
...@@ -594,6 +596,7 @@ def _run_encoder_attention_test( ...@@ -594,6 +596,7 @@ def _run_encoder_attention_test(
encoder_test_params: PhaseTestParameters, encoder_test_params: PhaseTestParameters,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
test_pt: TestPoint, test_pt: TestPoint,
vllm_config: VllmConfig,
) -> torch.Tensor: ) -> torch.Tensor:
''' '''
Run encoder attention. Run encoder attention.
...@@ -623,7 +626,7 @@ def _run_encoder_attention_test( ...@@ -623,7 +626,7 @@ def _run_encoder_attention_test(
attn_type = AttentionType.ENCODER attn_type = AttentionType.ENCODER
packed_qkv = encoder_test_params.packed_qkvo.packed_qkv packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
assert packed_qkv is not None assert packed_qkv is not None
with set_forward_context(attn_metadata): with set_forward_context(attn_metadata, vllm_config):
# In the test setup the shape of the query is # In the test setup the shape of the query is
# [batch_size, seq_len, num_heads, head_size]. However # [batch_size, seq_len, num_heads, head_size]. However
# the attention backend expect the shape to be # the attention backend expect the shape to be
...@@ -648,6 +651,7 @@ def _run_decoder_self_attention_test( ...@@ -648,6 +651,7 @@ def _run_decoder_self_attention_test(
decoder_test_params: PhaseTestParameters, decoder_test_params: PhaseTestParameters,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
test_pt: TestPoint, test_pt: TestPoint,
vllm_config: VllmConfig,
) -> torch.Tensor: ) -> torch.Tensor:
''' '''
Run decoder self-attention test. Run decoder self-attention test.
...@@ -677,7 +681,7 @@ def _run_decoder_self_attention_test( ...@@ -677,7 +681,7 @@ def _run_decoder_self_attention_test(
kv_cache = test_rsrcs.kv_cache kv_cache = test_rsrcs.kv_cache
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
assert packed_qkv is not None assert packed_qkv is not None
with set_forward_context(attn_metadata): with set_forward_context(attn_metadata, vllm_config):
# In the test setup the shape of the query is # In the test setup the shape of the query is
# [batch_size, seq_len, num_heads, head_size]. However # [batch_size, seq_len, num_heads, head_size]. However
# the attention backend expect the shape to be # the attention backend expect the shape to be
...@@ -701,6 +705,7 @@ def _run_encoder_decoder_cross_attention_test( ...@@ -701,6 +705,7 @@ def _run_encoder_decoder_cross_attention_test(
cross_test_params: Optional[PhaseTestParameters], cross_test_params: Optional[PhaseTestParameters],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
test_pt: TestPoint, test_pt: TestPoint,
vllm_config: VllmConfig,
) -> torch.Tensor: ) -> torch.Tensor:
''' '''
Run encoder/decoder cross-attention test. Run encoder/decoder cross-attention test.
...@@ -748,7 +753,7 @@ def _run_encoder_decoder_cross_attention_test( ...@@ -748,7 +753,7 @@ def _run_encoder_decoder_cross_attention_test(
cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv
key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key) key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key)
value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value) value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value)
with set_forward_context(attn_metadata): with set_forward_context(attn_metadata, vllm_config):
# In the test setup the shape of the query is # In the test setup the shape of the query is
# [batch_size, seq_len, num_heads, head_size]. However # [batch_size, seq_len, num_heads, head_size]. However
# the attention backend expect the shape to be # the attention backend expect the shape to be
...@@ -839,7 +844,9 @@ def test_encoder_only( ...@@ -839,7 +844,9 @@ def test_encoder_only(
# Attention scale factor, attention backend instance, attention wrapper # Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init # instance, KV cache init
test_rsrcs = _make_test_resources(test_pt) vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
test_rsrcs = _make_test_resources(test_pt)
# Construct encoder attention test params (only used # Construct encoder attention test params (only used
# during prefill) # during prefill)
...@@ -863,7 +870,8 @@ def test_encoder_only( ...@@ -863,7 +870,8 @@ def test_encoder_only(
test_rsrcs.attn, test_rsrcs.attn,
enc_test_params, enc_test_params,
prephase_attn_metadata, prephase_attn_metadata,
test_pt=test_pt)) test_pt=test_pt,
vllm_config=vllm_config))
# - Is encoder attention result correct? # - Is encoder attention result correct?
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out, assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
...@@ -960,7 +968,9 @@ def test_e2e_enc_dec_attn( ...@@ -960,7 +968,9 @@ def test_e2e_enc_dec_attn(
# Attention scale factor, attention backend instance, attention wrapper # Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init # instance, KV cache init
test_rsrcs = _make_test_resources(test_pt) vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
test_rsrcs = _make_test_resources(test_pt)
# Construct encoder attention test params (only used # Construct encoder attention test params (only used
# during prefill) # during prefill)
...@@ -1011,7 +1021,8 @@ def test_e2e_enc_dec_attn( ...@@ -1011,7 +1021,8 @@ def test_e2e_enc_dec_attn(
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn, enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
enc_test_params, enc_test_params,
prephase_attn_metadata, prephase_attn_metadata,
test_pt=test_pt) test_pt=test_pt,
vllm_config=vllm_config)
# - Is encoder attention result correct? # - Is encoder attention result correct?
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out, assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
...@@ -1023,7 +1034,8 @@ def test_e2e_enc_dec_attn( ...@@ -1023,7 +1034,8 @@ def test_e2e_enc_dec_attn(
test_rsrcs, test_rsrcs,
prephase_dec_test_params, prephase_dec_test_params,
prephase_attn_metadata, prephase_attn_metadata,
test_pt=test_pt) test_pt=test_pt,
vllm_config=vllm_config)
# - Is prefill decoder self-attention correct? # - Is prefill decoder self-attention correct?
assert_actual_matches_ideal(prephase_dec_test_params, assert_actual_matches_ideal(prephase_dec_test_params,
...@@ -1037,7 +1049,8 @@ def test_e2e_enc_dec_attn( ...@@ -1037,7 +1049,8 @@ def test_e2e_enc_dec_attn(
prephase_dec_test_params, prephase_dec_test_params,
prephase_cross_test_params, prephase_cross_test_params,
prephase_attn_metadata, prephase_attn_metadata,
test_pt=test_pt) test_pt=test_pt,
vllm_config=vllm_config)
# - Is prefill encoder/decoder cross-attention correct? # - Is prefill encoder/decoder cross-attention correct?
assert_actual_matches_ideal(prephase_cross_test_params, assert_actual_matches_ideal(prephase_cross_test_params,
...@@ -1061,7 +1074,8 @@ def test_e2e_enc_dec_attn( ...@@ -1061,7 +1074,8 @@ def test_e2e_enc_dec_attn(
test_rsrcs, test_rsrcs,
decphase_dec_test_params, decphase_dec_test_params,
decphase_attn_metadata, decphase_attn_metadata,
test_pt=test_pt) test_pt=test_pt,
vllm_config=vllm_config)
# - Is decode-phase decoder self-attention correct? # - Is decode-phase decoder self-attention correct?
assert_actual_matches_ideal(decphase_dec_test_params, assert_actual_matches_ideal(decphase_dec_test_params,
...@@ -1075,7 +1089,8 @@ def test_e2e_enc_dec_attn( ...@@ -1075,7 +1089,8 @@ def test_e2e_enc_dec_attn(
decphase_dec_test_params, decphase_dec_test_params,
None, None,
decphase_attn_metadata, decphase_attn_metadata,
test_pt=test_pt) test_pt=test_pt,
vllm_config=vllm_config)
# - Is decode-phase encoder/decoder cross-attention correct? # - Is decode-phase encoder/decoder cross-attention correct?
assert_actual_matches_ideal(decphase_cross_test_params, assert_actual_matches_ideal(decphase_cross_test_params,
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from enum import Enum, auto
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set, from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
Tuple, Type, TypeVar) Tuple, Type, TypeVar)
...@@ -15,13 +14,19 @@ if TYPE_CHECKING: ...@@ -15,13 +14,19 @@ if TYPE_CHECKING:
ModelRunnerInputBuilderBase) ModelRunnerInputBuilderBase)
class AttentionType(Enum): class AttentionType:
DECODER = auto() # Decoder attention between previous layer Q/K/V """
ENCODER = auto( Attention type.
) # Encoder attention between previous layer Q/K/V for encoder-decoder Use string to be compatible with `torch.compile`.
ENCODER_ONLY = auto() # Encoder attention between previous layer Q/K/V """
ENCODER_DECODER = auto( # Decoder attention between previous layer Q/K/V
) # Attention between dec. Q and enc. K/V for encoder-decoder DECODER = "decoder"
# Encoder attention between previous layer Q/K/V for encoder-decoder
ENCODER = "encoder"
# Encoder attention between previous layer Q/K/V
ENCODER_ONLY = "encoder_only"
# Attention between dec. Q and enc. K/V for encoder-decoder
ENCODER_DECODER = "encoder_decoder"
class AttentionBackend(ABC): class AttentionBackend(ABC):
...@@ -241,6 +246,6 @@ class AttentionImpl(ABC, Generic[T]): ...@@ -241,6 +246,6 @@ class AttentionImpl(ABC, Generic[T]):
attn_metadata: T, attn_metadata: T,
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
...@@ -354,7 +354,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl): ...@@ -354,7 +354,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
attn_metadata: BlocksparseFlashAttentionMetadata, attn_metadata: BlocksparseFlashAttentionMetadata,
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention. """Forward pass with FlashAttention and PagedAttention.
......
...@@ -16,10 +16,8 @@ from vllm.attention.backends.utils import ( ...@@ -16,10 +16,8 @@ from vllm.attention.backends.utils import (
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
get_seq_len_block_table_args, is_all_cross_attn_metadata_set, get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
is_all_encoder_attn_metadata_set, is_block_tables_empty) is_all_encoder_attn_metadata_set, is_block_tables_empty)
from vllm.forward_context import get_forward_context
from vllm.multimodal import MultiModalPlaceholderMap from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import (async_tensor_h2d, direct_register_custom_op, from vllm.utils import async_tensor_h2d, make_tensor_with_pad
make_tensor_with_pad)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder, from vllm.worker.model_runner import (ModelInputForGPUBuilder,
...@@ -639,7 +637,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -639,7 +637,7 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata: FlashAttentionMetadata, attn_metadata: FlashAttentionMetadata,
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention. """Forward pass with FlashAttention.
...@@ -668,23 +666,174 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -668,23 +666,174 @@ class FlashAttentionImpl(AttentionImpl):
"requires setting cross-attention " "requires setting cross-attention "
"metadata attributes.") "metadata attributes.")
output = torch.ops.vllm.unified_flash_attention( num_heads: int = self.num_heads
query, head_size: int = self.head_size
key, num_kv_heads: int = self.num_kv_heads
value, kv_cache_dtype: str = self.kv_cache_dtype
self.num_heads, softmax_scale: float = self.scale
self.head_size, window_size = self.sliding_window
self.num_kv_heads, alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes
kv_cache, logits_soft_cap: Optional[float] = self.logits_soft_cap
self.kv_cache_dtype,
k_scale, num_tokens, hidden_size = query.shape
v_scale,
self.scale, # Reshape the query, key, and value tensors.
attn_type.value, query = query.view(-1, num_heads, head_size)
self.sliding_window, if (key is not None) and (value is not None):
self.alibi_slopes, key = key.view(-1, num_kv_heads, head_size)
self.logits_soft_cap, value = value.view(-1, num_kv_heads, head_size)
)
if kv_cache.numel() > 0:
key_cache = kv_cache[0]
value_cache = kv_cache[1]
# We skip updating the KV cache under two conditions:
# a. When the Attention Type is ENCODER. In this phase, we compute
# only the encoder attention without updating the cache.
# b. When both Key and Value are None. This occurs during
# cross-attention computation in the decoding phase, where the
# KV cache is already populated with the cross-attention
# tensor. Thus, we skip cache updates during this time.
if (attn_type != AttentionType.ENCODER) and (key is not None) and (
value is not None):
if attn_type == AttentionType.ENCODER_DECODER:
# Update cross-attention KV cache (prefill-only)
updated_slot_mapping = attn_metadata.cross_slot_mapping
else:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping = attn_metadata.slot_mapping
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory
# profiling run.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
kv_cache[0],
kv_cache[1],
updated_slot_mapping.flatten(), # type: ignore[union-attr]
kv_cache_dtype,
k_scale,
v_scale,
)
(num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens) = \
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
decode_query = query[num_prefill_query_tokens:]
# QKV for prefill.
query = query[:num_prefill_query_tokens]
assert query.shape[0] == num_prefill_query_tokens
assert decode_query.shape[0] == num_decode_query_tokens
prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
or prefill_meta.block_tables.numel() == 0):
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \
_get_query_key_seq_metadata(prefill_meta, True, attn_type)
key = key[:num_prefill_kv_tokens]
value = value[:num_prefill_kv_tokens]
prefill_output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=q_seq_start_loc,
cu_seqlens_k=k_seq_start_loc,
max_seqlen_q=q_seq_len,
max_seqlen_k=k_seq_len,
softmax_scale=softmax_scale,
causal=_get_causal_option(attn_type),
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
)
else:
# prefix-enabled attention
assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support prefix caching")
assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens)
prefill_output = flash_attn_varlen_func( # noqa
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_k=max_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=logits_soft_cap,
)
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
# Use flash_attn_varlen_func kernel for speculative decoding
# because different queries might have different lengths.
assert decode_meta.max_decode_query_len is not None
# use only for actual varlen decoding
if decode_meta.max_decode_query_len > 1:
assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support max_decode_query_len > 1"
)
decode_output = flash_attn_varlen_func(
q=decode_query,
k=key_cache,
v=value_cache,
cu_seqlens_q=decode_meta.query_start_loc,
max_seqlen_q=decode_meta.max_decode_query_len,
cu_seqlens_k=decode_meta.seq_start_loc,
max_seqlen_k=decode_meta.max_decode_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
block_table=decode_meta.block_tables,
)
else:
# Use flash_attn_with_kvcache for normal decoding.
(
seq_lens_arg,
_,
block_tables_arg,
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
decode_output = flash_attn_with_kvcache(
q=decode_query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
block_table=block_tables_arg,
cache_seqlens=seq_lens_arg,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
).squeeze(1)
if prefill_output is None:
assert decode_output is not None
return decode_output.view(num_decode_query_tokens, hidden_size)
if decode_output is None:
assert prefill_output is not None
return prefill_output.view(num_prefill_query_tokens, hidden_size)
assert decode_meta is not None
decode_output = decode_output.squeeze(1)
output = torch.cat([prefill_output, decode_output], dim=0)
return output.view(num_tokens, hidden_size)
return output return output
...@@ -692,7 +841,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -692,7 +841,7 @@ class FlashAttentionImpl(AttentionImpl):
def _get_query_key_seq_metadata( def _get_query_key_seq_metadata(
attn_metadata, attn_metadata,
is_prompt: bool, is_prompt: bool,
attn_type: AttentionType, attn_type: str,
) -> tuple: ) -> tuple:
""" """
Returns sequence metadata for key and query based on the specified Returns sequence metadata for key and query based on the specified
...@@ -754,7 +903,7 @@ def _get_query_key_seq_metadata( ...@@ -754,7 +903,7 @@ def _get_query_key_seq_metadata(
raise AttributeError(f"Invalid attention type {str(attn_type)}") raise AttributeError(f"Invalid attention type {str(attn_type)}")
def _get_causal_option(attn_type: AttentionType) -> bool: def _get_causal_option(attn_type: str) -> bool:
""" """
Determine whether the given attention type is suitable for causal Determine whether the given attention type is suitable for causal
attention mechanisms. attention mechanisms.
...@@ -770,220 +919,3 @@ def _get_causal_option(attn_type: AttentionType) -> bool: ...@@ -770,220 +919,3 @@ def _get_causal_option(attn_type: AttentionType) -> bool:
return not (attn_type == AttentionType.ENCODER return not (attn_type == AttentionType.ENCODER
or attn_type == AttentionType.ENCODER_ONLY or attn_type == AttentionType.ENCODER_ONLY
or attn_type == AttentionType.ENCODER_DECODER) or attn_type == AttentionType.ENCODER_DECODER)
def unified_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
attn_type_int_val: int,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
# Convert integer attn_type to enum
try:
attn_type = AttentionType(attn_type_int_val)
except ValueError as err:
raise AttributeError(
f"Invalid attention type {str(attn_type_int_val)}") from err
current_metadata = get_forward_context()
assert current_metadata is not None
assert isinstance(current_metadata, FlashAttentionMetadata)
attn_metadata: FlashAttentionMetadata = current_metadata
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, num_heads, head_size)
if (key is not None) and (value is not None):
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
if kv_cache.numel() > 0:
key_cache = kv_cache[0]
value_cache = kv_cache[1]
# We skip updating the KV cache under two conditions:
# a. When the Attention Type is ENCODER. In this phase, we compute
# only the encoder attention without updating the cache.
# b. When both Key and Value are None. This occurs during
# cross-attention computation in the decoding phase, where the KV
# cache is already populated with the cross-attention tensor.
# Thus, we skip cache updates during this time.
if (attn_type != AttentionType.ENCODER) and (key is not None) and (
value is not None):
if attn_type == AttentionType.ENCODER_DECODER:
# Update cross-attention KV cache (prefill-only)
updated_slot_mapping = attn_metadata.cross_slot_mapping
else:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping = attn_metadata.slot_mapping
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
kv_cache[0],
kv_cache[1],
updated_slot_mapping.flatten(), # type: ignore[union-attr]
kv_cache_dtype,
k_scale,
v_scale,
)
(num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens) = \
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
decode_query = query[num_prefill_query_tokens:]
# QKV for prefill.
query = query[:num_prefill_query_tokens]
assert query.shape[0] == num_prefill_query_tokens
assert decode_query.shape[0] == num_decode_query_tokens
prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
or prefill_meta.block_tables.numel() == 0):
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \
_get_query_key_seq_metadata(prefill_meta, True, attn_type)
key = key[:num_prefill_kv_tokens]
value = value[:num_prefill_kv_tokens]
prefill_output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=q_seq_start_loc,
cu_seqlens_k=k_seq_start_loc,
max_seqlen_q=q_seq_len,
max_seqlen_k=k_seq_len,
softmax_scale=softmax_scale,
causal=_get_causal_option(attn_type),
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
)
else:
# prefix-enabled attention
assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support prefix caching")
assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens)
prefill_output = flash_attn_varlen_func( # noqa
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_k=max_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=logits_soft_cap,
)
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
# Use flash_attn_varlen_func kernel for speculative decoding
# because different queries might have different lengths.
assert decode_meta.max_decode_query_len is not None
# use only for actual varlen decoding
if decode_meta.max_decode_query_len > 1:
assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support max_decode_query_len > 1")
decode_output = flash_attn_varlen_func(
q=decode_query,
k=key_cache,
v=value_cache,
cu_seqlens_q=decode_meta.query_start_loc,
max_seqlen_q=decode_meta.max_decode_query_len,
cu_seqlens_k=decode_meta.seq_start_loc,
max_seqlen_k=decode_meta.max_decode_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
block_table=decode_meta.block_tables,
)
else:
# Use flash_attn_with_kvcache for normal decoding.
(
seq_lens_arg,
_,
block_tables_arg,
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
decode_output = flash_attn_with_kvcache(
q=decode_query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
block_table=block_tables_arg,
cache_seqlens=seq_lens_arg,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
).squeeze(1)
if prefill_output is None:
assert decode_output is not None
return decode_output.view(num_decode_query_tokens, hidden_size)
if decode_output is None:
assert prefill_output is not None
return prefill_output.view(num_prefill_query_tokens, hidden_size)
assert decode_meta is not None
decode_output = decode_output.squeeze(1)
output = torch.cat([prefill_output, decode_output], dim=0)
return output.view(num_tokens, hidden_size)
def unified_flash_attention_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
attn_type_int_val: int,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query)
direct_register_custom_op(
op_name="unified_flash_attention",
op_func=unified_flash_attention,
mutates_args=["kv_cache"],
fake_impl=unified_flash_attention_fake,
)
...@@ -30,9 +30,8 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, ...@@ -30,9 +30,8 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx, compute_slot_mapping_start_idx,
is_block_tables_empty) is_block_tables_empty)
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
from vllm.forward_context import get_forward_context from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
from vllm.utils import (async_tensor_h2d, direct_register_custom_op, make_tensor_with_pad)
get_kv_cache_torch_dtype, make_tensor_with_pad)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder, from vllm.worker.model_runner import (ModelInputForGPUBuilder,
...@@ -774,7 +773,7 @@ class FlashInferImpl(AttentionImpl): ...@@ -774,7 +773,7 @@ class FlashInferImpl(AttentionImpl):
attn_metadata: FlashInferMetadata, attn_metadata: FlashInferMetadata,
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
if attn_type != AttentionType.DECODER: if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and " raise NotImplementedError("Encoder self-attention and "
...@@ -782,174 +781,117 @@ class FlashInferImpl(AttentionImpl): ...@@ -782,174 +781,117 @@ class FlashInferImpl(AttentionImpl):
"are not implemented for " "are not implemented for "
"FlashInferImpl") "FlashInferImpl")
return torch.ops.vllm.unified_flash_infer( num_heads: int = self.num_heads
query, head_size: int = self.head_size
key, num_kv_heads: int = self.num_kv_heads
value, kv_cache_dtype: str = self.kv_cache_dtype
self.num_heads, softmax_scale: float = self.scale
self.head_size, window_size = self.sliding_window
self.num_kv_heads, alibi_slopes = self.alibi_slopes
kv_cache, logits_soft_cap = self.logits_soft_cap
self.kv_cache_dtype,
k_scale, num_tokens, hidden_size = query.shape
v_scale, query = query.view(-1, num_heads, head_size)
self.scale, key = key.view(-1, num_kv_heads, head_size)
self.sliding_window, value = value.view(-1, num_kv_heads, head_size)
self.alibi_slopes,
self.logits_soft_cap, if kv_cache.numel() > 0:
) # Use the same reshape and cache kernel as flash attention.
ops.reshape_and_cache_flash(
key,
def unified_flash_infer( value,
query: torch.Tensor, kv_cache[:, 0],
key: torch.Tensor, kv_cache[:, 1],
value: torch.Tensor, attn_metadata.slot_mapping.flatten(),
num_heads: int, kv_cache_dtype,
head_size: int, k_scale,
num_kv_heads: int, v_scale,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
current_metadata = get_forward_context()
assert current_metadata is not None
assert isinstance(current_metadata, FlashInferMetadata)
attn_metadata: FlashInferMetadata = current_metadata
num_tokens, hidden_size = query.shape
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
if kv_cache.numel() > 0:
# Use the same reshape and cache kernel as flash attention.
ops.reshape_and_cache_flash(
key,
value,
kv_cache[:, 0],
kv_cache[:, 1],
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
)
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
if kv_cache_dtype.startswith("fp8"):
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
kv_cache_dtype)
kv_cache = kv_cache.view(torch_dtype)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
query = query.contiguous() # Flashinfer requires query to be contiguous
# Query for decode. KV is not needed because it is already cached.
# QKV for prefill.
decode_query = query[num_prefill_tokens:]
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
window_left = window_size[0] if window_size is not None else -1
prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None
if prefill_meta := attn_metadata.prefill_metadata:
# We will use flash attention for prefill
# when kv_cache is not provided.
# This happens when vllm runs the profiling to
# determine the number of blocks.
if kv_cache.numel() == 0:
prefill_output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
) )
else: # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
assert prefill_meta is not None # to process the cache when the kv_cache_dtype is fp8
assert prefill_meta.prefill_wrapper is not None if kv_cache_dtype.startswith("fp8"):
prefill_output = prefill_meta.prefill_wrapper.forward( torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
query, kv_cache_dtype)
kv_cache = kv_cache.view(torch_dtype)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
query = query.contiguous(
) # Flashinfer requires query to be contiguous
# Query for decode. KV is not needed because it is already cached.
# QKV for prefill.
decode_query = query[num_prefill_tokens:]
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
window_left = window_size[0] if window_size is not None else -1
prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None
if prefill_meta := attn_metadata.prefill_metadata:
# We will use flash attention for prefill
# when kv_cache is not provided.
# This happens when vllm runs the profiling to
# determine the number of blocks.
if kv_cache.numel() == 0:
prefill_output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
)
else:
assert prefill_meta is not None
assert prefill_meta.prefill_wrapper is not None
prefill_output = prefill_meta.prefill_wrapper.forward(
query,
kv_cache,
logits_soft_cap=logits_soft_cap,
causal=True,
k_scale=k_scale,
v_scale=v_scale,
window_left=window_left)
if decode_meta := attn_metadata.decode_metadata:
assert decode_meta is not None
assert decode_meta.decode_wrapper is not None
decode_output = decode_meta.decode_wrapper.forward(
decode_query,
kv_cache, kv_cache,
sm_scale=softmax_scale,
logits_soft_cap=logits_soft_cap, logits_soft_cap=logits_soft_cap,
causal=True,
k_scale=k_scale, k_scale=k_scale,
v_scale=v_scale, v_scale=v_scale,
window_left=window_left) window_left=window_left)
if decode_meta := attn_metadata.decode_metadata:
assert attn_metadata.decode_metadata is not None if prefill_output is None and decode_output is not None:
assert attn_metadata.decode_metadata.decode_wrapper is not None # Decode only batch.
decode_output = attn_metadata.decode_metadata.decode_wrapper.forward( output, num_tokens = decode_output, num_decode_tokens
decode_query, elif decode_output is None and prefill_output is not None:
kv_cache, # Prefill only batch.
sm_scale=softmax_scale, output, num_tokens = prefill_output, num_prefill_tokens
logits_soft_cap=logits_soft_cap, else:
k_scale=k_scale, # Chunked prefill batch does not work with speculative decoding in
v_scale=v_scale, # FlashInfer backend, so the query length for decode should be 1.
window_left=window_left) assert prefill_output is not None
assert decode_output is not None
if prefill_output is None and decode_output is not None: assert decode_meta is not None
# Decode only batch. assert decode_meta.decode_query_len == 1
output, num_tokens = decode_output, num_decode_tokens decode_output = decode_output.squeeze(1)
elif decode_output is None and prefill_output is not None: output = torch.cat([prefill_output, decode_output], dim=0)
# Prefill only batch. return output.view(num_tokens, hidden_size)
output, num_tokens = prefill_output, num_prefill_tokens
else:
# Chunked prefill batch does not work with speculative decoding in
# FlashInfer backend, so the query length for decode should be 1.
assert prefill_output is not None
assert decode_output is not None
assert decode_meta is not None
assert decode_meta.decode_query_len == 1
decode_output = decode_output.squeeze(1)
output = torch.cat([prefill_output, decode_output], dim=0)
return output.view(num_tokens, hidden_size)
def unified_flash_infer_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query).contiguous()
direct_register_custom_op(
op_name="unified_flash_infer",
op_func=unified_flash_infer,
mutates_args=["kv_cache"],
fake_impl=unified_flash_infer_fake,
)
...@@ -140,7 +140,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): ...@@ -140,7 +140,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
attn_metadata: HPUAttentionMetadata, attn_metadata: HPUAttentionMetadata,
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention. """Forward pass with xFormers and PagedAttention.
......
...@@ -172,7 +172,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): ...@@ -172,7 +172,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
attn_metadata: IpexAttnMetadata, # type: ignore attn_metadata: IpexAttnMetadata, # type: ignore
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention. """Forward pass with IPEX varlen_attention and PagedAttention.
......
...@@ -150,7 +150,7 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -150,7 +150,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
attn_metadata: PallasMetadata, attn_metadata: PallasMetadata,
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with Pallas attention. """Forward pass with Pallas attention.
......
...@@ -414,7 +414,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -414,7 +414,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_metadata: ROCmFlashAttentionMetadata, attn_metadata: ROCmFlashAttentionMetadata,
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention. """Forward pass with FlashAttention and PagedAttention.
......
...@@ -141,7 +141,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -141,7 +141,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
def get_seq_lens( def get_seq_lens(
self, self,
attn_type: AttentionType, attn_type: str,
): ):
''' '''
Extract appropriate sequence lengths from attention metadata Extract appropriate sequence lengths from attention metadata
...@@ -174,7 +174,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -174,7 +174,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
def get_attn_bias( def get_attn_bias(
self, self,
attn_type: AttentionType, attn_type: str,
) -> Optional[List[torch.Tensor]]: ) -> Optional[List[torch.Tensor]]:
''' '''
Extract appropriate attention bias from attention metadata Extract appropriate attention bias from attention metadata
...@@ -203,7 +203,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -203,7 +203,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
def set_attn_bias( def set_attn_bias(
self, self,
attn_bias: List[torch.Tensor], attn_bias: List[torch.Tensor],
attn_type: AttentionType, attn_type: str,
) -> None: ) -> None:
''' '''
Update appropriate attention bias field of attention metadata, Update appropriate attention bias field of attention metadata,
...@@ -229,7 +229,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -229,7 +229,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
def get_seq_len_block_table_args( def get_seq_len_block_table_args(
self, self,
attn_type: AttentionType, attn_type: str,
) -> tuple: ) -> tuple:
''' '''
The particular choice of sequence-length- and block-table-related The particular choice of sequence-length- and block-table-related
...@@ -426,7 +426,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -426,7 +426,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
attn_metadata: TorchSDPAMetadata, # type: ignore attn_metadata: TorchSDPAMetadata, # type: ignore
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention. """Forward pass with torch SDPA and PagedAttention.
...@@ -574,7 +574,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -574,7 +574,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
attn_metadata: TorchSDPAMetadata, attn_metadata: TorchSDPAMetadata,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> None: ) -> None:
if self.num_kv_heads != self.num_heads: if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1) key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
......
...@@ -478,7 +478,7 @@ def is_all_cross_attn_metadata_set(attn_metadata): ...@@ -478,7 +478,7 @@ def is_all_cross_attn_metadata_set(attn_metadata):
def get_seq_len_block_table_args( def get_seq_len_block_table_args(
attn_metadata, attn_metadata,
is_prompt: bool, is_prompt: bool,
attn_type: AttentionType, attn_type: str,
) -> tuple: ) -> tuple:
''' '''
The particular choice of sequence-length- and block-table-related The particular choice of sequence-length- and block-table-related
...@@ -529,7 +529,7 @@ def get_seq_len_block_table_args( ...@@ -529,7 +529,7 @@ def get_seq_len_block_table_args(
def get_num_prefill_decode_query_kv_tokens( def get_num_prefill_decode_query_kv_tokens(
attn_metadata, attn_metadata,
attn_type: AttentionType, attn_type: str,
) -> Tuple[int, int, int]: ) -> Tuple[int, int, int]:
""" """
Calculate the number of prefill and decode tokens for query, key/value Calculate the number of prefill and decode tokens for query, key/value
......
...@@ -284,7 +284,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -284,7 +284,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
def _get_attn_bias( def _get_attn_bias(
attn_metadata: XFormersMetadata, attn_metadata: XFormersMetadata,
attn_type: AttentionType, attn_type: str,
) -> Optional[AttentionBias]: ) -> Optional[AttentionBias]:
''' '''
Extract appropriate attention bias from attention metadata Extract appropriate attention bias from attention metadata
...@@ -314,7 +314,7 @@ def _get_attn_bias( ...@@ -314,7 +314,7 @@ def _get_attn_bias(
def _set_attn_bias( def _set_attn_bias(
attn_metadata: XFormersMetadata, attn_metadata: XFormersMetadata,
attn_bias: List[Optional[AttentionBias]], attn_bias: List[Optional[AttentionBias]],
attn_type: AttentionType, attn_type: str,
) -> None: ) -> None:
''' '''
Update appropriate attention bias field of attention metadata, Update appropriate attention bias field of attention metadata,
...@@ -416,7 +416,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -416,7 +416,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
attn_metadata: "XFormersMetadata", attn_metadata: "XFormersMetadata",
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention. """Forward pass with xFormers and PagedAttention.
...@@ -617,7 +617,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -617,7 +617,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
attn_metadata: XFormersMetadata, attn_metadata: XFormersMetadata,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Attention for 1D query of multiple prompts. Multiple prompt """Attention for 1D query of multiple prompts. Multiple prompt
tokens are flattened in to `query` input. tokens are flattened in to `query` input.
......
...@@ -4,12 +4,17 @@ from typing import Any, Dict, List, Optional ...@@ -4,12 +4,17 @@ from typing import Any, Dict, List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import vllm.envs as envs
from vllm.attention import AttentionMetadata, AttentionType from vllm.attention import AttentionMetadata, AttentionType
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import current_platform
from vllm.plugins import get_current_vllm_config
from vllm.utils import direct_register_custom_op
class Attention(nn.Module): class Attention(nn.Module):
...@@ -86,6 +91,18 @@ class Attention(nn.Module): ...@@ -86,6 +91,18 @@ class Attention(nn.Module):
alibi_slopes, sliding_window, kv_cache_dtype, alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap) blocksparse_params, logits_soft_cap)
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
# torch.compile works by registering the attention as one giant
# opaque custom op. For other platforms, we directly call them
# and let torch.compile handle them.
self.use_direct_call = envs.VLLM_USE_V1 or not (
current_platform.is_cuda_alike() or current_platform.is_cpu())
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
def forward( def forward(
self, self,
query: torch.Tensor, query: torch.Tensor,
...@@ -93,17 +110,22 @@ class Attention(nn.Module): ...@@ -93,17 +110,22 @@ class Attention(nn.Module):
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
attn_type: AttentionType = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
return self.impl.forward(query, if self.use_direct_call:
key, return self.impl.forward(query,
value, key,
kv_cache, value,
attn_metadata, kv_cache,
self._k_scale, attn_metadata,
self._v_scale, self._k_scale,
attn_type=attn_type) self._v_scale,
attn_type=attn_type)
else:
return torch.ops.vllm.unified_attention(query, key, value,
kv_cache, attn_type,
self.layer_name)
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore s = f"head_size={self.impl.head_size}" # type: ignore
...@@ -112,3 +134,44 @@ class Attention(nn.Module): ...@@ -112,3 +134,44 @@ class Attention(nn.Module):
s += f", scale={self.impl.scale}" # type: ignore s += f", scale={self.impl.scale}" # type: ignore
s += f", backend={self.impl.__class__.__name__}" s += f", backend={self.impl.__class__.__name__}"
return s return s
def unified_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_type: str,
layer_name: str,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.dynamic_forward_context
self = forward_context.static_forward_context[layer_name]
return self.impl.forward(query,
key,
value,
kv_cache,
attn_metadata,
self._k_scale,
self._v_scale,
attn_type=attn_type)
def unified_attention_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_type: str,
layer_name: str,
) -> torch.Tensor:
return torch.empty_like(query).contiguous()
direct_register_custom_op(
op_name="unified_attention",
op_func=unified_attention,
mutates_args=["kv_cache"],
fake_impl=unified_attention_fake,
dispatch_key=current_platform.dispatch_key,
)
...@@ -2135,8 +2135,7 @@ class CompilationConfig(BaseModel): ...@@ -2135,8 +2135,7 @@ class CompilationConfig(BaseModel):
backend: str = "" backend: str = ""
custom_ops: List[str] = Field(default_factory=list) custom_ops: List[str] = Field(default_factory=list)
splitting_ops: List[str] = Field(default_factory=lambda: [ splitting_ops: List[str] = Field(default_factory=lambda: [
"vllm.unified_flash_attention", "vllm.unified_attention",
"vllm.unified_flash_infer",
"vllm.unified_v1_flash_attention", "vllm.unified_v1_flash_attention",
]) ])
...@@ -2197,6 +2196,11 @@ class CompilationConfig(BaseModel): ...@@ -2197,6 +2196,11 @@ class CompilationConfig(BaseModel):
enabled_custom_ops: Counter[str] = PrivateAttr enabled_custom_ops: Counter[str] = PrivateAttr
disabled_custom_ops: Counter[str] = PrivateAttr disabled_custom_ops: Counter[str] = PrivateAttr
# Per-model forward context
# Mainly used to store attention cls
# Map from layer name to the attention cls
static_forward_context: Dict[str, Any] = PrivateAttr
@classmethod @classmethod
def from_cli(cls, cli_value: str) -> "CompilationConfig": def from_cli(cls, cli_value: str) -> "CompilationConfig":
"""Parse the CLI value for the compilation config.""" """Parse the CLI value for the compilation config."""
...@@ -2228,6 +2232,7 @@ class CompilationConfig(BaseModel): ...@@ -2228,6 +2232,7 @@ class CompilationConfig(BaseModel):
self.enabled_custom_ops = Counter() self.enabled_custom_ops = Counter()
self.disabled_custom_ops = Counter() self.disabled_custom_ops = Counter()
self.static_forward_context = {}
def init_backend(self) -> Union[str, Callable]: def init_backend(self) -> Union[str, Callable]:
if self.level == CompilationLevel.NO_COMPILATION: if self.level == CompilationLevel.NO_COMPILATION:
......
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any from dataclasses import dataclass
from typing import Any, Dict, Optional
_forward_context: Any = None from vllm.config import VllmConfig
def get_forward_context() -> Any: @dataclass
class ForwardContext:
static_forward_context: Dict[str, Any]
# TODO: extend to support per-layer dynamic forward context
dynamic_forward_context: Any
_forward_context: Optional[ForwardContext] = None
def get_forward_context() -> ForwardContext:
"""Get the current forward context.""" """Get the current forward context."""
assert _forward_context is not None, (
"Forward context is not set. "
"Please use `set_forward_context` to set the forward context.")
return _forward_context return _forward_context
@contextmanager @contextmanager
def set_forward_context(context: Any): def set_forward_context(context: Any, vllm_config: VllmConfig):
"""A context manager that stores the current forward context, """A context manager that stores the current forward context,
can be attention metadata, etc.""" can be attention metadata, etc."""
global _forward_context global _forward_context
prev_context = _forward_context prev_context = _forward_context
_forward_context = context _forward_context = ForwardContext(
static_forward_context=vllm_config.compilation_config.
static_forward_context,
dynamic_forward_context=context)
try: try:
yield yield
finally: finally:
......
...@@ -223,6 +223,7 @@ class ArcticAttention(nn.Module): ...@@ -223,6 +223,7 @@ class ArcticAttention(nn.Module):
layer_idx: Optional[int] = None, layer_idx: Optional[int] = None,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -274,7 +275,8 @@ class ArcticAttention(nn.Module): ...@@ -274,7 +275,8 @@ class ArcticAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
...@@ -299,6 +301,7 @@ class ArcticDecoderLayer(nn.Module): ...@@ -299,6 +301,7 @@ class ArcticDecoderLayer(nn.Module):
layer_idx: int, layer_idx: int,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.layer_idx = layer_idx self.layer_idx = layer_idx
...@@ -308,7 +311,8 @@ class ArcticDecoderLayer(nn.Module): ...@@ -308,7 +311,8 @@ class ArcticDecoderLayer(nn.Module):
self.self_attn = ArcticAttention(config, self.self_attn = ArcticAttention(config,
layer_idx, layer_idx,
cache_config, cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.block_sparse_moe = ArcticMoE( self.block_sparse_moe = ArcticMoE(
config, config,
layer_id=layer_idx, layer_id=layer_idx,
...@@ -380,8 +384,11 @@ class ArcticModel(nn.Module): ...@@ -380,8 +384,11 @@ class ArcticModel(nn.Module):
org_num_embeddings=self.vocab_size) org_num_embeddings=self.vocab_size)
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: ArcticDecoderLayer(config, int( lambda prefix: ArcticDecoderLayer(config,
prefix.split(".")[-1]), cache_config, quant_config), int(prefix.split(".")[-1]),
cache_config,
quant_config,
prefix=prefix),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self._attn_implementation = config._attn_implementation self._attn_implementation = config._attn_implementation
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......
...@@ -116,6 +116,7 @@ class BaiChuanAttention(nn.Module): ...@@ -116,6 +116,7 @@ class BaiChuanAttention(nn.Module):
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -158,7 +159,8 @@ class BaiChuanAttention(nn.Module): ...@@ -158,7 +159,8 @@ class BaiChuanAttention(nn.Module):
self.head_dim, self.head_dim,
scaling, scaling,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
else: else:
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
...@@ -171,7 +173,8 @@ class BaiChuanAttention(nn.Module): ...@@ -171,7 +173,8 @@ class BaiChuanAttention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
...@@ -195,7 +198,8 @@ class BaiChuanDecoderLayer(nn.Module): ...@@ -195,7 +198,8 @@ class BaiChuanDecoderLayer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
position_embedding: str, position_embedding: str,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
...@@ -209,6 +213,7 @@ class BaiChuanDecoderLayer(nn.Module): ...@@ -209,6 +213,7 @@ class BaiChuanDecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn",
) )
self.mlp = BaiChuanMLP( self.mlp = BaiChuanMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
...@@ -275,8 +280,11 @@ class BaiChuanModel(nn.Module): ...@@ -275,8 +280,11 @@ class BaiChuanModel(nn.Module):
) )
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: BaiChuanDecoderLayer(config, position_embedding, lambda prefix: BaiChuanDecoderLayer(config,
cache_config, quant_config), position_embedding,
cache_config,
quant_config,
prefix=prefix),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......
...@@ -126,6 +126,7 @@ class BartEncoderAttention(nn.Module): ...@@ -126,6 +126,7 @@ class BartEncoderAttention(nn.Module):
config: Optional[BartConfig] = None, config: Optional[BartConfig] = None,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.d_model = config.d_model self.d_model = config.d_model
...@@ -178,7 +179,8 @@ class BartEncoderAttention(nn.Module): ...@@ -178,7 +179,8 @@ class BartEncoderAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata) -> torch.Tensor: attn_metadata: AttentionMetadata) -> torch.Tensor:
...@@ -208,6 +210,7 @@ class BartDecoderSelfAttention(nn.Module): ...@@ -208,6 +210,7 @@ class BartDecoderSelfAttention(nn.Module):
config: Optional[BartConfig] = None, config: Optional[BartConfig] = None,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.d_model = config.d_model self.d_model = config.d_model
...@@ -260,7 +263,8 @@ class BartDecoderSelfAttention(nn.Module): ...@@ -260,7 +263,8 @@ class BartDecoderSelfAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata) -> torch.Tensor: attn_metadata: AttentionMetadata) -> torch.Tensor:
...@@ -290,6 +294,7 @@ class BartCrossAttention(nn.Module): ...@@ -290,6 +294,7 @@ class BartCrossAttention(nn.Module):
config: Optional[BartConfig] = None, config: Optional[BartConfig] = None,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.d_model = config.d_model self.d_model = config.d_model
...@@ -342,7 +347,8 @@ class BartCrossAttention(nn.Module): ...@@ -342,7 +347,8 @@ class BartCrossAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
...@@ -384,6 +390,7 @@ class BartEncoderLayer(nn.Module): ...@@ -384,6 +390,7 @@ class BartEncoderLayer(nn.Module):
config: BartConfig, config: BartConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
...@@ -393,7 +400,9 @@ class BartEncoderLayer(nn.Module): ...@@ -393,7 +400,9 @@ class BartEncoderLayer(nn.Module):
num_heads=config.encoder_attention_heads, num_heads=config.encoder_attention_heads,
config=config, config=config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.activation_fn = get_act_fn(config.activation_function) self.activation_fn = get_act_fn(config.activation_function)
...@@ -464,6 +473,7 @@ class BartDecoderLayer(nn.Module): ...@@ -464,6 +473,7 @@ class BartDecoderLayer(nn.Module):
config: BartConfig, config: BartConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
...@@ -473,7 +483,9 @@ class BartDecoderLayer(nn.Module): ...@@ -473,7 +483,9 @@ class BartDecoderLayer(nn.Module):
num_heads=config.decoder_attention_heads, num_heads=config.decoder_attention_heads,
config=config, config=config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.activation_fn = get_act_fn(config.activation_function) self.activation_fn = get_act_fn(config.activation_function)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
...@@ -486,6 +498,7 @@ class BartDecoderLayer(nn.Module): ...@@ -486,6 +498,7 @@ class BartDecoderLayer(nn.Module):
self.embed_dim, self.embed_dim,
config.decoder_attention_heads, config.decoder_attention_heads,
config=config, config=config,
prefix=f"{prefix}.encoder_attn",
) )
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
...@@ -578,7 +591,8 @@ class BartEncoder(nn.Module): ...@@ -578,7 +591,8 @@ class BartEncoder(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
embed_tokens: Optional[nn.Embedding] = None): embed_tokens: Optional[nn.Embedding] = None,
prefix: str = ""):
super().__init__() super().__init__()
self.cache_config = cache_config self.cache_config = cache_config
...@@ -599,9 +613,13 @@ class BartEncoder(nn.Module): ...@@ -599,9 +613,13 @@ class BartEncoder(nn.Module):
config.max_position_embeddings, config.max_position_embeddings,
embed_dim, embed_dim,
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList([
[BartEncoderLayer(config,cache_config,quant_config) \ BartEncoderLayer(config,
for _ in range(config.encoder_layers)]) cache_config,
quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(config.encoder_layers)
])
self.layernorm_embedding = nn.LayerNorm(embed_dim) self.layernorm_embedding = nn.LayerNorm(embed_dim)
...@@ -661,6 +679,7 @@ class BartDecoder(nn.Module): ...@@ -661,6 +679,7 @@ class BartDecoder(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
embed_tokens: Optional[nn.Embedding] = None, embed_tokens: Optional[nn.Embedding] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.cache_config = cache_config self.cache_config = cache_config
...@@ -683,8 +702,9 @@ class BartDecoder(nn.Module): ...@@ -683,8 +702,9 @@ class BartDecoder(nn.Module):
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[BartDecoderLayer(config,cache_config,quant_config) \ [BartDecoderLayer(config,cache_config,quant_config,
for _ in range(config.decoder_layers)]) prefix=f"{prefix}.layers.{layer_idx}") \
for layer_idx in range(config.decoder_layers)])
self.layernorm_embedding = nn.LayerNorm(config.d_model) self.layernorm_embedding = nn.LayerNorm(config.d_model)
...@@ -759,10 +779,12 @@ class BartModel(nn.Module): ...@@ -759,10 +779,12 @@ class BartModel(nn.Module):
self.encoder = BartEncoder(config, self.encoder = BartEncoder(config,
cache_config, cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.encoder")
self.decoder = BartDecoder(config, self.decoder = BartDecoder(config,
cache_config, cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.decoder")
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
encoder_input_ids: torch.Tensor, encoder_input_ids: torch.Tensor,
......
...@@ -78,6 +78,7 @@ class BloomAttention(nn.Module): ...@@ -78,6 +78,7 @@ class BloomAttention(nn.Module):
config: BloomConfig, config: BloomConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -116,7 +117,8 @@ class BloomAttention(nn.Module): ...@@ -116,7 +117,8 @@ class BloomAttention(nn.Module):
scaling, scaling,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
...@@ -168,14 +170,17 @@ class BloomBlock(nn.Module): ...@@ -168,14 +170,17 @@ class BloomBlock(nn.Module):
config: BloomConfig, config: BloomConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
self.input_layernorm = nn.LayerNorm(hidden_size, self.input_layernorm = nn.LayerNorm(hidden_size,
eps=config.layer_norm_epsilon) eps=config.layer_norm_epsilon)
self.self_attention = BloomAttention(config, cache_config, self.self_attention = BloomAttention(config,
quant_config) cache_config,
quant_config,
prefix=f"{prefix}.self_attention")
self.post_attention_layernorm = nn.LayerNorm( self.post_attention_layernorm = nn.LayerNorm(
hidden_size, eps=config.layer_norm_epsilon) hidden_size, eps=config.layer_norm_epsilon)
self.mlp = BloomMLP(config, quant_config) self.mlp = BloomMLP(config, quant_config)
...@@ -242,7 +247,8 @@ class BloomModel(nn.Module): ...@@ -242,7 +247,8 @@ class BloomModel(nn.Module):
# Transformer blocks # Transformer blocks
self.start_layer, self.end_layer, self.h = make_layers( self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: BloomBlock(config, cache_config, quant_config), lambda prefix: BloomBlock(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.h") prefix=f"{prefix}.h")
# Final Layer Norm # Final Layer Norm
......
...@@ -223,6 +223,7 @@ class ChameleonAttention(nn.Module): ...@@ -223,6 +223,7 @@ class ChameleonAttention(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = False, bias: bool = False,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -276,7 +277,8 @@ class ChameleonAttention(nn.Module): ...@@ -276,7 +277,8 @@ class ChameleonAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def _apply_qk_norm(self, q: torch.Tensor, def _apply_qk_norm(self, q: torch.Tensor,
k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
...@@ -313,6 +315,7 @@ class ChameleonDecoderLayer(nn.Module): ...@@ -313,6 +315,7 @@ class ChameleonDecoderLayer(nn.Module):
config: ChameleonConfig, config: ChameleonConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -336,6 +339,7 @@ class ChameleonDecoderLayer(nn.Module): ...@@ -336,6 +339,7 @@ class ChameleonDecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
bias=False, bias=False,
cache_config=cache_config, cache_config=cache_config,
prefix=f"{prefix}.self_attn",
) )
self.mlp = ChameleonMLP( self.mlp = ChameleonMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
...@@ -386,6 +390,7 @@ class ChameleonSwinDecoderLayer(nn.Module): ...@@ -386,6 +390,7 @@ class ChameleonSwinDecoderLayer(nn.Module):
config: ChameleonConfig, config: ChameleonConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -409,6 +414,7 @@ class ChameleonSwinDecoderLayer(nn.Module): ...@@ -409,6 +414,7 @@ class ChameleonSwinDecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
bias=False, bias=False,
cache_config=cache_config, cache_config=cache_config,
prefix=f"{prefix}.self_attn",
) )
self.mlp = ChameleonMLP( self.mlp = ChameleonMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
...@@ -855,7 +861,8 @@ class ChameleonModel(nn.Module): ...@@ -855,7 +861,8 @@ class ChameleonModel(nn.Module):
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: decoder_layer(config=config, lambda prefix: decoder_layer(config=config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config), quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment