Commit a6bf968b authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-lzg' into 'v0.9.2-dev'

add v1 engine zero overhead

See merge request dcutoolkit/deeplearing/vllm!177
parents 74a444b5 59e80222
......@@ -61,7 +61,6 @@ def split_scheduler_output(runner, scheduler_output:SchedulerOutput):
else:
new_req_data_right.append(new_req)
#print('###scheduler_output.scheduled_cached_reqs', scheduler_output.scheduled_cached_reqs)
cached_reqs_left = CachedRequestData.make_empty()
cached_reqs_right = CachedRequestData.make_empty()
for req_idx, req_id in enumerate(scheduler_output.scheduled_cached_reqs.req_ids):
......
......@@ -15,6 +15,8 @@ from logging import DEBUG
from typing import Any, Callable, Optional, TypeVar, Union
import msgspec
from vllm import envs
from vllm.zero_overhead.v1.core import engine_core_step
import zmq
from vllm.config import ParallelConfig, VllmConfig
......@@ -226,6 +228,8 @@ class EngineCore:
Returns tuple of outputs and a flag indicating whether the model
was executed.
"""
if envs.VLLM_ZERO_OVERHEAD:
return engine_core_step(self)
# Check for any requests remaining in the scheduler - unfinished,
# or finished and not yet removed from the batch.
......@@ -235,7 +239,6 @@ class EngineCore:
model_output = self.execute_model(scheduler_output)
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output) # type: ignore
return (engine_core_outputs,
scheduler_output.total_num_scheduled_tokens > 0)
......
......@@ -69,6 +69,7 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.platforms import current_platform
from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model
from vllm.zero_overhead.v1.gpu_model_runner import execute_model_sampled, zero_prepare_inputs
from ..sample.logits_processor import LogitsProcessorManager
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
......@@ -1362,6 +1363,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
if envs.VLLM_ZERO_OVERHEAD:
zero_prepare_inputs(self, scheduler_output, input_ids)
if envs.VLLM_ENABLE_TBO and not self.use_cuda_graph:
model_output, finished_sending, finished_recving = \
tbo_split_and_execute_model(self, attn_metadata, num_input_tokens,
......@@ -1502,6 +1505,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1]
if envs.VLLM_ZERO_OVERHEAD:
return execute_model_sampled(self, max_gen_len, sampled_token_ids,
discard_sampled_tokens_req_indices, scheduler_output,
sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
attn_metadata,
logprobs_lists,
prompt_logprobs_dict,
finished_sending,
finished_recving,
num_nans_in_logits)
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = sampled_token_ids.tolist()
......
......@@ -28,6 +28,7 @@ from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase
from vllm.zero_overhead.utils import zero_overhead_stream
logger = init_logger(__name__)
......@@ -304,7 +305,12 @@ class Worker(WorkerBase):
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group()))
if envs.VLLM_ZERO_OVERHEAD:
use_stream = zero_overhead_stream(self.device)
with torch.cuda.stream(use_stream):
output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)
else:
output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)
parallel_config = self.vllm_config.parallel_config
......
import torch
from collections import defaultdict
from typing import Optional
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput
requsets_valid_token_len = {}
def check_stop(request: Request,
max_model_len: int,
pooler_output: Optional[torch.Tensor] = None) -> bool:
if request.request_id not in requsets_valid_token_len:
requsets_valid_token_len[request.request_id] = 0
return False
valid_output_len = requsets_valid_token_len[request.request_id]
valid_num_tokens = request.num_prompt_tokens + valid_output_len
if (valid_num_tokens >= max_model_len
or valid_output_len >= request.max_tokens):
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
return True
if request.pooling_params:
if pooler_output is not None:
request.status = RequestStatus.FINISHED_STOPPED
return True
return False
sampling_params = request.sampling_params
assert sampling_params is not None
last_token_id = request.output_token_ids[valid_output_len - 1]
if (not sampling_params.ignore_eos
and last_token_id == request.eos_token_id):
request.status = RequestStatus.FINISHED_STOPPED
return True
if last_token_id in (sampling_params.stop_token_ids or ()):
request.status = RequestStatus.FINISHED_STOPPED
request.stop_reason = last_token_id
return True
return False
def zero_overhead_update_from_output(scheduler:Scheduler,
scheduler_output: SchedulerOutput,
model_runner_output: ZeroV1ModelRunnerOutput):
global requsets_valid_token_len
sampled_token_ids = model_runner_output.sampled_token_ids
spec_token_ids = model_runner_output.spec_token_ids
logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
pooler_outputs = model_runner_output.pooler_output
num_nans_in_logits = model_runner_output.num_nans_in_logits
new_running: list[Request] = []
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
spec_decoding_stats: Optional[SpecDecodingStats] = None
# fix last model out in zero overhead
for req_idx, req_id in enumerate(model_runner_output.fix_req_ids):
if req_id not in scheduler.requests:
continue
request = scheduler.requests[req_id]
generated_token_ids = model_runner_output.fix_sampled_token_ids[req_idx]
if req_id not in requsets_valid_token_len:
requsets_valid_token_len[req_id] = 0
valid_output_len = requsets_valid_token_len[req_id]
fix_offset = valid_output_len - request.num_output_tokens
if isinstance(generated_token_ids, int):
request._output_token_ids[fix_offset] = generated_token_ids
request._all_token_ids[fix_offset] = generated_token_ids
requsets_valid_token_len[req_id] += 1
else:
valid_output_end = valid_output_len + len(generated_token_ids) - request.num_output_tokens
if valid_output_end == 0:
request._output_token_ids[fix_offset : ] = generated_token_ids
request._all_token_ids[fix_offset : ] = generated_token_ids
else:
request._output_token_ids[fix_offset : valid_output_end] = generated_token_ids
request._all_token_ids[fix_offset : valid_output_end] = generated_token_ids
requsets_valid_token_len[req_id] += len(generated_token_ids)
stopped = False
new_logprobs = None
new_token_ids = generated_token_ids
kv_transfer_params = None
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
for num_new, output_token_id in enumerate(new_token_ids, 1):
stopped = check_stop(request, scheduler.max_model_len)
if stopped:
kv_transfer_params = scheduler._free_request(request)
del new_token_ids[num_new:] # Trim new tokens if needed.
break
pooler_output = None
if pooler_outputs:
pooler_output = pooler_outputs[req_index]
stopped = check_stop(request, scheduler.max_model_len,
pooler_output)
if stopped:
kv_transfer_params = scheduler._free_request(request)
# Extract sample logprobs if needed.
if request.sampling_params is not None \
and request.sampling_params.logprobs is not None and logprobs:
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_index, req_index + 1)
if new_token_ids and scheduler.structured_output_manager.should_advance(
request):
# NOTE: structured_output_request
# should not be None if use_structured_output, we have
# check above, so safe to ignore type warning
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
req_id, new_token_ids)
# spec_token_ids comes from the model runner output
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
request.num_nans_in_logits = num_nans_in_logits[req_id]
# Add newly generated spec token ids to the request.
if spec_token_ids is not None:
if scheduler.structured_output_manager.should_advance(request):
metadata = request.structured_output_request
# Needs to happen after new_token_ids are accepted.
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
spec_token_ids[req_index])
else:
request.spec_token_ids = spec_token_ids[req_index]
# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if new_token_ids or pooler_output is not None \
or kv_transfer_params:
# Add EngineCoreOutput for this Request.
outputs[request.client_index].append(
EngineCoreOutput(
request_id=req_id,
new_token_ids=new_token_ids,
finish_reason=request.get_finished_reason(),
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
pooling_output=pooler_output,
stop_reason=request.stop_reason,
events=request.take_events(),
kv_transfer_params=kv_transfer_params,
num_cached_tokens=request.num_cached_tokens,
))
else:
# Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
# loop can be a performance bottleneck. We should do our best to avoid
# expensive operations inside the loop.
for request in scheduler.running:
req_id = request.request_id
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
if num_tokens_scheduled == 0:
# The request was not scheduled in this step.
new_running.append(request)
continue
req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[
req_index] if sampled_token_ids else []
scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
if scheduled_spec_token_ids:
# num_computed_tokens represents the number of tokens
# processed in the current step, considering scheduled
# tokens and rejections. If some tokens are rejected,
# num_computed_tokens is decreased by the number of rejected
# tokens, where is given by:
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 -
len(generated_token_ids))
request.num_computed_tokens -= num_tokens_rejected
spec_decoding_stats = scheduler.make_spec_decoding_stats(
spec_decoding_stats,
num_draft_tokens=len(scheduled_spec_token_ids),
num_accepted_tokens=len(generated_token_ids) - 1)
# NOTE(woosuk): This has to be executed after updating
# `request.num_computed_tokens`.
if request.has_encoder_inputs:
scheduler._free_encoder_inputs(request)
stopped = False
new_logprobs = None
new_token_ids = generated_token_ids
kv_transfer_params = None
# Append generated tokens and check for stop. Note that if
# a request is still being prefilled, we expect the model runner
# to return empty token ids for the request.
for num_new, output_token_id in enumerate(new_token_ids, 1):
request.append_output_token_ids(output_token_id)
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
stopped = check_stop(request, scheduler.max_model_len)
# if stopped:
# kv_transfer_params = scheduler._free_request(request)
# del new_token_ids[num_new:] # Trim new tokens if needed.
# break
pooler_output = None
if pooler_outputs:
pooler_output = pooler_outputs[req_index]
stopped = check_stop(request, scheduler.max_model_len,
pooler_output)
# if stopped:
# kv_transfer_params = scheduler._free_request(request)
# Extract sample logprobs if needed.
if request.sampling_params is not None \
and request.sampling_params.logprobs is not None and logprobs:
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_index, req_index + 1)
if new_token_ids and scheduler.structured_output_manager.should_advance(
request):
# NOTE: structured_output_request
# should not be None if use_structured_output, we have
# check above, so safe to ignore type warning
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
req_id, new_token_ids)
# spec_token_ids comes from the model runner output
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
request.num_nans_in_logits = num_nans_in_logits[req_id]
# Add newly generated spec token ids to the request.
if spec_token_ids is not None:
if scheduler.structured_output_manager.should_advance(request):
metadata = request.structured_output_request
# Needs to happen after new_token_ids are accepted.
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
spec_token_ids[req_index])
else:
request.spec_token_ids = spec_token_ids[req_index]
if not stopped:
new_running.append(request)
scheduler.running = new_running
# KV Connector: update state for finished KV Transfers.
scheduler._update_from_kv_xfer_finished(model_runner_output)
# Create EngineCoreOutputs for all clients that have requests with
# outputs in this step.
engine_core_outputs = {
client_index: EngineCoreOutputs(outputs=outs)
for client_index, outs in outputs.items()
}
finished_req_ids = scheduler.finished_req_ids_dict
if finished_req_ids:
# Include ids of requests that finished since last outputs
# were sent.
for client_index, finished_set in finished_req_ids.items():
# Set finished request set in EngineCoreOutputs for this client.
if (eco := engine_core_outputs.get(client_index)) is not None:
eco.finished_requests = finished_set
else:
engine_core_outputs[client_index] = EngineCoreOutputs(
finished_requests=finished_set)
finished_req_ids.clear()
if engine_core_outputs:
# Return stats to only one of the front-ends.
next(iter(engine_core_outputs.values())).scheduler_stats = (
scheduler.make_stats(spec_decoding_stats))
return engine_core_outputs
def engine_core_step(core) -> tuple[dict[int, EngineCoreOutputs], bool]:
"""Schedule, execute, and make output.
Returns tuple of outputs and a flag indicating whether the model
was executed.
"""
# Check for any requests remaining in the scheduler - unfinished,
# or finished and not yet removed from the batch.
if not core.scheduler.has_requests():
return {}, False
scheduler_output = core.scheduler.schedule()
model_output = core.execute_model(scheduler_output)
if isinstance(model_output, ZeroV1ModelRunnerOutput):
engine_core_outputs = zero_overhead_update_from_output(core.scheduler,
scheduler_output, model_output) # type: ignore
else:
engine_core_outputs = core.scheduler.update_from_output(
scheduler_output, model_output) # type: ignore
return (engine_core_outputs,
scheduler_output.total_num_scheduled_tokens > 0)
\ No newline at end of file
import torch
import numpy as np
from vllm.distributed.kv_transfer.kv_transfer_state import get_kv_transfer_group, has_kv_transfer_group
from vllm.distributed.parallel_state import get_tp_group
from vllm.utils import async_tensor_h2d
from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput
from vllm.profiler.prof import profile
class V1ZeroModelRunner():
def __init__(self):
self.last_sampled_token_ids = None
self.last_sampled_req_ids = []
self.last_sampled_token_lens = []
self.last_sampler_event = torch.cuda.Event(enable_timing=False)
self.last_sampler_host_tokens = None
self.token_ids_cpu_fix_recode = []
def set_last_sampled_token_ids(self, sampled_token_ids):
self.last_sampled_token_ids = sampled_token_ids
self.last_sampled_req_ids = []
self.last_sampled_token_lens = []
v1_zero_overhead = V1ZeroModelRunner()
def zero_prepare_inputs(runner, scheduler_output, input_ids):
req_ids = runner.input_batch.req_ids
update_req_indices = []
input_ids_indices = []
token_idx = 0
if v1_zero_overhead.last_sampled_token_ids is None:
return
sampled_tokens_num = v1_zero_overhead.last_sampled_token_ids.shape[1]
for req_id in req_ids:
if req_id in v1_zero_overhead.last_sampled_req_ids:
req_idx = v1_zero_overhead.last_sampled_req_ids.index(req_id) * sampled_tokens_num
update_req_indices.append(req_idx)
input_ids_indices.append(token_idx)
token_idx += scheduler_output.num_scheduled_tokens[req_id]
if len(update_req_indices) > 0:
update_req_indices_tensor = async_tensor_h2d(update_req_indices, torch.int32,
runner.device,
True)
input_ids_indices_tensor = async_tensor_h2d(input_ids_indices, torch.int32,
runner.device,
True)
last_sampled_token_ids = v1_zero_overhead.last_sampled_token_ids.flatten()
for i in range(sampled_tokens_num):
input_ids[input_ids_indices_tensor + i] = last_sampled_token_ids[update_req_indices_tensor + i]
def execute_model_sampled(runner, max_gen_len, sampled_token_ids,
discard_sampled_tokens_req_indices, scheduler_output,
sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
attn_metadata,
logprobs_lists,
prompt_logprobs_dict,
finished_sending,
finished_recving,
num_nans_in_logits
):
fix_req_ids = None
fix_sampled_token_ids = None
if max_gen_len == 1:
# No spec decode tokens.
if v1_zero_overhead.last_sampler_host_tokens != None:
v1_zero_overhead.last_sampler_event.synchronize()
fix_sampled_token_ids = v1_zero_overhead.last_sampler_host_tokens.tolist()
for req_idx, start_idx, end_idx in v1_zero_overhead.token_ids_cpu_fix_recode:
runner.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = fix_sampled_token_ids[req_idx]
fix_req_ids = v1_zero_overhead.last_sampled_req_ids
for req_idx, req_id in enumerate(fix_req_ids):
if req_id in runner.requests:
req_state = runner.requests[req_id]
token_idx = v1_zero_overhead.last_sampled_token_lens[req_idx]
req_state.output_token_ids[token_idx] = fix_sampled_token_ids[req_idx][0]
v1_zero_overhead.last_sampler_host_tokens = sampled_token_ids.to('cpu', non_blocking=True)
v1_zero_overhead.last_sampler_event.record()
v1_zero_overhead.set_last_sampled_token_ids(sampled_token_ids)
valid_sampled_token_ids = np.ones(sampled_token_ids.shape, dtype=int).tolist()
else:
# Includes spec decode tokens.
valid_sampled_token_ids = runner.rejection_sampler.parse_output(
sampled_token_ids,
runner.input_batch.vocab_size,
)
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
# Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back.
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
# the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker.
v1_zero_overhead.token_ids_cpu_fix_recode.clear()
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
if not sampled_ids:
continue
start_idx = runner.input_batch.num_tokens_no_spec[req_idx]
end_idx = start_idx + len(sampled_ids)
assert end_idx <= runner.max_model_len, (
"Sampled token IDs exceed the max model length. "
f"Total number of tokens: {end_idx} > max_model_len: "
f"{runner.max_model_len}")
runner.input_batch.token_ids_cpu[req_idx,
start_idx:end_idx] = sampled_ids
v1_zero_overhead.token_ids_cpu_fix_recode.append([req_idx, start_idx, end_idx])
runner.input_batch.num_tokens_no_spec[req_idx] = end_idx
runner.input_batch.num_tokens[req_idx] = end_idx
req_id = runner.input_batch.req_ids[req_idx]
if req_id in runner.requests:
req_state = runner.requests[req_id]
v1_zero_overhead.last_sampled_req_ids.append(req_id)
v1_zero_overhead.last_sampled_token_lens.append(len(req_state.output_token_ids))
req_state.output_token_ids.extend(sampled_ids)
if not runner.speculative_config:
# Speculative decoding is not enabled.
spec_token_ids = None
else:
spec_token_ids = runner.propose_draft_token_ids(
scheduler_output,
valid_sampled_token_ids,
sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
attn_metadata,
)
# Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()
runner.eplb_step()
model_output = ZeroV1ModelRunnerOutput(
req_ids=runner.input_batch.req_ids,
req_id_to_index=runner.input_batch.req_id_to_index,
sampled_token_ids=valid_sampled_token_ids,
spec_token_ids=spec_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
finished_sending=finished_sending,
finished_recving=finished_recving,
num_nans_in_logits=num_nans_in_logits,
fix_req_ids = fix_req_ids,
fix_sampled_token_ids = fix_sampled_token_ids
)
return model_output
\ No newline at end of file
from dataclasses import dataclass
from vllm.v1.outputs import ModelRunnerOutput
@dataclass
class ZeroV1ModelRunnerOutput(ModelRunnerOutput):
# [num_reqs]
fix_req_ids: list[str] = None
fix_sampled_token_ids:list[list[int]] = None
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment