Commit 22a95571 authored by zhuwenwen's avatar zhuwenwen
Browse files

add v1 engine + deepseek r1 mtp + zero-overhead scheduler

parent ac4cc84e
......@@ -764,10 +764,21 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
repeats = torch.from_numpy(query_lens).pin_memory().to(
block_table_tensor.device, non_blocking=True).contiguous()
if envs.VLLM_ZERO_OVERHEAD:
decode_block_table_tensor = torch.empty((self._num_decode_tokens, block_table_tensor.shape[1]),
device=block_table_tensor.device)
arange_np = np.arange(self._num_decodes)
indices_np = np.repeat(arange_np, query_lens)
indices = torch.from_numpy(indices_np).pin_memory().to(
block_table_tensor.device, non_blocking=True)
decode_block_table_tensor = block_table_tensor[indices].contiguous()
decode_seq_lens = seq_lens[indices].contiguous()
else:
decode_block_table_tensor = torch.repeat_interleave(
block_table_tensor[:num_decodes, ...],
block_table_tensor[:self._num_decodes, ...],
repeats, dim=0).contiguous()
decode_seq_lens = torch.repeat_interleave(seq_lens[:num_decodes], repeats, dim=0).contiguous()
decode_seq_lens = torch.repeat_interleave(seq_lens[:self._num_decodes], repeats, dim=0).contiguous()
seq_lens_minus = torch.from_numpy(rarange).to(torch.int32).pin_memory().to(
seq_lens.device, non_blocking=True).contiguous()
decode_seq_lens = decode_seq_lens - seq_lens_minus
......
......@@ -69,7 +69,6 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.platforms import current_platform
from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model
from vllm.zero_overhead.v1.gpu_model_runner import execute_model_sampled, zero_prepare_inputs
from ..sample.logits_processor import LogitsProcessorManager
from .utils import (bind_kv_cache, gather_mm_placeholders,
......@@ -1020,6 +1019,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# [0, 1, 2, 5, 6, 9]
target_logits_indices += arange
if envs.VLLM_ZERO_OVERHEAD:
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).pin_memory().to(
self.device, non_blocking=True)
logits_indices = torch.from_numpy(logits_indices).pin_memory().to(self.device,
non_blocking=True)
target_logits_indices = torch.from_numpy(target_logits_indices).pin_memory().to(
self.device, non_blocking=True)
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).pin_memory().to(
self.device, non_blocking=True)
else:
# TODO: Optimize the CPU -> GPU copy.
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
self.device, non_blocking=True)
......@@ -1030,6 +1039,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
self.device, non_blocking=True)
# Compute the draft token ids.
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
draft_token_ids = self.input_ids[logits_indices]
......@@ -1441,9 +1451,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
if envs.VLLM_ZERO_OVERHEAD:
zero_prepare_inputs(self, scheduler_output, input_ids)
if envs.VLLM_ENABLE_TBO and not self.use_cuda_graph:
model_output, finished_sending, finished_recving = \
tbo_split_and_execute_model(self, attn_metadata, num_input_tokens,
......@@ -1592,21 +1599,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1]
if envs.VLLM_ZERO_OVERHEAD:
return execute_model_sampled(self, max_gen_len, sampled_token_ids,
discard_sampled_tokens_req_indices, scheduler_output,
sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
attn_metadata,
logprobs_lists,
prompt_logprobs_dict,
finished_sending,
finished_recving,
num_nans_in_logits)
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = sampled_token_ids.tolist()
......
......@@ -33,6 +33,7 @@ from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase
from vllm.zero_overhead.utils import zero_overhead_stream
from vllm.zero_overhead.v1.gpu_model_runner import V1ZeroModelRunner
logger = init_logger(__name__)
......@@ -187,6 +188,10 @@ class Worker(WorkerBase):
set_random_seed(self.model_config.seed)
# Construct the model runner
if envs.VLLM_ZERO_OVERHEAD:
self.model_runner: GPUModelRunner = V1ZeroModelRunner(
self.vllm_config, self.device)
else:
self.model_runner: GPUModelRunner = GPUModelRunner(
self.vllm_config, self.device)
......
......@@ -12,11 +12,15 @@ requsets_valid_token_len = {}
def check_stop(request: Request,
max_model_len: int,
pooler_output: Optional[torch.Tensor] = None) -> bool:
pooler_output: Optional[torch.Tensor] = None,
use_valid_token_len:bool = False) -> bool:
if use_valid_token_len:
if request.request_id not in requsets_valid_token_len:
requsets_valid_token_len[request.request_id] = 0
return False
valid_output_len = requsets_valid_token_len[request.request_id]
else:
valid_output_len = request.num_output_tokens
valid_num_tokens = request.num_prompt_tokens + valid_output_len
if (valid_num_tokens >= max_model_len
or valid_output_len >= request.max_tokens):
......@@ -60,6 +64,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
spec_decoding_stats: Optional[SpecDecodingStats] = None
# fix last model out in zero overhead
if model_runner_output.fix_req_ids is not None:
for req_idx, req_id in enumerate(model_runner_output.fix_req_ids):
if req_id not in scheduler.requests:
continue
......@@ -92,7 +97,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
for num_new, output_token_id in enumerate(new_token_ids, 1):
stopped = check_stop(request, scheduler.max_model_len)
stopped = check_stop(request, scheduler.max_model_len, True)
if stopped:
kv_transfer_params = scheduler._free_request(request)
del new_token_ids[num_new:] # Trim new tokens if needed.
......@@ -102,7 +107,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
if pooler_outputs:
pooler_output = pooler_outputs[req_index]
stopped = check_stop(request, scheduler.max_model_len,
pooler_output)
pooler_output, True)
if stopped:
kv_transfer_params = scheduler._free_request(request)
......@@ -125,16 +130,6 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
request.num_nans_in_logits = num_nans_in_logits[req_id]
# Add newly generated spec token ids to the request.
if spec_token_ids is not None:
if scheduler.structured_output_manager.should_advance(request):
metadata = request.structured_output_request
# Needs to happen after new_token_ids are accepted.
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
spec_token_ids[req_index])
else:
request.spec_token_ids = spec_token_ids[req_index]
# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if new_token_ids or pooler_output is not None \
......@@ -154,16 +149,33 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
kv_transfer_params=kv_transfer_params,
num_cached_tokens=request.num_cached_tokens,
))
else:
# Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors
# fix last model out in zero overhead
if model_runner_output.fix_draft_req_ids is not None:
for req_idx, req_id in enumerate(model_runner_output.fix_draft_req_ids):
if req_id not in scheduler.requests:
continue
request = scheduler.requests[req_id]
# Add newly generated spec token ids to the request.
if model_runner_output.fix_draft_tokens_ids is not None:
if scheduler.structured_output_manager.should_advance(request):
metadata = request.structured_output_request
# Needs to happen after new_token_ids are accepted.
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
model_runner_output.fix_draft_tokens_ids[req_idx])
else:
request.spec_token_ids = model_runner_output.fix_draft_tokens_ids[req_idx]
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
# loop can be a performance bottleneck. We should do our best to avoid
# expensive operations inside the loop.
for request in scheduler.running:
if request.is_finished():
continue
req_id = request.request_id
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
if num_tokens_scheduled == 0:
......@@ -197,7 +209,6 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
if request.has_encoder_inputs:
scheduler._free_encoder_inputs(request)
stopped = False
new_logprobs = None
new_token_ids = generated_token_ids
kv_transfer_params = None
......@@ -210,19 +221,24 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
stopped = check_stop(request, scheduler.max_model_len)
# if stopped:
# kv_transfer_params = scheduler._free_request(request)
# del new_token_ids[num_new:] # Trim new tokens if needed.
# break
if model_runner_output.is_output_valid:
stopped = check_stop(request, scheduler.max_model_len,
False)
if stopped:
kv_transfer_params = scheduler._free_request(request)
del new_token_ids[num_new:] # Trim new tokens if needed.
break
pooler_output = None
if pooler_outputs:
if model_runner_output.is_output_valid:
pooler_output = pooler_outputs[req_index]
stopped = check_stop(request, scheduler.max_model_len,
pooler_output)
# if stopped:
# kv_transfer_params = scheduler._free_request(request)
pooler_output,
False)
if stopped:
kv_transfer_params = scheduler._free_request(request)
# Extract sample logprobs if needed.
if request.sampling_params is not None \
......@@ -253,6 +269,27 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
else:
request.spec_token_ids = spec_token_ids[req_index]
if model_runner_output.is_output_valid:
# # Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if new_token_ids or pooler_output is not None \
or kv_transfer_params:
# Add EngineCoreOutput for this Request.
outputs[request.client_index].append(
EngineCoreOutput(
request_id=req_id,
new_token_ids=new_token_ids,
finish_reason=request.get_finished_reason(),
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
pooling_output=pooler_output,
stop_reason=request.stop_reason,
events=request.take_events(),
kv_transfer_params=kv_transfer_params,
num_cached_tokens=request.num_cached_tokens,
))
if not stopped:
new_running.append(request)
......
This diff is collapsed.
......@@ -7,3 +7,6 @@ class ZeroV1ModelRunnerOutput(ModelRunnerOutput):
# [num_reqs]
fix_req_ids: list[str] = None
fix_sampled_token_ids:list[list[int]] = None
fix_draft_req_ids:list[list[int]] = None
fix_draft_tokens_ids:list[list[int]] = None
is_output_valid:bool = True
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment