Commit 500b93c8 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.3.post1' into v0.5.3.post1-dtk24.04.1

parents 99426767 38c4b7e8
......@@ -2,13 +2,14 @@ import dataclasses
import importlib
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
import torch
from vllm.distributed import broadcast_tensor_dict, get_pp_group
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SamplerOutput)
from vllm.utils import (enable_trace_function_call_for_thread,
......@@ -53,7 +54,7 @@ class WorkerBase(ABC):
"""
raise NotImplementedError
@torch.inference_mode()
@current_platform.inference_mode()
def start_worker_execution_loop(self) -> None:
"""Execute model loop in parallel worker.
......@@ -274,13 +275,40 @@ class LocalOrDistributedWorkerBase(WorkerBase):
num_steps)
if not get_pp_group().is_last_rank:
# output is IntermediateTensors
get_pp_group().send_tensor_dict(output.tensors)
return [None]
# Worker only supports single-step execution. Wrap the output in a
# list to conform to interface.
# output is List[SamplerOutput]
return output
def _execute_model_spmd(
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[SamplerOutput]]:
"""
Execute model in Single Program Multiple Data (SPMD) fashion.
All workers take the same request, prepare the input and
execute the model.
"""
assert execute_model_req is not None, (
"_execute_model_spmd() requires each worker to take in an "
"ExecuteModelRequest")
worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list))
self.execute_worker(worker_input)
# If there is no input, we don't need to execute the model.
if worker_input.num_seq_groups == 0:
return []
return self.model_runner.execute_model(
model_input, self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None)
class WorkerWrapperBase:
"""
......@@ -288,15 +316,24 @@ class WorkerWrapperBase:
We first instantiate the WorkerWrapper, which remembers the worker module
and class name. Then, when we call `update_environment_variables`, and the
real initialization happens in `init_worker`.
If worker_class_fn is specified, it will be executed to get the worker
class.
Otherwise, the worker class will be obtained by dynamically importing it
using worker_module_name and worker_class_name.
"""
def __init__(self,
worker_module_name: str,
worker_class_name: str,
trust_remote_code: bool = False) -> None:
def __init__(
self,
worker_module_name: str,
worker_class_name: str,
trust_remote_code: bool = False,
worker_class_fn: Optional[Callable[[],
Type[WorkerBase]]] = None) -> None:
self.worker_module_name = worker_module_name
self.worker_class_name = worker_class_name
self.worker = None
self.worker_class_fn = worker_class_fn
self.worker: Optional[WorkerBase] = None
if trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
......@@ -321,9 +358,14 @@ class WorkerWrapperBase:
# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'
mod = importlib.import_module(self.worker_module_name)
worker_class = getattr(mod, self.worker_class_name)
if self.worker_class_fn:
worker_class = self.worker_class_fn()
else:
mod = importlib.import_module(self.worker_module_name)
worker_class = getattr(mod, self.worker_class_name)
self.worker = worker_class(*args, **kwargs)
assert self.worker is not None
def execute_method(self, method, *args, **kwargs):
try:
......
......@@ -335,11 +335,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
dtype=torch.int,
device=self.device)
max_block_table_len = max(
len(block_table) for block_table in block_tables)
block_tables = make_tensor_with_pad(
block_tables,
max_len=max_block_table_len,
pad=0,
dtype=torch.int,
device=self.device,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment