Unverified Commit fd95e026 authored by afeldman-nm's avatar afeldman-nm Committed by GitHub
Browse files

[Core] Subclass ModelRunner to support cross-attention & encoder sequences...


[Core] Subclass ModelRunner to support cross-attention & encoder sequences (towards eventual encoder/decoder model support) (#4942)
Co-authored-by: default avatarAndrew Feldman <afeld2012@gmail.com>
Co-authored-by: default avatarNick Hill <nickhill@us.ibm.com>
parent 660470e5
......@@ -148,8 +148,9 @@ steps:
- python3 cpu_offload.py
- python3 offline_inference_with_prefix.py
- python3 llm_engine_example.py
- python3 llava_example.py
- python3 offline_inference_vision_language.py
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
- python3 offline_inference_encoder_decoder.py
- label: Models Test # 1hr10min
source_file_dependencies:
......@@ -289,6 +290,7 @@ steps:
commands:
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
- TARGET_TEST_SUITE=L4 pytest -v -s distributed/test_basic_distributed_correctness.py
- pytest -v -s distributed/test_basic_distributed_correctness_enc_dec.py
- pytest -v -s distributed/test_chunked_prefill_distributed.py
- pytest -v -s distributed/test_multimodal_broadcast.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
......
'''
Demonstrate prompting of text-to-text
encoder/decoder models, specifically BART
'''
from vllm import LLM, SamplingParams
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
from vllm.utils import zip_enc_dec_prompt_lists
dtype = "float"
# Create a BART encoder/decoder model instance
llm = LLM(
model="facebook/bart-large-cnn",
dtype=dtype,
)
# Get BART tokenizer
tokenizer = llm.llm_engine.get_tokenizer_group()
# Test prompts
#
# This section shows all of the valid ways to prompt an
# encoder/decoder model.
#
# - Helpers for building prompts
text_prompt_raw = "Hello, my name is"
text_prompt = TextPrompt(prompt="The president of the United States is")
tokens_prompt = TokensPrompt(prompt_token_ids=tokenizer.encode(
prompt="The capital of France is"))
# - Pass a single prompt to encoder/decoder model
# (implicitly encoder input prompt);
# decoder input prompt is assumed to be None
single_text_prompt_raw = text_prompt_raw # Pass a string directly
single_text_prompt = text_prompt # Pass a TextPrompt
single_tokens_prompt = tokens_prompt # Pass a TokensPrompt
# - Pass explicit encoder and decoder input prompts within one data structure.
# Encoder and decoder prompts can both independently be text or tokens, with
# no requirement that they be the same prompt type. Some example prompt-type
# combinations are shown below, note that these are not exhaustive.
enc_dec_prompt1 = ExplicitEncoderDecoderPrompt(
# Pass encoder prompt string directly, &
# pass decoder prompt tokens
encoder_prompt=single_text_prompt_raw,
decoder_prompt=single_tokens_prompt,
)
enc_dec_prompt2 = ExplicitEncoderDecoderPrompt(
# Pass TextPrompt to encoder, and
# pass decoder prompt string directly
encoder_prompt=single_text_prompt,
decoder_prompt=single_text_prompt_raw,
)
enc_dec_prompt3 = ExplicitEncoderDecoderPrompt(
# Pass encoder prompt tokens directly, and
# pass TextPrompt to decoder
encoder_prompt=single_tokens_prompt,
decoder_prompt=single_text_prompt,
)
# - Finally, here's a useful helper function for zipping encoder and
# decoder prompt lists together into a list of ExplicitEncoderDecoderPrompt
# instances
zipped_prompt_list = zip_enc_dec_prompt_lists(
['An encoder prompt', 'Another encoder prompt'],
['A decoder prompt', 'Another decoder prompt'])
# - Let's put all of the above example prompts together into one list
# which we will pass to the encoder/decoder LLM.
prompts = [
single_text_prompt_raw, single_text_prompt, single_tokens_prompt,
enc_dec_prompt1, enc_dec_prompt2, enc_dec_prompt3
] + zipped_prompt_list
print(prompts)
# Create a sampling params object.
sampling_params = SamplingParams(
temperature=0,
top_p=1.0,
min_tokens=0,
max_tokens=20,
)
# Generate output tokens from the prompts. The output is a list of
# RequestOutput objects that contain the prompt, generated
# text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
encoder_prompt = output.encoder_prompt
generated_text = output.outputs[0].text
print(f"Encoder prompt: {encoder_prompt!r}, "
f"Decoder prompt: {prompt!r}, "
f"Generated text: {generated_text!r}")
......@@ -10,9 +10,11 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
AutoTokenizer, BatchEncoding, BatchFeature)
from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM,
AutoModelForVision2Seq, AutoTokenizer, BatchEncoding,
BatchFeature)
from tests.models.utils import DecoderPromptType
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.config import TokenizerPoolConfig
......@@ -21,9 +23,11 @@ from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel)
from vllm.inputs import TextPrompt
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sequence import SampleLogprobs
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
is_cpu)
is_cpu, to_enc_dec_tuple_list,
zip_enc_dec_prompt_lists)
logger = init_logger(__name__)
......@@ -120,6 +124,40 @@ def example_prompts() -> List[str]:
return prompts
@pytest.fixture
def example_encoder_decoder_prompts() \
-> Dict[DecoderPromptType,
Tuple[List[str], List[Optional[str]]]]:
'''
Returns an encoder prompt list and a decoder prompt list, wherein each pair
of same-index entries in both lists corresponds to an (encoder prompt,
decoder prompt) tuple.
Returns:
* Encoder prompt list
* Decoder prompt list (reverse of encoder prompt list)
'''
encoder_prompts = []
for filename in _TEST_PROMPTS:
encoder_prompts += _read_prompts(filename)
custom_decoder_prompts = encoder_prompts[::-1]
empty_str_decoder_prompts = [""] * len(encoder_prompts)
none_decoder_prompts = [None] * len(encoder_prompts)
# NONE decoder prompt type
return {
DecoderPromptType.NONE:
zip_enc_dec_prompt_lists(encoder_prompts, none_decoder_prompts),
DecoderPromptType.EMPTY_STR:
zip_enc_dec_prompt_lists(encoder_prompts, empty_str_decoder_prompts),
DecoderPromptType.CUSTOM:
zip_enc_dec_prompt_lists(encoder_prompts, custom_decoder_prompts),
}
@pytest.fixture
def example_long_prompts() -> List[str]:
prompts = []
......@@ -152,6 +190,7 @@ class HfRunner:
model_kwargs: Optional[Dict[str, Any]] = None,
is_embedding_model: bool = False,
is_vision_model: bool = False,
is_encoder_decoder_model: bool = False,
) -> None:
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
......@@ -168,6 +207,8 @@ class HfRunner:
else:
if is_vision_model:
auto_cls = AutoModelForVision2Seq
elif is_encoder_decoder_model:
auto_cls = AutoModelForSeq2SeqLM
else:
auto_cls = AutoModelForCausalLM
......@@ -314,6 +355,44 @@ class HfRunner:
all_logprobs.append(seq_logprobs)
return all_logprobs
def _hidden_states_to_logprobs(
self,
hidden_states,
num_logprobs,
) -> Tuple[List[Dict[int, float]], int]:
seq_logprobs: List[torch.Tensor] = []
output_len = len(hidden_states)
for _, hidden_state in enumerate(hidden_states):
last_hidden_states = hidden_state[-1][0]
logits = torch.matmul(
last_hidden_states,
self.model.get_output_embeddings().weight.t(),
)
if getattr(self.model.get_output_embeddings(), "bias",
None) is not None:
logits += self.model.get_output_embeddings().bias.unsqueeze(0)
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
seq_logprobs.append(logprobs)
# convert to dict
seq_logprobs_lst: List[Dict[int, float]] = []
for tok_idx, tok_logprobs in enumerate(seq_logprobs):
# drop prompt logprobs
if tok_idx == 0:
tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
topk = tok_logprobs.topk(num_logprobs)
tok_logprobs_dct = {}
for token_id, logprob in zip(topk.indices[0], topk.values[0]):
tok_logprobs_dct[token_id.item()] = logprob.item()
seq_logprobs_lst.append(tok_logprobs_dct)
return (
seq_logprobs_lst,
output_len,
)
def generate_greedy_logprobs_limit(
self,
prompts: List[str],
......@@ -346,37 +425,66 @@ class HfRunner:
**kwargs,
)
seq_logprobs: List[torch.Tensor] = []
for _, hidden_states in enumerate(output.hidden_states):
last_hidden_states = hidden_states[-1][0]
logits = torch.matmul(
last_hidden_states,
self.model.get_output_embeddings().weight.t(),
)
if getattr(self.model.get_output_embeddings(), "bias",
None) is not None:
logits += self.model.get_output_embeddings(
).bias.unsqueeze(0)
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
seq_logprobs.append(logprobs)
(
seq_logprobs_lst,
output_len,
) = self._hidden_states_to_logprobs(output.hidden_states,
num_logprobs)
# convert to dict
seq_logprobs_lst: List[Dict[int, float]] = []
for tok_idx, tok_logprobs in enumerate(seq_logprobs):
# drop prompt logprobs
if tok_idx == 0:
tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
topk = tok_logprobs.topk(num_logprobs)
all_logprobs.append(seq_logprobs_lst)
seq_ids = output.sequences[0]
output_len = len(seq_logprobs_lst)
output_ids = seq_ids[-output_len:]
all_output_ids.append(output_ids.tolist())
all_output_strs.append(self.tokenizer.decode(output_ids))
tok_logprobs_dct = {}
for token_id, logprob in zip(topk.indices[0], topk.values[0]):
tok_logprobs_dct[token_id.item()] = logprob.item()
outputs = zip(all_output_ids, all_output_strs, all_logprobs)
return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]
seq_logprobs_lst.append(tok_logprobs_dct)
def generate_encoder_decoder_greedy_logprobs_limit(
self,
encoder_decoder_prompts: Tuple[List[str], List[str]],
max_tokens: int,
num_logprobs: int,
**kwargs: Any,
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
'''
Greedy logprobs generation for vLLM encoder/decoder models
'''
all_logprobs: List[List[Dict[int, float]]] = []
all_output_ids: List[List[int]] = []
all_output_strs: List[str] = []
for (encoder_prompt,
decoder_prompt) in to_enc_dec_tuple_list(encoder_decoder_prompts):
encoder_input_ids = self.wrap_device(
self.tokenizer(encoder_prompt, return_tensors="pt").input_ids)
decoder_input_ids = (
None if decoder_prompt is None else self.wrap_device(
self.tokenizer(decoder_prompt,
return_tensors="pt").input_ids))
output = self.model.generate(
encoder_input_ids,
decoder_input_ids=decoder_input_ids,
use_cache=True,
do_sample=False,
max_new_tokens=max_tokens,
output_hidden_states=True,
return_dict_in_generate=True,
**kwargs,
)
(
seq_logprobs_lst,
output_len,
) = self._hidden_states_to_logprobs(output.decoder_hidden_states,
num_logprobs)
all_logprobs.append(seq_logprobs_lst)
seq_ids = output.sequences[0]
output_len = len(seq_logprobs_lst)
output_ids = seq_ids[-output_len:]
all_output_ids.append(output_ids.tolist())
all_output_strs.append(self.tokenizer.decode(output_ids))
......@@ -416,7 +524,7 @@ class VllmRunner:
block_size: int = 16,
enable_chunked_prefill: bool = False,
swap_space: int = 4,
enforce_eager: bool = False,
enforce_eager: Optional[bool] = False,
**kwargs,
) -> None:
self.model = LLM(
......@@ -465,6 +573,19 @@ class VllmRunner:
outputs.append((req_sample_output_ids, req_sample_output_strs))
return outputs
def _final_steps_generate_w_logprobs(
self,
req_outputs: List[RequestOutput],
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
for req_output in req_outputs:
for sample in req_output.outputs:
output_str = sample.text
output_ids = sample.token_ids
output_logprobs = sample.logprobs
outputs.append((output_ids, output_str, output_logprobs))
return outputs
def generate_w_logprobs(
self,
prompts: List[str],
......@@ -483,14 +604,21 @@ class VllmRunner:
req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
for req_output in req_outputs:
for sample in req_output.outputs:
output_str = sample.text
output_ids = sample.token_ids
output_logprobs = sample.logprobs
outputs.append((output_ids, output_str, output_logprobs))
return outputs
return self._final_steps_generate_w_logprobs(req_outputs)
def generate_encoder_decoder_w_logprobs(
self,
encoder_decoder_prompts: Tuple[List[str], List[str]],
sampling_params: SamplingParams,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
'''
Logprobs generation for vLLM encoder/decoder models
'''
assert sampling_params.logprobs is not None
req_outputs = self.model.generate(encoder_decoder_prompts,
sampling_params=sampling_params)
return self._final_steps_generate_w_logprobs(req_outputs)
def generate_greedy(
self,
......@@ -523,6 +651,26 @@ class VllmRunner:
return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]
def generate_encoder_decoder_greedy_logprobs(
self,
encoder_decoder_prompts: Tuple[List[str], List[str]],
max_tokens: int,
num_logprobs: int,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
greedy_logprobs_params = SamplingParams(temperature=0.0,
use_beam_search=False,
max_tokens=max_tokens,
logprobs=num_logprobs)
'''
Greedy logprobs generation for vLLM encoder/decoder models
'''
outputs = self.generate_encoder_decoder_w_logprobs(
encoder_decoder_prompts, greedy_logprobs_params)
return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]
def generate_beam_search(
self,
prompts: List[str],
......
......@@ -9,33 +9,11 @@ from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus
from vllm.core.scheduler import Scheduler, SchedulingBudget
from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob, SequenceGroup, SequenceStatus
from vllm.sequence import SequenceGroup, SequenceStatus
from .utils import create_dummy_prompt
def get_sequence_groups(scheduler_output):
return [s.seq_group for s in scheduler_output.scheduled_seq_groups]
def append_new_token(out, token_id: int):
seq_groups = get_sequence_groups(out)
for seq_group in seq_groups:
for seq in seq_group.get_seqs():
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
def schedule_and_update_computed_tokens(scheduler):
metas, out = scheduler.schedule()
for s, meta in zip(out.scheduled_seq_groups, metas):
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
return metas, out
def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int):
seq_group.update_num_computed_tokens(token_chunk_size)
for seq in seq_group.get_seqs():
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
from .utils import (append_new_token, append_new_token_seq_group,
create_dummy_prompt, get_sequence_groups,
schedule_and_update_computed_tokens)
def test_scheduler_add_seq_group():
......
from typing import List
import pytest # noqa
from vllm.config import CacheConfig, SchedulerConfig
from vllm.core.scheduler import Scheduler
from vllm.sequence import SequenceGroup
from .utils import (append_new_token, create_dummy_prompt_encoder_decoder,
get_sequence_groups, schedule_and_update_computed_tokens)
def test_scheduler_schedule_simple_encoder_decoder():
'''
Test basic scheduler functionality in the context
of an encoder/decoder model. Focus on testing
enc/dec-specific functionality sense tests already
exist for decoder-only functionality
Test behavior:
* Construct Scheduler
* Construct dummy encoder/decoder sequence groups
* Add dummy seq groups to scheduler backlog
* Schedule the next seq group & validate:
* Cross-attn block tables
* Updated states of seq groups
* Number of batched tokens
* Number of blocks to copy/swap-in/swap-out
* Number of scheduled seq groups
* Repeat for both prefill- and decode-phase
* Abort scheduled seq groups
* Assert that aborted seq groups no longer appear in
cross-attention block table
'''
block_size = 4
num_seq_group = 4
max_model_len = 16
scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 16 # enc and dec prompts per seq_group
cache_config.num_gpu_blocks = 16 # enc and dec prompts per seq_group
scheduler = Scheduler(scheduler_config, cache_config, None)
running: List[SequenceGroup] = []
# Add seq groups to scheduler.
req_id_list = []
for i in range(num_seq_group):
req_id = str(i)
req_id_list.append(req_id)
_, _, seq_group = create_dummy_prompt_encoder_decoder(
req_id, block_size, block_size, block_size)
scheduler.add_seq_group(seq_group)
running.append(seq_group)
# Schedule seq groups prefill.
num_tokens = block_size * num_seq_group
seq_group_meta_list, out = schedule_and_update_computed_tokens(scheduler)
# - Verify that sequence group cross-attention block tables are
# registered with the block manager
assert all([(req_id in scheduler.block_manager.cross_block_tables)
for req_id in req_id_list])
# - Validate sequence-group status
assert set(get_sequence_groups(out)) == set(running)
# - Validate number of batched tokens
assert out.num_batched_tokens == num_tokens
# - Validate there are no remaining blocks to swap
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out)
# - Validate all seq groups were scheduled
assert len(seq_group_meta_list) == num_seq_group
append_new_token(out, 1)
# Schedule seq groups decode.
seq_group_meta_list, out = schedule_and_update_computed_tokens(scheduler)
# - Verify that sequence group metadata includes encoder attention
# and cross-attention metadata
assert all([
not ((seq_group_meta.encoder_seq_data is None) or
(seq_group_meta.cross_block_table is None))
for seq_group_meta in seq_group_meta_list
])
# - Validate sequence-group status
assert set(get_sequence_groups(out)) == set(running)
# - Validate there is one batched token per seq group
assert out.num_batched_tokens == num_seq_group
# - Validate there are no remaining blocks to swap
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out)
# - Validate that all seq groups were scheduled
assert len(seq_group_meta_list) == num_seq_group
append_new_token(out, 1)
# Abort sequences
for req_id in req_id_list:
scheduler.abort_seq_group(req_id)
# - Verify that sequence group cross-attention block tables are
# NO LONGER registered with the block manager
assert req_id not in scheduler.block_manager.cross_block_tables
......@@ -53,27 +53,30 @@ def create_dummy_prompt_encoder_decoder(
block_size = decoder_prompt_length
# Create dummy prompt sequence with tokens 0...block_size-1
# and prompt "0 ... block_size".
# and prompt "0 ... block_size". Note that the prompt string
# doesn't actually match the tokens
decoder_prompt_tokens = list(range(decoder_prompt_length))
decoder_prompt_str = " ".join([str(t) for t in decoder_prompt_tokens])
encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length))))
encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens])
decoder_prompt = Sequence(int(request_id),
inputs={
inputs = {
"prompt": decoder_prompt_str,
"prompt_token_ids": decoder_prompt_tokens,
"encoder_prompt": encoder_prompt_str,
"encoder_prompt_token_ids": encoder_prompt_tokens,
"multi_modal_data": None,
},
block_size=block_size)
}
decoder_prompt = Sequence(int(request_id),
inputs=inputs,
block_size=block_size,
from_decoder_prompt=True)
encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length))))
encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens])
encoder_prompt = Sequence(int(request_id),
inputs={
"prompt": encoder_prompt_str,
"prompt_token_ids": encoder_prompt_tokens,
"multi_modal_data": None,
},
block_size=block_size)
inputs=inputs,
block_size=block_size,
from_decoder_prompt=False)
seq_group = SequenceGroup(request_id=request_id,
seqs=[decoder_prompt],
sampling_params=SamplingParams(
......@@ -139,17 +142,21 @@ def create_seq_group_encoder_decoder(
prompt_token_ids = [0] * seq_prompt_len
seqs = []
for seq_id_offset, output_len in enumerate(seq_output_lens):
seq = Sequence(
seq_id=seq_id_start + seq_id_offset,
inputs={
inputs = {
"prompt": "",
"prompt_token_ids": prompt_token_ids,
"encoder_prompt": "",
"encoder_prompt_token_ids": prompt_token_ids,
"multi_modal_data": None,
},
}
seqs = []
for seq_id_offset, output_len in enumerate(seq_output_lens):
# Construct decoder input sequences
seq = Sequence(seq_id=seq_id_start + seq_id_offset,
inputs=inputs,
block_size=16,
)
from_decoder_prompt=True)
for i in range(output_len):
seq.append_token_id(
......@@ -158,16 +165,11 @@ def create_seq_group_encoder_decoder(
)
seqs.append(seq)
# Encoder sequence
encoder_seq = Sequence(
seq_id=seq_id_start + len(seq_output_lens),
inputs={
"prompt": "",
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": None,
},
# Encoder input sequence
encoder_seq = Sequence(seq_id=seq_id_start + len(seq_output_lens),
inputs=inputs,
block_size=16,
)
from_decoder_prompt=False)
return SequenceGroup(request_id=request_id,
seqs=seqs,
......@@ -178,3 +180,30 @@ def create_seq_group_encoder_decoder(
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
return (seq_len + block_size - 1) // block_size
# Helper functions for scheduler tests
def get_sequence_groups(scheduler_output):
return [s.seq_group for s in scheduler_output.scheduled_seq_groups]
def append_new_token(out, token_id: int):
seq_groups = get_sequence_groups(out)
for seq_group in seq_groups:
for seq in seq_group.get_seqs():
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
def schedule_and_update_computed_tokens(scheduler):
metas, out = scheduler.schedule()
for s, meta in zip(out.scheduled_seq_groups, metas):
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
return metas, out
def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int):
seq_group.update_num_computed_tokens(token_chunk_size)
for seq in seq_group.get_seqs():
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
"""For encoder/decoder models only:
Compare the outputs of HF and distributed vLLM when using greedy sampling.
Run:
```sh
cd $VLLM_PATH/tests
pytest distributed/test_basic_distributed_correctness_enc_dec.py
```
"""
import pytest
from tests.models.utils import DecoderPromptType
from vllm.utils import cuda_device_count_stateless
from ..models.utils import check_logprobs_close
from ..utils import fork_new_process_for_each_test
@pytest.mark.skipif(cuda_device_count_stateless() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("model, distributed_executor_backend", [
("facebook/bart-large-cnn", "ray"),
("facebook/bart-large-cnn", "mp"),
])
@fork_new_process_for_each_test
def test_models(
model: str,
distributed_executor_backend: str,
hf_runner,
vllm_runner,
example_encoder_decoder_prompts,
) -> None:
'''
Test vLLM BART inference on more than one GPU, comparing
outputs against HF as a baseline.
Fork a new process for each test, to prevent CUDA from
being re-initialized by successive tests within the same
process.
Arguments:
* model: the HF ID of the specific BART variant under test
* distributed_executor_backend
* hf_runner: HuggingFace (HF) test model runner
* vllm_runner: vLLM test model runner
* example_encoder_decoder_prompts: test fixture which provides a
dictionary of dummy prompts
'''
dtype = "float"
max_tokens = 64
num_logprobs = 5
# Example inputs with non-trivial (i.e. not None/empty) encoder &
# decoder prompts.
test_prompts = example_encoder_decoder_prompts[DecoderPromptType.CUSTOM]
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
with vllm_runner(
model,
dtype=dtype,
tensor_parallel_size=2,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
) as vllm_model:
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
test_prompts, max_tokens, num_logprobs)
# Configuration settings for HF baseline
hf_kwargs = {
"top_k": None,
"num_beams": 1,
"repetition_penalty": 1.0,
"top_p": 1.0,
"length_penalty": 1.0,
"early_stopping": False,
"no_repeat_ngram_size": None,
"min_length": 0
}
with hf_runner(model, dtype=dtype,
is_encoder_decoder_model=True) as hf_model:
hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
test_prompts,
max_tokens,
num_logprobs,
**hf_kwargs,
))
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
......@@ -3,9 +3,9 @@ from unittest.mock import patch
import pytest
import torch
from tests.kernels.utils import (STR_FLASH_ATTN_VAL, STR_INVALID_VAL,
override_backend_env_variable)
from tests.kernels.utils import override_backend_env_variable
from vllm.attention.selector import which_attn_to_use
from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL
@pytest.mark.parametrize(
......
......@@ -4,8 +4,6 @@ Tests:
* E2E test of Encoder attention + Decoder self-attention +
Encoder/decoder cross-attention (collectively
"encoder/decoder attention")
* Confirm enc/dec models will fail for chunked prefill
* Confirm enc/dec models will fail for prefix caching
"""
......@@ -15,19 +13,22 @@ import pytest
import torch
from tests.kernels.utils import *
from tests.kernels.utils import make_causal_mask, maybe_make_long_tensor
from vllm.attention import Attention, AttentionMetadata
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
AttentionType)
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from vllm.attention.selector import (_Backend,
global_force_attn_backend_context_manager)
from vllm.utils import is_hip
# List of support backends for encoder/decoder models
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS]
HEAD_SIZES = [64, 256]
NUM_HEADS = [1, 16]
BATCH_SIZES = [1, 16]
BLOCK_SIZES = [16]
BACKEND_NAMES = [STR_XFORMERS_ATTN_VAL]
CUDA_DEVICE = "cuda:0"
MAX_DEC_SEQ_LENS = [128]
......@@ -724,23 +725,58 @@ def _run_encoder_decoder_cross_attention_test(
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("backend_name", BACKEND_NAMES)
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS)
@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS)
def test_encoder_only(num_heads: int, head_size: int, backend_name: str,
batch_size: int, block_size: int, max_dec_seq_len: int,
max_enc_seq_len: int, monkeypatch):
def test_encoder_only(
num_heads: int,
head_size: int,
attn_backend: _Backend,
batch_size: int,
block_size: int,
max_dec_seq_len: int,
max_enc_seq_len: int,
):
'''
End-to-end encoder-only attention test:
* Construct fake test vectors for (1) encoder attention
* Construct (1) attention metadata structure with prefill-phase
encoder attention, and (2) an analogous attention metadata
structure but for decode-phase
* Test & validate encoder attention against ideal output
No KV cache is required for encoder-only attention.
Note on ROCm/HIP: currently encoder/decoder models are not supported on
AMD GPUs, therefore this test simply is skipped if is_hip().
This test globally forces an override of the usual backend
auto-selection process, forcing the specific backend-under-test
to be utilized.
Arguments:
* num_heads
* head_size,
* attn_backend: The attention backend to employ for testing
* batch_size
* block_size: KV cache block size
* max_dec_seq_len: max length of decoder input sequences
* max_enc_seq_len: max length of encoder input sequences
'''
# Force Attention wrapper backend
override_backend_env_variable(monkeypatch, backend_name)
with global_force_attn_backend_context_manager(attn_backend):
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size
# is not part of this test
test_pt = TestPoint(num_heads, head_size, backend_name, batch_size,
block_size, max_dec_seq_len, max_enc_seq_len, 4096)
test_pt = TestPoint(num_heads, head_size, attn_backend.name,
batch_size, block_size, max_dec_seq_len,
max_enc_seq_len, 4096)
# Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init
......@@ -774,7 +810,7 @@ def test_encoder_only(num_heads: int, head_size: int, backend_name: str,
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("backend_name", BACKEND_NAMES)
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS)
......@@ -782,12 +818,11 @@ def test_encoder_only(num_heads: int, head_size: int, backend_name: str,
def test_e2e_enc_dec_attn(
num_heads: int,
head_size: int,
backend_name: str,
attn_backend: _Backend,
batch_size: int,
block_size: int,
max_dec_seq_len: int,
max_enc_seq_len: int,
monkeypatch,
) -> None:
'''
End-to-end encoder/decoder test:
......@@ -820,8 +855,9 @@ def test_e2e_enc_dec_attn(
cross-attention K/Vs are allowed to differ in seq len, as is often the case
for cross-attention.
This test utilizes PyTest monkey patching to force the attention backend
via an environment variable.
This test globally forces an override of the usual backend
auto-selection process, forcing the specific backend-under-test
to be utilized.
Note on ROCm/HIP: currently encoder/decoder models are not supported on
AMD GPUs, therefore this test simply is skipped if is_hip().
......@@ -830,23 +866,34 @@ def test_e2e_enc_dec_attn(
all prefill-phase attention operations (encoder, decoder, enc/dec cross),
and a single one shared by all decode-phase attention operations
(decoder & enc/dec cross.) This is intended to reflect the behavior
of ModelRunner, which constructs a single attention metadata structure for
each prefill or decode run. A realistic scenario would rely on the
attention backend to utilize the appropriate attention metadata fields
according to the value of attn_metadata.attention_type. Thus, this test is
organized so as to confirm that the backend-under-test can handle a
shared prefill attention metadata structure & a shared decode attention
metadata structure.
of EncoderDecoderModelRunner, which constructs a single attention metadata
structure for each prefill or decode run. A realistic scenario would rely
on the attention backend to utilize the appropriate attention metadata
fields according to the value of attn_metadata.attention_type. Thus,
this test is organized so as to confirm that the backend-under-test can
handle a shared prefill attention metadata structure & a shared decode\
attention metadata structure.
Arguments:
* num_heads
* head_size,
* attn_backend: The attention backend to employ for testing
* batch_size
* block_size: KV cache block size
* max_dec_seq_len: max length of decoder input sequences
* max_enc_seq_len: max length of encoder input sequences
'''
# Force Attention wrapper backend
override_backend_env_variable(monkeypatch, backend_name)
with global_force_attn_backend_context_manager(attn_backend):
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size
# is not part of this test
test_pt = TestPoint(num_heads, head_size, backend_name, batch_size,
block_size, max_dec_seq_len, max_enc_seq_len, 4096)
test_pt = TestPoint(num_heads, head_size, attn_backend.name,
batch_size, block_size, max_dec_seq_len,
max_enc_seq_len, 4096)
# Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init
......@@ -870,8 +917,9 @@ def test_e2e_enc_dec_attn(
cross_block_base_addr,
) = _decoder_attn_setup(test_pt, test_rsrcs)
# Construct encoder/decoder cross-attention prefill-phase & decode-phase
# test params, including key/value tensors, cross-attention memory-mapping
# Construct encoder/decoder cross-attention prefill-phase
# & decode-phase test params, including key/value tensors,
# cross-attention memory-mapping
(
prephase_cross_test_params,
......
......@@ -211,5 +211,5 @@ def test_varlen_with_paged_kv(
sliding_window=sliding_window,
soft_cap=soft_cap,
)
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
assert torch.allclose(output, ref_output, atol=2e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
......@@ -8,24 +8,10 @@ from typing import Any, List, NamedTuple, Optional, Tuple, Union
import pytest
import torch
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata, AttentionType)
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.attention.backends.xformers import XFormersBackend
from vllm.utils import make_tensor_with_pad
# String name of register which may be set in order to
# force auto-selection of attention backend by Attention
# wrapper
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
# Possible string values of STR_BACKEND_ENV_VAR
# register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
make_tensor_with_pad)
class QKVInputs(NamedTuple):
......
"""Compare the outputs of HF and vLLM for BART models using greedy sampling.
Run `pytest tests/models/test_bart.py`.
"""
from vllm.utils import is_cpu
if not is_cpu():
# CPU backend is not currently supported with encoder/decoder models
# skip test definitions entirely to avoid importing GPU kernel libs
# (xFormers, etc.)
import pytest
from tests.models.utils import DecoderPromptType
from .utils import check_logprobs_close
MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"]
DECODER_PROMPT_TYPES = ([
DecoderPromptType.CUSTOM, DecoderPromptType.EMPTY_STR,
DecoderPromptType.NONE
])
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float", "bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("decoder_prompt_type", DECODER_PROMPT_TYPES)
def test_models(
hf_runner,
vllm_runner,
example_encoder_decoder_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
decoder_prompt_type: DecoderPromptType,
) -> None:
'''
Test the vLLM BART model for a variety of encoder/decoder input prompts,
by validating it against HuggingFace (HF) BART.
Arguments:
* hf_runner: HuggingFace (HF) test model runner
* vllm_runner: vLLM test model runner
* example_encoder_decoder_prompts: test fixture which provides a
dictionary of dummy prompts
* model: the HF ID of the specific BART variant under test
* dtype: the tensor datatype to employ
* max_tokens
* num_logprobs
* decoder_prompt_type: key into the example_encoder_decoder_prompts
dictionary; selects specific encoder/decoder
prompt scenarios to test
A note on using HF BART as a baseline for validating vLLM BART,
specifically when the decoder prompt is None.
The HF GenerationMixin's default behavior is to force the first
decoded token to be <BOS> if the prompt does not already contain
<BOS> (this is accomplished using a logit
processor setting.)
So when we use HF BART as our baseline for comparison, note that
when the user provides a request with a None decoder prompt
(i.e. a singleton encoder prompt, or else an explicit encoder/
decoder prompt with the decoder sub-prompt set to None), HF and
vLLM handle this in different ways:
* HF will (1) tokenize the None prompt as an empty token-list,
(2) append <decoder-start-token> to the beginning, yielding
[<decoder-start-token>], (3) pass this token list to the model, and
then (4) after computing logits during prefill, override the model
logits & force <BOS> to be the first generated token.
* vLLM will (1) tokenize the None prompt as [<BOS>], (2) append decoder-
start-token to the beginning, yielding [<decoder-start-token><BOS>],
(3) pass these tokens to the model & proceed with generation.
The net effect is that compared to vLLM, the list of HF *decoded* tokens
will contain one more initial <BOS> than the vLLM generated tokens,
because vLLM's <BOS> token is injected into the prompt rather than into
the generated output. This is in spite of the fact that overall, the
complete sequences (prompt + decoded tokens) produced by vLLM will match
HF.
So when we use HF decoded token output to validate vLLM's decoded token
output, the testing process must account for the difference in decoded
token sequences between vLLM and HF specifically in the
decoder-prompt-is-None case.
One option is to disable the logit processor feature that forces the
<BOS> token to be decoded (forced_bos_token_id = None), eliminating
the problem entirely. However this is not "normal" BART usage.
The other option is - only in the decoder-prompt-is-None case - to
discard the first decoded token from the HF output before comparing it
to vLLM.
To that end, when testing the scenario where the decoder prompt is None
(and only in that one scenario), this test skips the first HF decoded
token during the process of validating the vLLM decoded output.
'''
test_case_prompts = example_encoder_decoder_prompts[
decoder_prompt_type]
# Configuration settings for HF baseline
hf_kwargs = {
"top_k": None,
"num_beams": 1,
"repetition_penalty": 1.0,
"top_p": 1.0,
"length_penalty": 1.0,
"early_stopping": False,
"no_repeat_ngram_size": None,
"min_length": 0
}
with hf_runner(model, dtype=dtype,
is_encoder_decoder_model=True) as hf_model:
hf_outputs = (
hf_model.generate_encoder_decoder_greedy_logprobs_limit(
test_case_prompts,
max_tokens,
num_logprobs,
**hf_kwargs,
))
# Note: currently encoder/decoder models are only compatible with
# enforce_eager=True. Normally this is not a problem because
# for encoder/decoder models vLLM will
# default to enforce_eager=True if enforce_eager
# is left unspecified. However, the
# VllmRunner test fixture (which wraps around the LLM class) defaults to
# enforce_eager=False (a behavior which a number of already-exisitng
# decoder-only unit tests expect), so when testing an encoder/decoder
# model we must explicitly specify enforce_eager=True in the VllmRunner
# constructor.
with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model:
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
test_case_prompts, max_tokens, num_logprobs)
hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE
else 0)
check_logprobs_close(outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
num_outputs_0_skip_tokens=hf_skip_tokens)
import warnings
from enum import Enum
from typing import Dict, List, Optional, Sequence, Tuple, Union
from vllm.sequence import SampleLogprobs
......@@ -45,11 +46,27 @@ def check_logprobs_close(
outputs_1_lst: Sequence[TokensTextLogprobs],
name_0: str,
name_1: str,
num_outputs_0_skip_tokens: int = 0,
warn_on_mismatch: bool = True,
):
"""
Compare the logprobs of two sequences generated by different models,
which should be similar but not necessarily equal.
Arguments:
* outputs_0_lst: First sequence to compare
* outputs_0_lst: Second sequence to compare
* name_0: sequence #0 name
* name_1: sequence #1 name
* num_outputs_0_skip_tokens: If > 0, specifies the number of initial
sequence #0 tokens & logprobs to discard
before comparison, i.e. all
of sequence #1 will be compared to
sequence #0 beginning at index
num_outputs_0_skip_tokens
* warn_on_mismatch: Issue a warning if there is token-wise or text-wise
mismatch between the two sequences
"""
assert len(outputs_0_lst) == len(outputs_1_lst)
......@@ -65,6 +82,15 @@ def check_logprobs_close(
if logprobs_1 is None:
logprobs_1 = [None] * len(output_ids_1)
# Skip specified number of initial sequence #0 tokens
# & logprobs, leaving output text as-is for simplicity
# (text mismatches may generate warnings but do not
# cause the test to fail.)
if num_outputs_0_skip_tokens < 0:
raise ValueError("num_outputs_0_skip_tokens must be non-negative")
output_ids_0 = output_ids_0[num_outputs_0_skip_tokens:]
logprobs_0 = logprobs_0[num_outputs_0_skip_tokens:]
# Loop through generated tokens.
for idx, (output_id_0,
output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
......@@ -110,3 +136,13 @@ def check_logprobs_close(
warnings.simplefilter("always")
warnings.warn(fail_msg, stacklevel=2)
class DecoderPromptType(Enum):
'''
For encoder/decoder models only -
'''
CUSTOM = 1
NONE = 2
EMPTY_STR = 3
from typing import List
import pytest
import torch
from vllm.engine.arg_utils import EngineArgs
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import is_cpu
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
# CUDA graph scenarios to test
#
# Currently CUDA graph is not supported
ENFORCE_EAGER = [True]
BATCH_SIZES = [1, 4, 16, 64, 256]
def _create_model_runner(model: str, *args,
**kwargs) -> EncoderDecoderModelRunner:
engine_args = EngineArgs(model, *args, **kwargs)
engine_config = engine_args.create_engine_config()
model_runner = EncoderDecoderModelRunner(
model_config=engine_config.model_config,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
lora_config=engine_config.lora_config,
prompt_adapter_config=engine_config.prompt_adapter_config,
is_driver_worker=True,
)
return model_runner
@pytest.mark.skipif(condition=is_cpu(),
reason="CPU backend is currently "
"unsupported for encoder/ "
"decoder models")
@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER)
def test_empty_seq_group(enforce_eager, ):
"""Verify prepare prompt and decode returns empty output
for empty seq group list"""
model_runner = _create_model_runner(
"facebook/bart-base",
seed=0,
dtype="float16",
max_num_batched_tokens=100000,
max_num_seqs=100000,
enable_chunked_prefill=False,
enforce_eager=enforce_eager,
)
seq_group_metadata_list: List[SequenceGroupMetadata] = []
model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list)
(
input_tokens,
input_positions,
encoder_input_tokens,
encoder_input_positions,
attn_metadata,
return_seq_lens,
) = (
model_input.input_tokens,
model_input.input_positions,
model_input.encoder_input_tokens,
model_input.encoder_input_positions,
model_input.attn_metadata,
model_input.seq_lens,
)
assert input_tokens is None
assert input_positions is None
assert encoder_input_tokens is None
assert encoder_input_positions is None
assert attn_metadata is None
assert return_seq_lens is None
@pytest.mark.skipif(condition=is_cpu(),
reason="CPU backend is currently "
"unsupported for encoder/ "
"decoder models")
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER)
def test_prepare_prompt(
batch_size,
enforce_eager,
):
'''
Test the ability of the encoder/decoder model runner subclass to
produce prefill-phase model inputs & attention metadata.
Test behavior:
* Instantiate BART base model & enc/dec model runner
* Construct sequence-group metadata for dummy prompts
* Test that encoder attention, decoder self-attention,
and encoder/decoder cross-attention inputs are correct
Arguments:
* batch_size
* backend_name: The attention backend under test
* enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph)
'''
model_runner = _create_model_runner(
"facebook/bart-base",
seed=0,
dtype="float16",
max_num_batched_tokens=100000,
max_num_seqs=100000,
enable_chunked_prefill=False,
enforce_eager=enforce_eager,
)
seq_lens: List[int] = []
encoder_seq_lens: List[int] = []
seq_group_metadata_list: List[SequenceGroupMetadata] = []
block_tables = {0: [1]}
cross_block_table = [2]
for i in range(batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData(list(range(seq_len)))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_lens.append(encoder_seq_len)
encoder_seq_data = SequenceData(list(range(encoder_seq_len)))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: seq_data},
sampling_params=SamplingParams(temperature=0),
block_tables=block_tables,
encoder_seq_data=encoder_seq_data,
cross_block_table=cross_block_table,
)
assert seq_group_metadata.token_chunk_size == seq_data.get_len()
seq_group_metadata_list.append(seq_group_metadata)
# Build
# * Decoder model inputs
# * Decoder self-attention KV caching data structures
# * Encoder model inputs
# * Encoder/decoder cross-attention KV caching data structures
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
attn_metadata = model_input.attn_metadata
return_seq_lens = model_input.seq_lens
slot_mapping = attn_metadata.slot_mapping
encoder_input_tokens = model_input.encoder_input_tokens
encoder_input_positions = model_input.encoder_input_positions
cross_slot_mapping = attn_metadata.cross_slot_mapping
assert return_seq_lens == seq_lens
assert len(slot_mapping) == len(input_tokens)
assert len(cross_slot_mapping) == len(encoder_input_tokens)
# Verify input metadata is correct for prompts.
# - Decoder attention metadata
device = model_runner.device
assert attn_metadata.num_prefills > 0
assert attn_metadata.num_decode_tokens == 0
assert torch.equal(attn_metadata.seq_lens_tensor,
torch.tensor(seq_lens, device=device, dtype=torch.int))
assert attn_metadata.seq_lens == seq_lens
assert attn_metadata.max_prefill_seq_len == max(seq_lens)
assert attn_metadata.max_decode_seq_len == 0
# - Encoder attention metadata
assert attn_metadata.encoder_seq_lens == encoder_seq_lens
assert torch.equal(
attn_metadata.encoder_seq_lens_tensor,
torch.tensor(encoder_seq_lens, device=device, dtype=torch.int))
assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens)
assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens)
# Test decoder subquery start locs.
start_idx = 0
start_loc = [start_idx]
for seq_len in seq_lens:
start_idx += seq_len
start_loc.append(start_idx)
assert torch.equal(
attn_metadata.query_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device),
)
# Test decoder seq start locs & context lengths
assert torch.equal(
attn_metadata.seq_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device),
)
assert torch.equal(
attn_metadata.context_lens_tensor,
torch.zeros(attn_metadata.context_lens_tensor.shape[0],
dtype=torch.int,
device=device),
)
# Verify block tables are correct for prompts
# - Decoder self-attention
expected = torch.tensor(
[[] for _ in range(len(seq_group_metadata_list))],
dtype=torch.int32,
device=model_runner.device,
)
assert torch.equal(
attn_metadata.block_tables,
expected,
)
# - Encoder/decoder cross-attention
assert torch.equal(
attn_metadata.cross_block_tables,
expected,
)
# Cuda graph should not be used for prefill.
assert attn_metadata.use_cuda_graph is False
# Verify the lengths of input tokens & positions
# - Decoder
assert len(input_tokens) == sum(seq_lens)
assert len(input_positions) == sum(seq_lens)
# -- An indirect check that model_input.input_tokens
# and model_input.input_positions are correct -
# by design of the test, the input tokens are
# equal to the input position values, so if
# the model_input data structure has the correct
# values then these two should be equal
assert torch.equal(
input_tokens,
input_positions,
)
# - Encoder
assert len(encoder_input_tokens) == sum(encoder_seq_lens)
# -- An indirect check that model_input.encoder_input_tokens
# and model_input.encoder_input_positions are correct -
# by design of the test, the input tokens are
# equal to the input position values, so if
# the model_input data structure has the correct
# values then these two should be equal
assert torch.equal(
encoder_input_tokens,
encoder_input_positions,
)
# Test that vLLM sampling infrastructure chooses the correct
# sequence positions at which to sample (i.e. the end of
# each sequence) in the prefill phase
expected_selected_token_indices = []
selected_token_start_idx = 0
for seq_len in seq_lens:
# Compute the index offset of the final token in each
# prompt (recall that the prompts are concatenated)
expected_selected_token_indices.append(selected_token_start_idx +
seq_len - 1)
selected_token_start_idx += seq_len
sampling_metadata = model_input.sampling_metadata
actual = sampling_metadata.selected_token_indices
expected = torch.tensor(
expected_selected_token_indices,
device=actual.device,
dtype=actual.dtype,
)
assert torch.equal(actual, expected)
@pytest.mark.skipif(condition=is_cpu(),
reason="CPU backend is currently "
"unsupported for encoder/ "
"decoder models")
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER)
def test_prepare_decode(
batch_size,
enforce_eager,
):
'''
Test the ability of the encoder/decoder model runner subclass to
produce decode-phase model inputs & attention metadata.
Test behavior:
* Instantiate BART base model & enc/dec model runner
* Construct sequence-group metadata for dummy prompts
* Test that encoder attention, decoder self-attention,
and encoder/decoder cross-attention inputs are correct
Arguments:
* batch_size
* backend_name: The attention backend under test
* enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph)
'''
model_runner = _create_model_runner(
"facebook/bart-base",
seed=0,
dtype="float16",
max_num_batched_tokens=100000,
max_num_seqs=100000,
enable_chunked_prefill=False,
enforce_eager=enforce_eager,
)
seq_lens: List[int] = []
encoder_seq_lens: List[int] = []
seq_group_metadata_list: List[SequenceGroupMetadata] = []
block_tables = {0: [1]}
cross_block_table = [2]
for i in range(batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData(list(range(seq_len)))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_lens.append(encoder_seq_len)
encoder_seq_data = SequenceData(list(range(encoder_seq_len)))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,
seq_data={0: seq_data},
sampling_params=SamplingParams(temperature=0),
block_tables=block_tables,
encoder_seq_data=encoder_seq_data,
cross_block_table=cross_block_table,
)
assert seq_group_metadata.token_chunk_size == 1
seq_group_metadata_list.append(seq_group_metadata)
# Build
# * Decoder model inputs
# * Decoder self-attention KV caching data structures
# * Encoder model inputs
# * Encoder/decoder cross-attention KV caching data structures
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
attn_metadata = model_input.attn_metadata
return_seq_lens = model_input.seq_lens
slot_mapping = attn_metadata.slot_mapping
encoder_input_tokens = model_input.encoder_input_tokens
encoder_input_positions = model_input.encoder_input_positions
cross_slot_mapping = attn_metadata.cross_slot_mapping
assert return_seq_lens == seq_lens
assert len(slot_mapping) == len(input_tokens)
assert len(cross_slot_mapping) == len(encoder_input_tokens)
# Verify input metadata is correct for decode phase.
# - Decoder attention metadata
device = model_runner.device
assert attn_metadata.num_prefills == 0
assert attn_metadata.num_decode_tokens > 0
assert torch.equal(attn_metadata.seq_lens_tensor,
torch.tensor(seq_lens, device=device, dtype=torch.int))
assert attn_metadata.seq_lens == seq_lens
assert attn_metadata.max_prefill_seq_len == 0
assert attn_metadata.max_decode_seq_len == max(seq_lens)
# - Encoder attention metadata
assert attn_metadata.encoder_seq_lens == encoder_seq_lens
assert torch.equal(
attn_metadata.encoder_seq_lens_tensor,
torch.tensor(encoder_seq_lens, device=device, dtype=torch.int))
assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens)
assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens)
# Test decoder subquery start locs.
start_idx = 0
start_loc = [start_idx]
for seq_len in seq_lens:
start_idx += 1
start_loc.append(start_idx)
assert torch.equal(
attn_metadata.query_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device),
)
# Test decoder seq start locs. Note that for normal prefill it is
# equivalent to query_start_loc.
start_idx = 0
seq_start_loc = [start_idx]
for seq_len in seq_lens:
start_idx += seq_len
seq_start_loc.append(start_idx)
# Test seq_start_loc and context lengths
assert torch.equal(
attn_metadata.seq_start_loc,
torch.tensor(seq_start_loc, dtype=torch.int32, device=device),
)
assert torch.equal(
attn_metadata.context_lens_tensor,
torch.tensor([seq_len - 1 for seq_len in seq_lens],
dtype=torch.int,
device=device))
# Verify block tables are correct for prompts
# - Decoder self-attention
expected = torch.tensor(
[block_tables[0] for _ in range(len(seq_group_metadata_list))],
dtype=torch.int32,
device=model_runner.device)
assert torch.equal(
attn_metadata.block_tables,
expected,
)
# - Encoder/decoder cross-attention
expected = torch.tensor(
[cross_block_table for _ in range(len(seq_group_metadata_list))],
dtype=torch.int32,
device=model_runner.device)
assert torch.equal(
attn_metadata.cross_block_tables,
expected,
)
# Cuda graph should is currently not supported for encoder/decoer.
assert attn_metadata.use_cuda_graph is False
# Verify the lengths of input tokens & positions
# - Decoder
assert len(input_tokens) == len(seq_lens)
assert len(input_positions) == len(seq_lens)
# -- An indirect check that model_input.input_tokens
# and model_input.input_positions are correct -
# by design of the test, the input tokens are
# equal to the input position values, so if
# the model_input data structure has the correct
# values then these two should be equal
assert torch.equal(
input_tokens,
input_positions,
)
# - Encoder
assert len(encoder_input_tokens) == 0
assert len(encoder_input_tokens) == 0
# -- An indirect check that model_input.encoder_input_tokens
# and model_input.encoder_input_positions are correct -
# by design of the test, the input tokens are
# equal to the input position values, so if
# the model_input data structure has the correct
# values then these two should be equal
assert torch.equal(
encoder_input_tokens,
encoder_input_positions,
)
# Test that vLLM sampling infrastructure chooses the correct
# sequence positions at which to sample (i.e. the end of
# each sequence) in the decode phase
expected_selected_token_indices = []
selected_token_start_idx = 0
for seq_len in seq_lens:
# Compute the index offset of the final token in each
# sequence's decoded outputs; since a single token is
# decoded per iteration per sequence, then the length
# of the decoded tokens for a given sequence is 1 and
# the final index offset into a given sequence's
# generated tokens is 0 (i.e. the expected sampling index
# for a given sequence is just `selected_token_start_idx`)
expected_selected_token_indices.append(selected_token_start_idx)
selected_token_start_idx += 1
sampling_metadata = model_input.sampling_metadata
actual = sampling_metadata.selected_token_indices
expected = torch.tensor(
expected_selected_token_indices,
device=actual.device,
dtype=actual.dtype,
)
assert torch.equal(actual, expected)
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata,
AttentionMetadataBuilder)
AttentionMetadataBuilder,
AttentionType)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
......@@ -8,6 +9,7 @@ __all__ = [
"Attention",
"AttentionBackend",
"AttentionMetadata",
"AttentionType",
"AttentionMetadataBuilder",
"Attention",
"get_attn_backend",
......
......@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
import torch
import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadata, AttentionType
from vllm.attention import AttentionMetadata, AttentionType
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import (
......
import enum
import os
from contextlib import contextmanager
from functools import lru_cache
from typing import Optional, Type
from typing import Generator, Optional, Type
import torch
......@@ -8,7 +10,8 @@ import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import is_cpu, is_hip, is_openvino, is_tpu, is_xpu
from vllm.utils import (STR_BACKEND_ENV_VAR, is_cpu, is_hip, is_openvino,
is_tpu, is_xpu)
logger = init_logger(__name__)
......@@ -24,6 +27,66 @@ class _Backend(enum.Enum):
IPEX = enum.auto()
def backend_name_to_enum(backend_name: str) -> _Backend:
assert backend_name is not None
backend_members = _Backend.__members__
if backend_name not in backend_members:
raise ValueError(f"Invalid attention backend '{backend_name}'. "
f"Available backends: {', '.join(backend_members)} "
"(case-sensitive).")
return _Backend[backend_name]
def get_env_variable_attn_backend() -> Optional[_Backend]:
'''
Get the backend override specified by the vLLM attention
backend environment variable, if one is specified.
Returns:
* _Backend enum value if an override is specified
* None otherwise
'''
backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
return (None
if backend_name is None else backend_name_to_enum(backend_name))
# Global state allows a particular choice of backend
# to be forced, overriding the logic which auto-selects
# a backend based on system & workload configuration
# (default behavior if this variable is None)
#
# THIS SELECTION TAKES PRECEDENCE OVER THE
# VLLM ATTENTION BACKEND ENVIRONMENT VARIABLE
forced_attn_backend: Optional[_Backend] = None
def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
'''
Force all attention operations to use a specified backend.
Passing `None` for the argument re-enables automatic
backend selection.,
Arguments:
* attn_backend: backend selection (None to revert to auto)
'''
global forced_attn_backend
forced_attn_backend = attn_backend
def get_global_forced_attn_backend() -> Optional[_Backend]:
'''
Get the currently-forced choice of attention backend,
or None if auto-selection is currently enabled.
'''
return forced_attn_backend
@lru_cache(maxsize=None)
def get_attn_backend(
num_heads: int,
......@@ -101,16 +164,20 @@ def which_attn_to_use(
# Default case.
selected_backend = _Backend.FLASH_ATTN
# Check whether a particular choice of backend was
# previously forced.
#
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
# ENVIRONMENT VARIABLE.
backend_by_global_setting: Optional[_Backend] = (
get_global_forced_attn_backend())
if backend_by_global_setting is not None:
selected_backend = backend_by_global_setting
else:
# Check the environment variable and override if specified
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
backend_members = _Backend.__members__
if backend_by_env_var not in backend_members:
raise ValueError(
f"Invalid attention backend '{backend_by_env_var}'. "
f"Available backends: {', '.join(backend_members)} "
"(case-sensitive).")
selected_backend = _Backend[backend_by_env_var]
selected_backend = backend_name_to_enum(backend_by_env_var)
if is_cpu():
if selected_backend != _Backend.TORCH_SDPA:
......@@ -193,3 +260,35 @@ def which_attn_to_use(
selected_backend = _Backend.XFORMERS
return selected_backend
@contextmanager
def global_force_attn_backend_context_manager(
attn_backend: _Backend) -> Generator[None, None, None]:
'''
Globally force a vLLM attention backend override within a
context manager, reverting the global attention backend
override to its prior state upon exiting the context
manager.
Arguments:
* attn_backend: attention backend to force
Returns:
* Generator
'''
# Save the current state of the global backend override (if any)
original_value = get_global_forced_attn_backend()
# Globally force the new backend override
global_force_attn_backend(attn_backend)
# Yield control back to the enclosed code block
try:
yield
finally:
# Revert the original global backend override, if any
global_force_attn_backend(original_value)
......@@ -12,7 +12,8 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.models import ModelRegistry
from vllm.tracing import is_otel_installed
from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH,
cuda_device_count_stateless, get_cpu_memory, is_cpu,
is_hip, is_neuron, is_openvino, is_tpu, is_xpu,
print_warning_once)
......@@ -87,6 +88,9 @@ class ModelConfig:
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
If None, the user did not specify, so default to False -
except for encoder/decoder models, which currently require
eager mode.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode (DEPRECATED. Use max_seq_len_to_capture instead).
......@@ -121,7 +125,7 @@ class ModelConfig:
max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
quantization_param_path: Optional[str] = None,
enforce_eager: bool = False,
enforce_eager: Optional[bool] = None,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 20,
......@@ -160,6 +164,34 @@ class ModelConfig:
self.hf_text_config = get_hf_text_config(self.hf_config)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
# Choose a default enforce_eager value if the user did not specify
# a value (enforce_eager is None)
if getattr(self.hf_config, 'is_encoder_decoder', False):
if self.enforce_eager is None:
# *Only for encoder/decoder models* and
# *only if enforce_eager is unset*, override
# to enforce_eager=True
#
# Add a logger message since it is *somewhat* non-intuitive that
# enforce_eager is True when the user has not specified its
# value.
logger.info("Forcing enforce_eager == True because "
"enforce_eager setting was unspecified and "
"CUDAGraph is not supported with encoder/ "
"decoder models.")
self.enforce_eager = True
if not self.enforce_eager:
# Eager mode explicitly disabled by user for an encoder/
# decoder model; however CUDAGRAPH + encoder/decoder is
# not currently supported
raise ValueError(STR_NOT_IMPL_ENC_DEC_CUDAGRAPH)
elif self.enforce_eager is None:
# *Only for decoder-only models*, enforce_eager
# defaults to False if unset. This is intuitive
# so no logging message needed.
self.enforce_eager = False
if (not self.disable_sliding_window
and self.hf_text_config.model_type == "gemma2"
and self.hf_text_config.sliding_window is not None):
......
"""Block manager utils."""
from vllm.sequence import SequenceGroup
# Exception strings for non-implemented block manager enc/dec scenarios
STR_NOT_IMPL_ENC_DEC_SWA = \
"Sliding window attention for encoder/decoder models " + \
"is not currently supported."
STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \
"Prefix caching for encoder/decoder models " + \
"is not currently supported."
from vllm.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
STR_NOT_IMPL_ENC_DEC_SWA)
def _get_block_mgr_sliding_window_attr(block_mgr):
......
......@@ -392,6 +392,19 @@ class Scheduler:
seq.status = SequenceStatus.FINISHED_ABORTED
self.free_seq(seq)
self._free_seq_group_cross_attn_blocks(aborted_group)
def _free_seq_group_cross_attn_blocks(
self,
seq_group: SequenceGroup,
) -> None:
"""
Free a sequence group from a cross-attention block table.
Has no effect on decoder-only models.
"""
if seq_group.is_encoder_decoder():
self.block_manager.free_cross(seq_group)
def has_unfinished_seqs(self) -> bool:
return len(self.waiting) != 0 or len(self.running) != 0 or len(
self.swapped) != 0
......@@ -963,6 +976,17 @@ class Scheduler:
# seq_id -> physical block numbers
block_tables: Dict[int, List[int]] = {}
if seq_group.is_encoder_decoder():
# Encoder associated with SequenceGroup
encoder_seq_data = seq_group.get_encoder_seq().data
# Block table for cross-attention
# Also managed at SequenceGroup level
cross_block_table = self.block_manager.get_cross_block_table(
seq_group)
else:
encoder_seq_data = None
cross_block_table = None
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq_id = seq.seq_id
seq_data[seq_id] = seq.data
......@@ -1001,6 +1025,8 @@ class Scheduler:
token_chunk_size=token_chunk_size,
lora_request=seq_group.lora_request,
computed_block_nums=common_computed_block_nums,
encoder_seq_data=encoder_seq_data,
cross_block_table=cross_block_table,
# `multi_modal_data` will only be present for the 1st comm
# between engine and worker.
# the subsequent comms can still use delta, but
......@@ -1032,6 +1058,8 @@ class Scheduler:
remaining: Deque[SequenceGroup] = deque()
for seq_group in self.running:
if seq_group.is_finished():
# Free cross-attention block table, if it exists
self._free_seq_group_cross_attn_blocks(seq_group)
# Add the finished requests to the finished requests list.
# This list will be used to update the Mamba cache in the
# next step.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment