Commit 20316346 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev' of http://10.16.6.30/dcutoolkit/deeplearing/vllm into v0.9.2-dev

parents 31584b45 cc6a0017
...@@ -58,7 +58,8 @@ class TwoBatchOverlap(): ...@@ -58,7 +58,8 @@ class TwoBatchOverlap():
self.left_thread.start() self.left_thread.start()
self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,)) self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,))
self.right_thread.start() self.right_thread.start()
logger.info('tbo:two batch overlap start') if get_tp_group().rank == 0:
logger.info('tbo:two batch overlap start')
def finish_thread(self): def finish_thread(self):
self.left_thread.join() self.left_thread.join()
......
...@@ -9,6 +9,7 @@ from vllm.forward_context import set_forward_context ...@@ -9,6 +9,7 @@ from vllm.forward_context import set_forward_context
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.two_batch_overlap.v1.two_batch_overlap_v1 import tbo_model_executable_v1 from vllm.two_batch_overlap.v1.two_batch_overlap_v1 import tbo_model_executable_v1
from vllm.utils import async_tensor_h2d from vllm.utils import async_tensor_h2d
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
...@@ -224,28 +225,45 @@ def prepare_tbo_atten_metadata( ...@@ -224,28 +225,45 @@ def prepare_tbo_atten_metadata(
# Prepare for cascade attention if enabled & beneficial. # Prepare for cascade attention if enabled & beneficial.
common_prefix_len = 0 common_prefix_len = 0
metadata_builder = runner.attn_metadata_builders[kv_cache_group_id]
if runner.cascade_attn_enabled: if runner.cascade_attn_enabled:
common_prefix_len = runner._compute_cascade_attn_prefix_len( common_prefix_len = runner._compute_cascade_attn_prefix_len(
num_scheduled_tokens, num_scheduled_tokens,
scheduler_output. scheduler_output.
num_common_prefix_blocks[kv_cache_group_id], num_common_prefix_blocks[kv_cache_group_id],
kv_cache_group_spec.kv_cache_spec, kv_cache_group_spec.kv_cache_spec,
runner.attn_metadata_builders[kv_cache_group_id], metadata_builder,
) )
if req_offset > 0: if req_offset > 0:
origin_block_table = runner.attn_metadata_builders[kv_cache_group_id].block_table.block_table origin_block_table = metadata_builder.block_table.block_table
runner.attn_metadata_builders[kv_cache_group_id].block_table.block_table = origin_block_table[req_offset:, :] metadata_builder.block_table.block_table = origin_block_table[req_offset:, :]
origin_slot_mapping = runner.attn_metadata_builders[kv_cache_group_id].block_table.slot_mapping origin_slot_mapping = metadata_builder.block_table.slot_mapping
runner.attn_metadata_builders[kv_cache_group_id].block_table.slot_mapping = \ metadata_builder.block_table.slot_mapping = \
origin_slot_mapping[input_split.scheduler_output_left.total_num_scheduled_tokens:] origin_slot_mapping[input_split.scheduler_output_left.total_num_scheduled_tokens:]
if isinstance(metadata_builder, MLACommonMetadataBuilder): # now support prefill only
_num_decodes_record = metadata_builder._num_decodes
_num_prefills_record = metadata_builder._num_prefills
_num_decode_tokens_record = metadata_builder._num_decode_tokens
_num_prefill_tokens_record = metadata_builder._num_prefill_tokens
metadata_builder._num_decodes = 0
metadata_builder._num_prefills = num_reqs
metadata_builder._num_decode_tokens = 0
metadata_builder._num_prefill_tokens = total_num_scheduled_tokens
attn_metadata_i = ( attn_metadata_i = (
runner.attn_metadata_builders[kv_cache_group_id].build( metadata_builder.build(
common_prefix_len=common_prefix_len, common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata)) # maybe FlashAttentionMetadata common_attn_metadata=common_attn_metadata)) # maybe FlashAttentionMetadata
if req_offset > 0: if req_offset > 0:
runner.attn_metadata_builders[kv_cache_group_id].block_table.block_table = origin_block_table metadata_builder.block_table.block_table = origin_block_table
runner.attn_metadata_builders[kv_cache_group_id].block_table.slot_mapping = origin_slot_mapping metadata_builder.block_table.slot_mapping = origin_slot_mapping
if isinstance(metadata_builder, MLACommonMetadataBuilder): # now support prefill only
metadata_builder._num_decodes = _num_decodes_record
metadata_builder._num_prefills = _num_prefills_record
metadata_builder._num_decode_tokens = _num_decode_tokens_record
metadata_builder._num_prefill_tokens = _num_prefill_tokens_record
for layer_name in kv_cache_group_spec.layer_names: for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i attn_metadata[layer_name] = attn_metadata_i
...@@ -288,12 +306,16 @@ def tbo_split_and_execute_model( ...@@ -288,12 +306,16 @@ def tbo_split_and_execute_model(
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, IntermediateTensors]: ) -> Union[ModelRunnerOutput, IntermediateTensors]:
use_tbo = False use_tbo = False
if len(scheduler_output.num_scheduled_tokens) > 1:
split_scheduler_output(runner, scheduler_output)
if input_split.scheduler_output_left.total_num_scheduled_tokens >= envs.VLLM_TBO_MIN_TOKENS and \
input_split.scheduler_output_right.total_num_scheduled_tokens >= envs.VLLM_TBO_MIN_TOKENS:
use_tbo = True
if isinstance(runner.attn_metadata_builders[0], MLACommonMetadataBuilder) and \
runner.attn_metadata_builders[0]._num_decodes > 0: #is mla decode
use_tbo = False
else:
if len(scheduler_output.num_scheduled_tokens) > 1:
split_scheduler_output(runner, scheduler_output)
if input_split.scheduler_output_left.total_num_scheduled_tokens >= envs.VLLM_TBO_MIN_TOKENS and \
input_split.scheduler_output_right.total_num_scheduled_tokens >= envs.VLLM_TBO_MIN_TOKENS:
use_tbo = True
if use_tbo: if use_tbo:
num_input_tokens_left = input_split.scheduler_output_left.total_num_scheduled_tokens num_input_tokens_left = input_split.scheduler_output_left.total_num_scheduled_tokens
num_input_tokens_right = num_input_tokens - num_input_tokens_left num_input_tokens_right = num_input_tokens - num_input_tokens_left
...@@ -319,7 +341,8 @@ def tbo_split_and_execute_model( ...@@ -319,7 +341,8 @@ def tbo_split_and_execute_model(
with set_forward_context(attn_metadata, with set_forward_context(attn_metadata,
runner.vllm_config, runner.vllm_config,
num_tokens=num_input_tokens, num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp): num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=True):
runner.maybe_setup_kv_connector(scheduler_output) runner.maybe_setup_kv_connector(scheduler_output)
model_output = runner.model( model_output = runner.model(
......
...@@ -50,7 +50,8 @@ class TwoBatchOverlap(): ...@@ -50,7 +50,8 @@ class TwoBatchOverlap():
self.left_thread.start() self.left_thread.start()
self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,)) self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,))
self.right_thread.start() self.right_thread.start()
logger.info('tbo:two batch overlap start') if get_tp_group().rank == 0:
logger.info('tbo:two batch overlap start')
def finish_thread(self): def finish_thread(self):
self.left_thread.join() self.left_thread.join()
...@@ -90,7 +91,8 @@ class TwoBatchOverlap(): ...@@ -90,7 +91,8 @@ class TwoBatchOverlap():
with set_forward_context(attn_metadata, with set_forward_context(attn_metadata,
self.model_runner.vllm_config, self.model_runner.vllm_config,
num_tokens=num_input_tokens, num_tokens=num_input_tokens,
num_tokens_across_dp=self.num_tokens_across_dp): num_tokens_across_dp=self.num_tokens_across_dp,
skip_cuda_graphs=True):
model_output = self.model_runner.model( model_output = self.model_runner.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
......
import torch
from vllm.forward_context import set_forward_context
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, EagleProposer
class V1ZeroEagleProposer(EagleProposer):
def __init__(self, vllm_config, device, runner=None):
super().__init__(vllm_config, device, runner)
self.spec_scheduler_max_num_tokens = 0
def propose(
self,
# [num_tokens]
target_token_ids: torch.Tensor,
# [num_tokens]
target_positions: torch.Tensor,
# [num_tokens, hidden_size]
target_hidden_states: torch.Tensor,
# [num_tokens]
target_slot_mapping: torch.Tensor,
# [batch_size]
next_token_ids: torch.Tensor,
# [batch_size + 1] starting with 0
cu_num_tokens: torch.Tensor,
# [batch_size, max_num_blocks_per_req]
block_table: torch.Tensor,
# [batch_size]
sampling_metadata: SamplingMetadata,
decoding: bool = False,
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
last_token_indices = cu_num_tokens[1:] - 1
if self.method == "eagle3":
assert isinstance(self.model, Eagle3LlamaForCausalLM)
target_hidden_states = self.model.combine_hidden_states(
target_hidden_states)
assert target_hidden_states.shape[-1] == self.hidden_size
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self.input_ids[last_token_indices] = next_token_ids
# FA requires seq_len to have dtype int32.
seq_lens = (target_positions[last_token_indices] + 1).int()
if self.method in ["eagle", "eagle3"]:
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
max_seq_len = seq_lens.max().item()
max_num_tokens = (cu_num_tokens[1:] -
cu_num_tokens[:-1]).max().item()
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_tokens,
max_query_len=max_num_tokens,
query_start_loc=cu_num_tokens,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table,
slot_mapping=target_slot_mapping,
# TODO(woosuk): Support cascade attention.
use_cascade=False,
common_prefix_len=0,
cu_prefix_query_lens=None,
prefix_kv_lens=None,
suffix_kv_lens=None,
)
elif self.method == "deepseek_mtp":
max_query_len = self.spec_scheduler_max_num_tokens
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=cu_num_tokens,
seq_lens=seq_lens,
num_reqs=batch_size,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
slot_mapping=target_slot_mapping,
spec_layer_decoding=decoding
)
assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_metadata_builders[0].build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata
)
else:
raise ValueError(f"Unsupported method: {self.method}")
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
if self.use_cuda_graph and \
num_tokens <= self.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
else:
num_input_tokens = num_tokens
# copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states
if (decoding and self.use_full_cuda_graph
and num_tokens <= self.cudagraph_batch_sizes[-1]):
assert self.attn_metadata_cudagraph
if self.method in ["eagle", "eagle3"]:
self.attn_metadata_cudagraph.seq_lens[:batch_size] = (
attn_metadata.seq_lens)
self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
attn_metadata.slot_mapping)
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
attn_metadata.query_start_loc)
self.attn_metadata_cudagraph.block_table[:batch_size] = (
attn_metadata.block_table)
elif self.method == "deepseek_mtp":
self.attn_metadata_cudagraph.num_actual_tokens = (
attn_metadata.num_actual_tokens)
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
attn_metadata.query_start_loc)
self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
attn_metadata.slot_mapping)
self.attn_metadata_cudagraph.num_decodes = (
attn_metadata.num_decodes)
self.attn_metadata_cudagraph.num_decode_tokens = (
attn_metadata.num_decode_tokens)
self.attn_metadata_cudagraph.num_prefills = (
attn_metadata.num_prefills)
if attn_metadata.decode is not None:
self.attn_metadata_cudagraph.decode.block_table[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.block_table)
self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.seq_lens)
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
skip_cuda_graphs=not decoding):
ret_hidden_states = self.model(
self.input_ids[:num_input_tokens],
self.positions[:num_input_tokens],
self.hidden_states[:num_input_tokens],
)
if self.method == "deepseek_mtp":
last_hidden_states = ret_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
draft_token_ids = logits.argmax(dim=-1)
# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1:
# [batch_size, 1]
return draft_token_ids.view(-1, 1)
# TODO: Currently, MTP module released by deepseek only has
# one layer. Adapt this code to support multiple layers once
# there's a multi-layer MTP module.
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
positions = target_positions[last_token_indices]
if self.method == "deepseek_mtp":
hidden_states = last_hidden_states[last_token_indices]
else:
hidden_states = hidden_states[last_token_indices]
if self.use_cuda_graph and \
batch_size <= self.cudagraph_batch_sizes[-1]:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else:
input_batch_size = batch_size
attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
if isinstance(attn_metadata, MLACommonMetadata):
attn_metadata.num_decodes = batch_size
attn_metadata.num_decode_tokens = batch_size
attn_metadata.num_prefills = 0
block_table = self.runner.attn_metadata_builders[0].block_table.get_device_tensor()[:batch_size, ...]
attn_metadata.decode = self.runner.attn_metadata_builders[0]._build_decode(
block_table_tensor=block_table,
seq_lens=seq_lens,
)
for i in range(self.num_speculative_tokens - 1):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
input_ids = draft_token_ids_list[-1].int()
positions += 1
# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex
# to remove such requests from the batch, we keep them in the batch
# but adjust the position ids and slot mappings to avoid the
# out-of-range access during the model execution. The draft tokens
# generated with this adjustment should be ignored.
exceeds_max_model_len = positions >= self.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions = torch.where(exceeds_max_model_len, 0,
positions)
if isinstance(attn_metadata, MLACommonMetadata):
attn_metadata.decode.seq_lens += 1
else:
attn_metadata.seq_lens += 1
# Increment the sequence lengths.
attn_metadata.max_seq_len += 1
# Consider max model length.
attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
self.max_model_len)
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
# Compute the slot mapping.
block_numbers = clamped_positions // self.block_size
block_ids = block_table.gather(dim=1,
index=block_numbers.view(-1, 1))
block_ids = block_ids.view(-1)
attn_metadata.slot_mapping = (block_ids * self.block_size +
clamped_positions % self.block_size)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len,
PADDING_SLOT_ID)
# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions
self.hidden_states[:batch_size] = hidden_states
if (self.use_full_cuda_graph
and batch_size <= self.cudagraph_batch_sizes[-1]):
assert self.attn_metadata_cudagraph
if self.method in ["eagle", "eagle3"]:
self.attn_metadata_cudagraph.seq_lens[:batch_size] = (
attn_metadata.seq_lens)
self.attn_metadata_cudagraph.slot_mapping[:batch_size] = (
attn_metadata.slot_mapping)
if i == 0:
self.attn_metadata_cudagraph.query_start_loc[:batch_size +
1] = (
attn_metadata
.
query_start_loc
)
self.attn_metadata_cudagraph.block_table[:batch_size] = (
attn_metadata.block_table)
elif self.method == "deepseek_mtp":
self.attn_metadata_cudagraph.num_actual_tokens = (
attn_metadata.num_actual_tokens)
self.attn_metadata_cudagraph.slot_mapping[:attn_metadata.num_decode_tokens] = (
attn_metadata.slot_mapping)
self.attn_metadata_cudagraph.num_decodes = (
attn_metadata.num_decodes)
self.attn_metadata_cudagraph.num_decode_tokens = (
attn_metadata.num_decode_tokens)
self.attn_metadata_cudagraph.num_prefills = (
attn_metadata.num_prefills)
self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.seq_lens)
if i == 0:
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
attn_metadata.query_start_loc)
self.attn_metadata_cudagraph.decode.block_table[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.block_table)
# Run the model.
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=input_batch_size):
ret_hidden_states = self.model(
self.input_ids[:input_batch_size],
self.positions[:input_batch_size],
self.hidden_states[:input_batch_size],
)
if self.method == "deepseek_mtp":
last_hidden_states = ret_hidden_states
hidden_states = last_hidden_states[:batch_size]
else:
last_hidden_states, hidden_states = ret_hidden_states
hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size],
None)
# TODO(wenlong): get more than one token for tree attention
draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids)
# [batch_size, num_speculative_tokens]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids
...@@ -18,6 +18,7 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata ...@@ -18,6 +18,7 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.zero_overhead.v1.eagle import V1ZeroEagleProposer
from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput
from vllm.profiler.prof import profile from vllm.profiler.prof import profile
from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model
...@@ -31,10 +32,15 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -31,10 +32,15 @@ class V1ZeroModelRunner(GPUModelRunner):
self.last_sampled_token_lens = [] self.last_sampled_token_lens = []
self.last_sampler_event = torch.cuda.Event(enable_timing=False) self.last_sampler_event = torch.cuda.Event(enable_timing=False)
self.last_sampler_host_tokens = None self.last_sampler_host_tokens = None
self.token_ids_cpu_fix_recode = [] self.token_ids_cpu_fix_record = []
self.last_draft_token_ids = None self.last_draft_token_ids = None
self.last_draft_host_tokens = None self.last_draft_host_tokens = None
self.last_draft_event = torch.cuda.Event(enable_timing=False) self.last_draft_event = torch.cuda.Event(enable_timing=False)
self.spec_sampler_event = torch.cuda.Event(enable_timing=False)
self.spec_scheduler_max_num_tokens = 0
if hasattr(self, 'drafter') and isinstance(self.drafter, EagleProposer):
self.drafter = V1ZeroEagleProposer(self.vllm_config, self.device,
self)
def _prepare_inputs( def _prepare_inputs(
self, self,
...@@ -62,6 +68,7 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -62,6 +68,7 @@ class V1ZeroModelRunner(GPUModelRunner):
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
num_scheduled_tokens = np.array(tokens, dtype=np.int32) num_scheduled_tokens = np.array(tokens, dtype=np.int32)
max_num_scheduled_tokens = max(tokens) max_num_scheduled_tokens = max(tokens)
self.spec_scheduler_max_num_tokens = max_num_scheduled_tokens
# Get request indices. # Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
...@@ -281,7 +288,8 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -281,7 +288,8 @@ class V1ZeroModelRunner(GPUModelRunner):
def propose_draft_token_ids( def propose_draft_token_ids(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
sampled_token_ids: list[list[int]], num_accepted_tokens_tensor: torch.Tensor,
sampled_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sample_hidden_states: torch.Tensor, sample_hidden_states: torch.Tensor,
...@@ -317,26 +325,8 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -317,26 +325,8 @@ class V1ZeroModelRunner(GPUModelRunner):
elif self.speculative_config.use_eagle(): elif self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer) assert isinstance(self.drafter, EagleProposer)
# TODO(woosuk): Refactor the loop. # TODO(woosuk): Refactor the loop.
if self.last_sampled_token_ids is not None: row_indices = torch.arange(sampled_token_ids.size(0), device=sampled_token_ids.device)
next_token_ids = self.last_sampled_token_ids.flatten() next_token_ids = sampled_token_ids[row_indices, num_accepted_tokens_tensor].flatten()
else:
next_token_ids: list[int] = []
for i, token_ids in enumerate(sampled_token_ids):
if token_ids:
# Common case.
next_token_id = token_ids[-1]
else:
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id = self.input_batch.req_ids[i]
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.device)
# At this moment, we assume all eagle layers belong to the same KV # At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata. # cache group, thus using the same attention metadata.
eagle_attn_metadata = attn_metadata[ eagle_attn_metadata = attn_metadata[
...@@ -348,6 +338,7 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -348,6 +338,7 @@ class V1ZeroModelRunner(GPUModelRunner):
else: else:
block_table = None block_table = None
spec_scheduler_max_num_tokens = self.spec_scheduler_max_num_tokens
if spec_decode_metadata is None: if spec_decode_metadata is None:
# input_ids can be None for multimodal models. # input_ids can be None for multimodal models.
target_token_ids = self.input_ids[:num_scheduled_tokens] target_token_ids = self.input_ids[:num_scheduled_tokens]
...@@ -363,16 +354,11 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -363,16 +354,11 @@ class V1ZeroModelRunner(GPUModelRunner):
cu_num_tokens = eagle_attn_metadata.query_start_loc cu_num_tokens = eagle_attn_metadata.query_start_loc
else: else:
# TODO(woosuk): Refactor this. # TODO(woosuk): Refactor this.
num_accepted_tokens = [len(s) - 1 for s in sampled_token_ids]
num_accepted_tokens_tensor = async_tensor_h2d(
num_accepted_tokens,
dtype=torch.int32,
target_device=self.device,
pin_memory=True)
cu_num_tokens, token_indices = self.drafter.prepare_inputs( cu_num_tokens, token_indices = self.drafter.prepare_inputs(
eagle_attn_metadata.query_start_loc, eagle_attn_metadata.query_start_loc,
num_accepted_tokens_tensor, num_accepted_tokens_tensor,
) )
spec_scheduler_max_num_tokens = 1
target_token_ids = self.input_ids[token_indices] target_token_ids = self.input_ids[token_indices]
# TODO(woosuk): Support M-RoPE. # TODO(woosuk): Support M-RoPE.
target_positions = self.positions[token_indices] target_positions = self.positions[token_indices]
...@@ -383,6 +369,7 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -383,6 +369,7 @@ class V1ZeroModelRunner(GPUModelRunner):
target_hidden_states = hidden_states[token_indices] target_hidden_states = hidden_states[token_indices]
target_slot_mapping = eagle_attn_metadata.slot_mapping[ target_slot_mapping = eagle_attn_metadata.slot_mapping[
token_indices] token_indices]
self.drafter.spec_scheduler_max_num_tokens = spec_scheduler_max_num_tokens
draft_token_ids = self.drafter.propose( draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids, target_token_ids=target_token_ids,
target_positions=target_positions, target_positions=target_positions,
...@@ -392,7 +379,7 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -392,7 +379,7 @@ class V1ZeroModelRunner(GPUModelRunner):
cu_num_tokens=cu_num_tokens, cu_num_tokens=cu_num_tokens,
block_table=block_table, block_table=block_table,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
decoding=spec_decode_metadata is not None decoding=spec_decode_metadata is not None,
) )
spec_token_ids = np.ones(draft_token_ids.shape, dtype=int).tolist() spec_token_ids = np.ones(draft_token_ids.shape, dtype=int).tolist()
self.last_draft_token_ids = draft_token_ids self.last_draft_token_ids = draft_token_ids
...@@ -622,22 +609,49 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -622,22 +609,49 @@ class V1ZeroModelRunner(GPUModelRunner):
scheduler_output, scheduler_output,
) )
# Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1]
fix_req_ids = None fix_req_ids = None
fix_sampled_token_ids = None fix_sampled_token_ids = None
fix_draft_token_ids = None fix_draft_token_ids = None
fix_draft_req_ids = self.last_sampled_req_ids fix_draft_req_ids = self.last_sampled_req_ids
is_output_valid = False is_output_valid = False
# Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1]
if not self.speculative_config:
# Speculative decoding is not enabled.
spec_token_ids = None
fix_draft_req_ids = None
else:
sampled_token_ids_cpu = sampled_token_ids.to('cpu', non_blocking=True)
self.spec_sampler_event.record()
if self.last_draft_host_tokens is not None:
self.last_draft_event.synchronize()
fix_draft_token_ids = self.last_draft_host_tokens.tolist()
mask = (sampled_token_ids == -1)
mask_int = mask.int()
first_neg_one_indices = torch.argmax(mask_int, dim=1)
num_accepted_tokens_tensor = torch.where(torch.any(mask, dim=1), first_neg_one_indices, sampled_token_ids.size(1)) - 1
spec_token_ids = self.propose_draft_token_ids(
scheduler_output,
num_accepted_tokens_tensor,
sampled_token_ids,
sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
attn_metadata,
)
if self.speculative_config: if self.speculative_config:
self.spec_sampler_event.synchronize()
if max_gen_len == 1: if max_gen_len == 1:
valid_sampled_token_ids = sampled_token_ids.tolist() valid_sampled_token_ids = sampled_token_ids_cpu.tolist()
else: else:
# Includes spec decode tokens. # Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output( valid_sampled_token_ids = self.rejection_sampler.parse_output(
sampled_token_ids, sampled_token_ids_cpu,
self.input_batch.vocab_size, self.input_batch.vocab_size,
) )
self.last_sampler_host_tokens = None self.last_sampler_host_tokens = None
...@@ -649,13 +663,21 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -649,13 +663,21 @@ class V1ZeroModelRunner(GPUModelRunner):
if self.last_sampler_host_tokens != None: if self.last_sampler_host_tokens != None:
self.last_sampler_event.synchronize() self.last_sampler_event.synchronize()
fix_sampled_token_ids = self.last_sampler_host_tokens.tolist() fix_sampled_token_ids = self.last_sampler_host_tokens.tolist()
for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_recode: for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_record:
self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = fix_sampled_token_ids[req_idx] if start_idx == -1:
continue
req_id = fix_req_ids[req_idx]
if req_id in self.input_batch.req_ids:
new_req_idx = self.input_batch.req_ids.index(req_id)
self.input_batch.token_ids_cpu[new_req_idx, start_idx:end_idx] = fix_sampled_token_ids[req_idx]
for req_idx, req_id in enumerate(fix_req_ids): for req_idx, req_id in enumerate(fix_req_ids):
if req_id in self.requests: if req_id in self.requests:
req_state = self.requests[req_id] req_state = self.requests[req_id]
token_idx = self.last_sampled_token_lens[req_idx] token_idx = self.last_sampled_token_lens[req_idx]
req_state.output_token_ids[token_idx] = fix_sampled_token_ids[req_idx][0] if token_idx == -1:
continue
fix_len = len(fix_sampled_token_ids[req_idx])
req_state.output_token_ids[token_idx:token_idx + fix_len] = fix_sampled_token_ids[req_idx]
self.last_sampler_host_tokens = sampled_token_ids.to('cpu', non_blocking=True) self.last_sampler_host_tokens = sampled_token_ids.to('cpu', non_blocking=True)
self.last_sampler_event.record() self.last_sampler_event.record()
self.last_sampled_token_ids = sampled_token_ids self.last_sampled_token_ids = sampled_token_ids
...@@ -670,11 +692,16 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -670,11 +692,16 @@ class V1ZeroModelRunner(GPUModelRunner):
# NOTE(woosuk): As an exception, when using PP, the scheduler sends # NOTE(woosuk): As an exception, when using PP, the scheduler sends
# the sampled tokens back, because there's no direct communication # the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker. # between the first-stage worker and the last-stage worker.
self.token_ids_cpu_fix_recode.clear() self.token_ids_cpu_fix_record.clear()
self.last_sampled_req_ids = [] self.last_sampled_req_ids = []
self.last_sampled_token_lens = [] self.last_sampled_token_lens = []
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids): for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
req_id = self.input_batch.req_ids[req_idx]
self.last_sampled_req_ids.append(req_id)
cache_output_len = -1
if not sampled_ids: if not sampled_ids:
self.last_sampled_token_lens.append(-1)
self.token_ids_cpu_fix_record.append([req_idx, -1, -1])
continue continue
start_idx = self.input_batch.num_tokens_no_spec[req_idx] start_idx = self.input_batch.num_tokens_no_spec[req_idx]
...@@ -686,34 +713,15 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -686,34 +713,15 @@ class V1ZeroModelRunner(GPUModelRunner):
self.input_batch.token_ids_cpu[req_idx, self.input_batch.token_ids_cpu[req_idx,
start_idx:end_idx] = sampled_ids start_idx:end_idx] = sampled_ids
self.token_ids_cpu_fix_recode.append([req_idx, start_idx, end_idx]) self.token_ids_cpu_fix_record.append([req_idx, start_idx, end_idx])
self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx
req_id = self.input_batch.req_ids[req_idx]
if req_id in self.requests: if req_id in self.requests:
req_state = self.requests[req_id] req_state = self.requests[req_id]
self.last_sampled_req_ids.append(req_id) cache_output_len = len(req_state.output_token_ids)
self.last_sampled_token_lens.append(len(req_state.output_token_ids))
req_state.output_token_ids.extend(sampled_ids) req_state.output_token_ids.extend(sampled_ids)
self.last_sampled_token_lens.append(cache_output_len)
if not self.speculative_config:
# Speculative decoding is not enabled.
spec_token_ids = None
fix_draft_req_ids = None
else:
if self.last_draft_host_tokens is not None:
self.last_draft_event.synchronize()
fix_draft_token_ids = self.last_draft_host_tokens.tolist()
spec_token_ids = self.propose_draft_token_ids(
scheduler_output,
valid_sampled_token_ids,
sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
attn_metadata,
)
# Clear KVConnector state after all KVs are generated. # Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group(): if has_kv_transfer_group():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment