Commit 20316346 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev' of http://10.16.6.30/dcutoolkit/deeplearing/vllm into v0.9.2-dev

parents 31584b45 cc6a0017
......@@ -58,6 +58,7 @@ class TwoBatchOverlap():
self.left_thread.start()
self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,))
self.right_thread.start()
if get_tp_group().rank == 0:
logger.info('tbo:two batch overlap start')
def finish_thread(self):
......
......@@ -9,6 +9,7 @@ from vllm.forward_context import set_forward_context
from vllm.sequence import IntermediateTensors
from vllm.two_batch_overlap.v1.two_batch_overlap_v1 import tbo_model_executable_v1
from vllm.utils import async_tensor_h2d
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
......@@ -224,28 +225,45 @@ def prepare_tbo_atten_metadata(
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len = 0
metadata_builder = runner.attn_metadata_builders[kv_cache_group_id]
if runner.cascade_attn_enabled:
common_prefix_len = runner._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
scheduler_output.
num_common_prefix_blocks[kv_cache_group_id],
kv_cache_group_spec.kv_cache_spec,
runner.attn_metadata_builders[kv_cache_group_id],
metadata_builder,
)
if req_offset > 0:
origin_block_table = runner.attn_metadata_builders[kv_cache_group_id].block_table.block_table
runner.attn_metadata_builders[kv_cache_group_id].block_table.block_table = origin_block_table[req_offset:, :]
origin_slot_mapping = runner.attn_metadata_builders[kv_cache_group_id].block_table.slot_mapping
runner.attn_metadata_builders[kv_cache_group_id].block_table.slot_mapping = \
origin_block_table = metadata_builder.block_table.block_table
metadata_builder.block_table.block_table = origin_block_table[req_offset:, :]
origin_slot_mapping = metadata_builder.block_table.slot_mapping
metadata_builder.block_table.slot_mapping = \
origin_slot_mapping[input_split.scheduler_output_left.total_num_scheduled_tokens:]
if isinstance(metadata_builder, MLACommonMetadataBuilder): # now support prefill only
_num_decodes_record = metadata_builder._num_decodes
_num_prefills_record = metadata_builder._num_prefills
_num_decode_tokens_record = metadata_builder._num_decode_tokens
_num_prefill_tokens_record = metadata_builder._num_prefill_tokens
metadata_builder._num_decodes = 0
metadata_builder._num_prefills = num_reqs
metadata_builder._num_decode_tokens = 0
metadata_builder._num_prefill_tokens = total_num_scheduled_tokens
attn_metadata_i = (
runner.attn_metadata_builders[kv_cache_group_id].build(
metadata_builder.build(
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata)) # maybe FlashAttentionMetadata
if req_offset > 0:
runner.attn_metadata_builders[kv_cache_group_id].block_table.block_table = origin_block_table
runner.attn_metadata_builders[kv_cache_group_id].block_table.slot_mapping = origin_slot_mapping
metadata_builder.block_table.block_table = origin_block_table
metadata_builder.block_table.slot_mapping = origin_slot_mapping
if isinstance(metadata_builder, MLACommonMetadataBuilder): # now support prefill only
metadata_builder._num_decodes = _num_decodes_record
metadata_builder._num_prefills = _num_prefills_record
metadata_builder._num_decode_tokens = _num_decode_tokens_record
metadata_builder._num_prefill_tokens = _num_prefill_tokens_record
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
......@@ -288,12 +306,16 @@ def tbo_split_and_execute_model(
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, IntermediateTensors]:
use_tbo = False
if isinstance(runner.attn_metadata_builders[0], MLACommonMetadataBuilder) and \
runner.attn_metadata_builders[0]._num_decodes > 0: #is mla decode
use_tbo = False
else:
if len(scheduler_output.num_scheduled_tokens) > 1:
split_scheduler_output(runner, scheduler_output)
if input_split.scheduler_output_left.total_num_scheduled_tokens >= envs.VLLM_TBO_MIN_TOKENS and \
input_split.scheduler_output_right.total_num_scheduled_tokens >= envs.VLLM_TBO_MIN_TOKENS:
use_tbo = True
if use_tbo:
num_input_tokens_left = input_split.scheduler_output_left.total_num_scheduled_tokens
num_input_tokens_right = num_input_tokens - num_input_tokens_left
......@@ -319,7 +341,8 @@ def tbo_split_and_execute_model(
with set_forward_context(attn_metadata,
runner.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp):
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=True):
runner.maybe_setup_kv_connector(scheduler_output)
model_output = runner.model(
......
......@@ -50,6 +50,7 @@ class TwoBatchOverlap():
self.left_thread.start()
self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,))
self.right_thread.start()
if get_tp_group().rank == 0:
logger.info('tbo:two batch overlap start')
def finish_thread(self):
......@@ -90,7 +91,8 @@ class TwoBatchOverlap():
with set_forward_context(attn_metadata,
self.model_runner.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=self.num_tokens_across_dp):
num_tokens_across_dp=self.num_tokens_across_dp,
skip_cuda_graphs=True):
model_output = self.model_runner.model(
input_ids=input_ids,
positions=positions,
......
import torch
from vllm.forward_context import set_forward_context
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, EagleProposer
class V1ZeroEagleProposer(EagleProposer):
def __init__(self, vllm_config, device, runner=None):
super().__init__(vllm_config, device, runner)
self.spec_scheduler_max_num_tokens = 0
def propose(
self,
# [num_tokens]
target_token_ids: torch.Tensor,
# [num_tokens]
target_positions: torch.Tensor,
# [num_tokens, hidden_size]
target_hidden_states: torch.Tensor,
# [num_tokens]
target_slot_mapping: torch.Tensor,
# [batch_size]
next_token_ids: torch.Tensor,
# [batch_size + 1] starting with 0
cu_num_tokens: torch.Tensor,
# [batch_size, max_num_blocks_per_req]
block_table: torch.Tensor,
# [batch_size]
sampling_metadata: SamplingMetadata,
decoding: bool = False,
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
last_token_indices = cu_num_tokens[1:] - 1
if self.method == "eagle3":
assert isinstance(self.model, Eagle3LlamaForCausalLM)
target_hidden_states = self.model.combine_hidden_states(
target_hidden_states)
assert target_hidden_states.shape[-1] == self.hidden_size
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self.input_ids[last_token_indices] = next_token_ids
# FA requires seq_len to have dtype int32.
seq_lens = (target_positions[last_token_indices] + 1).int()
if self.method in ["eagle", "eagle3"]:
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
max_seq_len = seq_lens.max().item()
max_num_tokens = (cu_num_tokens[1:] -
cu_num_tokens[:-1]).max().item()
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_tokens,
max_query_len=max_num_tokens,
query_start_loc=cu_num_tokens,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table,
slot_mapping=target_slot_mapping,
# TODO(woosuk): Support cascade attention.
use_cascade=False,
common_prefix_len=0,
cu_prefix_query_lens=None,
prefix_kv_lens=None,
suffix_kv_lens=None,
)
elif self.method == "deepseek_mtp":
max_query_len = self.spec_scheduler_max_num_tokens
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=cu_num_tokens,
seq_lens=seq_lens,
num_reqs=batch_size,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
slot_mapping=target_slot_mapping,
spec_layer_decoding=decoding
)
assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_metadata_builders[0].build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata
)
else:
raise ValueError(f"Unsupported method: {self.method}")
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
if self.use_cuda_graph and \
num_tokens <= self.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
else:
num_input_tokens = num_tokens
# copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states
if (decoding and self.use_full_cuda_graph
and num_tokens <= self.cudagraph_batch_sizes[-1]):
assert self.attn_metadata_cudagraph
if self.method in ["eagle", "eagle3"]:
self.attn_metadata_cudagraph.seq_lens[:batch_size] = (
attn_metadata.seq_lens)
self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
attn_metadata.slot_mapping)
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
attn_metadata.query_start_loc)
self.attn_metadata_cudagraph.block_table[:batch_size] = (
attn_metadata.block_table)
elif self.method == "deepseek_mtp":
self.attn_metadata_cudagraph.num_actual_tokens = (
attn_metadata.num_actual_tokens)
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
attn_metadata.query_start_loc)
self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
attn_metadata.slot_mapping)
self.attn_metadata_cudagraph.num_decodes = (
attn_metadata.num_decodes)
self.attn_metadata_cudagraph.num_decode_tokens = (
attn_metadata.num_decode_tokens)
self.attn_metadata_cudagraph.num_prefills = (
attn_metadata.num_prefills)
if attn_metadata.decode is not None:
self.attn_metadata_cudagraph.decode.block_table[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.block_table)
self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.seq_lens)
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
skip_cuda_graphs=not decoding):
ret_hidden_states = self.model(
self.input_ids[:num_input_tokens],
self.positions[:num_input_tokens],
self.hidden_states[:num_input_tokens],
)
if self.method == "deepseek_mtp":
last_hidden_states = ret_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
draft_token_ids = logits.argmax(dim=-1)
# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1:
# [batch_size, 1]
return draft_token_ids.view(-1, 1)
# TODO: Currently, MTP module released by deepseek only has
# one layer. Adapt this code to support multiple layers once
# there's a multi-layer MTP module.
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
positions = target_positions[last_token_indices]
if self.method == "deepseek_mtp":
hidden_states = last_hidden_states[last_token_indices]
else:
hidden_states = hidden_states[last_token_indices]
if self.use_cuda_graph and \
batch_size <= self.cudagraph_batch_sizes[-1]:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else:
input_batch_size = batch_size
attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
if isinstance(attn_metadata, MLACommonMetadata):
attn_metadata.num_decodes = batch_size
attn_metadata.num_decode_tokens = batch_size
attn_metadata.num_prefills = 0
block_table = self.runner.attn_metadata_builders[0].block_table.get_device_tensor()[:batch_size, ...]
attn_metadata.decode = self.runner.attn_metadata_builders[0]._build_decode(
block_table_tensor=block_table,
seq_lens=seq_lens,
)
for i in range(self.num_speculative_tokens - 1):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
input_ids = draft_token_ids_list[-1].int()
positions += 1
# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex
# to remove such requests from the batch, we keep them in the batch
# but adjust the position ids and slot mappings to avoid the
# out-of-range access during the model execution. The draft tokens
# generated with this adjustment should be ignored.
exceeds_max_model_len = positions >= self.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions = torch.where(exceeds_max_model_len, 0,
positions)
if isinstance(attn_metadata, MLACommonMetadata):
attn_metadata.decode.seq_lens += 1
else:
attn_metadata.seq_lens += 1
# Increment the sequence lengths.
attn_metadata.max_seq_len += 1
# Consider max model length.
attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
self.max_model_len)
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
# Compute the slot mapping.
block_numbers = clamped_positions // self.block_size
block_ids = block_table.gather(dim=1,
index=block_numbers.view(-1, 1))
block_ids = block_ids.view(-1)
attn_metadata.slot_mapping = (block_ids * self.block_size +
clamped_positions % self.block_size)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len,
PADDING_SLOT_ID)
# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions
self.hidden_states[:batch_size] = hidden_states
if (self.use_full_cuda_graph
and batch_size <= self.cudagraph_batch_sizes[-1]):
assert self.attn_metadata_cudagraph
if self.method in ["eagle", "eagle3"]:
self.attn_metadata_cudagraph.seq_lens[:batch_size] = (
attn_metadata.seq_lens)
self.attn_metadata_cudagraph.slot_mapping[:batch_size] = (
attn_metadata.slot_mapping)
if i == 0:
self.attn_metadata_cudagraph.query_start_loc[:batch_size +
1] = (
attn_metadata
.
query_start_loc
)
self.attn_metadata_cudagraph.block_table[:batch_size] = (
attn_metadata.block_table)
elif self.method == "deepseek_mtp":
self.attn_metadata_cudagraph.num_actual_tokens = (
attn_metadata.num_actual_tokens)
self.attn_metadata_cudagraph.slot_mapping[:attn_metadata.num_decode_tokens] = (
attn_metadata.slot_mapping)
self.attn_metadata_cudagraph.num_decodes = (
attn_metadata.num_decodes)
self.attn_metadata_cudagraph.num_decode_tokens = (
attn_metadata.num_decode_tokens)
self.attn_metadata_cudagraph.num_prefills = (
attn_metadata.num_prefills)
self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.seq_lens)
if i == 0:
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
attn_metadata.query_start_loc)
self.attn_metadata_cudagraph.decode.block_table[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.block_table)
# Run the model.
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=input_batch_size):
ret_hidden_states = self.model(
self.input_ids[:input_batch_size],
self.positions[:input_batch_size],
self.hidden_states[:input_batch_size],
)
if self.method == "deepseek_mtp":
last_hidden_states = ret_hidden_states
hidden_states = last_hidden_states[:batch_size]
else:
last_hidden_states, hidden_states = ret_hidden_states
hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size],
None)
# TODO(wenlong): get more than one token for tree attention
draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids)
# [batch_size, num_speculative_tokens]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids
......@@ -18,6 +18,7 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.zero_overhead.v1.eagle import V1ZeroEagleProposer
from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput
from vllm.profiler.prof import profile
from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model
......@@ -31,10 +32,15 @@ class V1ZeroModelRunner(GPUModelRunner):
self.last_sampled_token_lens = []
self.last_sampler_event = torch.cuda.Event(enable_timing=False)
self.last_sampler_host_tokens = None
self.token_ids_cpu_fix_recode = []
self.token_ids_cpu_fix_record = []
self.last_draft_token_ids = None
self.last_draft_host_tokens = None
self.last_draft_event = torch.cuda.Event(enable_timing=False)
self.spec_sampler_event = torch.cuda.Event(enable_timing=False)
self.spec_scheduler_max_num_tokens = 0
if hasattr(self, 'drafter') and isinstance(self.drafter, EagleProposer):
self.drafter = V1ZeroEagleProposer(self.vllm_config, self.device,
self)
def _prepare_inputs(
self,
......@@ -62,6 +68,7 @@ class V1ZeroModelRunner(GPUModelRunner):
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
max_num_scheduled_tokens = max(tokens)
self.spec_scheduler_max_num_tokens = max_num_scheduled_tokens
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
......@@ -281,7 +288,8 @@ class V1ZeroModelRunner(GPUModelRunner):
def propose_draft_token_ids(
self,
scheduler_output: "SchedulerOutput",
sampled_token_ids: list[list[int]],
num_accepted_tokens_tensor: torch.Tensor,
sampled_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata,
hidden_states: torch.Tensor,
sample_hidden_states: torch.Tensor,
......@@ -317,26 +325,8 @@ class V1ZeroModelRunner(GPUModelRunner):
elif self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
# TODO(woosuk): Refactor the loop.
if self.last_sampled_token_ids is not None:
next_token_ids = self.last_sampled_token_ids.flatten()
else:
next_token_ids: list[int] = []
for i, token_ids in enumerate(sampled_token_ids):
if token_ids:
# Common case.
next_token_id = token_ids[-1]
else:
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id = self.input_batch.req_ids[i]
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.device)
row_indices = torch.arange(sampled_token_ids.size(0), device=sampled_token_ids.device)
next_token_ids = sampled_token_ids[row_indices, num_accepted_tokens_tensor].flatten()
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
eagle_attn_metadata = attn_metadata[
......@@ -348,6 +338,7 @@ class V1ZeroModelRunner(GPUModelRunner):
else:
block_table = None
spec_scheduler_max_num_tokens = self.spec_scheduler_max_num_tokens
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.input_ids[:num_scheduled_tokens]
......@@ -363,16 +354,11 @@ class V1ZeroModelRunner(GPUModelRunner):
cu_num_tokens = eagle_attn_metadata.query_start_loc
else:
# TODO(woosuk): Refactor this.
num_accepted_tokens = [len(s) - 1 for s in sampled_token_ids]
num_accepted_tokens_tensor = async_tensor_h2d(
num_accepted_tokens,
dtype=torch.int32,
target_device=self.device,
pin_memory=True)
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
eagle_attn_metadata.query_start_loc,
num_accepted_tokens_tensor,
)
spec_scheduler_max_num_tokens = 1
target_token_ids = self.input_ids[token_indices]
# TODO(woosuk): Support M-RoPE.
target_positions = self.positions[token_indices]
......@@ -383,6 +369,7 @@ class V1ZeroModelRunner(GPUModelRunner):
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = eagle_attn_metadata.slot_mapping[
token_indices]
self.drafter.spec_scheduler_max_num_tokens = spec_scheduler_max_num_tokens
draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
......@@ -392,7 +379,7 @@ class V1ZeroModelRunner(GPUModelRunner):
cu_num_tokens=cu_num_tokens,
block_table=block_table,
sampling_metadata=sampling_metadata,
decoding=spec_decode_metadata is not None
decoding=spec_decode_metadata is not None,
)
spec_token_ids = np.ones(draft_token_ids.shape, dtype=int).tolist()
self.last_draft_token_ids = draft_token_ids
......@@ -622,22 +609,49 @@ class V1ZeroModelRunner(GPUModelRunner):
scheduler_output,
)
# Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1]
fix_req_ids = None
fix_sampled_token_ids = None
fix_draft_token_ids = None
fix_draft_req_ids = self.last_sampled_req_ids
is_output_valid = False
# Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1]
if not self.speculative_config:
# Speculative decoding is not enabled.
spec_token_ids = None
fix_draft_req_ids = None
else:
sampled_token_ids_cpu = sampled_token_ids.to('cpu', non_blocking=True)
self.spec_sampler_event.record()
if self.last_draft_host_tokens is not None:
self.last_draft_event.synchronize()
fix_draft_token_ids = self.last_draft_host_tokens.tolist()
mask = (sampled_token_ids == -1)
mask_int = mask.int()
first_neg_one_indices = torch.argmax(mask_int, dim=1)
num_accepted_tokens_tensor = torch.where(torch.any(mask, dim=1), first_neg_one_indices, sampled_token_ids.size(1)) - 1
spec_token_ids = self.propose_draft_token_ids(
scheduler_output,
num_accepted_tokens_tensor,
sampled_token_ids,
sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
attn_metadata,
)
if self.speculative_config:
self.spec_sampler_event.synchronize()
if max_gen_len == 1:
valid_sampled_token_ids = sampled_token_ids.tolist()
valid_sampled_token_ids = sampled_token_ids_cpu.tolist()
else:
# Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output(
sampled_token_ids,
sampled_token_ids_cpu,
self.input_batch.vocab_size,
)
self.last_sampler_host_tokens = None
......@@ -649,13 +663,21 @@ class V1ZeroModelRunner(GPUModelRunner):
if self.last_sampler_host_tokens != None:
self.last_sampler_event.synchronize()
fix_sampled_token_ids = self.last_sampler_host_tokens.tolist()
for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_recode:
self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = fix_sampled_token_ids[req_idx]
for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_record:
if start_idx == -1:
continue
req_id = fix_req_ids[req_idx]
if req_id in self.input_batch.req_ids:
new_req_idx = self.input_batch.req_ids.index(req_id)
self.input_batch.token_ids_cpu[new_req_idx, start_idx:end_idx] = fix_sampled_token_ids[req_idx]
for req_idx, req_id in enumerate(fix_req_ids):
if req_id in self.requests:
req_state = self.requests[req_id]
token_idx = self.last_sampled_token_lens[req_idx]
req_state.output_token_ids[token_idx] = fix_sampled_token_ids[req_idx][0]
if token_idx == -1:
continue
fix_len = len(fix_sampled_token_ids[req_idx])
req_state.output_token_ids[token_idx:token_idx + fix_len] = fix_sampled_token_ids[req_idx]
self.last_sampler_host_tokens = sampled_token_ids.to('cpu', non_blocking=True)
self.last_sampler_event.record()
self.last_sampled_token_ids = sampled_token_ids
......@@ -670,11 +692,16 @@ class V1ZeroModelRunner(GPUModelRunner):
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
# the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker.
self.token_ids_cpu_fix_recode.clear()
self.token_ids_cpu_fix_record.clear()
self.last_sampled_req_ids = []
self.last_sampled_token_lens = []
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
req_id = self.input_batch.req_ids[req_idx]
self.last_sampled_req_ids.append(req_id)
cache_output_len = -1
if not sampled_ids:
self.last_sampled_token_lens.append(-1)
self.token_ids_cpu_fix_record.append([req_idx, -1, -1])
continue
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
......@@ -686,34 +713,15 @@ class V1ZeroModelRunner(GPUModelRunner):
self.input_batch.token_ids_cpu[req_idx,
start_idx:end_idx] = sampled_ids
self.token_ids_cpu_fix_recode.append([req_idx, start_idx, end_idx])
self.token_ids_cpu_fix_record.append([req_idx, start_idx, end_idx])
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx
req_id = self.input_batch.req_ids[req_idx]
if req_id in self.requests:
req_state = self.requests[req_id]
self.last_sampled_req_ids.append(req_id)
self.last_sampled_token_lens.append(len(req_state.output_token_ids))
cache_output_len = len(req_state.output_token_ids)
req_state.output_token_ids.extend(sampled_ids)
self.last_sampled_token_lens.append(cache_output_len)
if not self.speculative_config:
# Speculative decoding is not enabled.
spec_token_ids = None
fix_draft_req_ids = None
else:
if self.last_draft_host_tokens is not None:
self.last_draft_event.synchronize()
fix_draft_token_ids = self.last_draft_host_tokens.tolist()
spec_token_ids = self.propose_draft_token_ids(
scheduler_output,
valid_sampled_token_ids,
sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
attn_metadata,
)
# Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment