Unverified Commit 59a0192f authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Interface for accessing model from `VllmRunner` (#10353)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 83609791
from typing import List, Optional, Set, Tuple
import torch
import torch.nn as nn
from vllm.distributed.parallel_state import (get_tp_group,
init_model_parallel_group,
......@@ -15,6 +16,10 @@ from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
logger = init_logger(__name__)
class _DummyModel(nn.Module):
pass
class SmallerTpProposerWorker(ProposerWorkerBase):
"""Class which allows a speculative draft model to run with smaller tensor
parallel degree than target model.
......@@ -139,6 +144,13 @@ class SmallerTpProposerWorker(ProposerWorkerBase):
return self._worker.get_spec_proposals(
execute_model_req, seq_ids_with_bonus_token_in_last_step)
def get_model(self) -> nn.Module:
if self._is_dummy:
return _DummyModel()
with self._patch_tensor_parallel_group():
return self._worker.get_model()
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
......
......@@ -4,6 +4,7 @@ from functools import cached_property
from typing import Any, Dict, List, Optional, Set, Tuple, Type
import torch
import torch.nn as nn
from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
from vllm.distributed.communication_op import broadcast_tensor_dict
......@@ -403,6 +404,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)
def get_model(self) -> nn.Module:
return self.scorer_worker.get_model()
@torch.inference_mode()
def execute_model(
self,
......
......@@ -94,22 +94,12 @@ class MultiprocExecutor(Executor):
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
"""
Execute an RPC call on workers.
Args:
method: Name of the worker method to execute
timeout: Maximum time in seconds to wait for execution. Rases a
TimeoutError on timeout. None means wait indefinitely.
args: Positional arguments to pass to the worker method
kwargs: Keyword arguments to pass to the worker method
Returns:
List of results from each worker
"""
start_time = time.monotonic()
kwargs = kwargs or {}
# NOTE: If the args are heterogeneous, then we pack them into a list,
# and unpack them in the method of every worker, because every worker
# knows their own rank.
try:
if isinstance(method, str):
send_method = method
......
......@@ -689,6 +689,9 @@ class GPUModelRunner:
encoder_outputs.append(encoder_output[start_idx:end_idx])
return encoder_outputs
def get_model(self) -> nn.Module:
return self.model
@torch.inference_mode()
def execute_model(
self,
......
......@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Optional
import torch
import torch.distributed
import torch.nn as nn
import vllm.envs as envs
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig
......@@ -176,6 +177,9 @@ class Worker:
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
@torch.inference_mode()
def execute_model(
self,
......
......@@ -509,6 +509,9 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
)
self.model = self.lora_manager.create_lora_manager(self.model)
def get_model(self) -> nn.Module:
return self.model
def _prepare_model_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
......
......@@ -21,6 +21,7 @@ from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple,
import habana_frameworks.torch as htorch
import habana_frameworks.torch.internal.bridge_config as bc
import torch
import torch.nn as nn
from vllm_hpu_extension.ops import LoraMask as LoraMask
from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
HabanaMemoryProfiler, format_bytes)
......@@ -676,6 +677,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
msg = f"Loading model weights took in total {m.get_summary_string()}"
logger.info(msg)
def get_model(self) -> nn.Module:
return self.model
def _use_graphs(self, batch_size, seq_len, is_prompt):
if self.enforce_eager:
return False
......
......@@ -1176,6 +1176,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend=backend)
def get_model(self) -> nn.Module:
return self.model
def save_sharded_state(
self,
path: str,
......
......@@ -7,6 +7,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
Optional, Type, TypeVar)
import torch
import torch.nn as nn
from torch import is_tensor
from vllm.config import VllmConfig
......@@ -264,6 +265,10 @@ class ModelRunnerBase(ABC, Generic[T]):
"""
raise NotImplementedError
@abstractmethod
def get_model(self) -> nn.Module:
raise NotImplementedError
def execute_model(
self,
model_input: T,
......@@ -297,9 +302,9 @@ class ModelRunnerWrapperBase:
def __init__(
self,
moderl_runner: ModelRunnerBase,
model_runner: ModelRunnerBase,
) -> None:
self.model_runner: ModelRunnerBase = moderl_runner
self.model_runner: ModelRunnerBase = model_runner
def __getattr__(self, attr):
return getattr(self.model_runner, attr)
......@@ -113,6 +113,9 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
raise NotImplementedError(
"Supports only Transformer-NeuronX based models.")
def get_model(self) -> nn.Module:
return self.model
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
......
......@@ -84,6 +84,9 @@ class OpenVINOModelRunner(ModelRunnerBase):
kv_cache_dtype=self.kv_cache_dtype,
ov_core=self.ov_core)
def get_model(self) -> nn.Module:
return self.model
def _prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
......
......@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple
import openvino as ov
import torch
import torch.distributed
import torch.nn as nn
import vllm.envs as envs
from vllm.attention import get_attn_backend
......@@ -362,6 +363,9 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
) -> None:
self.cache_engine.copy(blocks_to_copy) # type: ignore
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
@torch.inference_mode()
def execute_model(
self,
......
......@@ -158,6 +158,9 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
fullgraph=True,
dynamic=False)
def get_model(self) -> nn.Module:
return self.model.model
def _dummy_run(
self,
batch_size: int,
......
......@@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
import cloudpickle
import torch
import torch.nn as nn
from vllm.config import ObservabilityConfig, VllmConfig
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
......@@ -90,6 +91,11 @@ class WorkerBase(ABC):
if output is None:
return None
@abstractmethod
def get_model(self) -> nn.Module:
raise NotImplementedError
@abstractmethod
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
......@@ -147,6 +153,9 @@ class DelegateWorkerBase(WorkerBase):
num_cpu_blocks: int) -> None:
self.worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def get_model(self) -> nn.Module:
return self.worker.get_model()
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
......@@ -363,6 +372,9 @@ class LocalOrDistributedWorkerBase(WorkerBase):
else:
return self._get_worker_input_from_broadcast()
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
......
......@@ -416,6 +416,9 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
logger.info("Loading model weights took %.4f GB",
self.model_memory_usage / float(2**30))
def get_model(self) -> nn.Module:
return self.model
@property
def vocab_size(self) -> int:
return self.model_config.get_vocab_size()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment