Commit 22a95571 authored by zhuwenwen's avatar zhuwenwen
Browse files

add v1 engine + deepseek r1 mtp + zero-overhead scheduler

parent ac4cc84e
...@@ -764,10 +764,21 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -764,10 +764,21 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
repeats = torch.from_numpy(query_lens).pin_memory().to( repeats = torch.from_numpy(query_lens).pin_memory().to(
block_table_tensor.device, non_blocking=True).contiguous() block_table_tensor.device, non_blocking=True).contiguous()
decode_block_table_tensor = torch.repeat_interleave( if envs.VLLM_ZERO_OVERHEAD:
block_table_tensor[:num_decodes, ...], decode_block_table_tensor = torch.empty((self._num_decode_tokens, block_table_tensor.shape[1]),
repeats, dim=0).contiguous() device=block_table_tensor.device)
decode_seq_lens = torch.repeat_interleave(seq_lens[:num_decodes], repeats, dim=0).contiguous() arange_np = np.arange(self._num_decodes)
indices_np = np.repeat(arange_np, query_lens)
indices = torch.from_numpy(indices_np).pin_memory().to(
block_table_tensor.device, non_blocking=True)
decode_block_table_tensor = block_table_tensor[indices].contiguous()
decode_seq_lens = seq_lens[indices].contiguous()
else:
decode_block_table_tensor = torch.repeat_interleave(
block_table_tensor[:self._num_decodes, ...],
repeats, dim=0).contiguous()
decode_seq_lens = torch.repeat_interleave(seq_lens[:self._num_decodes], repeats, dim=0).contiguous()
seq_lens_minus = torch.from_numpy(rarange).to(torch.int32).pin_memory().to( seq_lens_minus = torch.from_numpy(rarange).to(torch.int32).pin_memory().to(
seq_lens.device, non_blocking=True).contiguous() seq_lens.device, non_blocking=True).contiguous()
decode_seq_lens = decode_seq_lens - seq_lens_minus decode_seq_lens = decode_seq_lens - seq_lens_minus
......
...@@ -69,7 +69,6 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch ...@@ -69,7 +69,6 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model
from vllm.zero_overhead.v1.gpu_model_runner import execute_model_sampled, zero_prepare_inputs
from ..sample.logits_processor import LogitsProcessorManager from ..sample.logits_processor import LogitsProcessorManager
from .utils import (bind_kv_cache, gather_mm_placeholders, from .utils import (bind_kv_cache, gather_mm_placeholders,
...@@ -1020,15 +1019,26 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1020,15 +1019,26 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# [0, 1, 2, 5, 6, 9] # [0, 1, 2, 5, 6, 9]
target_logits_indices += arange target_logits_indices += arange
# TODO: Optimize the CPU -> GPU copy. if envs.VLLM_ZERO_OVERHEAD:
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to( cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).pin_memory().to(
self.device, non_blocking=True) self.device, non_blocking=True)
logits_indices = torch.from_numpy(logits_indices).to(self.device, logits_indices = torch.from_numpy(logits_indices).pin_memory().to(self.device,
non_blocking=True) non_blocking=True)
target_logits_indices = torch.from_numpy(target_logits_indices).to( target_logits_indices = torch.from_numpy(target_logits_indices).pin_memory().to(
self.device, non_blocking=True) self.device, non_blocking=True)
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to( bonus_logits_indices = torch.from_numpy(bonus_logits_indices).pin_memory().to(
self.device, non_blocking=True) self.device, non_blocking=True)
else:
# TODO: Optimize the CPU -> GPU copy.
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
self.device, non_blocking=True)
logits_indices = torch.from_numpy(logits_indices).to(self.device,
non_blocking=True)
target_logits_indices = torch.from_numpy(target_logits_indices).to(
self.device, non_blocking=True)
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
self.device, non_blocking=True)
# Compute the draft token ids. # Compute the draft token ids.
# draft_token_indices: [ 1, 2, 3, 105, 106, 208] # draft_token_indices: [ 1, 2, 3, 105, 106, 208]
...@@ -1440,9 +1450,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1440,9 +1450,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# If attention doesn't support CUDA Graphs for this batch, but we # If attention doesn't support CUDA Graphs for this batch, but we
# compiled with full CUDA graphs, we have to skip them entirely. # compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
if envs.VLLM_ZERO_OVERHEAD:
zero_prepare_inputs(self, scheduler_output, input_ids)
if envs.VLLM_ENABLE_TBO and not self.use_cuda_graph: if envs.VLLM_ENABLE_TBO and not self.use_cuda_graph:
model_output, finished_sending, finished_recving = \ model_output, finished_sending, finished_recving = \
...@@ -1591,21 +1598,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1591,21 +1598,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Get the valid generated tokens. # Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1] max_gen_len = sampled_token_ids.shape[-1]
if envs.VLLM_ZERO_OVERHEAD:
return execute_model_sampled(self, max_gen_len, sampled_token_ids,
discard_sampled_tokens_req_indices, scheduler_output,
sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
attn_metadata,
logprobs_lists,
prompt_logprobs_dict,
finished_sending,
finished_recving,
num_nans_in_logits)
if max_gen_len == 1: if max_gen_len == 1:
# No spec decode tokens. # No spec decode tokens.
......
...@@ -33,6 +33,7 @@ from vllm.v1.utils import report_usage_stats ...@@ -33,6 +33,7 @@ from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase from vllm.v1.worker.worker_base import WorkerBase
from vllm.zero_overhead.utils import zero_overhead_stream from vllm.zero_overhead.utils import zero_overhead_stream
from vllm.zero_overhead.v1.gpu_model_runner import V1ZeroModelRunner
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -187,8 +188,12 @@ class Worker(WorkerBase): ...@@ -187,8 +188,12 @@ class Worker(WorkerBase):
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
# Construct the model runner # Construct the model runner
self.model_runner: GPUModelRunner = GPUModelRunner( if envs.VLLM_ZERO_OVERHEAD:
self.vllm_config, self.device) self.model_runner: GPUModelRunner = V1ZeroModelRunner(
self.vllm_config, self.device)
else:
self.model_runner: GPUModelRunner = GPUModelRunner(
self.vllm_config, self.device)
if self.rank == 0: if self.rank == 0:
# If usage stat is enabled, collect relevant info. # If usage stat is enabled, collect relevant info.
......
...@@ -12,11 +12,15 @@ requsets_valid_token_len = {} ...@@ -12,11 +12,15 @@ requsets_valid_token_len = {}
def check_stop(request: Request, def check_stop(request: Request,
max_model_len: int, max_model_len: int,
pooler_output: Optional[torch.Tensor] = None) -> bool: pooler_output: Optional[torch.Tensor] = None,
if request.request_id not in requsets_valid_token_len: use_valid_token_len:bool = False) -> bool:
requsets_valid_token_len[request.request_id] = 0 if use_valid_token_len:
return False if request.request_id not in requsets_valid_token_len:
valid_output_len = requsets_valid_token_len[request.request_id] requsets_valid_token_len[request.request_id] = 0
return False
valid_output_len = requsets_valid_token_len[request.request_id]
else:
valid_output_len = request.num_output_tokens
valid_num_tokens = request.num_prompt_tokens + valid_output_len valid_num_tokens = request.num_prompt_tokens + valid_output_len
if (valid_num_tokens >= max_model_len if (valid_num_tokens >= max_model_len
or valid_output_len >= request.max_tokens): or valid_output_len >= request.max_tokens):
...@@ -60,110 +64,118 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -60,110 +64,118 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
spec_decoding_stats: Optional[SpecDecodingStats] = None spec_decoding_stats: Optional[SpecDecodingStats] = None
# fix last model out in zero overhead # fix last model out in zero overhead
for req_idx, req_id in enumerate(model_runner_output.fix_req_ids): if model_runner_output.fix_req_ids is not None:
if req_id not in scheduler.requests: for req_idx, req_id in enumerate(model_runner_output.fix_req_ids):
continue if req_id not in scheduler.requests:
request = scheduler.requests[req_id] continue
generated_token_ids = model_runner_output.fix_sampled_token_ids[req_idx] request = scheduler.requests[req_id]
if req_id not in requsets_valid_token_len: generated_token_ids = model_runner_output.fix_sampled_token_ids[req_idx]
requsets_valid_token_len[req_id] = 0 if req_id not in requsets_valid_token_len:
valid_output_len = requsets_valid_token_len[req_id] requsets_valid_token_len[req_id] = 0
fix_offset = valid_output_len - request.num_output_tokens valid_output_len = requsets_valid_token_len[req_id]
if isinstance(generated_token_ids, int): fix_offset = valid_output_len - request.num_output_tokens
request._output_token_ids[fix_offset] = generated_token_ids if isinstance(generated_token_ids, int):
request._all_token_ids[fix_offset] = generated_token_ids request._output_token_ids[fix_offset] = generated_token_ids
requsets_valid_token_len[req_id] += 1 request._all_token_ids[fix_offset] = generated_token_ids
else: requsets_valid_token_len[req_id] += 1
valid_output_end = valid_output_len + len(generated_token_ids) - request.num_output_tokens
if valid_output_end == 0:
request._output_token_ids[fix_offset : ] = generated_token_ids
request._all_token_ids[fix_offset : ] = generated_token_ids
else: else:
request._output_token_ids[fix_offset : valid_output_end] = generated_token_ids valid_output_end = valid_output_len + len(generated_token_ids) - request.num_output_tokens
request._all_token_ids[fix_offset : valid_output_end] = generated_token_ids if valid_output_end == 0:
requsets_valid_token_len[req_id] += len(generated_token_ids) request._output_token_ids[fix_offset : ] = generated_token_ids
request._all_token_ids[fix_offset : ] = generated_token_ids
else:
request._output_token_ids[fix_offset : valid_output_end] = generated_token_ids
request._all_token_ids[fix_offset : valid_output_end] = generated_token_ids
requsets_valid_token_len[req_id] += len(generated_token_ids)
stopped = False
new_logprobs = None
new_token_ids = generated_token_ids
kv_transfer_params = None
stopped = False # Check for stop and update request state.
new_logprobs = None # This must be called before we make the EngineCoreOutput.
new_token_ids = generated_token_ids for num_new, output_token_id in enumerate(new_token_ids, 1):
kv_transfer_params = None stopped = check_stop(request, scheduler.max_model_len, True)
if stopped:
# Check for stop and update request state. kv_transfer_params = scheduler._free_request(request)
# This must be called before we make the EngineCoreOutput. del new_token_ids[num_new:] # Trim new tokens if needed.
for num_new, output_token_id in enumerate(new_token_ids, 1): break
stopped = check_stop(request, scheduler.max_model_len)
if stopped: pooler_output = None
kv_transfer_params = scheduler._free_request(request) if pooler_outputs:
del new_token_ids[num_new:] # Trim new tokens if needed. pooler_output = pooler_outputs[req_index]
break stopped = check_stop(request, scheduler.max_model_len,
pooler_output, True)
pooler_output = None if stopped:
if pooler_outputs: kv_transfer_params = scheduler._free_request(request)
pooler_output = pooler_outputs[req_index]
stopped = check_stop(request, scheduler.max_model_len, # Extract sample logprobs if needed.
pooler_output) if request.sampling_params is not None \
if stopped: and request.sampling_params.logprobs is not None and logprobs:
kv_transfer_params = scheduler._free_request(request) # NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
# Extract sample logprobs if needed. new_logprobs = logprobs.slice(req_index, req_index + 1)
if request.sampling_params is not None \
and request.sampling_params.logprobs is not None and logprobs: if new_token_ids and scheduler.structured_output_manager.should_advance(
# NOTE: once we support N tokens per step (spec decode), request):
# the outer lists can be of length > 1. # NOTE: structured_output_request
new_logprobs = logprobs.slice(req_index, req_index + 1) # should not be None if use_structured_output, we have
# check above, so safe to ignore type warning
if new_token_ids and scheduler.structured_output_manager.should_advance( request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
request): req_id, new_token_ids)
# NOTE: structured_output_request
# should not be None if use_structured_output, we have # spec_token_ids comes from the model runner output
# check above, so safe to ignore type warning if num_nans_in_logits is not None and req_id in num_nans_in_logits:
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] request.num_nans_in_logits = num_nans_in_logits[req_id]
req_id, new_token_ids)
# Get prompt logprobs for this request.
# spec_token_ids comes from the model runner output prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if num_nans_in_logits is not None and req_id in num_nans_in_logits: if new_token_ids or pooler_output is not None \
request.num_nans_in_logits = num_nans_in_logits[req_id] or kv_transfer_params:
# Add newly generated spec token ids to the request. # Add EngineCoreOutput for this Request.
if spec_token_ids is not None: outputs[request.client_index].append(
if scheduler.structured_output_manager.should_advance(request): EngineCoreOutput(
metadata = request.structured_output_request request_id=req_id,
# Needs to happen after new_token_ids are accepted. new_token_ids=new_token_ids,
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] finish_reason=request.get_finished_reason(),
spec_token_ids[req_index]) new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
pooling_output=pooler_output,
stop_reason=request.stop_reason,
events=request.take_events(),
kv_transfer_params=kv_transfer_params,
num_cached_tokens=request.num_cached_tokens,
))
else: else:
request.spec_token_ids = spec_token_ids[req_index] assert not prompt_logprobs_tensors
# Get prompt logprobs for this request. # fix last model out in zero overhead
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) if model_runner_output.fix_draft_req_ids is not None:
if new_token_ids or pooler_output is not None \ for req_idx, req_id in enumerate(model_runner_output.fix_draft_req_ids):
or kv_transfer_params: if req_id not in scheduler.requests:
continue
# Add EngineCoreOutput for this Request. request = scheduler.requests[req_id]
outputs[request.client_index].append(
EngineCoreOutput( # Add newly generated spec token ids to the request.
request_id=req_id, if model_runner_output.fix_draft_tokens_ids is not None:
new_token_ids=new_token_ids, if scheduler.structured_output_manager.should_advance(request):
finish_reason=request.get_finished_reason(), metadata = request.structured_output_request
new_logprobs=new_logprobs, # Needs to happen after new_token_ids are accepted.
new_prompt_logprobs_tensors=prompt_logprobs_tensors, request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
pooling_output=pooler_output, model_runner_output.fix_draft_tokens_ids[req_idx])
stop_reason=request.stop_reason, else:
events=request.take_events(), request.spec_token_ids = model_runner_output.fix_draft_tokens_ids[req_idx]
kv_transfer_params=kv_transfer_params,
num_cached_tokens=request.num_cached_tokens,
))
else:
# Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
# loop can be a performance bottleneck. We should do our best to avoid # loop can be a performance bottleneck. We should do our best to avoid
# expensive operations inside the loop. # expensive operations inside the loop.
for request in scheduler.running: for request in scheduler.running:
if request.is_finished():
continue
req_id = request.request_id req_id = request.request_id
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
if num_tokens_scheduled == 0: if num_tokens_scheduled == 0:
...@@ -197,7 +209,6 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -197,7 +209,6 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
if request.has_encoder_inputs: if request.has_encoder_inputs:
scheduler._free_encoder_inputs(request) scheduler._free_encoder_inputs(request)
stopped = False
new_logprobs = None new_logprobs = None
new_token_ids = generated_token_ids new_token_ids = generated_token_ids
kv_transfer_params = None kv_transfer_params = None
...@@ -210,19 +221,24 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -210,19 +221,24 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
# Check for stop and update request state. # Check for stop and update request state.
# This must be called before we make the EngineCoreOutput. # This must be called before we make the EngineCoreOutput.
stopped = check_stop(request, scheduler.max_model_len) if model_runner_output.is_output_valid:
# if stopped: stopped = check_stop(request, scheduler.max_model_len,
# kv_transfer_params = scheduler._free_request(request) False)
# del new_token_ids[num_new:] # Trim new tokens if needed. if stopped:
# break kv_transfer_params = scheduler._free_request(request)
del new_token_ids[num_new:] # Trim new tokens if needed.
break
pooler_output = None pooler_output = None
if pooler_outputs: if pooler_outputs:
pooler_output = pooler_outputs[req_index] if model_runner_output.is_output_valid:
stopped = check_stop(request, scheduler.max_model_len, pooler_output = pooler_outputs[req_index]
pooler_output) stopped = check_stop(request, scheduler.max_model_len,
# if stopped: pooler_output,
# kv_transfer_params = scheduler._free_request(request) False)
if stopped:
kv_transfer_params = scheduler._free_request(request)
# Extract sample logprobs if needed. # Extract sample logprobs if needed.
if request.sampling_params is not None \ if request.sampling_params is not None \
...@@ -252,6 +268,27 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -252,6 +268,27 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
spec_token_ids[req_index]) spec_token_ids[req_index])
else: else:
request.spec_token_ids = spec_token_ids[req_index] request.spec_token_ids = spec_token_ids[req_index]
if model_runner_output.is_output_valid:
# # Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if new_token_ids or pooler_output is not None \
or kv_transfer_params:
# Add EngineCoreOutput for this Request.
outputs[request.client_index].append(
EngineCoreOutput(
request_id=req_id,
new_token_ids=new_token_ids,
finish_reason=request.get_finished_reason(),
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
pooling_output=pooler_output,
stop_reason=request.stop_reason,
events=request.take_events(),
kv_transfer_params=kv_transfer_params,
num_cached_tokens=request.num_cached_tokens,
))
if not stopped: if not stopped:
new_running.append(request) new_running.append(request)
......
from typing import Any, Optional, Union
import torch import torch
import numpy as np import numpy as np
from vllm import envs
from vllm.distributed.kv_transfer.kv_transfer_state import get_kv_transfer_group, has_kv_transfer_group from vllm.distributed.kv_transfer.kv_transfer_state import get_kv_transfer_group, has_kv_transfer_group
from vllm.distributed.parallel_state import get_tp_group from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.utils import async_tensor_h2d from vllm.forward_context import set_forward_context
from vllm.sequence import IntermediateTensors
from vllm.utils import async_tensor_h2d, round_up
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput
from vllm.profiler.prof import profile from vllm.profiler.prof import profile
from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model
class V1ZeroModelRunner(): class V1ZeroModelRunner(GPUModelRunner):
def __init__(self): def __init__(self, vllm_config, device):
super().__init__(vllm_config, device)
self.last_sampled_token_ids = None self.last_sampled_token_ids = None
self.last_sampled_req_ids = [] self.last_sampled_req_ids = []
self.last_sampled_token_lens = [] self.last_sampled_token_lens = []
self.last_sampler_event = torch.cuda.Event(enable_timing=False) self.last_sampler_event = torch.cuda.Event(enable_timing=False)
self.last_sampler_host_tokens = None self.last_sampler_host_tokens = None
self.token_ids_cpu_fix_recode = [] self.token_ids_cpu_fix_recode = []
self.last_draft_token_ids = None
self.last_draft_host_tokens = None
self.last_draft_event = torch.cuda.Event(enable_timing=False)
def set_last_sampled_token_ids(self, sampled_token_ids): def zero_prepare_inputs(self, scheduler_output, input_ids):
self.last_sampled_token_ids = sampled_token_ids req_ids = self.input_batch.req_ids
self.last_sampled_req_ids = [] update_req_indices = []
self.last_sampled_token_lens = [] input_ids_indices = []
token_idx = 0
v1_zero_overhead = V1ZeroModelRunner() if self.last_draft_token_ids is not None:
draft_tokens_num = self.last_draft_token_ids.shape[1]
def zero_prepare_inputs(runner, scheduler_output, input_ids): for req_id in req_ids:
req_ids = runner.input_batch.req_ids if req_id in self.last_sampled_req_ids:
update_req_indices = [] req_idx = self.last_sampled_req_ids.index(req_id) * draft_tokens_num
input_ids_indices = [] for num_idx in range(draft_tokens_num):
token_idx = 0 update_req_indices.append(req_idx + num_idx)
if v1_zero_overhead.last_sampled_token_ids is None: input_ids_indices.append(token_idx + num_idx + 1)
return token_idx += draft_tokens_num + 1
sampled_tokens_num = v1_zero_overhead.last_sampled_token_ids.shape[1] if len(update_req_indices) > 0:
for req_id in req_ids: update_req_indices_tensor = async_tensor_h2d(update_req_indices, torch.int32,
if req_id in v1_zero_overhead.last_sampled_req_ids: self.device,
req_idx = v1_zero_overhead.last_sampled_req_ids.index(req_id) * sampled_tokens_num True)
update_req_indices.append(req_idx) input_ids_indices_tensor = async_tensor_h2d(input_ids_indices, torch.int32,
input_ids_indices.append(token_idx) self.device,
token_idx += scheduler_output.num_scheduled_tokens[req_id] True)
if len(update_req_indices) > 0: last_draft_token_ids = self.last_draft_token_ids.flatten().to(torch.int)
update_req_indices_tensor = async_tensor_h2d(update_req_indices, torch.int32, input_ids[input_ids_indices_tensor] = last_draft_token_ids[update_req_indices_tensor]
runner.device,
True)
input_ids_indices_tensor = async_tensor_h2d(input_ids_indices, torch.int32,
runner.device,
True)
last_sampled_token_ids = v1_zero_overhead.last_sampled_token_ids.flatten()
for i in range(sampled_tokens_num):
input_ids[input_ids_indices_tensor + i] = last_sampled_token_ids[update_req_indices_tensor + i]
def execute_model_sampled(runner, max_gen_len, sampled_token_ids,
discard_sampled_tokens_req_indices, scheduler_output,
sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
attn_metadata,
logprobs_lists,
prompt_logprobs_dict,
finished_sending,
finished_recving,
num_nans_in_logits
):
fix_req_ids = None
fix_sampled_token_ids = None
if max_gen_len == 1:
# No spec decode tokens.
if v1_zero_overhead.last_sampler_host_tokens != None:
v1_zero_overhead.last_sampler_event.synchronize()
fix_sampled_token_ids = v1_zero_overhead.last_sampler_host_tokens.tolist()
for req_idx, start_idx, end_idx in v1_zero_overhead.token_ids_cpu_fix_recode:
runner.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = fix_sampled_token_ids[req_idx]
fix_req_ids = v1_zero_overhead.last_sampled_req_ids
for req_idx, req_id in enumerate(fix_req_ids):
if req_id in runner.requests:
req_state = runner.requests[req_id]
token_idx = v1_zero_overhead.last_sampled_token_lens[req_idx]
req_state.output_token_ids[token_idx] = fix_sampled_token_ids[req_idx][0]
v1_zero_overhead.last_sampler_host_tokens = sampled_token_ids.to('cpu', non_blocking=True)
v1_zero_overhead.last_sampler_event.record()
v1_zero_overhead.set_last_sampled_token_ids(sampled_token_ids)
valid_sampled_token_ids = np.ones(sampled_token_ids.shape, dtype=int).tolist()
else:
# Includes spec decode tokens.
valid_sampled_token_ids = runner.rejection_sampler.parse_output(
sampled_token_ids,
runner.input_batch.vocab_size,
)
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
# Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back.
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
# the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker.
v1_zero_overhead.token_ids_cpu_fix_recode.clear()
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
if not sampled_ids:
continue
start_idx = runner.input_batch.num_tokens_no_spec[req_idx]
end_idx = start_idx + len(sampled_ids)
assert end_idx <= runner.max_model_len, (
"Sampled token IDs exceed the max model length. "
f"Total number of tokens: {end_idx} > max_model_len: "
f"{runner.max_model_len}")
runner.input_batch.token_ids_cpu[req_idx,
start_idx:end_idx] = sampled_ids
v1_zero_overhead.token_ids_cpu_fix_recode.append([req_idx, start_idx, end_idx])
runner.input_batch.num_tokens_no_spec[req_idx] = end_idx
runner.input_batch.num_tokens[req_idx] = end_idx
req_id = runner.input_batch.req_ids[req_idx]
if req_id in runner.requests:
req_state = runner.requests[req_id]
v1_zero_overhead.last_sampled_req_ids.append(req_id)
v1_zero_overhead.last_sampled_token_lens.append(len(req_state.output_token_ids))
req_state.output_token_ids.extend(sampled_ids)
if not runner.speculative_config:
# Speculative decoding is not enabled.
spec_token_ids = None
else: else:
spec_token_ids = runner.propose_draft_token_ids( update_req_indices = []
input_ids_indices = []
token_idx = 0
if self.last_sampled_token_ids is not None:
sampled_tokens_num = self.last_sampled_token_ids.shape[1]
for req_id in req_ids:
if req_id in self.last_sampled_req_ids:
req_idx = self.last_sampled_req_ids.index(req_id) * sampled_tokens_num
update_req_indices.append(req_idx)
input_ids_indices.append(token_idx)
token_idx += scheduler_output.num_scheduled_tokens[req_id]
if len(update_req_indices) > 0:
update_req_indices_tensor = async_tensor_h2d(update_req_indices, torch.int32,
self.device,
True)
input_ids_indices_tensor = async_tensor_h2d(input_ids_indices, torch.int32,
self.device,
True)
last_sampled_token_ids = self.last_sampled_token_ids.flatten()
for i in range(sampled_tokens_num):
input_ids[input_ids_indices_tensor + i] = last_sampled_token_ids[update_req_indices_tensor + i]
def propose_draft_token_ids(
self,
scheduler_output: "SchedulerOutput",
sampled_token_ids: list[list[int]],
sampling_metadata: SamplingMetadata,
hidden_states: torch.Tensor,
sample_hidden_states: torch.Tensor,
aux_hidden_states: Optional[torch.Tensor],
spec_decode_metadata: Optional[SpecDecodeMetadata],
attn_metadata: dict[str, Any],
) -> list[list[int]]:
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if self.speculative_config.method == "ngram":
assert isinstance(self.drafter, NgramProposer)
spec_token_ids = self.propose_ngram_draft_token_ids(
sampled_token_ids)
elif self.speculative_config.method == "medusa":
assert isinstance(self.drafter, MedusaProposer)
if sample_hidden_states.shape[0] == len(sampled_token_ids):
# The input to the target model does not include draft tokens.
hidden_states = sample_hidden_states
else:
indices = []
offset = 0
for num_draft, tokens in zip(
spec_decode_metadata.num_draft_tokens,
sampled_token_ids):
indices.append(offset + len(tokens) - 1)
offset += num_draft + 1
indices = torch.tensor(indices, device=self.device)
hidden_states = sample_hidden_states[indices]
spec_token_ids = self.drafter.propose(
target_hidden_states=hidden_states,
sampling_metadata=sampling_metadata,
)
elif self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
# TODO(woosuk): Refactor the loop.
if self.last_sampled_token_ids is not None:
next_token_ids = self.last_sampled_token_ids.flatten()
else:
next_token_ids: list[int] = []
for i, token_ids in enumerate(sampled_token_ids):
if token_ids:
# Common case.
next_token_id = token_ids[-1]
else:
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id = self.input_batch.req_ids[i]
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.device)
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
eagle_attn_metadata = attn_metadata[
self.drafter.attn_layer_names[0]]
# NOTE: deepseek_mtp uses MLA which does not have `block_table`
if hasattr(eagle_attn_metadata, "block_table"):
block_table = eagle_attn_metadata.block_table
else:
block_table = None
num_rejected_tokens = None
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.input_ids[:num_scheduled_tokens]
# TODO(woosuk): Support M-RoPE.
target_positions = self.positions[:num_scheduled_tokens]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(
[h[:num_scheduled_tokens] for h in aux_hidden_states],
dim=-1)
else:
target_hidden_states = hidden_states[:num_scheduled_tokens]
target_slot_mapping = eagle_attn_metadata.slot_mapping
cu_num_tokens = eagle_attn_metadata.query_start_loc
else:
# TODO(woosuk): Refactor this.
num_draft_tokens = spec_decode_metadata.num_draft_tokens
num_rejected_tokens = [
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]
num_rejected_tokens_tensor = async_tensor_h2d(
num_rejected_tokens,
dtype=torch.int32,
target_device=self.device,
pin_memory=True)
num_tokens = num_scheduled_tokens - sum(num_rejected_tokens)
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
eagle_attn_metadata.query_start_loc,
num_rejected_tokens_tensor,
num_tokens,
)
target_token_ids = self.input_ids[token_indices]
# TODO(woosuk): Support M-RoPE.
target_positions = self.positions[token_indices]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(
[h[token_indices] for h in aux_hidden_states], dim=-1)
else:
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = eagle_attn_metadata.slot_mapping[
token_indices]
draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
target_slot_mapping=target_slot_mapping,
next_token_ids=next_token_ids,
cu_num_tokens=cu_num_tokens,
block_table=block_table,
sampling_metadata=sampling_metadata,
num_rejected_tokens=num_rejected_tokens
)
spec_token_ids = np.ones(draft_token_ids.shape, dtype=int).tolist()
self.last_draft_token_ids = draft_token_ids
self.last_draft_host_tokens = draft_token_ids.to('cpu', non_blocking=True)
self.last_draft_event.record()
return spec_token_ids
@torch.inference_mode()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, IntermediateTensors]:
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
return self.kv_connector_no_forward(scheduler_output)
# Prepare the decoder inputs.
(attn_metadata, attention_cuda_graphs, logits_indices,
spec_decode_metadata,
num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output))
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_scheduled_tokens)
else:
# Eager mode.
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.compilation_config.pass_config. \
enable_sequence_parallelism and tp_size > 1:
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
else:
num_input_tokens = num_scheduled_tokens
# Padding for DP
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
num_input_tokens += num_pad
# _prepare_inputs may reorder the batch, so we must gather multi
# modal outputs after that to ensure the correct order
if self.is_multimodal_model:
# Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output)
else:
mm_embeds = []
if self.is_multimodal_model and get_pp_group().is_first_rank:
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
input_ids = self.input_ids[:num_scheduled_tokens]
if mm_embeds:
inputs_embeds = self.model.get_input_embeddings(
input_ids, mm_embeds)
else:
inputs_embeds = self.model.get_input_embeddings(input_ids)
# TODO(woosuk): Avoid the copy. Optimize.
self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
inputs_embeds = self.inputs_embeds[:num_input_tokens]
input_ids = None
else:
# For text-only models, we use token ids as input.
# While it is possible to use embeddings as input just like the
# multimodal models, it is not desirable for performance since
# then the embedding layer is not included in the CUDA graph.
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None
if self.uses_mrope:
positions = self.mrope_positions[:, :num_input_tokens]
else:
positions = self.positions[:num_input_tokens]
if get_pp_group().is_first_rank:
intermediate_tensors = None
else:
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_input_tokens, intermediate_tensors, True)
# Some attention backends only support CUDA Graphs in pure decode.
# If attention doesn't support CUDA Graphs for this batch, but we
# compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
self.zero_prepare_inputs(scheduler_output, input_ids)
if envs.VLLM_ENABLE_TBO and not self.use_cuda_graph:
model_output, finished_sending, finished_recving = \
tbo_split_and_execute_model(self, attn_metadata, num_input_tokens,
num_tokens_across_dp, input_ids, positions,
inputs_embeds, scheduler_output, intermediate_tensors)
else:
# Run the model.
# Use persistent buffers for CUDA graphs.
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs,
):
self.maybe_setup_kv_connector(scheduler_output)
model_output = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
self.maybe_wait_for_kv_save()
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
aux_hidden_states = None
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
# TODO: Support overlapping mirco-batches
# https://github.com/vllm-project/vllm/issues/18019
broadcast_pp_output = \
self.parallel_config.distributed_executor_backend \
== "external_launcher" and len(get_pp_group().ranks) > 0
if not get_pp_group().is_last_rank:
# For mid-pipeline stages, return the hidden states.
if not broadcast_pp_output:
return hidden_states
assert isinstance(hidden_states, IntermediateTensors)
get_pp_group().send_tensor_dict(hidden_states.tensors,
all_gather_group=get_tp_group())
logits = None
else:
if self.input_batch.pooling_params:
return self._pool(hidden_states, num_scheduled_tokens,
num_scheduled_tokens_np, finished_sending,
finished_recving)
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
if broadcast_pp_output:
model_output_broadcast_data = {
"logits": logits.contiguous(),
} if logits is not None else {}
model_output_broadcast_data = get_pp_group().broadcast_tensor_dict(
model_output_broadcast_data, src=len(get_pp_group().ranks) - 1)
assert model_output_broadcast_data is not None
logits = model_output_broadcast_data["logits"]
# Apply structured output bitmasks if present
if scheduler_output.grammar_bitmask is not None:
self.apply_grammar_bitmask(scheduler_output, logits)
# Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.sampling_metadata
if spec_decode_metadata is None:
sampler_output = self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
)
else:
# When indexing with a tensor (bonus_logits_indices), PyTorch
# creates a new tensor with separate storage from the original
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
assert logits is not None
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
sampler_output = self.sampler(
logits=bonus_logits,
sampling_metadata=sampling_metadata,
)
bonus_token_ids = sampler_output.sampled_token_ids
# Just like `bonus_logits`, `target_logits` is a new tensor with
# separate storage from the original `logits` tensor. Therefore,
# it is safe to update `target_logits` in place.
target_logits = logits[spec_decode_metadata.target_logits_indices]
output_token_ids = self.rejection_sampler(
spec_decode_metadata,
None, # draft_probs
target_logits,
bonus_token_ids,
sampling_metadata,
)
sampler_output.sampled_token_ids = output_token_ids
num_nans_in_logits = {}
if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
num_nans_in_logits = self._get_nans_in_logits(logits)
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
discard_sampled_tokens_req_indices = []
for i, req_id in enumerate(self.input_batch.req_ids):
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
if seq_len < req_state.num_tokens:
# Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled.
# This relies on cuda-specific torch-internal impl details
generator = self.input_batch.generators.get(i)
if generator is not None:
generator.set_offset(generator.get_offset() - 4)
# Record the index of the request that should not be sampled,
# so that we could clear the sampled tokens before returning.
discard_sampled_tokens_req_indices.append(i)
# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
logprobs_tensors = sampler_output.logprobs_tensors
logprobs_lists = logprobs_tensors.tolists() \
if logprobs_tensors is not None else None
# Compute prompt logprobs if needed.
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
hidden_states[:num_scheduled_tokens],
scheduler_output, scheduler_output,
valid_sampled_token_ids,
sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
attn_metadata,
) )
# Clear KVConnector state after all KVs are generated. # Get the valid generated tokens.
if has_kv_transfer_group(): sampled_token_ids = sampler_output.sampled_token_ids
get_kv_transfer_group().clear_connector_metadata() max_gen_len = sampled_token_ids.shape[-1]
runner.eplb_step() fix_req_ids = None
fix_sampled_token_ids = None
model_output = ZeroV1ModelRunnerOutput( fix_draft_token_ids = None
req_ids=runner.input_batch.req_ids, fix_draft_req_ids = self.last_sampled_req_ids
req_id_to_index=runner.input_batch.req_id_to_index, is_output_valid = False
sampled_token_ids=valid_sampled_token_ids, if self.speculative_config:
spec_token_ids=spec_token_ids, if max_gen_len == 1:
logprobs=logprobs_lists, valid_sampled_token_ids = sampled_token_ids.tolist()
prompt_logprobs_dict=prompt_logprobs_dict, else:
pooler_output=[], # Includes spec decode tokens.
finished_sending=finished_sending, valid_sampled_token_ids = self.rejection_sampler.parse_output(
finished_recving=finished_recving, sampled_token_ids,
num_nans_in_logits=num_nans_in_logits, self.input_batch.vocab_size,
fix_req_ids = fix_req_ids, )
fix_sampled_token_ids = fix_sampled_token_ids self.last_sampler_host_tokens = None
) self.last_sampled_token_ids = None
return model_output is_output_valid = True
\ No newline at end of file else:
# No spec decode tokens.
fix_req_ids = self.last_sampled_req_ids
if self.last_sampler_host_tokens != None:
self.last_sampler_event.synchronize()
fix_sampled_token_ids = self.last_sampler_host_tokens.tolist()
for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_recode:
self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = fix_sampled_token_ids[req_idx]
for req_idx, req_id in enumerate(fix_req_ids):
if req_id in self.requests:
req_state = self.requests[req_id]
token_idx = self.last_sampled_token_lens[req_idx]
req_state.output_token_ids[token_idx] = fix_sampled_token_ids[req_idx][0]
self.last_sampler_host_tokens = sampled_token_ids.to('cpu', non_blocking=True)
self.last_sampler_event.record()
self.last_sampled_token_ids = sampled_token_ids
valid_sampled_token_ids = np.ones(sampled_token_ids.shape, dtype=int).tolist()
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
# Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back.
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
# the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker.
self.token_ids_cpu_fix_recode.clear()
self.last_sampled_req_ids = []
self.last_sampled_token_lens = []
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
if not sampled_ids:
continue
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
end_idx = start_idx + len(sampled_ids)
assert end_idx <= self.max_model_len, (
"Sampled token IDs exceed the max model length. "
f"Total number of tokens: {end_idx} > max_model_len: "
f"{self.max_model_len}")
self.input_batch.token_ids_cpu[req_idx,
start_idx:end_idx] = sampled_ids
self.token_ids_cpu_fix_recode.append([req_idx, start_idx, end_idx])
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx
req_id = self.input_batch.req_ids[req_idx]
if req_id in self.requests:
req_state = self.requests[req_id]
self.last_sampled_req_ids.append(req_id)
self.last_sampled_token_lens.append(len(req_state.output_token_ids))
req_state.output_token_ids.extend(sampled_ids)
if not self.speculative_config:
# Speculative decoding is not enabled.
spec_token_ids = None
fix_draft_req_ids = None
else:
if self.last_draft_host_tokens is not None:
self.last_draft_event.synchronize()
fix_draft_token_ids = self.last_draft_host_tokens.tolist()
spec_token_ids = self.propose_draft_token_ids(
scheduler_output,
valid_sampled_token_ids,
sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
attn_metadata,
)
# Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()
self.eplb_step()
model_output = ZeroV1ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=valid_sampled_token_ids,
spec_token_ids=spec_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
finished_sending=finished_sending,
finished_recving=finished_recving,
num_nans_in_logits=num_nans_in_logits,
fix_req_ids = fix_req_ids,
fix_sampled_token_ids = fix_sampled_token_ids,
fix_draft_tokens_ids = fix_draft_token_ids,
fix_draft_req_ids = fix_draft_req_ids,
is_output_valid=is_output_valid
)
return model_output
\ No newline at end of file
...@@ -6,4 +6,7 @@ from vllm.v1.outputs import ModelRunnerOutput ...@@ -6,4 +6,7 @@ from vllm.v1.outputs import ModelRunnerOutput
class ZeroV1ModelRunnerOutput(ModelRunnerOutput): class ZeroV1ModelRunnerOutput(ModelRunnerOutput):
# [num_reqs] # [num_reqs]
fix_req_ids: list[str] = None fix_req_ids: list[str] = None
fix_sampled_token_ids:list[list[int]] = None fix_sampled_token_ids:list[list[int]] = None
\ No newline at end of file fix_draft_req_ids:list[list[int]] = None
fix_draft_tokens_ids:list[list[int]] = None
is_output_valid:bool = True
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment