Unverified Commit 5faedf1b authored by Kevin Lin's avatar Kevin Lin Committed by GitHub
Browse files

[Spec Decode] Move ops.advance_step to flash attn advance_step (#8224)

parent 02751a7a
...@@ -16,7 +16,8 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState, ...@@ -16,7 +16,8 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.utils import async_tensor_h2d, make_tensor_with_pad
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)
from vllm_flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func from vllm_flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func
from vllm_flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache from vllm_flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache
...@@ -302,14 +303,12 @@ class FlashAttentionMetadata(AttentionMetadata): ...@@ -302,14 +303,12 @@ class FlashAttentionMetadata(AttentionMetadata):
) )
return self._cached_decode_metadata return self._cached_decode_metadata
def advance_step(self, num_seqs: int, num_queries: int): def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int, num_seqs: int, num_queries: int):
""" """
Update metadata in-place to advance one decode step. Update metadata in-place to advance one decode step.
""" """
# GPU in-place update is currently called separately through
# custom_ops.advance_step(). See draft_model_runner. TODO(will): Move
# this logic to the backend.
# When using cudagraph, the num_seqs is padded to the next captured # When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in # batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries # the batch. For --enforce-eager mode, num_seqs == num_queries
...@@ -347,6 +346,16 @@ class FlashAttentionMetadata(AttentionMetadata): ...@@ -347,6 +346,16 @@ class FlashAttentionMetadata(AttentionMetadata):
self.seq_lens[i] += 1 self.seq_lens[i] += 1
self.max_decode_seq_len = max(self.seq_lens) self.max_decode_seq_len = max(self.seq_lens)
ops.advance_step(num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=model_input.input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions,
seq_lens=self.seq_lens_tensor,
slot_mapping=self.slot_mapping,
block_tables=self.block_tables)
class FlashAttentionMetadataBuilder( class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]): AttentionMetadataBuilder[FlashAttentionMetadata]):
......
...@@ -2,7 +2,6 @@ from typing import List, Optional ...@@ -2,7 +2,6 @@ from typing import List, Optional
import torch import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
try: try:
...@@ -116,18 +115,9 @@ class TP1DraftModelRunner(ModelRunner): ...@@ -116,18 +115,9 @@ class TP1DraftModelRunner(ModelRunner):
# Update attn_metadata # Update attn_metadata
attn_metadata = model_input.attn_metadata attn_metadata = model_input.attn_metadata
assert isinstance(attn_metadata, FlashAttentionMetadata) assert isinstance(attn_metadata, FlashAttentionMetadata)
attn_metadata.advance_step(num_seqs, num_queries)
# Update GPU tensors attn_metadata.advance_step(model_input, sampled_token_ids,
ops.advance_step(num_seqs=num_seqs, self.block_size, num_seqs, num_queries)
num_queries=num_queries,
block_size=self.block_size,
input_tokens=model_input.input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions,
seq_lens=attn_metadata.seq_lens_tensor,
slot_mapping=attn_metadata.slot_mapping,
block_tables=attn_metadata.block_tables)
# Update sampling_metadata # Update sampling_metadata
sampling_metadata = model_input.sampling_metadata sampling_metadata = model_input.sampling_metadata
......
...@@ -13,7 +13,6 @@ except ModuleNotFoundError: ...@@ -13,7 +13,6 @@ except ModuleNotFoundError:
import torch import torch
from vllm import _custom_ops as ops
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs, from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
...@@ -499,19 +498,11 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): ...@@ -499,19 +498,11 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
attn_metadata = frozen_model_input.attn_metadata attn_metadata = frozen_model_input.attn_metadata
assert isinstance(attn_metadata, FlashAttentionMetadata) assert isinstance(attn_metadata, FlashAttentionMetadata)
attn_metadata.advance_step(num_seqs, num_queries)
attn_metadata.advance_step(
# Update GPU tensors frozen_model_input,
ops.advance_step( model_input.cached_outputs[-1].sampled_token_ids, self.block_size,
num_seqs=num_seqs, num_seqs, num_queries)
num_queries=num_queries,
block_size=self.block_size,
input_tokens=frozen_model_input.input_tokens,
sampled_token_ids=model_input.cached_outputs[-1].sampled_token_ids,
input_positions=frozen_model_input.input_positions,
seq_lens=attn_metadata.seq_lens_tensor,
slot_mapping=attn_metadata.slot_mapping,
block_tables=attn_metadata.block_tables)
if frozen_model_input.seq_lens is not None: if frozen_model_input.seq_lens is not None:
for i in range(num_queries): for i in range(num_queries):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment