Unverified Commit 4f95ffee authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Hardware][CPU] Cross-attention and Encoder-Decoder models support on CPU backend (#9089)

parent 8c6de96e
...@@ -23,6 +23,7 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py" ...@@ -23,6 +23,7 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"
# Run basic model test # Run basic model test
docker exec cpu-test bash -c " docker exec cpu-test bash -c "
pip install pytest matplotlib einops transformers_stream_generator datamodel_code_generator pip install pytest matplotlib einops transformers_stream_generator datamodel_code_generator
pytest -v -s tests/models/encoder_decoder/language
pytest -v -s tests/models/decoder_only/language \ pytest -v -s tests/models/decoder_only/language \
--ignore=tests/models/test_fp8.py \ --ignore=tests/models/test_fp8.py \
--ignore=tests/models/decoder_only/language/test_jamba.py \ --ignore=tests/models/decoder_only/language/test_jamba.py \
......
...@@ -4,29 +4,23 @@ Run `pytest tests/models/encoder_decoder/language/test_bart.py`. ...@@ -4,29 +4,23 @@ Run `pytest tests/models/encoder_decoder/language/test_bart.py`.
""" """
from typing import List, Optional, Tuple, Type from typing import List, Optional, Tuple, Type
from vllm.utils import is_cpu import pytest
from transformers import AutoModelForSeq2SeqLM
if not is_cpu(): from vllm.sequence import SampleLogprobs
# CPU backend is not currently supported with encoder/decoder models
# skip test definitions entirely to avoid importing GPU kernel libs
# (xFormers, etc.)
import pytest from ....conftest import (DecoderPromptType, ExplicitEncoderDecoderPrompt,
from transformers import AutoModelForSeq2SeqLM
from vllm.sequence import SampleLogprobs
from ....conftest import (DecoderPromptType, ExplicitEncoderDecoderPrompt,
HfRunner, VllmRunner) HfRunner, VllmRunner)
from ....utils import multi_gpu_test from ....utils import multi_gpu_test
from ...utils import check_logprobs_close from ...utils import check_logprobs_close
MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"] MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"]
def vllm_to_hf_output(
def vllm_to_hf_output(
vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],
decoder_prompt_type: DecoderPromptType, decoder_prompt_type: DecoderPromptType,
): ):
"""Sanitize vllm output to be comparable with hf output.""" """Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output output_ids, output_str, out_logprobs = vllm_output
...@@ -36,7 +30,8 @@ if not is_cpu(): ...@@ -36,7 +30,8 @@ if not is_cpu():
return output_ids, hf_output_str, out_logprobs return output_ids, hf_output_str, out_logprobs
def run_test(
def run_test(
hf_runner: Type[HfRunner], hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner], vllm_runner: Type[VllmRunner],
prompts: List[ExplicitEncoderDecoderPrompt[str, str]], prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
...@@ -48,7 +43,7 @@ if not is_cpu(): ...@@ -48,7 +43,7 @@ if not is_cpu():
num_logprobs: int, num_logprobs: int,
tensor_parallel_size: int, tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None, distributed_executor_backend: Optional[str] = None,
) -> None: ) -> None:
''' '''
Test the vLLM BART model for a variety of encoder/decoder input prompts, Test the vLLM BART model for a variety of encoder/decoder input prompts,
by validating it against HuggingFace (HF) BART. by validating it against HuggingFace (HF) BART.
...@@ -131,8 +126,7 @@ if not is_cpu(): ...@@ -131,8 +126,7 @@ if not is_cpu():
# decoder-only unit tests expect), so when testing an encoder/decoder # decoder-only unit tests expect), so when testing an encoder/decoder
# model we must explicitly specify enforce_eager=True in the VllmRunner # model we must explicitly specify enforce_eager=True in the VllmRunner
# constructor. # constructor.
with vllm_runner( with vllm_runner(model,
model,
dtype=dtype, dtype=dtype,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend, distributed_executor_backend=distributed_executor_backend,
...@@ -154,16 +148,15 @@ if not is_cpu(): ...@@ -154,16 +148,15 @@ if not is_cpu():
with hf_runner(model, dtype=dtype, with hf_runner(model, dtype=dtype,
auto_cls=AutoModelForSeq2SeqLM) as hf_model: auto_cls=AutoModelForSeq2SeqLM) as hf_model:
hf_outputs = ( hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
hf_model.generate_encoder_decoder_greedy_logprobs_limit(
prompts, prompts,
max_tokens, max_tokens,
num_logprobs, num_logprobs,
**hf_kwargs, **hf_kwargs,
)) ))
hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE hf_skip_tokens = (1
else 0) if decoder_prompt_type == DecoderPromptType.NONE else 0)
check_logprobs_close( check_logprobs_close(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,
...@@ -176,14 +169,14 @@ if not is_cpu(): ...@@ -176,14 +169,14 @@ if not is_cpu():
num_outputs_0_skip_tokens=hf_skip_tokens, num_outputs_0_skip_tokens=hf_skip_tokens,
) )
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float", "bfloat16"]) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("dtype", ["float", "bfloat16"])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) @pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, @pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
model, dtype, max_tokens, num_logprobs, def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model,
decoder_prompt_type) -> None: dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None:
run_test( run_test(
hf_runner, hf_runner,
...@@ -197,14 +190,15 @@ if not is_cpu(): ...@@ -197,14 +190,15 @@ if not is_cpu():
tensor_parallel_size=1, tensor_parallel_size=1,
) )
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) @multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"]) @pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"])
@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("model", ["facebook/bart-large-cnn"])
@pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM]) @pytest.mark.parametrize("num_logprobs", [5])
def test_models_distributed(hf_runner, vllm_runner, @pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM])
def test_models_distributed(hf_runner, vllm_runner,
example_encoder_decoder_prompts, example_encoder_decoder_prompts,
distributed_executor_backend, model, dtype, distributed_executor_backend, model, dtype,
max_tokens, num_logprobs, max_tokens, num_logprobs,
......
...@@ -75,6 +75,22 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -75,6 +75,22 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
seq_lens: Optional[List[int]] seq_lens: Optional[List[int]]
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens: Optional[List[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
# Maximum sequence length among encoder sequences
max_encoder_seq_len: Optional[int] = None
# Number of tokens input to encoder
num_encoder_tokens: Optional[int] = None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping: Optional[torch.Tensor] = None
cross_block_tables: Optional[torch.Tensor] = None
def __post_init__(self): def __post_init__(self):
# Set during the execution of the first attention op. # Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt # It is a list because it is needed to set per prompt
...@@ -82,6 +98,28 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -82,6 +98,28 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
# from xformer API. # from xformer API.
# will not appear in the __repr__ and __init__ # will not appear in the __repr__ and __init__
self.attn_bias: Optional[List[torch.Tensor]] = None self.attn_bias: Optional[List[torch.Tensor]] = None
self.encoder_attn_bias: Optional[List[torch.Tensor]] = None
self.cross_attn_bias: Optional[List[torch.Tensor]] = None
@property
def is_all_encoder_attn_metadata_set(self):
'''
All attention metadata required for encoder attention is set.
'''
return ((self.encoder_seq_lens is not None)
and (self.encoder_seq_lens_tensor is not None)
and (self.max_encoder_seq_len is not None))
@property
def is_all_cross_attn_metadata_set(self):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return (self.is_all_encoder_attn_metadata_set
and (self.cross_slot_mapping is not None)
and (self.cross_block_tables is not None))
@property @property
def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]: def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]:
...@@ -101,6 +139,136 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -101,6 +139,136 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
return self return self
def get_seq_lens(
self,
attn_type: AttentionType,
):
'''
Extract appropriate sequence lengths from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence lengths tensor for query
* Appropriate sequence lengths tensor for key & value
'''
if attn_type == AttentionType.DECODER:
seq_lens_q = self.seq_lens
seq_lens_kv = self.seq_lens
elif attn_type == AttentionType.ENCODER:
seq_lens_q = self.encoder_seq_lens
seq_lens_kv = self.encoder_seq_lens
elif attn_type == AttentionType.ENCODER_DECODER:
seq_lens_q = self.seq_lens
seq_lens_kv = self.encoder_seq_lens
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
return seq_lens_q, seq_lens_kv
def get_attn_bias(
self,
attn_type: AttentionType,
) -> Optional[List[torch.Tensor]]:
'''
Extract appropriate attention bias from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate attention bias value given the attention type
'''
if attn_type == AttentionType.DECODER:
return self.attn_bias
elif attn_type == AttentionType.ENCODER:
return self.encoder_attn_bias
elif attn_type == AttentionType.ENCODER_DECODER:
return self.cross_attn_bias
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def set_attn_bias(
self,
attn_bias: List[torch.Tensor],
attn_type: AttentionType,
) -> None:
'''
Update appropriate attention bias field of attention metadata,
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_bias: The desired attention bias value
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
'''
if attn_type == AttentionType.DECODER:
self.attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER:
self.encoder_attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER_DECODER:
self.cross_attn_bias = attn_bias
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def get_seq_len_block_table_args(
self,
attn_type: AttentionType,
) -> tuple:
'''
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
'''
if attn_type == AttentionType.DECODER:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
return (self.seq_lens_tensor, self.max_decode_seq_len,
self.block_tables)
elif attn_type == AttentionType.ENCODER_DECODER:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
self.cross_block_tables)
elif attn_type == AttentionType.ENCODER:
# No block tables associated with encoder attention
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
None)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
...@@ -171,84 +339,101 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -171,84 +339,101 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
assert k_scale == 1.0 and v_scale == 1.0 assert k_scale == 1.0 and v_scale == 1.0
if attn_type != AttentionType.DECODER: if (attn_type == AttentionType.ENCODER
raise NotImplementedError("Encoder self-attention and " and (not attn_metadata.is_all_encoder_attn_metadata_set)):
"encoder/decoder cross-attention " raise AttributeError("Encoder attention requires setting "
"are not implemented for " "encoder metadata attributes.")
"TorchSDPABackendImpl") elif (attn_type == AttentionType.ENCODER_DECODER
num_tokens, hidden_size = query.shape and (not attn_metadata.is_all_cross_attn_metadata_set)):
raise AttributeError("Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes.")
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
if key is not None:
assert value is not None
key = key.view(-1, self.num_kv_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size)
else:
if kv_cache.numel() > 0: assert value is None
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
#
# Even if there are no new key/value pairs to cache,
# we still need to break out key_cache and value_cache
# i.e. for later use by paged attention
key_cache, value_cache = PagedAttention.split_kv_cache( key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size) kv_cache, self.num_kv_heads, self.head_size)
if (key is not None) and (value is not None):
if attn_type == AttentionType.ENCODER_DECODER:
# Update cross-attention KV cache (prefill-only)
# During cross-attention decode, key & value will be None,
# preventing this IF-statement branch from running
updated_slot_mapping = attn_metadata.cross_slot_mapping
else:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping = attn_metadata.slot_mapping
PagedAttention.write_to_paged_cache(key, value, key_cache, PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache, value_cache,
attn_metadata.slot_mapping, updated_slot_mapping,
self.kv_cache_dtype, k_scale, self.kv_cache_dtype,
v_scale) k_scale, v_scale)
if attn_metadata.is_prompt: if attn_type != AttentionType.ENCODER:
# Decoder self-attention supports chunked prefill.
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
else:
# Encoder attention - chunked prefill is not applicable;
# derive token-count from query shape & and treat them
# as 100% prefill tokens
assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens
num_decode_tokens = 0
if attn_type == AttentionType.DECODER:
# Only enforce this shape-constraint for decoder
# self-attention
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
if prefill_meta := attn_metadata.prefill_metadata:
assert attn_metadata.seq_lens is not None assert attn_metadata.seq_lens is not None
if (kv_cache.numel() == 0 if (kv_cache.numel() == 0
or attn_metadata.block_tables.numel() == 0): or prefill_meta.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads: output = self._run_sdpa_forward(query,
key = key.repeat_interleave(self.num_queries_per_kv, dim=1) key,
value = value.repeat_interleave(self.num_queries_per_kv, value,
dim=1) prefill_meta,
attn_type=attn_type)
if attn_metadata.attn_bias is None:
if self.alibi_slopes is not None:
att_masks = _make_alibi_bias(
self.alibi_slopes, query.dtype,
attn_metadata.seq_lens) # type: ignore
elif self.sliding_window is not None:
att_masks = _make_sliding_window_bias(
attn_metadata.seq_lens, self.sliding_window,
query.dtype) # type: ignore
else:
att_masks = [None] * len(attn_metadata.seq_lens)
attn_metadata.attn_bias = att_masks
query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
start = 0
output = torch.empty(
(num_tokens, self.num_heads, self.head_size),
dtype=query.dtype)
for seq_len, mask in zip(attn_metadata.seq_lens,
attn_metadata.attn_bias):
end = start + seq_len
sub_out = scaled_dot_product_attention(
query[None, :, start:end, :],
key[None, :, start:end, :],
value[None, :, start:end, :],
attn_mask=mask,
dropout_p=0.0,
is_causal=not self.need_mask,
scale=self.scale).squeeze(0).movedim(
query.dim() - 2, 0)
output[start:end, :, :] = sub_out
start = end
else: else:
# prefix-enabled attention # prefix-enabled attention
raise RuntimeError( raise RuntimeError(
"Torch SDPA backend doesn't support prefix decoding.") "Torch SDPA backend doesn't support prefix decoding.")
else: if decode_meta := attn_metadata.decode_metadata:
# Decoding run. # Decoding run.
(
seq_lens_arg,
max_seq_len_arg,
block_tables_arg,
) = decode_meta.get_seq_len_block_table_args(attn_type)
output = PagedAttention.forward_decode( output = PagedAttention.forward_decode(
query, query,
key_cache, key_cache,
value_cache, value_cache,
attn_metadata.block_tables, block_tables_arg,
attn_metadata.seq_lens_tensor, seq_lens_arg,
attn_metadata.max_decode_seq_len, max_seq_len_arg,
self.kv_cache_dtype, self.kv_cache_dtype,
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
...@@ -260,6 +445,59 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -260,6 +445,59 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
# Reshape the output tensor. # Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size) return output.view(-1, self.num_heads * self.head_size)
def _run_sdpa_forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: TorchSDPAMetadata,
attn_type: AttentionType = AttentionType.DECODER,
):
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
value = value.repeat_interleave(self.num_queries_per_kv, dim=1)
attn_masks = attn_metadata.get_attn_bias(attn_type)
if attn_masks is None:
if self.alibi_slopes is not None:
attn_masks = _make_alibi_bias(
self.alibi_slopes, query.dtype,
attn_metadata.seq_lens) # type: ignore
elif self.sliding_window is not None:
assert attn_metadata.seq_lens is not None
attn_masks = _make_sliding_window_bias(
attn_metadata.seq_lens, self.sliding_window,
query.dtype) # type: ignore
else:
seq_lens, _ = attn_metadata.get_seq_lens(attn_type)
attn_masks = [None] * len(seq_lens)
attn_metadata.set_attn_bias(attn_masks, attn_type)
output = torch.empty_like(query)
query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
causal_attn = (attn_type == AttentionType.DECODER)
seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)
start_q, start_kv = 0, 0
for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv,
attn_masks):
end_q = start_q + seq_len_q
end_kv = start_kv + seq_len_kv
sub_out = scaled_dot_product_attention(
query[None, :, start_q:end_q, :],
key[None, :, start_kv:end_kv, :],
value[None, :, start_kv:end_kv, :],
attn_mask=mask,
dropout_p=0.0,
is_causal=causal_attn and not self.need_mask,
scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0)
output[start_q:end_q, :, :] = sub_out
start_q, start_kv = end_q, end_kv
return output
def _make_alibi_bias( def _make_alibi_bias(
alibi_slopes: torch.Tensor, alibi_slopes: torch.Tensor,
......
import dataclasses
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast
import torch
from vllm.attention import AttentionMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MultiModalInputs
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad
from vllm.worker.cpu_model_runner import (CPUModelRunner,
ModelInputForCPUBuilder,
ModelInputForCPUWithSamplingMetadata)
from vllm.worker.model_runner_base import (
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
@dataclasses.dataclass(frozen=True)
class EncoderDecoderModelInputForCPU(ModelInputForCPUWithSamplingMetadata):
"""
Used by the EncoderDecoderModelRunner.
"""
encoder_input_tokens: Optional[torch.Tensor] = None
encoder_input_positions: Optional[torch.Tensor] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"encoder_input_tokens": self.encoder_input_tokens,
"encoder_input_positions": self.encoder_input_positions,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
self.sampling_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls,
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "EncoderDecoderModelInputForCPU":
return cast(
EncoderDecoderModelInputForCPU,
super().from_broadcasted_tensor_dict(tensor_dict, attn_backend))
class CPUEncoderDecoderModelRunner(CPUModelRunner):
_model_input_cls: Type[EncoderDecoderModelInputForCPU] = (
EncoderDecoderModelInputForCPU)
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder
def _list_to_int32_tensor(
self,
_list: List[int],
) -> torch.Tensor:
return torch.tensor(_list, dtype=torch.int32, device=self.device)
def _list_to_long_tensor(
self,
_list: List[int],
) -> torch.Tensor:
return torch.tensor(_list, dtype=torch.long, device=self.device)
def _empty_int32_tensor(self) -> torch.Tensor:
return self._list_to_int32_tensor([])
def _empty_long_tensor(self) -> torch.Tensor:
return self._list_to_long_tensor([])
def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str,
Any]) -> EncoderDecoderModelInputForCPU:
return EncoderDecoderModelInputForCPU.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
)
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> EncoderDecoderModelInputForCPU:
model_input = super().prepare_model_input(seq_group_metadata_list,
virtual_engine,
finished_requests_ids)
model_input = cast(EncoderDecoderModelInputForCPU, model_input)
(
attn_metadata,
encoder_input_tokens_tensor,
encoder_input_positions_tensor,
) = self._prepare_encoder_model_input_tensors(seq_group_metadata_list,
model_input)
return dataclasses.replace(
model_input,
attn_metadata=attn_metadata,
encoder_input_tokens=encoder_input_tokens_tensor,
encoder_input_positions=encoder_input_positions_tensor,
)
def _prepare_encoder_model_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
model_input: EncoderDecoderModelInputForCPU,
) -> Tuple[AttentionMetadata, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""Helper method to prepare the encoder- and cross-attn-related
model inputs based on a given sequence group. These additional inputs
are used to augment an already-computed `EncoderDecoderModelInput`
data structure which already has decoder-related model inputs
populated.
Sets the following attn_metadata fields:
* `num_encoder_tokens`
* `encoder_seq_lens`
* `encoder_seq_lens_tensor`
* `max_encoder_seq_len`
* `cross_slot_mapping`
* `cross_block_tables`
Constructs a new model inputs data structure, based on
(1) the existing fields in the `model_inputs` argument,
and (2) the following additional fields which are
computed (or in the case of `attn_metadata`, updated)
by this function:
* attn_metadata
* encoder_input_tokens
* encoder_input_positions
Arguments:
* seq_group_metadata_list: list of sequence groups for which to
compute inputs
* model_inputs: model inputs data structure with decoder-oriented
fields already computed.
Return:
* Updated model inputs data structure
"""
if len(seq_group_metadata_list) == 0:
return (model_input.attn_metadata, None, None)
# Since we are not supporting chunked prefill either the entire
# batch is prefill or it is decode
is_prompt = seq_group_metadata_list[0].is_prompt
# Build encoder inputs
encoder_seq_lens: List[int] = []
if is_prompt:
# Prefill phase.
cross_block_tables = self._empty_int32_tensor().view(
len(seq_group_metadata_list), -1)
# Extract input tokens/positions, cross-attention slot-mapping,
# & seq len from each sequence group metadata
(
encoder_input_tokens,
encoder_input_positions,
cross_slot_mapping,
) = (
[],
[],
[],
)
for seq_group_metadata in seq_group_metadata_list:
# Build seq lens
seq_len = seq_group_metadata.encoder_seq_data.get_len()
token_ids = seq_group_metadata.encoder_seq_data.get_token_ids()
encoder_seq_lens.append(seq_len)
# Build slot mapping
for i in range(0, seq_len):
block_number = seq_group_metadata.cross_block_table[
i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
cross_slot_mapping.append(slot)
# Build encoder input tokens
encoder_input_tokens.extend(token_ids)
encoder_input_positions.extend(list(range(0, seq_len)))
# Convert tokens/positions & cross-attention
# slot-mapping to encoder input tensors
encoder_input_tokens_tensor = self._list_to_long_tensor(
encoder_input_tokens)
encoder_input_positions_tensor = self._list_to_long_tensor(
encoder_input_positions)
cross_slot_mapping_tensor = self._list_to_long_tensor(
cross_slot_mapping)
else:
# Decode phase.
encoder_input_tokens_tensor = self._empty_long_tensor()
encoder_input_positions_tensor = self._empty_long_tensor()
cross_slot_mapping_tensor = self._empty_long_tensor()
# Extract cross-attention block tables &
# seq len from each sequence group metadata.
# Cross-attention block tables are empty
# during vLLM memory profiling.
cross_block_tables = []
for seq_group_metadata in seq_group_metadata_list:
for _ in range(len(seq_group_metadata.seq_data)):
encoder_seq_lens.append(
seq_group_metadata.encoder_seq_data.get_len())
cross_block_table = seq_group_metadata.cross_block_table
cross_block_tables.append([] if (
cross_block_table is None) else cross_block_table)
max_len_of_block_table = max(
len(block_table) for block_table in cross_block_tables)
cross_block_tables = make_tensor_with_pad(
cross_block_tables,
max_len=max_len_of_block_table,
pad=0,
dtype=torch.int32,
device=self.device,
)
# Compute encoder sequence lengths & encoder
# sequence starting offset tensors
max_encoder_seq_len = max(encoder_seq_lens, default=0)
encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens)
encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] +
1,
dtype=torch.int32,
device=self.device)
torch.cumsum(encoder_seq_lens_tensor,
dim=0,
dtype=encoder_seq_start_loc.dtype,
out=encoder_seq_start_loc[1:])
# Update attention metadata with encoder-oriented attributes
attn_metadata = model_input.attn_metadata
assert attn_metadata is not None
(
attn_metadata.num_encoder_tokens,
attn_metadata.encoder_seq_lens,
attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len,
attn_metadata.cross_slot_mapping,
attn_metadata.cross_block_tables,
) = (
sum(encoder_seq_lens),
encoder_seq_lens,
encoder_seq_lens_tensor,
max_encoder_seq_len,
cross_slot_mapping_tensor,
cross_block_tables,
)
return (attn_metadata, encoder_input_tokens_tensor,
encoder_input_positions_tensor)
@torch.no_grad()
def execute_model(
self,
model_input: EncoderDecoderModelInputForCPU,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
if num_steps > 1:
raise ValueError(
"CPU worker does not support multi-step execution.")
model_executable = self.model
execute_model_kwargs = {
"input_ids":
model_input.input_tokens,
"positions":
model_input.input_positions,
"encoder_input_ids":
model_input.encoder_input_tokens,
"encoder_positions":
model_input.encoder_input_positions,
"kv_caches":
kv_caches,
"attn_metadata":
model_input.attn_metadata,
**MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
"intermediate_tensors":
intermediate_tensors,
}
hidden_states = model_executable(**execute_model_kwargs)
# Compute the logits.
logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata)
# Only perform sampling in the driver worker.
if not self.is_driver_worker:
return []
# Sample the next token.
output = self.model.sample(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
return [output]
...@@ -19,7 +19,7 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, ...@@ -19,7 +19,7 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs) MultiModalInputs)
from vllm.sequence import (IntermediateTensors, SequenceData, from vllm.sequence import (IntermediateTensors, SequenceData,
SequenceGroupMetadata) SequenceGroupMetadata)
from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS, make_tensor_with_pad from vllm.utils import make_tensor_with_pad
from vllm.worker.model_runner_base import ( from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict, _add_attn_metadata_broadcastable_dict,
...@@ -434,10 +434,6 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]): ...@@ -434,10 +434,6 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
# Lazy initialization. # Lazy initialization.
self.model: nn.Module # Set after init_Model self.model: nn.Module # Set after init_Model
if self.model_config.is_encoder_decoder_model:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU'])
@property @property
def model_is_mrope(self) -> bool: def model_is_mrope(self) -> bool:
"""Detect if the model has "mrope" rope_scaling type. """Detect if the model has "mrope" rope_scaling type.
...@@ -459,8 +455,8 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]): ...@@ -459,8 +455,8 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
def make_model_input_from_broadcasted_tensor_dict( def make_model_input_from_broadcasted_tensor_dict(
self, self,
tensor_dict: Dict[str, Any], tensor_dict: Dict[str, Any],
) -> ModelInputForCPU: ) -> ModelInputForCPUWithSamplingMetadata:
return ModelInputForCPU.from_broadcasted_tensor_dict( return ModelInputForCPUWithSamplingMetadata.from_broadcasted_tensor_dict( # noqa: E501
tensor_dict, tensor_dict,
attn_backend=self.attn_backend, attn_backend=self.attn_backend,
) )
......
"""A CPU worker class.""" """A CPU worker class."""
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple, Type
import torch import torch
import torch.distributed import torch.distributed
...@@ -15,6 +15,7 @@ from vllm.logger import init_logger ...@@ -15,6 +15,7 @@ from vllm.logger import init_logger
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
from vllm.worker.cpu_model_runner import CPUModelRunner from vllm.worker.cpu_model_runner import CPUModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerInput) LoraNotSupportedWorkerBase, WorkerInput)
...@@ -163,7 +164,10 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -163,7 +164,10 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
else: else:
self.local_omp_cpuid = omp_cpuids.split("|")[rank] self.local_omp_cpuid = omp_cpuids.split("|")[rank]
self.model_runner: CPUModelRunner = CPUModelRunner( ModelRunnerClass: Type[CPUModelRunner] = CPUModelRunner
if self._is_encoder_decoder_model():
ModelRunnerClass = CPUEncoderDecoderModelRunner
self.model_runner: CPUModelRunner = ModelRunnerClass(
model_config, model_config,
parallel_config, parallel_config,
scheduler_config, scheduler_config,
...@@ -205,6 +209,9 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -205,6 +209,9 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
raise RuntimeError("Profiler is not enabled.") raise RuntimeError("Profiler is not enabled.")
self.profiler.stop() self.profiler.stop()
def _is_encoder_decoder_model(self):
return self.model_config.is_encoder_decoder_model
def init_device(self) -> None: def init_device(self) -> None:
if self.local_omp_cpuid != "all": if self.local_omp_cpuid != "all":
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment