Commit 0640f227 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.0' into v0.6.0-dev

parents 82f1ffdf 32e7db25
...@@ -6,8 +6,8 @@ import time ...@@ -6,8 +6,8 @@ import time
import warnings import warnings
import weakref import weakref
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type, from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set,
TypeVar, Union) Tuple, Type, TypeVar, Union)
import numpy as np import numpy as np
import torch import torch
...@@ -21,6 +21,7 @@ from vllm.attention.backends.utils import CommonAttentionState ...@@ -21,6 +21,7 @@ from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig, ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig) PromptAdapterConfig, SchedulerConfig)
from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.distributed.parallel_state import graph_capture from vllm.distributed.parallel_state import graph_capture
from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.inputs import INPUT_REGISTRY, InputRegistry
...@@ -29,6 +30,7 @@ from vllm.lora.layers import LoRAMapping ...@@ -29,6 +30,7 @@ from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models.interfaces import (supports_lora, from vllm.model_executor.models.interfaces import (supports_lora,
...@@ -41,10 +43,10 @@ from vllm.prompt_adapter.request import PromptAdapterRequest ...@@ -41,10 +43,10 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.prompt_adapter.worker_manager import ( from vllm.prompt_adapter.worker_manager import (
LRUCacheWorkerPromptAdapterManager) LRUCacheWorkerPromptAdapterManager)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SamplerOutput, from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
SequenceGroupMetadata)
from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d, from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d,
flatten_2d_lists, is_hip, is_pin_memory_available) flatten_2d_lists, is_hip, is_pin_memory_available,
supports_dynamo)
from vllm.worker.model_runner_base import ( from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict, _add_attn_metadata_broadcastable_dict,
...@@ -59,10 +61,14 @@ logger = init_logger(__name__) ...@@ -59,10 +61,14 @@ logger = init_logger(__name__)
LORA_WARMUP_RANK = 8 LORA_WARMUP_RANK = 8
_BATCH_SIZE_ALIGNMENT = 8 _BATCH_SIZE_ALIGNMENT = 8
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. # all the token sizes that **can** be captured by cudagraph.
# they can be arbitrarily large.
# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192.
# the actual sizes to capture will be determined by the model,
# depending on the model's max_num_seqs.
# NOTE: _get_graph_batch_size needs to be updated if this list is changed. # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 33) _BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025)
] ]
_NUM_WARMUP_ITERS = 2 _NUM_WARMUP_ITERS = 2
...@@ -90,6 +96,9 @@ class ModelInputForGPU(ModelRunnerInputBase): ...@@ -90,6 +96,9 @@ class ModelInputForGPU(ModelRunnerInputBase):
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
finished_requests_ids: Optional[List[str]] = None finished_requests_ids: Optional[List[str]] = None
virtual_engine: int = 0 virtual_engine: int = 0
async_callback: Optional[Callable] = None
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
scheduler_outputs: Optional[SchedulerOutputs] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = { tensor_dict = {
...@@ -499,23 +508,48 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -499,23 +508,48 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
and self.sliding_window is None and self.sliding_window is None
and inter_data.is_prompt) and inter_data.is_prompt)
inter_data.prefix_cache_hit = prefix_cache_hit inter_data.prefix_cache_hit = prefix_cache_hit
if self.chunked_prefill_enabled and prefix_cache_hit:
raise RuntimeError( if not prefix_cache_hit:
"chunked prefill cannot be used with prefix caching now.") return
# If prefix cache is hit, advance context length to bypass assert computed_block_nums is not None
# hit blocks. Accordingly, input tokens, position and query length # The cache hit prompt tokens in this sequence. Note that
# have to be updated. # this may be larger than the sequence length if chunked
if prefix_cache_hit: # prefill is enabled.
assert computed_block_nums is not None prefix_cache_len = len(computed_block_nums) * self.block_size
context_len = len(computed_block_nums) * self.block_size # The number of so far computed prompt tokens in this sequence.
context_len = inter_data.context_lens[seq_idx]
# The total number of prompt tokens in this sequence.
# When chunked prefill is enabled, this is the token number of
# computed chunks + current chunk.
seq_len = inter_data.seq_lens[seq_idx]
if prefix_cache_len <= context_len:
# We already passed the cache hit region,
# so do normal computation.
pass
elif context_len < prefix_cache_len < seq_len:
# Partial hit. Compute the missing part.
uncomputed_start = prefix_cache_len - context_len
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[ inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
seq_idx][context_len:] seq_idx][uncomputed_start:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[ inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][context_len:] seq_idx][uncomputed_start:]
context_len = prefix_cache_len
inter_data.context_lens[seq_idx] = context_len inter_data.context_lens[seq_idx] = context_len
inter_data.query_lens[ inter_data.query_lens[
seq_idx] = inter_data.seq_lens[seq_idx] - context_len seq_idx] = inter_data.seq_lens[seq_idx] - context_len
elif seq_len <= prefix_cache_len:
# Full hit. Only compute the last token to avoid
# erroneous behavior. FIXME: Ideally we should directly
# mark all tokens as computed in the scheduler and do not
# schedule this sequence, so this case should not happen.
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
seq_idx][-1:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][-1:]
inter_data.query_lens[seq_idx] = 1
inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1
def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup, def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup,
seq_idx: int, seq_idx: int,
...@@ -632,7 +666,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -632,7 +666,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
def _use_captured_graph(self, batch_size: int, def _use_captured_graph(self, batch_size: int,
max_decode_seq_len: int) -> bool: max_decode_seq_len: int) -> bool:
return (self.decode_only and not self.runner.model_config.enforce_eager return (self.decode_only and not self.runner.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] and batch_size <= self.runner.max_batchsize_to_capture
and max_decode_seq_len <= self.runner.max_seq_len_to_capture) and max_decode_seq_len <= self.runner.max_seq_len_to_capture)
def build(self) -> ModelInputForGPU: def build(self) -> ModelInputForGPU:
...@@ -818,6 +852,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -818,6 +852,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.sliding_window = model_config.get_sliding_window() self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size self.block_size = cache_config.block_size
self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
self.max_batchsize_to_capture = _get_max_graph_batch_size(
self.scheduler_config.max_num_seqs)
self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [ self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
{} for _ in range(self.parallel_config.pipeline_parallel_size) {} for _ in range(self.parallel_config.pipeline_parallel_size)
...@@ -835,7 +871,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -835,7 +871,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# The shape of the cached block table will be # The shape of the cached block table will be
# (max batch size to capture, max context len to capture / block size). # (max batch size to capture, max context len to capture / block size).
self.graph_block_tables = np.zeros( self.graph_block_tables = np.zeros(
(max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), (self.max_batchsize_to_capture, self.get_max_block_per_batch()),
dtype=np.int32) dtype=np.int32)
num_attn_heads = self.model_config.get_num_attention_heads( num_attn_heads = self.model_config.get_num_attention_heads(
self.parallel_config) self.parallel_config)
...@@ -945,7 +981,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -945,7 +981,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"provided. Defaulting to scaling factors of 1.0. " "provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!") "This may lead to less accurate results!")
if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE: if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo():
self.model = torch.compile(self.model, self.model = torch.compile(self.model,
fullgraph=True, fullgraph=True,
backend="eager") backend="eager")
...@@ -1220,7 +1256,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1220,7 +1256,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
start_time = time.perf_counter() start_time = time.perf_counter()
# Prepare dummy inputs. These will be reused for all batch sizes. # Prepare dummy inputs. These will be reused for all batch sizes.
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) max_batch_size = self.max_batchsize_to_capture
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
...@@ -1248,8 +1284,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1248,8 +1284,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
None None
] * self.parallel_config.pipeline_parallel_size ] * self.parallel_config.pipeline_parallel_size
graph_batch_size = _get_graph_batch_size( graph_batch_size = self.max_batchsize_to_capture
self.scheduler_config.max_num_seqs)
batch_size_capture_list = [ batch_size_capture_list = [
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
] ]
...@@ -1357,7 +1392,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1357,7 +1392,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0, virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None finished_requests_ids: Optional[List[str]] = None,
) -> ModelInputForGPUWithSamplingMetadata: ) -> ModelInputForGPUWithSamplingMetadata:
"""Prepare the model input based on a given sequence group, including """Prepare the model input based on a given sequence group, including
metadata for the sampling step. metadata for the sampling step.
...@@ -1481,6 +1516,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1481,6 +1516,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if not self.is_driver_worker: if not self.is_driver_worker:
return [] return []
if model_input.async_callback is not None:
model_input.async_callback()
# Sample the next token. # Sample the next token.
output: SamplerOutput = self.model.sample( output: SamplerOutput = self.model.sample(
logits=logits, logits=logits,
...@@ -1672,3 +1710,22 @@ def _get_graph_batch_size(batch_size: int) -> int: ...@@ -1672,3 +1710,22 @@ def _get_graph_batch_size(batch_size: int) -> int:
else: else:
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
def _get_max_graph_batch_size(max_num_seqs: int) -> int:
"""
max_num_seqs: Maximum number of sequences in a batch.
_BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.
pad the max_num_seqs if necessary by calling _get_graph_batch_size,
which will deal with some edge cases like 1, 2, 4.
if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded size.
if not, it means the padded size is larger than the largest size in
_BATCH_SIZES_TO_CAPTURE, return the largest size in _BATCH_SIZES_TO_CAPTURE.
"""
padded_size = _get_graph_batch_size(max_num_seqs)
if padded_size in _BATCH_SIZES_TO_CAPTURE:
return padded_size
assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1]
return _BATCH_SIZES_TO_CAPTURE[-1]
...@@ -5,9 +5,9 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type, ...@@ -5,9 +5,9 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
import torch import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import (IntermediateTensors, SamplerOutput, from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
SequenceGroupMetadata)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
......
import dataclasses
import functools
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Union)
try: try:
from vllm.attention.backends.flash_attn import FlashAttentionMetadata from vllm.attention.backends.flash_attn import FlashAttentionMetadata
...@@ -13,9 +16,13 @@ import torch ...@@ -13,9 +16,13 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
SamplerOutput,
SamplingMetadata, get_logprobs,
get_pythonized_sample_results)
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SamplerOutput, SequenceGroupMetadata, Logprob, SequenceGroupMetadata, SequenceOutput)
SequenceOutput) from vllm.utils import PyObjectCache
from vllm.worker.model_runner import (GPUModelRunnerBase, from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUWithSamplingMetadata) ModelInputForGPUWithSamplingMetadata)
from vllm.worker.model_runner_base import ( from vllm.worker.model_runner_base import (
...@@ -31,6 +38,29 @@ if TYPE_CHECKING: ...@@ -31,6 +38,29 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
def seq_output_builder():
return SequenceOutput(
0, 0,
{0: Logprob(logprob=float('inf'), rank=None, decoded_token=None)})
def completion_seq_group_output_builder():
return CompletionSequenceGroupOutput([], None)
# Used by pythonization to reduce python object allocations
class PythonizationCache:
def __init__(self):
self.cached_seq_output = PyObjectCache(seq_output_builder)
self.cached_completion_seq_group_output = PyObjectCache(
completion_seq_group_output_builder)
def reset(self):
self.cached_seq_output.reset()
self.cached_completion_seq_group_output.reset()
@dataclass @dataclass
class ModelOutput: class ModelOutput:
"""The output of a single model forward pass. """The output of a single model forward pass.
...@@ -51,6 +81,9 @@ class ModelOutput: ...@@ -51,6 +81,9 @@ class ModelOutput:
sampler_output_ready_event: torch.cuda.Event sampler_output_ready_event: torch.cuda.Event
sampled_token_ids: Optional[torch.Tensor] = None sampled_token_ids: Optional[torch.Tensor] = None
pythonized: bool = False pythonized: bool = False
# On-device tensor containing the logprobs of each token.
logprobs: Optional["torch.Tensor"] = None
pythonization_cache: Optional[PythonizationCache] = None
def pythonize(self, input_metadata: "StatefulModelInput", def pythonize(self, input_metadata: "StatefulModelInput",
copy_stream: torch.cuda.Stream, copy_stream: torch.cuda.Stream,
...@@ -76,7 +109,9 @@ class ModelOutput: ...@@ -76,7 +109,9 @@ class ModelOutput:
blocking: bool) -> bool: blocking: bool) -> bool:
""" """
If blocking is set, will block until the forward pass for the output is If blocking is set, will block until the forward pass for the output is
ready and pythonize the output. ready and pythonize the output. Upon completing Pythonization, erases
self.logprobs (note that a non-blocking call that is performed when
the sampler output is not yet ready, will not erase self.logprobs.)
""" """
assert self.sampled_token_ids is not None assert self.sampled_token_ids is not None
if not blocking and not self.sampler_output_ready_event.query(): if not blocking and not self.sampler_output_ready_event.query():
...@@ -87,7 +122,16 @@ class ModelOutput: ...@@ -87,7 +122,16 @@ class ModelOutput:
with torch.cuda.stream(copy_stream): with torch.cuda.stream(copy_stream):
_pythonize_sampler_output(input_metadata, self.sampler_output, _pythonize_sampler_output(input_metadata, self.sampler_output,
pinned_sampled_token_buffer, pinned_sampled_token_buffer,
self.sampled_token_ids) self.sampled_token_ids, self.logprobs,
self.pythonization_cache)
# Erase the logprobs GPU-side tensor.
# Note that although _pythonize_sampler_output() runs in its
# own CUDA stream, nonetheless _pythonize_sampler_output()
# cannot return until Pythonization is complete; therefore
# we know that by the time the CPU reaches this point,
# `self.logprobs` is no longer needed.
self.logprobs = None
return True return True
...@@ -191,6 +235,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): ...@@ -191,6 +235,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
self._copy_stream = torch.cuda.Stream() self._copy_stream = torch.cuda.Stream()
self.pinned_sampled_token_ids: Optional[torch.Tensor] = None self.pinned_sampled_token_ids: Optional[torch.Tensor] = None
self.pythonization_cache = PythonizationCache()
def make_model_input_from_broadcasted_tensor_dict( def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str, Any]) -> StatefulModelInput: self, tensor_dict: Dict[str, Any]) -> StatefulModelInput:
model_input = (StatefulModelInput.from_broadcasted_tensor_dict( model_input = (StatefulModelInput.from_broadcasted_tensor_dict(
...@@ -215,6 +261,79 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): ...@@ -215,6 +261,79 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
) )
return model_input return model_input
def _async_process_outputs(self, model_input: StatefulModelInput,
output_proc_callback: Callable):
# Proceed with pythonization and output_proc in order.
# Stop on the first one that fails to pythonize
output_proc_callback()
cont = True
for model_output in model_input.cached_outputs:
if not model_output.pythonized:
model_output.maybe_pythonize(model_input, self._copy_stream,
self.pinned_sampled_token_ids)
if model_output.pythonized:
ctx = output_proc_callback.keywords["ctx"]
is_async = False
is_last_step = False
ctx.output_queue.append(
([model_output.sampler_output
], ctx.seq_group_metadata_list,
ctx.scheduler_outputs, is_async, is_last_step))
output_proc_callback()
else:
cont = False
if not cont:
break
def _final_process_outputs(self, model_input: StatefulModelInput,
output_proc_callback: Optional[Callable]):
assert model_input.frozen_model_input is not None
has_async_callback = output_proc_callback is not None
outputs = []
for output_id in range(len(model_input.cached_outputs)):
output = model_input.cached_outputs[output_id]
is_last_step = output_id == len(model_input.cached_outputs) - 1
# For non-async case:
# -- We simply add the outputs
# For async case:
# -- Invoke callback, pythonize, add to callback queue and repeat
# -- For last output, just add to callback queue
if has_async_callback:
assert output_proc_callback is not None
# Invoke callback before pythonize (to overlap with GPU)
output_proc_callback()
# Pythonize
if not output.pythonized:
output.pythonize(model_input, self._copy_stream,
self.pinned_sampled_token_ids)
# For non last step, add to callback queue to chain
# callbacks=>pythonize pairs (for GPU overlap)
if not is_last_step:
ctx = output_proc_callback.keywords[ # type: ignore
"ctx"] # type: ignore
is_async = False
is_last_step = False
ctx.output_queue.append(
([output.sampler_output
], ctx.seq_group_metadata_list,
ctx.scheduler_outputs, is_async, is_last_step))
else:
outputs.append(output.sampler_output)
else:
output.pythonize(model_input, self._copy_stream,
self.pinned_sampled_token_ids)
outputs.append(output.sampler_output)
return outputs
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
...@@ -271,6 +390,20 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): ...@@ -271,6 +390,20 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
model_input = self._advance_step( model_input = self._advance_step(
model_input, model_input.cached_outputs[-1].sampler_output) model_input, model_input.cached_outputs[-1].sampler_output)
output_proc_callback = None
if frozen_model_input.async_callback is not None:
output_proc_callback = frozen_model_input.async_callback
assert output_proc_callback is not None
async_callback = functools.partial(
self._async_process_outputs,
model_input=model_input,
output_proc_callback=output_proc_callback)
frozen_model_input = dataclasses.replace( # type: ignore
model_input.frozen_model_input,
async_callback=async_callback)
assert frozen_model_input is not None
# Execute the model # Execute the model
output = self._base_model_runner.execute_model(frozen_model_input, output = self._base_model_runner.execute_model(frozen_model_input,
kv_caches, kv_caches,
...@@ -294,16 +427,23 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): ...@@ -294,16 +427,23 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
0].sampled_token_ids.cpu() 0].sampled_token_ids.cpu()
model_input.cached_outputs.append( model_input.cached_outputs.append(
ModelOutput(output[0], output_ready_event, ModelOutput(output[0], output_ready_event,
output[0].sampled_token_ids, False)) output[0].sampled_token_ids, False,
# make sure we dont try to serialize any GPU tensors output[0].logprobs, self.pythonization_cache))
# These GPU tensors are not required by multi-step;
# erase them to ensure they are not pythonized or
# transferred to CPU
output[0].sampled_token_ids = None output[0].sampled_token_ids = None
output[0].sampled_token_probs = None output[0].sampled_token_probs = None
output[0].logprobs = None output[0].logprobs = None
# Pythonize the output if CPU is ahead and the previous step is # Pythonize the output if CPU is ahead and the previous step is
# ready. # ready.
for model_output in model_input.cached_outputs: if frozen_model_input.async_callback is None:
model_output.maybe_pythonize(model_input, self._copy_stream, for model_output in model_input.cached_outputs:
self.pinned_sampled_token_ids) model_output.maybe_pythonize(model_input,
self._copy_stream,
self.pinned_sampled_token_ids)
model_input.current_step += 1 model_input.current_step += 1
...@@ -316,11 +456,9 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): ...@@ -316,11 +456,9 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
# Pythonize the output and block if needed since it is the last step # Pythonize the output and block if needed since it is the last step
if model_input.is_last_step: if model_input.is_last_step:
outputs = [] outputs = self._final_process_outputs(model_input,
for output in model_input.cached_outputs: output_proc_callback)
output.pythonize(model_input, self._copy_stream, self.pythonization_cache.reset()
self.pinned_sampled_token_ids)
outputs.append(output.sampler_output)
return outputs return outputs
# should be [SamplerOutput] # should be [SamplerOutput]
...@@ -409,12 +547,76 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): ...@@ -409,12 +547,76 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
return self._base_model_runner.vocab_size return self._base_model_runner.vocab_size
def _pythonize_sampler_output(model_input: StatefulModelInput, DeferredLogprobsReturnType = Tuple[Optional[List[Optional[PromptLogprobs]]],
output: SamplerOutput, Optional[List[SampleLogprobs]]]
pinned_sampled_token_buffer: torch.Tensor,
sampled_token_ids: torch.Tensor) -> None:
def deferred_pythonize_logprobs(
output: SamplerOutput,
sampling_metadata: SamplingMetadata,
logprobs_tensor: Optional[torch.Tensor],
) -> DeferredLogprobsReturnType:
"""Perform deferred logprob Pythonization.
1. Pythonize GPU-side sampler result tensors into CPU-side sampler result.
2. Pythonize GPU-side logprobs tensor into CPU-side logprobs lists,
utilizing the Pythonized sampler result computed in step 1.
These deferred computations are not required for single-step scheduling
or the `profile_run()` phase of multi-step scheduling.
Args:
output: sampler output (under deferred Pythonization)
sampling_metadata
Returns:
prompt_logprobs (CPU), sample_logprobs (CPU)
"""
# - Deferred pythonization of sample result
sampler_result = get_pythonized_sample_results(
output.deferred_sample_results_args)
# - Erase the GPU-side deferred sample_result
# computation args to ensure it is never
# pythonized or transferred to CPU
output.deferred_sample_results_args = None
# - Deferred pythonization of logprobs
(
prompt_logprobs,
sample_logprobs,
) = get_logprobs(logprobs_tensor, sampling_metadata, sampler_result)
assert len(prompt_logprobs) == len(sampling_metadata.seq_groups)
assert len(sample_logprobs) == len(sampling_metadata.seq_groups)
return prompt_logprobs, sample_logprobs
def _pythonize_sampler_output(
model_input: StatefulModelInput,
output: SamplerOutput,
pinned_sampled_token_buffer: torch.Tensor,
sampled_token_ids: torch.Tensor,
logprobs_tensor: Optional[torch.Tensor],
cache: Optional[PythonizationCache],
) -> None:
""" This function is only called when the output tensors are ready. """ This function is only called when the output tensors are ready.
See ModelOutput See :class:`ModelOutput`.
Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place,
adding a Pythonized output data structure
(:class:`CompletionSequenceGroupOutput`) for each :class:`SequenceGroup`.
Args:
model_input
output: sampler output
pinned_sampled_token_token_buffer: CPU-side pinned memory
(receives copy of
GPU-side token buffer.)
sampled_token_ids: GPU-side token buffer
logprobs_tensor: GPU-side tensor containing
logprobs computed during sampling
""" """
assert model_input.frozen_model_input is not None assert model_input.frozen_model_input is not None
...@@ -434,20 +636,107 @@ def _pythonize_sampler_output(model_input: StatefulModelInput, ...@@ -434,20 +636,107 @@ def _pythonize_sampler_output(model_input: StatefulModelInput,
sampling_metadata = frozen_model_input.sampling_metadata sampling_metadata = frozen_model_input.sampling_metadata
for (seq_group, sample_result) in zip(sampling_metadata.seq_groups, skip_sampler_cpu_output = (
samples_list): frozen_model_input.sampling_metadata.skip_sampler_cpu_output)
seq_ids = seq_group.seq_ids
next_token_ids = sample_result # We are guaranteed output tensors are ready, so it is safe to
parent_ids = [0] # pythonize the sampler output & obtain CPU-side logprobs.
seq_outputs: List[SequenceOutput] = [] #
# However this computation may be skipped entirely
# if no pythonization was deferred.
seq_groups = sampling_metadata.seq_groups
logprobs_are_requested = any([
sg.sampling_params.logprobs is not None
or sg.sampling_params.prompt_logprobs is not None for sg in seq_groups
])
do_pythonize_logprobs = (skip_sampler_cpu_output
and logprobs_are_requested)
(
prompt_logprobs,
sample_logprobs,
) = (deferred_pythonize_logprobs(output, sampling_metadata,
logprobs_tensor)
if do_pythonize_logprobs else (None, None))
for sgdx, (seq_group,
sample_result) in enumerate(zip(seq_groups, samples_list)):
if seq_group.sampling_params.logits_processors: if seq_group.sampling_params.logits_processors:
assert len(seq_group.sampling_params.logits_processors) == 0, ( assert len(seq_group.sampling_params.logits_processors) == 0, (
"Logits Processors are not supported in multi-step decoding") "Logits Processors are not supported in multi-step decoding")
for parent_id, next_token_id in zip(parent_ids, next_token_ids):
# TODO(will): support logprobs if do_pythonize_logprobs:
# Hard coded logprob assert prompt_logprobs is not None
seq_outputs.append( assert sample_logprobs is not None
SequenceOutput(seq_ids[parent_id], next_token_id,
{next_token_id: Logprob(logprob=-1)})) (
output.outputs.append(CompletionSequenceGroupOutput(seq_outputs, None)) group_prompt_logprobs,
group_sample_logprobs,
) = ( # Utilize deferred pythonization results
prompt_logprobs[sgdx],
sample_logprobs[sgdx],
)
elif logprobs_are_requested:
(
group_prompt_logprobs,
group_sample_logprobs,
) = (
# profile_run: use already-computed logprobs
output.outputs[sgdx].prompt_logprobs,
[sample.logprobs for sample in output.outputs[sgdx].samples])
seq_ids = seq_group.seq_ids
next_token_ids = sample_result
parent_ids = [0]
if cache is not None:
completion_seq_group_output: CompletionSequenceGroupOutput = \
cache.cached_completion_seq_group_output.get_object()
completion_seq_group_output.samples.clear()
seq_outputs: List[
SequenceOutput] = completion_seq_group_output.samples
else:
seq_outputs = []
for tdx, (parent_id,
next_token_id) in enumerate(zip(parent_ids, next_token_ids)):
if cache is not None:
seq_output: SequenceOutput = cache.cached_seq_output.get_object(
)
seq_output.parent_seq_id = seq_ids[parent_id]
seq_output.output_token = next_token_id
if logprobs_are_requested:
seq_output.logprobs = group_sample_logprobs[tdx]
else:
logprobs = next(iter(seq_output.logprobs.values()))
seq_output.logprobs.clear()
logprobs.logprob = float('inf')
logprobs.rank = None
logprobs.decoded_token = None
seq_output.logprobs[next_token_id] = logprobs
seq_outputs.append(seq_output)
else:
seq_outputs.append(
SequenceOutput(seq_ids[parent_id], next_token_id,
(group_sample_logprobs[tdx]
if logprobs_are_requested else {
next_token_id:
Logprob(logprob=float('inf'),
rank=None,
decoded_token=None)
})))
if cache is not None:
completion_seq_group_output.prompt_logprobs = \
group_prompt_logprobs if logprobs_are_requested else None
output.outputs.append(completion_seq_group_output)
else:
output.outputs.append(
CompletionSequenceGroupOutput(
seq_outputs, (group_prompt_logprobs
if logprobs_are_requested else None)))
assert len(output.outputs) > 0 assert len(output.outputs) > 0
import dataclasses
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
from vllm.distributed import broadcast_tensor_dict, get_pp_group from vllm.distributed import broadcast_tensor_dict, get_pp_group
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.worker.model_runner_base import BroadcastableModelInput from vllm.worker.model_runner_base import BroadcastableModelInput
from vllm.worker.multi_step_model_runner import (MultiStepModelRunner, from vllm.worker.multi_step_model_runner import (MultiStepModelRunner,
StatefulModelInput) StatefulModelInput)
...@@ -61,6 +63,11 @@ class MultiStepWorker(Worker): ...@@ -61,6 +63,11 @@ class MultiStepWorker(Worker):
execute_model_req.seq_group_metadata_list, execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine, execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids)) execute_model_req.finished_requests_ids))
if execute_model_req.async_callback:
model_input.frozen_model_input = dataclasses.replace( # type: ignore
model_input.frozen_model_input,
async_callback=execute_model_req.async_callback)
else: else:
# on subsequent steps we reuse the worker input and model input # on subsequent steps we reuse the worker input and model input
multi_step_state = self.multi_step_states[virtual_engine] multi_step_state = self.multi_step_states[virtual_engine]
......
from dataclasses import dataclass from dataclasses import dataclass
from importlib.util import find_spec
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch import torch
...@@ -8,11 +9,11 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, ...@@ -8,11 +9,11 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.neuron import get_neuron_model from vllm.model_executor.model_loader.neuron import get_neuron_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs) MultiModalInputs)
from vllm.sequence import (IntermediateTensors, SamplerOutput, from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
SequenceGroupMetadata)
from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
...@@ -76,9 +77,14 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): ...@@ -76,9 +77,14 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
self.model: nn.Module # initialize after load_model. self.model: nn.Module # initialize after load_model.
def load_model(self) -> None: def load_model(self) -> None:
self.model = get_neuron_model(self.model_config, if find_spec("transformers_neuronx") is not None:
parallel_config=self.parallel_config, self.model = get_neuron_model(
scheduler_config=self.scheduler_config) self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)
else:
raise NotImplementedError(
"Supports only Transformer-NeuronX based models.")
def _prepare_prompt( def _prepare_prompt(
self, self,
......
...@@ -6,6 +6,8 @@ import torch.distributed ...@@ -6,6 +6,8 @@ import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig) ParallelConfig, SchedulerConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.worker.neuron_model_runner import NeuronModelRunner from vllm.worker.neuron_model_runner import NeuronModelRunner
...@@ -24,12 +26,18 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -24,12 +26,18 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
cache_config: CacheConfig, cache_config: CacheConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
) -> None: ) -> None:
self.model_config = model_config self.model_config = model_config
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.device_config = device_config self.device_config = device_config
self.cache_config = cache_config self.cache_config = cache_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
if self.model_config.trust_remote_code: if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing # note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules from vllm.utils import init_cached_hf_modules
...@@ -40,6 +48,8 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -40,6 +48,8 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self.is_driver_worker = True self.is_driver_worker = True
def init_device(self) -> None: def init_device(self) -> None:
self.init_distributed_environment()
# Set random seed. # Set random seed.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
...@@ -98,3 +108,20 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -98,3 +108,20 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
This is required for speculative decoding; it is not yet implemented. This is required for speculative decoding; it is not yet implemented.
""" """
raise NotImplementedError raise NotImplementedError
def init_distributed_environment(self):
"""Neuron uses transformers-neuronx for tensor parallelism.
vLLM still needs the environment inited when TP/PP > 1
"""
init_distributed_environment(
world_size=1,
rank=self.rank,
local_rank=self.local_rank,
distributed_init_method=self.distributed_init_method,
backend="gloo",
)
ensure_model_parallel_initialized(
1,
1,
)
...@@ -11,10 +11,11 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ...@@ -11,10 +11,11 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
SchedulerConfig) SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.openvino import get_model from vllm.model_executor.model_loader.openvino import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs) MultiModalInputs)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SequenceGroupMetadata
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -14,7 +14,8 @@ from vllm.distributed import (broadcast_tensor_dict, ...@@ -14,7 +14,8 @@ from vllm.distributed import (broadcast_tensor_dict,
init_distributed_environment) init_distributed_environment)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.worker.openvino_model_runner import OpenVINOModelRunner from vllm.worker.openvino_model_runner import OpenVINOModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase from vllm.worker.worker_base import LoraNotSupportedWorkerBase
......
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Type, Union)
from unittest.mock import patch from unittest.mock import patch
import numpy as np import numpy as np
...@@ -10,14 +11,15 @@ import torch_xla.core.xla_model as xm ...@@ -10,14 +11,15 @@ import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr import torch_xla.runtime as xr
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
ParallelConfig, SchedulerConfig) ParallelConfig, SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SamplerOutput, SequenceGroupMetadata, Logprob, SequenceGroupMetadata, SequenceOutput)
SequenceOutput)
from vllm.worker.model_runner_base import ( from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerBase, ModelRunnerInputBase,
_add_attn_metadata_broadcastable_dict, _add_attn_metadata_broadcastable_dict,
...@@ -50,6 +52,7 @@ class ModelInputForTPU(ModelRunnerInputBase): ...@@ -50,6 +52,7 @@ class ModelInputForTPU(ModelRunnerInputBase):
best_of: List[int] best_of: List[int]
seq_groups: List[List[int]] seq_groups: List[List[int]]
virtual_engine: int = 0 virtual_engine: int = 0
async_callback: Optional[Callable] = None
def as_broadcastable_tensor_dict( def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]: self) -> Dict[str, Union[int, torch.Tensor]]:
...@@ -144,11 +147,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -144,11 +147,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
) )
model = model.eval() model = model.eval()
xm.wait_device_ops() xm.wait_device_ops()
model = ModelWrapper(model) self.model = ModelWrapper(model)
self.model = torch.compile(model,
backend="openxla",
fullgraph=True,
dynamic=False)
def _dummy_run( def _dummy_run(
self, self,
...@@ -235,8 +234,15 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -235,8 +234,15 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
torch._dynamo.mark_dynamic(t, 0) torch._dynamo.mark_dynamic(t, 0)
torch._dynamo.mark_dynamic(p, 0) torch._dynamo.mark_dynamic(p, 0)
# Dummy run. # Dummy run.
self.model(token_ids, position_ids, attn_metadata, input_lens, t, p, self.model(token_ids,
num_samples, kv_caches) position_ids,
attn_metadata,
input_lens,
t,
p,
num_samples,
kv_caches,
is_prompt=is_prompt)
def warmup_model( def warmup_model(
self, self,
...@@ -530,7 +536,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -530,7 +536,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
if getattr(arg, "context_lens", None) is not None: if getattr(arg, "context_lens", None) is not None:
arg.context_lens = arg.context_lens.to(self.device) arg.context_lens = arg.context_lens.to(self.device)
new_args.append(arg) new_args.append(arg)
return self.model(*new_args) return self.model(*new_args, is_prompt=is_prompt)
num_prefills = model_input.attn_metadata.num_prefills num_prefills = model_input.attn_metadata.num_prefills
is_prompt = num_prefills > 0 is_prompt = num_prefills > 0
...@@ -558,6 +564,8 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -558,6 +564,8 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
model_input.attn_metadata, model_input.input_lens[i:i + 1], model_input.attn_metadata, model_input.input_lens[i:i + 1],
model_input.t[i:i + 1], model_input.p[i:i + 1], model_input.t[i:i + 1], model_input.p[i:i + 1],
model_input.num_samples, kv_caches) model_input.num_samples, kv_caches)
if i == 0 and model_input.async_callback is not None:
model_input.async_callback()
# Retrieve the outputs to CPU. # Retrieve the outputs to CPU.
next_token_ids += output_token_ids.cpu().tolist() next_token_ids += output_token_ids.cpu().tolist()
start_idx = end_idx start_idx = end_idx
...@@ -568,6 +576,8 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -568,6 +576,8 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
model_input.attn_metadata, model_input.input_lens, model_input.attn_metadata, model_input.input_lens,
model_input.t, model_input.p, model_input.num_samples, model_input.t, model_input.p, model_input.num_samples,
kv_caches) kv_caches)
if model_input.async_callback is not None:
model_input.async_callback()
# Retrieve the outputs to CPU. # Retrieve the outputs to CPU.
next_token_ids = output_token_ids.cpu().tolist() next_token_ids = output_token_ids.cpu().tolist()
...@@ -591,7 +601,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -591,7 +601,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
batch_idx += 1 batch_idx += 1
else: else:
for seq_id in seq_ids: for seq_id in seq_ids:
next_token_id = next_token_ids[batch_idx][0] next_token_id = next_token_ids[batch_idx]
seq_outputs.append( seq_outputs.append(
SequenceOutput(seq_id, next_token_id, SequenceOutput(seq_id, next_token_id,
{next_token_id: zero_logprob})) {next_token_id: zero_logprob}))
...@@ -601,11 +611,32 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -601,11 +611,32 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
return [SamplerOutput(sampler_outputs)] return [SamplerOutput(sampler_outputs)]
class ModelWrapper(nn.Module): class ModelWrapper(TorchCompileWrapperWithCustomDispacther):
def __init__(self, model: nn.Module): def __init__(self, model: nn.Module):
super().__init__()
self.model = model self.model = model
compiled_callable = torch.compile(self.forward,
backend="openxla",
fullgraph=True,
dynamic=False)
super().__init__(compiled_callable)
def __call__(self, *args, is_prompt: bool, **kwargs):
if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher:
# not fully compiled yet, or not using the custom dispatcher,
# let PyTorch handle it
return self.compiled_callable(*args, **kwargs)
# the 3 compiled codes are:
# 0: for profiling
# 1: for prompt
# 2: for decode
# dispatch to the compiled code directly, skip PyTorch
if is_prompt:
with self.dispatch_to_code(1):
return self.forward(*args, **kwargs)
else:
with self.dispatch_to_code(2):
return self.forward(*args, **kwargs)
def forward( def forward(
self, self,
...@@ -691,6 +722,9 @@ class ModelWrapper(nn.Module): ...@@ -691,6 +722,9 @@ class ModelWrapper(nn.Module):
sampled_token_ids = torch.multinomial(probs, sampled_token_ids = torch.multinomial(probs,
num_samples, num_samples,
replacement=True) replacement=True)
if num_samples == 1:
argmax_token_ids = argmax_token_ids.squeeze(dim=-1)
sampled_token_ids = sampled_token_ids.squeeze(dim=-1)
next_token_ids = torch.where(t != 0, sampled_token_ids, next_token_ids = torch.where(t != 0, sampled_token_ids,
argmax_token_ids) argmax_token_ids)
return next_token_ids return next_token_ids
......
...@@ -102,8 +102,9 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -102,8 +102,9 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
# NOTE(woosuk): Set per-rank cache path since different ranks # NOTE(woosuk): Set per-rank cache path since different ranks
# can have slightly different XLA graphs. # can have slightly different XLA graphs.
world_size = self.parallel_config.world_size world_size = self.parallel_config.world_size
rank = xr.global_ordinal()
per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH,
f"tp{world_size}_rank{self.rank}") f"tp{world_size}_rank{rank}")
xr.initialize_cache(per_rank_path, readonly=False) xr.initialize_cache(per_rank_path, readonly=False)
def load_model(self): def load_model(self):
......
...@@ -39,7 +39,7 @@ def assert_enc_dec_mr_supported_scenario( ...@@ -39,7 +39,7 @@ def assert_enc_dec_mr_supported_scenario(
raise NotImplementedError( raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP']) STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP'])
if enc_dec_mr.model_config.multimodal_config is not None: if enc_dec_mr.model_config.is_multimodal_model:
raise NotImplementedError( raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_MM']) STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_MM'])
......
...@@ -17,12 +17,12 @@ from vllm.distributed import (ensure_model_parallel_initialized, ...@@ -17,12 +17,12 @@ from vllm.distributed import (ensure_model_parallel_initialized,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SamplerOutput, SequenceGroupMetadata, SequenceGroupMetadata, SequenceGroupMetadataDelta)
SequenceGroupMetadataDelta)
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.embedding_model_runner import EmbeddingModelRunner
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
......
...@@ -11,9 +11,9 @@ from vllm.config import ObservabilityConfig ...@@ -11,9 +11,9 @@ from vllm.config import ObservabilityConfig
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, from vllm.sequence import ExecuteModelRequest, IntermediateTensors
SamplerOutput)
from vllm.utils import (enable_trace_function_call_for_thread, from vllm.utils import (enable_trace_function_call_for_thread,
update_environment_variables) update_environment_variables)
from vllm.worker.model_runner_base import (BroadcastableModelInput, from vllm.worker.model_runner_base import (BroadcastableModelInput,
...@@ -263,6 +263,11 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -263,6 +263,11 @@ class LocalOrDistributedWorkerBase(WorkerBase):
broadcast_data.update(kwargs) broadcast_data.update(kwargs)
broadcast_tensor_dict(broadcast_data, src=0) broadcast_tensor_dict(broadcast_data, src=0)
if execute_model_req.async_callback:
model_input = dataclasses.replace( # type: ignore
model_input,
async_callback=execute_model_req.async_callback)
return model_input, worker_input, kwargs return model_input, worker_input, kwargs
def prepare_input( def prepare_input(
...@@ -289,7 +294,7 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -289,7 +294,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
def execute_model( def execute_model(
self, self,
execute_model_req: Optional[ExecuteModelRequest] = None execute_model_req: Optional[ExecuteModelRequest] = None,
) -> Optional[List[SamplerOutput]]: ) -> Optional[List[SamplerOutput]]:
"""Executes at least one model step on the given sequences, unless no """Executes at least one model step on the given sequences, unless no
sequences are provided.""" sequences are provided."""
......
...@@ -12,14 +12,15 @@ from vllm.attention import get_attn_backend ...@@ -12,14 +12,15 @@ from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig, ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig) PromptAdapterConfig, SchedulerConfig)
from vllm.distributed import get_pp_group
from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs, MultiModalRegistry) MultiModalInputs, MultiModalRegistry)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SamplerOutput, from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
SequenceGroupMetadata)
from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad
from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata
from vllm.worker.model_runner_base import ( from vllm.worker.model_runner_base import (
...@@ -439,9 +440,11 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): ...@@ -439,9 +440,11 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
"Setting it to the minimum value of 1.", expr) "Setting it to the minimum value of 1.", expr)
max_num_seqs = 1 max_num_seqs = 1
batch_size = 0
for group_id in range(max_num_seqs): for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs + seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs)) (group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len
seq_data, dummy_multi_modal_data = self.input_registry \ seq_data, dummy_multi_modal_data = self.input_registry \
.dummy_data_for_profiling(self.model_config, .dummy_data_for_profiling(self.model_config,
...@@ -465,7 +468,13 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): ...@@ -465,7 +468,13 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
finished_requests_ids = [seq.request_id for seq in seqs] finished_requests_ids = [seq.request_id for seq in seqs]
model_input = self.prepare_model_input( model_input = self.prepare_model_input(
seqs, finished_requests_ids=finished_requests_ids) seqs, finished_requests_ids=finished_requests_ids)
self.execute_model(model_input, kv_caches) intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = self.model.make_empty_intermediate_tensors(
batch_size=batch_size,
dtype=self.model_config.dtype,
device=self.device)
self.execute_model(model_input, kv_caches, intermediate_tensors)
torch.xpu.synchronize() torch.xpu.synchronize()
return return
...@@ -537,7 +546,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): ...@@ -537,7 +546,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
and self.observability_config.collect_model_forward_time): and self.observability_config.collect_model_forward_time):
model_forward_start_time = time.time() model_forward_start_time = time.time()
hidden_states = model_executable( hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,
kv_caches=kv_caches, kv_caches=kv_caches,
...@@ -545,12 +554,16 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): ...@@ -545,12 +554,16 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
**MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device)) device=self.device))
# Compute the logits in the last pipeline stage.
if not get_pp_group().is_last_rank:
return hidden_or_intermediate_states
if (self.observability_config is not None if (self.observability_config is not None
and self.observability_config.collect_model_forward_time): and self.observability_config.collect_model_forward_time):
model_forward_end_time = time.time() model_forward_end_time = time.time()
# Compute the logits. # Compute the logits.
logits = self.model.compute_logits(hidden_states, logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata) model_input.sampling_metadata)
# Only perform sampling in the driver worker. # Only perform sampling in the driver worker.
......
...@@ -14,6 +14,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ...@@ -14,6 +14,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
SpeculativeConfig) SpeculativeConfig)
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.utils import is_xpu from vllm.utils import is_xpu
...@@ -198,3 +199,8 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker): ...@@ -198,3 +199,8 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
ensure_model_parallel_initialized( ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size, parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size) parallel_config.pipeline_parallel_size)
if parallel_config.pipeline_parallel_size > 1:
# torch-ccl xpu need a collective API warm up
# before calling send/recv API
get_pp_group().all_reduce(torch.zeros(1).xpu())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment