Commit 5dcaac2f authored by lizhigong's avatar lizhigong
Browse files

add v1 engine + deepseek r1 mtp + zero-overhead scheduler

parent 59e80222
...@@ -651,10 +651,22 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -651,10 +651,22 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
repeats = torch.from_numpy(query_lens).pin_memory().to( repeats = torch.from_numpy(query_lens).pin_memory().to(
block_table_tensor.device, non_blocking=True).contiguous() block_table_tensor.device, non_blocking=True).contiguous()
decode_block_table_tensor = torch.repeat_interleave(
block_table_tensor[:self._num_decodes, ...], if envs.VLLM_ZERO_OVERHEAD:
repeats, dim=0).contiguous() decode_block_table_tensor = torch.empty((self._num_decode_tokens, block_table_tensor.shape[1]),
decode_seq_lens = torch.repeat_interleave(seq_lens[:self._num_decodes], repeats, dim=0).contiguous() device=block_table_tensor.device)
arange_np = np.arange(self._num_decodes)
indices_np = np.repeat(arange_np, query_lens)
indices = torch.from_numpy(indices_np).pin_memory().to(
block_table_tensor.device, non_blocking=True)
decode_block_table_tensor = block_table_tensor[indices].contiguous()
decode_seq_lens = seq_lens[indices].contiguous()
else:
decode_block_table_tensor = torch.repeat_interleave(
block_table_tensor[:self._num_decodes, ...],
repeats, dim=0).contiguous()
decode_seq_lens = torch.repeat_interleave(seq_lens[:self._num_decodes], repeats, dim=0).contiguous()
seq_lens_minus = torch.from_numpy(rarange).to(torch.int32).pin_memory().to( seq_lens_minus = torch.from_numpy(rarange).to(torch.int32).pin_memory().to(
seq_lens.device, non_blocking=True).contiguous() seq_lens.device, non_blocking=True).contiguous()
decode_seq_lens = decode_seq_lens - seq_lens_minus decode_seq_lens = decode_seq_lens - seq_lens_minus
......
...@@ -69,7 +69,6 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch ...@@ -69,7 +69,6 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model
from vllm.zero_overhead.v1.gpu_model_runner import execute_model_sampled, zero_prepare_inputs
from ..sample.logits_processor import LogitsProcessorManager from ..sample.logits_processor import LogitsProcessorManager
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
...@@ -954,15 +953,25 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -954,15 +953,25 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# [0, 1, 2, 5, 6, 9] # [0, 1, 2, 5, 6, 9]
target_logits_indices += arange target_logits_indices += arange
# TODO: Optimize the CPU -> GPU copy. if envs.VLLM_ZERO_OVERHEAD:
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to( cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).pin_memory().to(
self.device, non_blocking=True) self.device, non_blocking=True)
logits_indices = torch.from_numpy(logits_indices).to(self.device, logits_indices = torch.from_numpy(logits_indices).pin_memory().to(self.device,
non_blocking=True) non_blocking=True)
target_logits_indices = torch.from_numpy(target_logits_indices).to( target_logits_indices = torch.from_numpy(target_logits_indices).pin_memory().to(
self.device, non_blocking=True) self.device, non_blocking=True)
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to( bonus_logits_indices = torch.from_numpy(bonus_logits_indices).pin_memory().to(
self.device, non_blocking=True) self.device, non_blocking=True)
else:
# TODO: Optimize the CPU -> GPU copy.
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
self.device, non_blocking=True)
logits_indices = torch.from_numpy(logits_indices).to(self.device,
non_blocking=True)
target_logits_indices = torch.from_numpy(target_logits_indices).to(
self.device, non_blocking=True)
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
self.device, non_blocking=True)
# Compute the draft token ids. # Compute the draft token ids.
# draft_token_indices: [ 1, 2, 3, 105, 106, 208] # draft_token_indices: [ 1, 2, 3, 105, 106, 208]
...@@ -1363,8 +1372,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1363,8 +1372,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# compiled with full CUDA graphs, we have to skip them entirely. # compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
if envs.VLLM_ZERO_OVERHEAD:
zero_prepare_inputs(self, scheduler_output, input_ids)
if envs.VLLM_ENABLE_TBO and not self.use_cuda_graph: if envs.VLLM_ENABLE_TBO and not self.use_cuda_graph:
model_output, finished_sending, finished_recving = \ model_output, finished_sending, finished_recving = \
tbo_split_and_execute_model(self, attn_metadata, num_input_tokens, tbo_split_and_execute_model(self, attn_metadata, num_input_tokens,
...@@ -1506,21 +1513,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1506,21 +1513,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
sampled_token_ids = sampler_output.sampled_token_ids sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1] max_gen_len = sampled_token_ids.shape[-1]
if envs.VLLM_ZERO_OVERHEAD:
return execute_model_sampled(self, max_gen_len, sampled_token_ids,
discard_sampled_tokens_req_indices, scheduler_output,
sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
attn_metadata,
logprobs_lists,
prompt_logprobs_dict,
finished_sending,
finished_recving,
num_nans_in_logits)
if max_gen_len == 1: if max_gen_len == 1:
# No spec decode tokens. # No spec decode tokens.
valid_sampled_token_ids = sampled_token_ids.tolist() valid_sampled_token_ids = sampled_token_ids.tolist()
......
...@@ -29,6 +29,7 @@ from vllm.v1.utils import report_usage_stats ...@@ -29,6 +29,7 @@ from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase from vllm.v1.worker.worker_base import WorkerBase
from vllm.zero_overhead.utils import zero_overhead_stream from vllm.zero_overhead.utils import zero_overhead_stream
from vllm.zero_overhead.v1.gpu_model_runner import V1ZeroModelRunner
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -163,8 +164,12 @@ class Worker(WorkerBase): ...@@ -163,8 +164,12 @@ class Worker(WorkerBase):
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
# Construct the model runner # Construct the model runner
self.model_runner: GPUModelRunner = GPUModelRunner( if envs.VLLM_ZERO_OVERHEAD:
self.vllm_config, self.device) self.model_runner: GPUModelRunner = V1ZeroModelRunner(
self.vllm_config, self.device)
else:
self.model_runner: GPUModelRunner = GPUModelRunner(
self.vllm_config, self.device)
if self.rank == 0: if self.rank == 0:
# If usage stat is enabled, collect relevant info. # If usage stat is enabled, collect relevant info.
......
...@@ -14,11 +14,15 @@ requsets_valid_token_len = {} ...@@ -14,11 +14,15 @@ requsets_valid_token_len = {}
def check_stop(request: Request, def check_stop(request: Request,
max_model_len: int, max_model_len: int,
pooler_output: Optional[torch.Tensor] = None) -> bool: pooler_output: Optional[torch.Tensor] = None,
if request.request_id not in requsets_valid_token_len: use_valid_token_len:bool = False) -> bool:
requsets_valid_token_len[request.request_id] = 0 if use_valid_token_len:
return False if request.request_id not in requsets_valid_token_len:
valid_output_len = requsets_valid_token_len[request.request_id] requsets_valid_token_len[request.request_id] = 0
return False
valid_output_len = requsets_valid_token_len[request.request_id]
else:
valid_output_len = request.num_output_tokens
valid_num_tokens = request.num_prompt_tokens + valid_output_len valid_num_tokens = request.num_prompt_tokens + valid_output_len
if (valid_num_tokens >= max_model_len if (valid_num_tokens >= max_model_len
or valid_output_len >= request.max_tokens): or valid_output_len >= request.max_tokens):
...@@ -62,110 +66,119 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -62,110 +66,119 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
spec_decoding_stats: Optional[SpecDecodingStats] = None spec_decoding_stats: Optional[SpecDecodingStats] = None
# fix last model out in zero overhead # fix last model out in zero overhead
for req_idx, req_id in enumerate(model_runner_output.fix_req_ids): if model_runner_output.fix_req_ids is not None:
if req_id not in scheduler.requests: for req_idx, req_id in enumerate(model_runner_output.fix_req_ids):
continue if req_id not in scheduler.requests:
request = scheduler.requests[req_id] continue
generated_token_ids = model_runner_output.fix_sampled_token_ids[req_idx] request = scheduler.requests[req_id]
if req_id not in requsets_valid_token_len: generated_token_ids = model_runner_output.fix_sampled_token_ids[req_idx]
requsets_valid_token_len[req_id] = 0 if req_id not in requsets_valid_token_len:
valid_output_len = requsets_valid_token_len[req_id] requsets_valid_token_len[req_id] = 0
fix_offset = valid_output_len - request.num_output_tokens valid_output_len = requsets_valid_token_len[req_id]
if isinstance(generated_token_ids, int): fix_offset = valid_output_len - request.num_output_tokens
request._output_token_ids[fix_offset] = generated_token_ids if isinstance(generated_token_ids, int):
request._all_token_ids[fix_offset] = generated_token_ids request._output_token_ids[fix_offset] = generated_token_ids
requsets_valid_token_len[req_id] += 1 request._all_token_ids[fix_offset] = generated_token_ids
else: requsets_valid_token_len[req_id] += 1
valid_output_end = valid_output_len + len(generated_token_ids) - request.num_output_tokens
if valid_output_end == 0:
request._output_token_ids[fix_offset : ] = generated_token_ids
request._all_token_ids[fix_offset : ] = generated_token_ids
else: else:
request._output_token_ids[fix_offset : valid_output_end] = generated_token_ids valid_output_end = valid_output_len + len(generated_token_ids) - request.num_output_tokens
request._all_token_ids[fix_offset : valid_output_end] = generated_token_ids if valid_output_end == 0:
requsets_valid_token_len[req_id] += len(generated_token_ids) request._output_token_ids[fix_offset : ] = generated_token_ids
request._all_token_ids[fix_offset : ] = generated_token_ids
else:
request._output_token_ids[fix_offset : valid_output_end] = generated_token_ids
request._all_token_ids[fix_offset : valid_output_end] = generated_token_ids
requsets_valid_token_len[req_id] += len(generated_token_ids)
stopped = False
new_logprobs = None
new_token_ids = generated_token_ids
kv_transfer_params = None
stopped = False # Check for stop and update request state.
new_logprobs = None # This must be called before we make the EngineCoreOutput.
new_token_ids = generated_token_ids for num_new, output_token_id in enumerate(new_token_ids, 1):
kv_transfer_params = None stopped = check_stop(request, scheduler.max_model_len, True)
if stopped:
# Check for stop and update request state. kv_transfer_params = scheduler._free_request(request)
# This must be called before we make the EngineCoreOutput. del new_token_ids[num_new:] # Trim new tokens if needed.
for num_new, output_token_id in enumerate(new_token_ids, 1): break
stopped = check_stop(request, scheduler.max_model_len)
if stopped: pooler_output = None
kv_transfer_params = scheduler._free_request(request) if pooler_outputs:
del new_token_ids[num_new:] # Trim new tokens if needed. pooler_output = pooler_outputs[req_index]
break stopped = check_stop(request, scheduler.max_model_len,
pooler_output, True)
pooler_output = None if stopped:
if pooler_outputs: kv_transfer_params = scheduler._free_request(request)
pooler_output = pooler_outputs[req_index]
stopped = check_stop(request, scheduler.max_model_len, # Extract sample logprobs if needed.
pooler_output) if request.sampling_params is not None \
if stopped: and request.sampling_params.logprobs is not None and logprobs:
kv_transfer_params = scheduler._free_request(request) # NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
# Extract sample logprobs if needed. new_logprobs = logprobs.slice(req_index, req_index + 1)
if request.sampling_params is not None \
and request.sampling_params.logprobs is not None and logprobs: if new_token_ids and scheduler.structured_output_manager.should_advance(
# NOTE: once we support N tokens per step (spec decode), request):
# the outer lists can be of length > 1. # NOTE: structured_output_request
new_logprobs = logprobs.slice(req_index, req_index + 1) # should not be None if use_structured_output, we have
# check above, so safe to ignore type warning
if new_token_ids and scheduler.structured_output_manager.should_advance( request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
request): req_id, new_token_ids)
# NOTE: structured_output_request
# should not be None if use_structured_output, we have # spec_token_ids comes from the model runner output
# check above, so safe to ignore type warning if num_nans_in_logits is not None and req_id in num_nans_in_logits:
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] request.num_nans_in_logits = num_nans_in_logits[req_id]
req_id, new_token_ids)
# Get prompt logprobs for this request.
# spec_token_ids comes from the model runner output prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if num_nans_in_logits is not None and req_id in num_nans_in_logits: if new_token_ids or pooler_output is not None \
request.num_nans_in_logits = num_nans_in_logits[req_id] or kv_transfer_params:
# Add EngineCoreOutput for this Request.
outputs[request.client_index].append(
EngineCoreOutput(
request_id=req_id,
new_token_ids=new_token_ids,
finish_reason=request.get_finished_reason(),
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
pooling_output=pooler_output,
stop_reason=request.stop_reason,
events=request.take_events(),
kv_transfer_params=kv_transfer_params,
num_cached_tokens=request.num_cached_tokens,
))
# Add newly generated spec token ids to the request.
if spec_token_ids is not None:
if scheduler.structured_output_manager.should_advance(request):
metadata = request.structured_output_request
# Needs to happen after new_token_ids are accepted.
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
spec_token_ids[req_index])
else: else:
request.spec_token_ids = spec_token_ids[req_index] # Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors
# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if new_token_ids or pooler_output is not None \
or kv_transfer_params:
# Add EngineCoreOutput for this Request.
outputs[request.client_index].append(
EngineCoreOutput(
request_id=req_id,
new_token_ids=new_token_ids,
finish_reason=request.get_finished_reason(),
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
pooling_output=pooler_output,
stop_reason=request.stop_reason,
events=request.take_events(),
kv_transfer_params=kv_transfer_params,
num_cached_tokens=request.num_cached_tokens,
))
else:
# Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors
# fix last model out in zero overhead
if model_runner_output.fix_draft_req_ids is not None:
for req_idx, req_id in enumerate(model_runner_output.fix_draft_req_ids):
if req_id not in scheduler.requests:
continue
request = scheduler.requests[req_id]
# Add newly generated spec token ids to the request.
if model_runner_output.fix_draft_tokens_ids is not None:
if scheduler.structured_output_manager.should_advance(request):
metadata = request.structured_output_request
# Needs to happen after new_token_ids are accepted.
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
model_runner_output.fix_draft_tokens_ids[req_idx])
else:
request.spec_token_ids = model_runner_output.fix_draft_tokens_ids[req_idx]
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
# loop can be a performance bottleneck. We should do our best to avoid # loop can be a performance bottleneck. We should do our best to avoid
# expensive operations inside the loop. # expensive operations inside the loop.
for request in scheduler.running: for request in scheduler.running:
if request.is_finished():
continue
req_id = request.request_id req_id = request.request_id
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
if num_tokens_scheduled == 0: if num_tokens_scheduled == 0:
...@@ -199,7 +212,6 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -199,7 +212,6 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
if request.has_encoder_inputs: if request.has_encoder_inputs:
scheduler._free_encoder_inputs(request) scheduler._free_encoder_inputs(request)
stopped = False
new_logprobs = None new_logprobs = None
new_token_ids = generated_token_ids new_token_ids = generated_token_ids
kv_transfer_params = None kv_transfer_params = None
...@@ -212,19 +224,24 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -212,19 +224,24 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
# Check for stop and update request state. # Check for stop and update request state.
# This must be called before we make the EngineCoreOutput. # This must be called before we make the EngineCoreOutput.
stopped = check_stop(request, scheduler.max_model_len)
# if stopped: if model_runner_output.is_output_valid:
# kv_transfer_params = scheduler._free_request(request) stopped = check_stop(request, scheduler.max_model_len,
# del new_token_ids[num_new:] # Trim new tokens if needed. False)
# break if stopped:
kv_transfer_params = scheduler._free_request(request)
del new_token_ids[num_new:] # Trim new tokens if needed.
break
pooler_output = None pooler_output = None
if pooler_outputs: if pooler_outputs:
pooler_output = pooler_outputs[req_index] if model_runner_output.is_output_valid:
stopped = check_stop(request, scheduler.max_model_len, pooler_output = pooler_outputs[req_index]
pooler_output) stopped = check_stop(request, scheduler.max_model_len,
# if stopped: pooler_output,
# kv_transfer_params = scheduler._free_request(request) False)
if stopped:
kv_transfer_params = scheduler._free_request(request)
# Extract sample logprobs if needed. # Extract sample logprobs if needed.
if request.sampling_params is not None \ if request.sampling_params is not None \
...@@ -255,6 +272,27 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -255,6 +272,27 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
else: else:
request.spec_token_ids = spec_token_ids[req_index] request.spec_token_ids = spec_token_ids[req_index]
if model_runner_output.is_output_valid:
# # Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if new_token_ids or pooler_output is not None \
or kv_transfer_params:
# Add EngineCoreOutput for this Request.
outputs[request.client_index].append(
EngineCoreOutput(
request_id=req_id,
new_token_ids=new_token_ids,
finish_reason=request.get_finished_reason(),
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
pooling_output=pooler_output,
stop_reason=request.stop_reason,
events=request.take_events(),
kv_transfer_params=kv_transfer_params,
num_cached_tokens=request.num_cached_tokens,
))
if not stopped: if not stopped:
new_running.append(request) new_running.append(request)
......
This diff is collapsed.
...@@ -8,4 +8,7 @@ from vllm.v1.outputs import ModelRunnerOutput ...@@ -8,4 +8,7 @@ from vllm.v1.outputs import ModelRunnerOutput
class ZeroV1ModelRunnerOutput(ModelRunnerOutput): class ZeroV1ModelRunnerOutput(ModelRunnerOutput):
# [num_reqs] # [num_reqs]
fix_req_ids: list[str] = None fix_req_ids: list[str] = None
fix_sampled_token_ids:list[list[int]] = None fix_sampled_token_ids:list[list[int]] = None
\ No newline at end of file fix_draft_req_ids:list[list[int]] = None
fix_draft_tokens_ids:list[list[int]] = None
is_output_valid:bool = True
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment