Commit 4ff58b66 authored by lizhigong's avatar lizhigong
Browse files

debug v0 zero overhead schedule

parent 54294854
......@@ -239,8 +239,11 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.from_numpy(input_block_tables).to(
device, non_blocking=True)
# block_tables = torch.from_numpy(input_block_tables).to(
# device, non_blocking=True)
block_tables = async_tensor_h2d(input_block_tables.tolist(), torch.int32,
device, self.runner.pin_memory)
else:
block_tables = make_tensor_with_pad(
self.block_tables,
......
......@@ -1450,6 +1450,8 @@ class LLM:
if use_tqdm:
pbar.close()
if is_zero_overhead():
self.llm_engine.finish_thread()
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# its previous requests.
......
......@@ -21,7 +21,6 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput, Logprob,
PromptLogprobs, SampleLogprobs, SequenceOutput)
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
from vllm.zero_overhead.v0.sampler import ZeroOverheadSampler
from vllm.zero_overhead.v0.utils import is_zero_overhead
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
......@@ -41,6 +40,7 @@ def get_sampler() -> torch.nn.Module:
from vllm.v1.sample.sampler import Sampler as V1Sampler
return V1Sampler()
if is_zero_overhead():
from vllm.zero_overhead.v0.sampler import ZeroOverheadSampler
return ZeroOverheadSampler()
return Sampler()
......
......@@ -60,7 +60,6 @@ from vllm.worker.model_runner_base import (
_add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
from vllm.zero_overhead.v0.model_runner import ZeroOverheadModelInputForGpuBuilder
from vllm.zero_overhead.v0.utils import is_zero_overhead
if TYPE_CHECKING:
......@@ -1639,6 +1638,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
ModelInputForGPUWithSamplingMetadata)
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
if is_zero_overhead():
from vllm.zero_overhead.v0.model_runner import ZeroOverheadModelInputForGpuBuilder
_builder_cls = ZeroOverheadModelInputForGpuBuilder
def make_model_input_from_broadcasted_tensor_dict(
......
from collections import Counter
from functools import partial
import os
import queue
......@@ -13,7 +12,7 @@ from vllm.core.scheduler import ScheduledSequenceGroup
from vllm.engine.llm_engine import _LOCAL_LOGGING_INTERVAL_SEC, LLMEngine, SchedulerContext, SchedulerOutputState
from vllm.engine.metrics_types import StatLoggerBase
from vllm.engine.output_processor.interfaces import SequenceGroupOutputProcessor
from vllm.entrypoints import logger
from vllm.logger import init_logger
from vllm.executor.executor_base import ExecutorBase
from vllm.inputs import INPUT_REGISTRY
from vllm.inputs.data import ProcessorInputs
......@@ -31,8 +30,10 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, ParallelSampleSequenceGroup, SequenceGroup, SequenceGroupBase, SequenceGroupMetadata
from vllm.tracing import init_tracer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.version import __version__ as VLLM_VERSION
from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled
from vllm.utils import resolve_obj_by_qualname, weak_bind
from vllm.utils import resolve_obj_by_qualname, weak_bind, Counter
from vllm.zero_overhead.v0.sampler import SampleRecorder, get_last_sampler
from vllm.zero_overhead.v0.sequence import ZeroOverheadSequence
from vllm.zero_overhead.v0.stop_check import ZeroOverheadStopChecker
from vllm.zero_overhead.v0.tokenizer import ZeroOverheadDetokenizer
......@@ -40,6 +41,8 @@ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.profiler.prof import profile
logger = init_logger(__name__)
class ZeroOverheadEngine(LLMEngine):
def __init__(
self,
......@@ -77,7 +80,7 @@ class ZeroOverheadEngine(LLMEngine):
logger.info(
"Initializing a V0 LLM engine (v%s) with config: %s, "
"use_cached_outputs=%s, ",
ZLIB_VERSION,
VLLM_VERSION,
vllm_config,
use_cached_outputs,
)
......@@ -259,6 +262,7 @@ class ZeroOverheadEngine(LLMEngine):
self._skip_scheduling_next_step = False
self.async_d2h = None
self.last_record = None
assert os.environ.get('HIP_ALLOC_INITIALIZE') == '0'
self.async_event = torch.cuda.Event(enable_timing=False)
self.zero_thread = threading.Thread(target=self.thread_zero_overhead)
self.q_recorder = queue.Queue()
......@@ -277,6 +281,7 @@ class ZeroOverheadEngine(LLMEngine):
self.sem_m2s.release()
def thread_zero_overhead(self):
logger.info('zero overhead thread start!')
try:
while True:
self.sem_m2s.acquire()
......@@ -290,12 +295,9 @@ class ZeroOverheadEngine(LLMEngine):
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()
last_outputs_ids = None
last_outputs_tensor = None
if self.last_record is not None:
last_output = self.last_record[0][0]
last_outputs_ids, last_outputs_tensor = last_output.sampler_out_ids, last_output.sampler_out_tenosr
self.async_d2h = last_outputs_tensor.to('cpu', non_blocking=True)
last_sampler = self.last_record[1]
self.async_d2h = last_sampler.sampled_token_ids_tensor.to('cpu', non_blocking=True)
self.async_event.record()
self.q_recorder.put(self.last_record)
else:
......@@ -322,10 +324,7 @@ class ZeroOverheadEngine(LLMEngine):
finished_requests_ids=finished_requests_ids,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids,
last_outputs_ids = last_outputs_ids,
last_outputs_sample = last_outputs_tensor)
last_sampled_token_ids=last_sampled_token_ids)
outputs = self.model_executor.execute_model(
execute_model_req=execute_model_req)
......@@ -334,7 +333,8 @@ class ZeroOverheadEngine(LLMEngine):
outputs[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
scheduler_outputs.scheduled_seq_groups = [item for item in scheduler_outputs.scheduled_seq_groups] #deep copy
self.last_record = [outputs, seq_group_metadata_list, scheduler_outputs]
last_sampler = get_last_sampler()
self.last_record = [outputs, last_sampler, seq_group_metadata_list, scheduler_outputs]
except Exception as e:
print(f"thread_zero_overhead error : {e}")
......@@ -353,12 +353,12 @@ class ZeroOverheadEngine(LLMEngine):
virtual_engine = 0
ctx = self.scheduler_contexts[virtual_engine]
ctx.request_outputs.clear()
outputs, seq_group_metadata_list, scheduler_outputs = recode_output
outputs, last_sampler, seq_group_metadata_list, scheduler_outputs = recode_output
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
self.async_event.synchronize()
self._fix_last_step(
outputs, seq_group_metadata_list,
outputs, last_sampler, seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
# is_first_step_output is True only when the num_steps of all
......@@ -398,12 +398,13 @@ class ZeroOverheadEngine(LLMEngine):
def _fix_last_step(
self, output: List[SamplerOutput],
last_sampler: SampleRecorder,
seq_group_metadata_list: List[SequenceGroupMetadata],
scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
#sample_out_list = output[0].sampler_out_tenosr.cpu().tolist()
sample_out_list = self.async_d2h.tolist()
sample_out_ids = output[0].sampler_out_ids.tolist()
sample_out_ids = last_sampler.seq_id.tolist()
for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
zip(seq_group_metadata_list, output[0], scheduled_seq_groups):
seq_group = scheduled_seq_group.seq_group
......
......@@ -35,7 +35,7 @@ class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder):
def build(self) -> ModelInputForGPU:
model_input = super().build()
last_sampler = get_last_sampler()
if last_sampler.sampled_token_ids_tensor is not None:
if last_sampler is not None:
input_ids = async_tensor_h2d(self.req_ids, torch.long,
self.runner.device,
self.runner.pin_memory)
......
......@@ -3,11 +3,10 @@ from typing import Dict, List, Optional
import torch
from vllm import envs
from vllm.model_executor.layers.rejection_sampler import _multinomial
from vllm.model_executor.layers.sampler import MultinomialSamplesType, SampleMetadataType, \
SampleResultArgsType, SampleResultType, SampleResultsDictType, SampleReturnType, Sampler, \
SamplerOutput, _apply_min_p, _apply_min_tokens_penalty, _apply_top_k_top_p, _build_sampler_output, \
_modify_greedy_probs_inplace, _top_k_top_p_multinomial_with_flashinfer, get_logprobs
_modify_greedy_probs_inplace, _top_k_top_p_multinomial_with_flashinfer, get_logprobs, _multinomial
from vllm.model_executor.layers.utils import apply_penalties
from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors, SequenceGroupToSample
from vllm.sampling_params import SamplingType
......@@ -17,13 +16,16 @@ if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
# yapf: disable
from flashinfer.sampling import (
top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling)
# yapf: enable
else:
flashinfer_top_k_top_p_sampling = None
class SampleRecorder:
def __init__(self):
self.seq_id:torch.Tensor = None
self.sampled_token_ids_tensor:torch.Tensor = None
last_sampler = SampleRecorder()
last_sampler = None
def get_last_sampler():
return last_sampler
......@@ -55,6 +57,8 @@ class ZeroOverheadSampler(Sampler):
logits: (num_tokens, vocab_size).
sampling_metadata: Metadata for sampling.
"""
global last_sampler
last_sampler = SampleRecorder()
assert logits is not None
_, vocab_size = logits.shape
......@@ -282,7 +286,6 @@ def _sample_with_torch(
sample_metadata: SampleMetadataType = {}
multinomial_samples: MultinomialSamplesType = {}
greedy_samples: Optional[torch.Tensor] = None
beam_search_logprobs: Optional[torch.Tensor] = None
# Create output tensor for sampled token ids.
if include_gpu_probs_tensor:
......@@ -356,11 +359,6 @@ def _sample_with_torch(
sampled_token_ids_tensor[long_sample_indices] = \
multinomial_samples[sampling_type].to(torch.long)
elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices]
else:
raise ValueError(f"Unsupported sampling type: {sampling_type}")
# Encapsulate arguments for computing Pythonized sampler
# results, whether deferred or otherwise.
maybe_deferred_args = SampleResultArgsType(
......@@ -368,7 +366,6 @@ def _sample_with_torch(
sample_metadata=sample_metadata,
multinomial_samples=multinomial_samples,
greedy_samples=greedy_samples,
beam_search_logprobs=beam_search_logprobs,
sample_results_dict=sample_results_dict)
if not sampling_metadata.skip_sampler_cpu_output:
......
......@@ -44,7 +44,7 @@ class ZeroOverheadStopChecker(StopChecker):
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id = seq.get_last_token_id(self.zero_overhead)
last_token_id = seq.zero_overhead_get_last_token_id()
if last_token_id in (sampling_params.stop_token_ids or ()):
if new_char_count and (
not sampling_params.include_stop_str_in_output):
......
......@@ -23,6 +23,12 @@ def _update_input_tokens(
output_token = tl.load(sample_output + i)
tl.store(input_tokens + pid, output_token)
_update_input_tokens_ptr = None
def UpdateInputTokens(input_tokens, input_seq_ids, last_sample, last_ids):
global _update_input_tokens_ptr
grid = [input_seq_ids.shape[0], 1, 1]
_update_input_tokens[grid](last_sample, last_ids, input_tokens, input_seq_ids, last_ids.shape[0], input_seq_ids.shape[0])
\ No newline at end of file
if _update_input_tokens_ptr is None:
_update_input_tokens_ptr = _update_input_tokens[grid](last_sample, last_ids, input_tokens, input_seq_ids, last_ids.shape[0], input_seq_ids.shape[0])
else:
_update_input_tokens_ptr[grid](last_sample, last_ids, input_tokens, input_seq_ids, last_ids.shape[0], input_seq_ids.shape[0])
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment