Commit 0640f227 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.0' into v0.6.0-dev

parents 82f1ffdf 32e7db25
......@@ -6,8 +6,8 @@ import time
import warnings
import weakref
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type,
TypeVar, Union)
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set,
Tuple, Type, TypeVar, Union)
import numpy as np
import torch
......@@ -21,6 +21,7 @@ from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_pp_group
from vllm.distributed.parallel_state import graph_capture
from vllm.inputs import INPUT_REGISTRY, InputRegistry
......@@ -29,6 +30,7 @@ from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models.interfaces import (supports_lora,
......@@ -41,10 +43,10 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.prompt_adapter.worker_manager import (
LRUCacheWorkerPromptAdapterManager)
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d,
flatten_2d_lists, is_hip, is_pin_memory_available)
flatten_2d_lists, is_hip, is_pin_memory_available,
supports_dynamo)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict,
......@@ -59,10 +61,14 @@ logger = init_logger(__name__)
LORA_WARMUP_RANK = 8
_BATCH_SIZE_ALIGNMENT = 8
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
# all the token sizes that **can** be captured by cudagraph.
# they can be arbitrarily large.
# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192.
# the actual sizes to capture will be determined by the model,
# depending on the model's max_num_seqs.
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025)
]
_NUM_WARMUP_ITERS = 2
......@@ -90,6 +96,9 @@ class ModelInputForGPU(ModelRunnerInputBase):
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
finished_requests_ids: Optional[List[str]] = None
virtual_engine: int = 0
async_callback: Optional[Callable] = None
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
scheduler_outputs: Optional[SchedulerOutputs] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
......@@ -499,23 +508,48 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
and self.sliding_window is None
and inter_data.is_prompt)
inter_data.prefix_cache_hit = prefix_cache_hit
if self.chunked_prefill_enabled and prefix_cache_hit:
raise RuntimeError(
"chunked prefill cannot be used with prefix caching now.")
# If prefix cache is hit, advance context length to bypass
# hit blocks. Accordingly, input tokens, position and query length
# have to be updated.
if prefix_cache_hit:
assert computed_block_nums is not None
context_len = len(computed_block_nums) * self.block_size
if not prefix_cache_hit:
return
assert computed_block_nums is not None
# The cache hit prompt tokens in this sequence. Note that
# this may be larger than the sequence length if chunked
# prefill is enabled.
prefix_cache_len = len(computed_block_nums) * self.block_size
# The number of so far computed prompt tokens in this sequence.
context_len = inter_data.context_lens[seq_idx]
# The total number of prompt tokens in this sequence.
# When chunked prefill is enabled, this is the token number of
# computed chunks + current chunk.
seq_len = inter_data.seq_lens[seq_idx]
if prefix_cache_len <= context_len:
# We already passed the cache hit region,
# so do normal computation.
pass
elif context_len < prefix_cache_len < seq_len:
# Partial hit. Compute the missing part.
uncomputed_start = prefix_cache_len - context_len
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
seq_idx][context_len:]
seq_idx][uncomputed_start:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][context_len:]
seq_idx][uncomputed_start:]
context_len = prefix_cache_len
inter_data.context_lens[seq_idx] = context_len
inter_data.query_lens[
seq_idx] = inter_data.seq_lens[seq_idx] - context_len
elif seq_len <= prefix_cache_len:
# Full hit. Only compute the last token to avoid
# erroneous behavior. FIXME: Ideally we should directly
# mark all tokens as computed in the scheduler and do not
# schedule this sequence, so this case should not happen.
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
seq_idx][-1:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][-1:]
inter_data.query_lens[seq_idx] = 1
inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1
def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup,
seq_idx: int,
......@@ -632,7 +666,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
def _use_captured_graph(self, batch_size: int,
max_decode_seq_len: int) -> bool:
return (self.decode_only and not self.runner.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and batch_size <= self.runner.max_batchsize_to_capture
and max_decode_seq_len <= self.runner.max_seq_len_to_capture)
def build(self) -> ModelInputForGPU:
......@@ -818,6 +852,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
self.max_batchsize_to_capture = _get_max_graph_batch_size(
self.scheduler_config.max_num_seqs)
self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
{} for _ in range(self.parallel_config.pipeline_parallel_size)
......@@ -835,7 +871,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# The shape of the cached block table will be
# (max batch size to capture, max context len to capture / block size).
self.graph_block_tables = np.zeros(
(max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
(self.max_batchsize_to_capture, self.get_max_block_per_batch()),
dtype=np.int32)
num_attn_heads = self.model_config.get_num_attention_heads(
self.parallel_config)
......@@ -945,7 +981,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!")
if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE:
if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo():
self.model = torch.compile(self.model,
fullgraph=True,
backend="eager")
......@@ -1220,7 +1256,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
start_time = time.perf_counter()
# Prepare dummy inputs. These will be reused for all batch sizes.
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
max_batch_size = self.max_batchsize_to_capture
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
......@@ -1248,8 +1284,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
None
] * self.parallel_config.pipeline_parallel_size
graph_batch_size = _get_graph_batch_size(
self.scheduler_config.max_num_seqs)
graph_batch_size = self.max_batchsize_to_capture
batch_size_capture_list = [
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
]
......@@ -1357,7 +1392,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
finished_requests_ids: Optional[List[str]] = None,
) -> ModelInputForGPUWithSamplingMetadata:
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
......@@ -1481,6 +1516,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if not self.is_driver_worker:
return []
if model_input.async_callback is not None:
model_input.async_callback()
# Sample the next token.
output: SamplerOutput = self.model.sample(
logits=logits,
......@@ -1672,3 +1710,22 @@ def _get_graph_batch_size(batch_size: int) -> int:
else:
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
def _get_max_graph_batch_size(max_num_seqs: int) -> int:
"""
max_num_seqs: Maximum number of sequences in a batch.
_BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.
pad the max_num_seqs if necessary by calling _get_graph_batch_size,
which will deal with some edge cases like 1, 2, 4.
if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded size.
if not, it means the padded size is larger than the largest size in
_BATCH_SIZES_TO_CAPTURE, return the largest size in _BATCH_SIZES_TO_CAPTURE.
"""
padded_size = _get_graph_batch_size(max_num_seqs)
if padded_size in _BATCH_SIZES_TO_CAPTURE:
return padded_size
assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1]
return _BATCH_SIZES_TO_CAPTURE[-1]
......@@ -5,9 +5,9 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
......
import dataclasses
import functools
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Union)
try:
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
......@@ -13,9 +16,13 @@ import torch
from vllm import _custom_ops as ops
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
SamplerOutput,
SamplingMetadata, get_logprobs,
get_pythonized_sample_results)
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SamplerOutput, SequenceGroupMetadata,
SequenceOutput)
Logprob, SequenceGroupMetadata, SequenceOutput)
from vllm.utils import PyObjectCache
from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUWithSamplingMetadata)
from vllm.worker.model_runner_base import (
......@@ -31,6 +38,29 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
def seq_output_builder():
return SequenceOutput(
0, 0,
{0: Logprob(logprob=float('inf'), rank=None, decoded_token=None)})
def completion_seq_group_output_builder():
return CompletionSequenceGroupOutput([], None)
# Used by pythonization to reduce python object allocations
class PythonizationCache:
def __init__(self):
self.cached_seq_output = PyObjectCache(seq_output_builder)
self.cached_completion_seq_group_output = PyObjectCache(
completion_seq_group_output_builder)
def reset(self):
self.cached_seq_output.reset()
self.cached_completion_seq_group_output.reset()
@dataclass
class ModelOutput:
"""The output of a single model forward pass.
......@@ -51,6 +81,9 @@ class ModelOutput:
sampler_output_ready_event: torch.cuda.Event
sampled_token_ids: Optional[torch.Tensor] = None
pythonized: bool = False
# On-device tensor containing the logprobs of each token.
logprobs: Optional["torch.Tensor"] = None
pythonization_cache: Optional[PythonizationCache] = None
def pythonize(self, input_metadata: "StatefulModelInput",
copy_stream: torch.cuda.Stream,
......@@ -76,7 +109,9 @@ class ModelOutput:
blocking: bool) -> bool:
"""
If blocking is set, will block until the forward pass for the output is
ready and pythonize the output.
ready and pythonize the output. Upon completing Pythonization, erases
self.logprobs (note that a non-blocking call that is performed when
the sampler output is not yet ready, will not erase self.logprobs.)
"""
assert self.sampled_token_ids is not None
if not blocking and not self.sampler_output_ready_event.query():
......@@ -87,7 +122,16 @@ class ModelOutput:
with torch.cuda.stream(copy_stream):
_pythonize_sampler_output(input_metadata, self.sampler_output,
pinned_sampled_token_buffer,
self.sampled_token_ids)
self.sampled_token_ids, self.logprobs,
self.pythonization_cache)
# Erase the logprobs GPU-side tensor.
# Note that although _pythonize_sampler_output() runs in its
# own CUDA stream, nonetheless _pythonize_sampler_output()
# cannot return until Pythonization is complete; therefore
# we know that by the time the CPU reaches this point,
# `self.logprobs` is no longer needed.
self.logprobs = None
return True
......@@ -191,6 +235,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
self._copy_stream = torch.cuda.Stream()
self.pinned_sampled_token_ids: Optional[torch.Tensor] = None
self.pythonization_cache = PythonizationCache()
def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str, Any]) -> StatefulModelInput:
model_input = (StatefulModelInput.from_broadcasted_tensor_dict(
......@@ -215,6 +261,79 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
)
return model_input
def _async_process_outputs(self, model_input: StatefulModelInput,
output_proc_callback: Callable):
# Proceed with pythonization and output_proc in order.
# Stop on the first one that fails to pythonize
output_proc_callback()
cont = True
for model_output in model_input.cached_outputs:
if not model_output.pythonized:
model_output.maybe_pythonize(model_input, self._copy_stream,
self.pinned_sampled_token_ids)
if model_output.pythonized:
ctx = output_proc_callback.keywords["ctx"]
is_async = False
is_last_step = False
ctx.output_queue.append(
([model_output.sampler_output
], ctx.seq_group_metadata_list,
ctx.scheduler_outputs, is_async, is_last_step))
output_proc_callback()
else:
cont = False
if not cont:
break
def _final_process_outputs(self, model_input: StatefulModelInput,
output_proc_callback: Optional[Callable]):
assert model_input.frozen_model_input is not None
has_async_callback = output_proc_callback is not None
outputs = []
for output_id in range(len(model_input.cached_outputs)):
output = model_input.cached_outputs[output_id]
is_last_step = output_id == len(model_input.cached_outputs) - 1
# For non-async case:
# -- We simply add the outputs
# For async case:
# -- Invoke callback, pythonize, add to callback queue and repeat
# -- For last output, just add to callback queue
if has_async_callback:
assert output_proc_callback is not None
# Invoke callback before pythonize (to overlap with GPU)
output_proc_callback()
# Pythonize
if not output.pythonized:
output.pythonize(model_input, self._copy_stream,
self.pinned_sampled_token_ids)
# For non last step, add to callback queue to chain
# callbacks=>pythonize pairs (for GPU overlap)
if not is_last_step:
ctx = output_proc_callback.keywords[ # type: ignore
"ctx"] # type: ignore
is_async = False
is_last_step = False
ctx.output_queue.append(
([output.sampler_output
], ctx.seq_group_metadata_list,
ctx.scheduler_outputs, is_async, is_last_step))
else:
outputs.append(output.sampler_output)
else:
output.pythonize(model_input, self._copy_stream,
self.pinned_sampled_token_ids)
outputs.append(output.sampler_output)
return outputs
@torch.inference_mode()
def execute_model(
self,
......@@ -271,6 +390,20 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
model_input = self._advance_step(
model_input, model_input.cached_outputs[-1].sampler_output)
output_proc_callback = None
if frozen_model_input.async_callback is not None:
output_proc_callback = frozen_model_input.async_callback
assert output_proc_callback is not None
async_callback = functools.partial(
self._async_process_outputs,
model_input=model_input,
output_proc_callback=output_proc_callback)
frozen_model_input = dataclasses.replace( # type: ignore
model_input.frozen_model_input,
async_callback=async_callback)
assert frozen_model_input is not None
# Execute the model
output = self._base_model_runner.execute_model(frozen_model_input,
kv_caches,
......@@ -294,16 +427,23 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
0].sampled_token_ids.cpu()
model_input.cached_outputs.append(
ModelOutput(output[0], output_ready_event,
output[0].sampled_token_ids, False))
# make sure we dont try to serialize any GPU tensors
output[0].sampled_token_ids, False,
output[0].logprobs, self.pythonization_cache))
# These GPU tensors are not required by multi-step;
# erase them to ensure they are not pythonized or
# transferred to CPU
output[0].sampled_token_ids = None
output[0].sampled_token_probs = None
output[0].logprobs = None
# Pythonize the output if CPU is ahead and the previous step is
# ready.
for model_output in model_input.cached_outputs:
model_output.maybe_pythonize(model_input, self._copy_stream,
self.pinned_sampled_token_ids)
if frozen_model_input.async_callback is None:
for model_output in model_input.cached_outputs:
model_output.maybe_pythonize(model_input,
self._copy_stream,
self.pinned_sampled_token_ids)
model_input.current_step += 1
......@@ -316,11 +456,9 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
# Pythonize the output and block if needed since it is the last step
if model_input.is_last_step:
outputs = []
for output in model_input.cached_outputs:
output.pythonize(model_input, self._copy_stream,
self.pinned_sampled_token_ids)
outputs.append(output.sampler_output)
outputs = self._final_process_outputs(model_input,
output_proc_callback)
self.pythonization_cache.reset()
return outputs
# should be [SamplerOutput]
......@@ -409,12 +547,76 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
return self._base_model_runner.vocab_size
def _pythonize_sampler_output(model_input: StatefulModelInput,
output: SamplerOutput,
pinned_sampled_token_buffer: torch.Tensor,
sampled_token_ids: torch.Tensor) -> None:
DeferredLogprobsReturnType = Tuple[Optional[List[Optional[PromptLogprobs]]],
Optional[List[SampleLogprobs]]]
def deferred_pythonize_logprobs(
output: SamplerOutput,
sampling_metadata: SamplingMetadata,
logprobs_tensor: Optional[torch.Tensor],
) -> DeferredLogprobsReturnType:
"""Perform deferred logprob Pythonization.
1. Pythonize GPU-side sampler result tensors into CPU-side sampler result.
2. Pythonize GPU-side logprobs tensor into CPU-side logprobs lists,
utilizing the Pythonized sampler result computed in step 1.
These deferred computations are not required for single-step scheduling
or the `profile_run()` phase of multi-step scheduling.
Args:
output: sampler output (under deferred Pythonization)
sampling_metadata
Returns:
prompt_logprobs (CPU), sample_logprobs (CPU)
"""
# - Deferred pythonization of sample result
sampler_result = get_pythonized_sample_results(
output.deferred_sample_results_args)
# - Erase the GPU-side deferred sample_result
# computation args to ensure it is never
# pythonized or transferred to CPU
output.deferred_sample_results_args = None
# - Deferred pythonization of logprobs
(
prompt_logprobs,
sample_logprobs,
) = get_logprobs(logprobs_tensor, sampling_metadata, sampler_result)
assert len(prompt_logprobs) == len(sampling_metadata.seq_groups)
assert len(sample_logprobs) == len(sampling_metadata.seq_groups)
return prompt_logprobs, sample_logprobs
def _pythonize_sampler_output(
model_input: StatefulModelInput,
output: SamplerOutput,
pinned_sampled_token_buffer: torch.Tensor,
sampled_token_ids: torch.Tensor,
logprobs_tensor: Optional[torch.Tensor],
cache: Optional[PythonizationCache],
) -> None:
""" This function is only called when the output tensors are ready.
See ModelOutput
See :class:`ModelOutput`.
Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place,
adding a Pythonized output data structure
(:class:`CompletionSequenceGroupOutput`) for each :class:`SequenceGroup`.
Args:
model_input
output: sampler output
pinned_sampled_token_token_buffer: CPU-side pinned memory
(receives copy of
GPU-side token buffer.)
sampled_token_ids: GPU-side token buffer
logprobs_tensor: GPU-side tensor containing
logprobs computed during sampling
"""
assert model_input.frozen_model_input is not None
......@@ -434,20 +636,107 @@ def _pythonize_sampler_output(model_input: StatefulModelInput,
sampling_metadata = frozen_model_input.sampling_metadata
for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
samples_list):
seq_ids = seq_group.seq_ids
next_token_ids = sample_result
parent_ids = [0]
seq_outputs: List[SequenceOutput] = []
skip_sampler_cpu_output = (
frozen_model_input.sampling_metadata.skip_sampler_cpu_output)
# We are guaranteed output tensors are ready, so it is safe to
# pythonize the sampler output & obtain CPU-side logprobs.
#
# However this computation may be skipped entirely
# if no pythonization was deferred.
seq_groups = sampling_metadata.seq_groups
logprobs_are_requested = any([
sg.sampling_params.logprobs is not None
or sg.sampling_params.prompt_logprobs is not None for sg in seq_groups
])
do_pythonize_logprobs = (skip_sampler_cpu_output
and logprobs_are_requested)
(
prompt_logprobs,
sample_logprobs,
) = (deferred_pythonize_logprobs(output, sampling_metadata,
logprobs_tensor)
if do_pythonize_logprobs else (None, None))
for sgdx, (seq_group,
sample_result) in enumerate(zip(seq_groups, samples_list)):
if seq_group.sampling_params.logits_processors:
assert len(seq_group.sampling_params.logits_processors) == 0, (
"Logits Processors are not supported in multi-step decoding")
for parent_id, next_token_id in zip(parent_ids, next_token_ids):
# TODO(will): support logprobs
# Hard coded logprob
seq_outputs.append(
SequenceOutput(seq_ids[parent_id], next_token_id,
{next_token_id: Logprob(logprob=-1)}))
output.outputs.append(CompletionSequenceGroupOutput(seq_outputs, None))
if do_pythonize_logprobs:
assert prompt_logprobs is not None
assert sample_logprobs is not None
(
group_prompt_logprobs,
group_sample_logprobs,
) = ( # Utilize deferred pythonization results
prompt_logprobs[sgdx],
sample_logprobs[sgdx],
)
elif logprobs_are_requested:
(
group_prompt_logprobs,
group_sample_logprobs,
) = (
# profile_run: use already-computed logprobs
output.outputs[sgdx].prompt_logprobs,
[sample.logprobs for sample in output.outputs[sgdx].samples])
seq_ids = seq_group.seq_ids
next_token_ids = sample_result
parent_ids = [0]
if cache is not None:
completion_seq_group_output: CompletionSequenceGroupOutput = \
cache.cached_completion_seq_group_output.get_object()
completion_seq_group_output.samples.clear()
seq_outputs: List[
SequenceOutput] = completion_seq_group_output.samples
else:
seq_outputs = []
for tdx, (parent_id,
next_token_id) in enumerate(zip(parent_ids, next_token_ids)):
if cache is not None:
seq_output: SequenceOutput = cache.cached_seq_output.get_object(
)
seq_output.parent_seq_id = seq_ids[parent_id]
seq_output.output_token = next_token_id
if logprobs_are_requested:
seq_output.logprobs = group_sample_logprobs[tdx]
else:
logprobs = next(iter(seq_output.logprobs.values()))
seq_output.logprobs.clear()
logprobs.logprob = float('inf')
logprobs.rank = None
logprobs.decoded_token = None
seq_output.logprobs[next_token_id] = logprobs
seq_outputs.append(seq_output)
else:
seq_outputs.append(
SequenceOutput(seq_ids[parent_id], next_token_id,
(group_sample_logprobs[tdx]
if logprobs_are_requested else {
next_token_id:
Logprob(logprob=float('inf'),
rank=None,
decoded_token=None)
})))
if cache is not None:
completion_seq_group_output.prompt_logprobs = \
group_prompt_logprobs if logprobs_are_requested else None
output.outputs.append(completion_seq_group_output)
else:
output.outputs.append(
CompletionSequenceGroupOutput(
seq_outputs, (group_prompt_logprobs
if logprobs_are_requested else None)))
assert len(output.outputs) > 0
import dataclasses
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch
from vllm.distributed import broadcast_tensor_dict, get_pp_group
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.worker.model_runner_base import BroadcastableModelInput
from vllm.worker.multi_step_model_runner import (MultiStepModelRunner,
StatefulModelInput)
......@@ -61,6 +63,11 @@ class MultiStepWorker(Worker):
execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids))
if execute_model_req.async_callback:
model_input.frozen_model_input = dataclasses.replace( # type: ignore
model_input.frozen_model_input,
async_callback=execute_model_req.async_callback)
else:
# on subsequent steps we reuse the worker input and model input
multi_step_state = self.multi_step_states[virtual_engine]
......
from dataclasses import dataclass
from importlib.util import find_spec
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch
......@@ -8,11 +9,11 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.neuron import get_neuron_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs)
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
......@@ -76,9 +77,14 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
self.model: nn.Module # initialize after load_model.
def load_model(self) -> None:
self.model = get_neuron_model(self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)
if find_spec("transformers_neuronx") is not None:
self.model = get_neuron_model(
self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)
else:
raise NotImplementedError(
"Supports only Transformer-NeuronX based models.")
def _prepare_prompt(
self,
......
......@@ -6,6 +6,8 @@ import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest
from vllm.worker.neuron_model_runner import NeuronModelRunner
......@@ -24,12 +26,18 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
) -> None:
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
......@@ -40,6 +48,8 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self.is_driver_worker = True
def init_device(self) -> None:
self.init_distributed_environment()
# Set random seed.
set_random_seed(self.model_config.seed)
......@@ -98,3 +108,20 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
This is required for speculative decoding; it is not yet implemented.
"""
raise NotImplementedError
def init_distributed_environment(self):
"""Neuron uses transformers-neuronx for tensor parallelism.
vLLM still needs the environment inited when TP/PP > 1
"""
init_distributed_environment(
world_size=1,
rank=self.rank,
local_rank=self.local_rank,
distributed_init_method=self.distributed_init_method,
backend="gloo",
)
ensure_model_parallel_initialized(
1,
1,
)
......@@ -11,10 +11,11 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.openvino import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.sequence import SequenceGroupMetadata
logger = init_logger(__name__)
......
......@@ -14,7 +14,8 @@ from vllm.distributed import (broadcast_tensor_dict,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.worker.openvino_model_runner import OpenVINOModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
......
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Type, Union)
from unittest.mock import patch
import numpy as np
......@@ -10,14 +11,15 @@ import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SamplerOutput, SequenceGroupMetadata,
SequenceOutput)
Logprob, SequenceGroupMetadata, SequenceOutput)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
_add_attn_metadata_broadcastable_dict,
......@@ -50,6 +52,7 @@ class ModelInputForTPU(ModelRunnerInputBase):
best_of: List[int]
seq_groups: List[List[int]]
virtual_engine: int = 0
async_callback: Optional[Callable] = None
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
......@@ -144,11 +147,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
)
model = model.eval()
xm.wait_device_ops()
model = ModelWrapper(model)
self.model = torch.compile(model,
backend="openxla",
fullgraph=True,
dynamic=False)
self.model = ModelWrapper(model)
def _dummy_run(
self,
......@@ -235,8 +234,15 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
torch._dynamo.mark_dynamic(t, 0)
torch._dynamo.mark_dynamic(p, 0)
# Dummy run.
self.model(token_ids, position_ids, attn_metadata, input_lens, t, p,
num_samples, kv_caches)
self.model(token_ids,
position_ids,
attn_metadata,
input_lens,
t,
p,
num_samples,
kv_caches,
is_prompt=is_prompt)
def warmup_model(
self,
......@@ -530,7 +536,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
if getattr(arg, "context_lens", None) is not None:
arg.context_lens = arg.context_lens.to(self.device)
new_args.append(arg)
return self.model(*new_args)
return self.model(*new_args, is_prompt=is_prompt)
num_prefills = model_input.attn_metadata.num_prefills
is_prompt = num_prefills > 0
......@@ -558,6 +564,8 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
model_input.attn_metadata, model_input.input_lens[i:i + 1],
model_input.t[i:i + 1], model_input.p[i:i + 1],
model_input.num_samples, kv_caches)
if i == 0 and model_input.async_callback is not None:
model_input.async_callback()
# Retrieve the outputs to CPU.
next_token_ids += output_token_ids.cpu().tolist()
start_idx = end_idx
......@@ -568,6 +576,8 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
model_input.attn_metadata, model_input.input_lens,
model_input.t, model_input.p, model_input.num_samples,
kv_caches)
if model_input.async_callback is not None:
model_input.async_callback()
# Retrieve the outputs to CPU.
next_token_ids = output_token_ids.cpu().tolist()
......@@ -591,7 +601,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
batch_idx += 1
else:
for seq_id in seq_ids:
next_token_id = next_token_ids[batch_idx][0]
next_token_id = next_token_ids[batch_idx]
seq_outputs.append(
SequenceOutput(seq_id, next_token_id,
{next_token_id: zero_logprob}))
......@@ -601,11 +611,32 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
return [SamplerOutput(sampler_outputs)]
class ModelWrapper(nn.Module):
class ModelWrapper(TorchCompileWrapperWithCustomDispacther):
def __init__(self, model: nn.Module):
super().__init__()
self.model = model
compiled_callable = torch.compile(self.forward,
backend="openxla",
fullgraph=True,
dynamic=False)
super().__init__(compiled_callable)
def __call__(self, *args, is_prompt: bool, **kwargs):
if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher:
# not fully compiled yet, or not using the custom dispatcher,
# let PyTorch handle it
return self.compiled_callable(*args, **kwargs)
# the 3 compiled codes are:
# 0: for profiling
# 1: for prompt
# 2: for decode
# dispatch to the compiled code directly, skip PyTorch
if is_prompt:
with self.dispatch_to_code(1):
return self.forward(*args, **kwargs)
else:
with self.dispatch_to_code(2):
return self.forward(*args, **kwargs)
def forward(
self,
......@@ -691,6 +722,9 @@ class ModelWrapper(nn.Module):
sampled_token_ids = torch.multinomial(probs,
num_samples,
replacement=True)
if num_samples == 1:
argmax_token_ids = argmax_token_ids.squeeze(dim=-1)
sampled_token_ids = sampled_token_ids.squeeze(dim=-1)
next_token_ids = torch.where(t != 0, sampled_token_ids,
argmax_token_ids)
return next_token_ids
......
......@@ -102,8 +102,9 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
# NOTE(woosuk): Set per-rank cache path since different ranks
# can have slightly different XLA graphs.
world_size = self.parallel_config.world_size
rank = xr.global_ordinal()
per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH,
f"tp{world_size}_rank{self.rank}")
f"tp{world_size}_rank{rank}")
xr.initialize_cache(per_rank_path, readonly=False)
def load_model(self):
......
......@@ -39,7 +39,7 @@ def assert_enc_dec_mr_supported_scenario(
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP'])
if enc_dec_mr.model_config.multimodal_config is not None:
if enc_dec_mr.model_config.is_multimodal_model:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_MM'])
......
......@@ -17,12 +17,12 @@ from vllm.distributed import (ensure_model_parallel_initialized,
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SamplerOutput, SequenceGroupMetadata,
SequenceGroupMetadataDelta)
SequenceGroupMetadata, SequenceGroupMetadataDelta)
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
......
......@@ -11,9 +11,9 @@ from vllm.config import ObservabilityConfig
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SamplerOutput)
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import (enable_trace_function_call_for_thread,
update_environment_variables)
from vllm.worker.model_runner_base import (BroadcastableModelInput,
......@@ -263,6 +263,11 @@ class LocalOrDistributedWorkerBase(WorkerBase):
broadcast_data.update(kwargs)
broadcast_tensor_dict(broadcast_data, src=0)
if execute_model_req.async_callback:
model_input = dataclasses.replace( # type: ignore
model_input,
async_callback=execute_model_req.async_callback)
return model_input, worker_input, kwargs
def prepare_input(
......@@ -289,7 +294,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> Optional[List[SamplerOutput]]:
"""Executes at least one model step on the given sequences, unless no
sequences are provided."""
......
......@@ -12,14 +12,15 @@ from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.distributed import get_pp_group
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs, MultiModalRegistry)
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad
from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata
from vllm.worker.model_runner_base import (
......@@ -439,9 +440,11 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
"Setting it to the minimum value of 1.", expr)
max_num_seqs = 1
batch_size = 0
for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len
seq_data, dummy_multi_modal_data = self.input_registry \
.dummy_data_for_profiling(self.model_config,
......@@ -465,7 +468,13 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
finished_requests_ids = [seq.request_id for seq in seqs]
model_input = self.prepare_model_input(
seqs, finished_requests_ids=finished_requests_ids)
self.execute_model(model_input, kv_caches)
intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = self.model.make_empty_intermediate_tensors(
batch_size=batch_size,
dtype=self.model_config.dtype,
device=self.device)
self.execute_model(model_input, kv_caches, intermediate_tensors)
torch.xpu.synchronize()
return
......@@ -537,7 +546,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
and self.observability_config.collect_model_forward_time):
model_forward_start_time = time.time()
hidden_states = model_executable(
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
kv_caches=kv_caches,
......@@ -545,12 +554,16 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
intermediate_tensors=intermediate_tensors,
**MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device))
# Compute the logits in the last pipeline stage.
if not get_pp_group().is_last_rank:
return hidden_or_intermediate_states
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end_time = time.time()
# Compute the logits.
logits = self.model.compute_logits(hidden_states,
logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata)
# Only perform sampling in the driver worker.
......
......@@ -14,6 +14,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
SpeculativeConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.utils import is_xpu
......@@ -198,3 +199,8 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
if parallel_config.pipeline_parallel_size > 1:
# torch-ccl xpu need a collective API warm up
# before calling send/recv API
get_pp_group().all_reduce(torch.zeros(1).xpu())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment