Unverified Commit 15702038 authored by Lily Liu's avatar Lily Liu Committed by GitHub
Browse files

[Spec Decode] (1/2) Remove batch expansion (#8839)

parent 22f5851b
...@@ -208,7 +208,7 @@ steps: ...@@ -208,7 +208,7 @@ steps:
- tests/spec_decode - tests/spec_decode
commands: commands:
- pytest -v -s spec_decode/e2e/test_multistep_correctness.py - pytest -v -s spec_decode/e2e/test_multistep_correctness.py
- pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py
- label: LoRA Test %N # 15min each - label: LoRA Test %N # 15min each
mirror_hardwares: [amd] mirror_hardwares: [amd]
......
...@@ -434,7 +434,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): ...@@ -434,7 +434,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
sampling_metadata = SamplingMetadata.prepare( sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_group_metadata_list,
seq_lens=seq_lens if seq_lens else None, seq_lens=seq_lens if seq_lens else None,
query_lens=seq_lens if seq_lens else None, query_lens=seq_lens if seq_lens else [1] * batch_size,
device=device, device=device,
pin_memory=is_pin_memory_available()) pin_memory=is_pin_memory_available())
# the logits tensor is modified in-place by the sampler # the logits tensor is modified in-place by the sampler
......
...@@ -102,3 +102,47 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs, ...@@ -102,3 +102,47 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
max_output_len=32, max_output_len=32,
seed=seed, seed=seed,
temperature=0.0) temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": MAIN_MODEL,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_disable_mqa_scorer": True,
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
"""Verify that ngram speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
...@@ -350,6 +350,55 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs, ...@@ -350,6 +350,55 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
temperature=0.0) temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_disable_mqa_scorer": True,
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
"""Verify that speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
if __name__ == "__main__": if __name__ == "__main__":
import pytest import pytest
pytest.main([__file__]) pytest.main([__file__])
...@@ -460,3 +460,46 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs, ...@@ -460,3 +460,46 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
max_output_len=output_len, max_output_len=output_len,
seed=seed, seed=seed,
temperature=0.0) temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": MAIN_MODEL,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
"speculative_model": SPEC_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_disable_mqa_scorer": True,
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
"""Verify that speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
...@@ -292,3 +292,49 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs, ...@@ -292,3 +292,49 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
max_output_len=output_len, max_output_len=output_len,
seed=seed, seed=seed,
temperature=0.0) temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_disable_mqa_scorer": True,
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_ngram_scorer(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that ngram speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
...@@ -173,7 +173,6 @@ def test_same_output_for_multi_step(): ...@@ -173,7 +173,6 @@ def test_same_output_for_multi_step():
block_size, block_size,
num_gpu_blocks, num_gpu_blocks,
seed, seed,
model_runner_cls=TP1DraftModelRunner,
) )
worker = create_worker( worker = create_worker(
......
import pytest
import torch
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeScores
from vllm.spec_decode.mqa_scorer import MQAScorer
from vllm.worker.worker import Worker
from .utils import create_batch, create_worker
def create_proposal(batch_size: int, propose_len: int, vocab_size: int,
device: str) -> SpeculativeProposals:
proposal_probs = torch.rand((batch_size, propose_len, vocab_size),
device=device)
proposal_token_ids = torch.argmax(proposal_probs, dim=-1)
proposal_lens = torch.tensor([propose_len] * batch_size, device=device)
return SpeculativeProposals(proposal_token_ids, proposal_probs,
proposal_lens)
def assert_score_equal(score1: SpeculativeScores,
score2: SpeculativeScores) -> None:
assert torch.allclose(score1.probs, score2.probs)
assert torch.allclose(score1.logprobs, score2.logprobs)
assert torch.equal(score1.token_ids, score2.token_ids)
@pytest.mark.parametrize('model_name', ['facebook/opt-125m'])
@pytest.mark.parametrize('batch_size', [1, 2, 4, 8, 16])
@pytest.mark.parametrize('propose_len', [1, 3, 5])
@pytest.mark.parametrize('device', ['cuda'])
def test_scoroer(model_name: str, batch_size: int, propose_len: int,
device: str) -> None:
"""
Compare the batch expansion scorer and mqa scorer return the same score
"""
seed = 0
block_size = 32
num_gpu_blocks = 2048 // block_size
scorer_worker = create_worker(Worker, model_name, block_size,
num_gpu_blocks, seed)
scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor = True
scorer_worker.model_runner.model.sampler.\
should_modify_greedy_probs_inplace = True
vocab_size = scorer_worker.vocab_size
proposals = create_proposal(batch_size, propose_len, vocab_size, device)
seq_group_metadatalist, _, _ = create_batch(batch_size,
propose_len,
block_size=block_size,
num_gpu_blocks=num_gpu_blocks)
requests = ExecuteModelRequest(seq_group_metadatalist,
num_lookahead_slots=propose_len)
batch_expansion_scorer = BatchExpansionTop1Scorer(scorer_worker, device,
vocab_size)
batch_expansion_score = batch_expansion_scorer.score_proposals(
requests, proposals)
mqa_scorer = MQAScorer(scorer_worker, device, vocab_size)
mqa_score = mqa_scorer.score_proposals(requests, proposals)
assert_score_equal(batch_expansion_score, mqa_score)
...@@ -63,10 +63,10 @@ def test_correctly_calls_draft_model(k: int, batch_size: int, ...@@ -63,10 +63,10 @@ def test_correctly_calls_draft_model(k: int, batch_size: int,
@pytest.mark.parametrize("acceptance_sampler_method", @pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"]) ["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode() @torch.inference_mode()
def test_correctly_calls_target_model(k: int, batch_size: int, def test_batch_expansion_correctly_calls_target_model(
acceptance_sampler_method: str): k: int, batch_size: int, acceptance_sampler_method: str):
"""Verify SpecDecodeWorker calls the target model with correct """Verify SpecDecodeWorker calls the target model with correct
inputs. Everything else is mocked out. inputs with batch expansion. Everything else is mocked out.
""" """
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False) draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
target_worker = mock_worker(use_spec=False) target_worker = mock_worker(use_spec=False)
...@@ -82,7 +82,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int, ...@@ -82,7 +82,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int,
target_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), mock_spec_decode_sampler(acceptance_sampler_method),
disable_logprobs=False, disable_logprobs=False,
metrics_collector=metrics_collector) metrics_collector=metrics_collector,
disable_mqa_scorer=True)
worker.init_device() worker.init_device()
vocab_size = 32_000 vocab_size = 32_000
......
...@@ -131,19 +131,22 @@ def create_seq_group_metadata_from_prompts( ...@@ -131,19 +131,22 @@ def create_seq_group_metadata_from_prompts(
for i, final_len in enumerate(final_prompt_lens) for i, final_len in enumerate(final_prompt_lens)
} }
return [ seq_grou_metadata_list = []
SequenceGroupMetadata( for i, (prompt_token_ids,
request_id=str(i), cont_token_ids) in enumerate(zip(prompts, continuations)):
is_prompt=len(cont_token_ids) == 0, data = SequenceData.from_seqs(prompt_token_ids, cont_token_ids)
seq_data={ data.update_num_computed_tokens(
i: SequenceData.from_seqs(prompt_token_ids[:], len(prompt_token_ids) + len(cont_token_ids) - 1)
cont_token_ids[:]), seq_data = {i: data}
}, seq_grou_metadata_list.append(
sampling_params=SamplingParams(temperature=0.0, ), SequenceGroupMetadata(
block_tables={i: block_allocations[i][:]}, request_id=str(i),
) for i, (prompt_token_ids, is_prompt=len(cont_token_ids) == 0,
cont_token_ids) in enumerate(zip(prompts, continuations)) seq_data=seq_data,
] sampling_params=SamplingParams(temperature=0.0),
block_tables={i: block_allocations[i][:]},
))
return seq_grou_metadata_list
def assert_logprobs_dict_allclose( def assert_logprobs_dict_allclose(
......
...@@ -186,6 +186,12 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata): ...@@ -186,6 +186,12 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool use_cuda_graph: bool
# Number of query tokens for each request in the batch.
# Currently, we require that all requests have the same number of query
# tokens during the decoding phase. When speculavie decoding is enabled,
# decode_query_len might be greater than 1. In all other cases, it is 1.
decode_query_len: Optional[int] = None
_cached_prefill_metadata: Optional[ _cached_prefill_metadata: Optional[
"BlocksparseFlashAttentionMetadata"] = None "BlocksparseFlashAttentionMetadata"] = None
_cached_decode_metadata: Optional[ _cached_decode_metadata: Optional[
......
...@@ -245,8 +245,15 @@ class FlashAttentionMetadata(AttentionMetadata): ...@@ -245,8 +245,15 @@ class FlashAttentionMetadata(AttentionMetadata):
# |-------------------- seq_len ---------------------| # |-------------------- seq_len ---------------------|
# |-- query_len ---| # |-- query_len ---|
# Maximum query length in the batch. None for decoding. # Maximum query length in the batch.
max_query_len: Optional[int] max_query_len: Optional[int]
# Number of query tokens for each request in the batch.
# Currently, we require that all requests have the same number of query
# tokens during the decoding phase. When speculavie decoding is enabled,
# decode_query_len might be greater than 1. In all other cases, it is 1.
decode_query_len: Optional[int]
# Maximum sequence length among prefill batch. 0 if there are decoding # Maximum sequence length among prefill batch. 0 if there are decoding
# requests only. # requests only.
max_prefill_seq_len: int max_prefill_seq_len: int
...@@ -303,6 +310,7 @@ class FlashAttentionMetadata(AttentionMetadata): ...@@ -303,6 +310,7 @@ class FlashAttentionMetadata(AttentionMetadata):
slot_mapping=self.slot_mapping[:self.num_prefill_tokens], slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
seq_lens=self.seq_lens[:self.num_prefills], seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
decode_query_len=0,
max_query_len=self.max_query_len, max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len, max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0, max_decode_seq_len=0,
...@@ -331,7 +339,8 @@ class FlashAttentionMetadata(AttentionMetadata): ...@@ -331,7 +339,8 @@ class FlashAttentionMetadata(AttentionMetadata):
slot_mapping=self.slot_mapping[self.num_prefill_tokens:], slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
seq_lens=None, seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None, decode_query_len=self.decode_query_len,
max_query_len=self.max_query_len,
max_prefill_seq_len=0, max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len, max_decode_seq_len=self.max_decode_seq_len,
query_start_loc=None, query_start_loc=None,
...@@ -461,9 +470,6 @@ class FlashAttentionMetadataBuilder( ...@@ -461,9 +470,6 @@ class FlashAttentionMetadataBuilder(
self.num_prefill_tokens += token_len self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len) self.prefill_seq_lens.append(seq_len)
else: else:
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len) self.curr_seq_lens.append(curr_seq_len)
...@@ -518,6 +524,11 @@ class FlashAttentionMetadataBuilder( ...@@ -518,6 +524,11 @@ class FlashAttentionMetadataBuilder(
use_captured_graph = cuda_graph_pad_size != -1 use_captured_graph = cuda_graph_pad_size != -1
max_query_len = max(query_lens) max_query_len = max(query_lens)
decode_query_lens = query_lens[self.num_prefills:]
if len(decode_query_lens) > 0:
decode_query_len = max(decode_query_lens)
else:
decode_query_len = 1
max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens num_decode_tokens = self.num_decode_tokens
...@@ -586,6 +597,7 @@ class FlashAttentionMetadataBuilder( ...@@ -586,6 +597,7 @@ class FlashAttentionMetadataBuilder(
seq_lens=seq_lens, seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor, seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len, max_query_len=max_query_len,
decode_query_len=decode_query_len,
max_prefill_seq_len=max_prefill_seq_len, max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len, max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
...@@ -786,8 +798,12 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -786,8 +798,12 @@ class FlashAttentionImpl(AttentionImpl):
if decode_meta := attn_metadata.decode_metadata: if decode_meta := attn_metadata.decode_metadata:
# Decoding run. # Decoding run.
_, num_head, head_dim = decode_query.shape
decode_query = decode_query.reshape(-1,
decode_meta.decode_query_len,
num_head, head_dim)
decode_output = torch.ops.vllm.flash_attn_with_kvcache( decode_output = torch.ops.vllm.flash_attn_with_kvcache(
decode_query.unsqueeze(1), decode_query,
key_cache, key_cache,
value_cache, value_cache,
block_table=decode_meta.block_tables, block_table=decode_meta.block_tables,
...@@ -796,7 +812,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -796,7 +812,7 @@ class FlashAttentionImpl(AttentionImpl):
causal=True, causal=True,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap, softcap=self.logits_soft_cap,
).squeeze(1) )
if prefill_output is None: if prefill_output is None:
assert decode_output is not None assert decode_output is not None
...@@ -804,5 +820,11 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -804,5 +820,11 @@ class FlashAttentionImpl(AttentionImpl):
if decode_output is None: if decode_output is None:
assert prefill_output is not None assert prefill_output is not None
return prefill_output.view(num_prefill_tokens, hidden_size) return prefill_output.view(num_prefill_tokens, hidden_size)
# Chunked prefill does not work with speculative decoding.
# Therefore, the query length for decode should be 1 in chunked prefill.
assert decode_meta is not None
assert decode_meta.decode_query_len == 1
decode_output = decode_output.squeeze(1)
output = torch.cat([prefill_output, decode_output], dim=0) output = torch.cat([prefill_output, decode_output], dim=0)
return output.view(num_tokens, hidden_size) return output.view(num_tokens, hidden_size)
...@@ -595,7 +595,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -595,7 +595,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
device = self.runner.device device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1 use_captured_graph = cuda_graph_pad_size != -1
max_query_len = max(query_lens)
max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens num_decode_tokens = self.num_decode_tokens
...@@ -634,7 +633,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -634,7 +633,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
dtype=torch.int, dtype=torch.int,
device=device, device=device,
) )
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
assert device is not None assert device is not None
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
......
...@@ -116,9 +116,17 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -116,9 +116,17 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
# Cuda-graph is currently enabled for decoding only. # Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool use_cuda_graph: bool
# (batch_size,) A tensor of context lengths (tokens that are computed # (batch_size,) A tensor of context lengths (tokens that are computed
# so far). # so far).
context_lens_tensor: Optional[torch.Tensor] context_lens_tensor: Optional[torch.Tensor]
# Number of query tokens for each request in the batch.
# Currently, we require that all requests have the same number of query
# tokens during the decoding phase. When speculavie decoding is enabled,
# decode_query_len might be greater than 1. In all other cases, it is 1.
decode_query_len: Optional[int] = None
_cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
_cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
......
...@@ -312,7 +312,8 @@ class CommonAttentionState(AttentionState): ...@@ -312,7 +312,8 @@ class CommonAttentionState(AttentionState):
slot_mapping=self._graph_slot_mapping[:batch_size], slot_mapping=self._graph_slot_mapping[:batch_size],
seq_lens=None, seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size], seq_lens_tensor=self._graph_seq_lens[:batch_size],
max_query_len=None, max_query_len=1,
decode_query_len=1,
max_prefill_seq_len=0, max_prefill_seq_len=0,
max_decode_seq_len=self.runner.max_seq_len_to_capture, max_decode_seq_len=self.runner.max_seq_len_to_capture,
query_start_loc=None, query_start_loc=None,
......
...@@ -118,6 +118,12 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -118,6 +118,12 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# Maximum query length in the batch. None for decoding. # Maximum query length in the batch. None for decoding.
max_query_len: Optional[int] = None max_query_len: Optional[int] = None
# Number of query tokens for each request in the batch.
# Currently, we require that all requests have the same number of query
# tokens during the decoding phase. When speculavie decoding is enabled,
# decode_query_len might be greater than 1. In all other cases, it is 1.
decode_query_len: Optional[int] = None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in # (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length # the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10]. # is [4, 6], it is [0, 4, 10].
......
...@@ -1116,6 +1116,7 @@ class SpeculativeConfig: ...@@ -1116,6 +1116,7 @@ class SpeculativeConfig:
speculative_model_quantization: Optional[str], speculative_model_quantization: Optional[str],
speculative_draft_tensor_parallel_size: Optional[int], speculative_draft_tensor_parallel_size: Optional[int],
num_speculative_tokens: Optional[int], num_speculative_tokens: Optional[int],
speculative_disable_mqa_scorer: Optional[bool],
speculative_max_model_len: Optional[int], speculative_max_model_len: Optional[int],
enable_chunked_prefill: bool, enable_chunked_prefill: bool,
use_v2_block_manager: bool, use_v2_block_manager: bool,
...@@ -1150,6 +1151,9 @@ class SpeculativeConfig: ...@@ -1150,6 +1151,9 @@ class SpeculativeConfig:
num_speculative_tokens (Optional[int]): The number of speculative num_speculative_tokens (Optional[int]): The number of speculative
tokens, if provided. Will default to the number in the draft tokens, if provided. Will default to the number in the draft
model config if present, otherwise is required. model config if present, otherwise is required.
speculative_disable_mqa_scorer (Optional[bool]): Disable the MQA
scorer for the speculative model and fall back to batch
expansion for scoring.
speculative_max_model_len (Optional[int]): The maximum model len of speculative_max_model_len (Optional[int]): The maximum model len of
the speculative model. Used when testing the ability to skip the speculative model. Used when testing the ability to skip
speculation for some sequences. speculation for some sequences.
...@@ -1304,6 +1308,7 @@ class SpeculativeConfig: ...@@ -1304,6 +1308,7 @@ class SpeculativeConfig:
draft_model_config, draft_model_config,
draft_parallel_config, draft_parallel_config,
num_speculative_tokens, num_speculative_tokens,
speculative_disable_mqa_scorer,
speculative_disable_by_batch_size, speculative_disable_by_batch_size,
ngram_prompt_lookup_max, ngram_prompt_lookup_max,
ngram_prompt_lookup_min, ngram_prompt_lookup_min,
...@@ -1400,6 +1405,7 @@ class SpeculativeConfig: ...@@ -1400,6 +1405,7 @@ class SpeculativeConfig:
draft_model_config: ModelConfig, draft_model_config: ModelConfig,
draft_parallel_config: ParallelConfig, draft_parallel_config: ParallelConfig,
num_speculative_tokens: int, num_speculative_tokens: int,
speculative_disable_mqa_scorer: Optional[bool],
speculative_disable_by_batch_size: Optional[int], speculative_disable_by_batch_size: Optional[int],
ngram_prompt_lookup_max: Optional[int], ngram_prompt_lookup_max: Optional[int],
ngram_prompt_lookup_min: Optional[int], ngram_prompt_lookup_min: Optional[int],
...@@ -1446,6 +1452,7 @@ class SpeculativeConfig: ...@@ -1446,6 +1452,7 @@ class SpeculativeConfig:
self.draft_model_config = draft_model_config self.draft_model_config = draft_model_config
self.draft_parallel_config = draft_parallel_config self.draft_parallel_config = draft_parallel_config
self.num_speculative_tokens = num_speculative_tokens self.num_speculative_tokens = num_speculative_tokens
self.speculative_disable_mqa_scorer = speculative_disable_mqa_scorer
self.speculative_disable_by_batch_size = \ self.speculative_disable_by_batch_size = \
speculative_disable_by_batch_size speculative_disable_by_batch_size
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0 self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
......
...@@ -162,6 +162,7 @@ class EngineArgs: ...@@ -162,6 +162,7 @@ class EngineArgs:
speculative_model_quantization: Optional[str] = None speculative_model_quantization: Optional[str] = None
speculative_draft_tensor_parallel_size: Optional[int] = None speculative_draft_tensor_parallel_size: Optional[int] = None
num_speculative_tokens: Optional[int] = None num_speculative_tokens: Optional[int] = None
speculative_disable_mqa_scorer: Optional[bool] = False
speculative_max_model_len: Optional[int] = None speculative_max_model_len: Optional[int] = None
speculative_disable_by_batch_size: Optional[int] = None speculative_disable_by_batch_size: Optional[int] = None
ngram_prompt_lookup_max: Optional[int] = None ngram_prompt_lookup_max: Optional[int] = None
...@@ -640,6 +641,12 @@ class EngineArgs: ...@@ -640,6 +641,12 @@ class EngineArgs:
default=EngineArgs.num_speculative_tokens, default=EngineArgs.num_speculative_tokens,
help='The number of speculative tokens to sample from ' help='The number of speculative tokens to sample from '
'the draft model in speculative decoding.') 'the draft model in speculative decoding.')
parser.add_argument(
'--speculative-disable-mqa-scorer',
action='store_true',
help=
'If set to True, the MQA scorer will be disabled in speculative '
' and fall back to batch expansion')
parser.add_argument( parser.add_argument(
'--speculative-draft-tensor-parallel-size', '--speculative-draft-tensor-parallel-size',
'-spec-draft-tp', '-spec-draft-tp',
...@@ -970,6 +977,7 @@ class EngineArgs: ...@@ -970,6 +977,7 @@ class EngineArgs:
speculative_draft_tensor_parallel_size = \ speculative_draft_tensor_parallel_size = \
self.speculative_draft_tensor_parallel_size, self.speculative_draft_tensor_parallel_size,
num_speculative_tokens=self.num_speculative_tokens, num_speculative_tokens=self.num_speculative_tokens,
speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer,
speculative_disable_by_batch_size=self. speculative_disable_by_batch_size=self.
speculative_disable_by_batch_size, speculative_disable_by_batch_size,
speculative_max_model_len=self.speculative_max_model_len, speculative_max_model_len=self.speculative_max_model_len,
......
...@@ -1110,6 +1110,8 @@ class LLMEngine: ...@@ -1110,6 +1110,8 @@ class LLMEngine:
update_prefill_num_computed_tokens(seq_group, seq_group_meta, update_prefill_num_computed_tokens(seq_group, seq_group_meta,
len(output), len(output),
is_first_step_output) is_first_step_output)
elif not is_async:
seq_group.update_num_computed_tokens(1)
if outputs: if outputs:
for o in outputs: for o in outputs:
...@@ -1133,8 +1135,16 @@ class LLMEngine: ...@@ -1133,8 +1135,16 @@ class LLMEngine:
else: else:
self.output_processor.process_prompt_logprob(seq_group, output) self.output_processor.process_prompt_logprob(seq_group, output)
if seq_group_meta.do_sample: if seq_group_meta.do_sample:
self.output_processor.process_outputs( output_token_num = self.output_processor.process_outputs(
seq_group, output, is_async) seq_group, output, is_async)
if self.speculative_config:
# We -1 here because we always
# (w/o speculative decoding) add the number of
# computed tokens by one in the decoding phase.
# Therefore, we remove that one token that
# is already added.
seq_group.update_num_computed_tokens(output_token_num -
1)
if seq_group.is_finished(): if seq_group.is_finished():
finished_now.append(i) finished_now.append(i)
...@@ -1251,11 +1261,12 @@ class LLMEngine: ...@@ -1251,11 +1261,12 @@ class LLMEngine:
# decodes after the very first step. Therefore, # decodes after the very first step. Therefore,
# we skip the update to the num_computed_tokens # we skip the update to the num_computed_tokens
# here. # here.
pass seq_group.update_num_computed_tokens(1)
else: else:
seq_group.update_num_computed_tokens( seq_group.update_num_computed_tokens(
seq_group_metadata.token_chunk_size) seq_group_metadata.token_chunk_size)
else:
seq_group.update_num_computed_tokens(1)
if seq_group_metadata.do_sample: if seq_group_metadata.do_sample:
assert len(sequence_group_outputs.samples) == 1, ( assert len(sequence_group_outputs.samples) == 1, (
"Async output processor expects a single sample" "Async output processor expects a single sample"
...@@ -1266,7 +1277,6 @@ class LLMEngine: ...@@ -1266,7 +1277,6 @@ class LLMEngine:
assert len(seq_group.seqs) == 1 assert len(seq_group.seqs) == 1
seq = seq_group.seqs[0] seq = seq_group.seqs[0]
seq.append_token_id(sample.output_token, sample.logprobs) seq.append_token_id(sample.output_token, sample.logprobs)
seq_group.update_num_computed_tokens(1)
def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results. """Performs one decoding iteration and returns newly generated results.
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, List from typing import Callable, List, Optional
from vllm.config import SchedulerConfig from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler from vllm.core.scheduler import Scheduler
...@@ -58,10 +58,14 @@ class SequenceGroupOutputProcessor(ABC): ...@@ -58,10 +58,14 @@ class SequenceGroupOutputProcessor(ABC):
@abstractmethod @abstractmethod
def process_outputs(self, sequence_group: SequenceGroup, def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput], outputs: List[SequenceGroupOutput],
is_async: bool) -> None: is_async: bool) -> Optional[int]:
"""Process new token ids for the sequence group. Handles logic such as """Process new token ids for the sequence group. Handles logic such as
detokenization, stop checking, and freeing/forking sequences in the detokenization, stop checking, and freeing/forking sequences in the
scheduler. scheduler.
Return the number of new tokens generated in the sequence group.
The returned value is optional because it is only used for
speculative decoding mqa scorer.
""" """
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment