Unverified Commit 59a0192f authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Interface for accessing model from `VllmRunner` (#10353)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 83609791
from typing import List, Optional, Set, Tuple from typing import List, Optional, Set, Tuple
import torch import torch
import torch.nn as nn
from vllm.distributed.parallel_state import (get_tp_group, from vllm.distributed.parallel_state import (get_tp_group,
init_model_parallel_group, init_model_parallel_group,
...@@ -15,6 +16,10 @@ from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase ...@@ -15,6 +16,10 @@ from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
logger = init_logger(__name__) logger = init_logger(__name__)
class _DummyModel(nn.Module):
pass
class SmallerTpProposerWorker(ProposerWorkerBase): class SmallerTpProposerWorker(ProposerWorkerBase):
"""Class which allows a speculative draft model to run with smaller tensor """Class which allows a speculative draft model to run with smaller tensor
parallel degree than target model. parallel degree than target model.
...@@ -139,6 +144,13 @@ class SmallerTpProposerWorker(ProposerWorkerBase): ...@@ -139,6 +144,13 @@ class SmallerTpProposerWorker(ProposerWorkerBase):
return self._worker.get_spec_proposals( return self._worker.get_spec_proposals(
execute_model_req, seq_ids_with_bonus_token_in_last_step) execute_model_req, seq_ids_with_bonus_token_in_last_step)
def get_model(self) -> nn.Module:
if self._is_dummy:
return _DummyModel()
with self._patch_tensor_parallel_group():
return self._worker.get_model()
def execute_model( def execute_model(
self, self,
execute_model_req: Optional[ExecuteModelRequest] = None execute_model_req: Optional[ExecuteModelRequest] = None
......
...@@ -4,6 +4,7 @@ from functools import cached_property ...@@ -4,6 +4,7 @@ from functools import cached_property
from typing import Any, Dict, List, Optional, Set, Tuple, Type from typing import Any, Dict, List, Optional, Set, Tuple, Type
import torch import torch
import torch.nn as nn
from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
from vllm.distributed.communication_op import broadcast_tensor_dict from vllm.distributed.communication_op import broadcast_tensor_dict
...@@ -403,6 +404,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -403,6 +404,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks) num_cpu_blocks=num_cpu_blocks)
def get_model(self) -> nn.Module:
return self.scorer_worker.get_model()
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
......
...@@ -94,22 +94,12 @@ class MultiprocExecutor(Executor): ...@@ -94,22 +94,12 @@ class MultiprocExecutor(Executor):
timeout: Optional[float] = None, timeout: Optional[float] = None,
args: Tuple = (), args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]: kwargs: Optional[Dict] = None) -> List[Any]:
"""
Execute an RPC call on workers.
Args:
method: Name of the worker method to execute
timeout: Maximum time in seconds to wait for execution. Rases a
TimeoutError on timeout. None means wait indefinitely.
args: Positional arguments to pass to the worker method
kwargs: Keyword arguments to pass to the worker method
Returns:
List of results from each worker
"""
start_time = time.monotonic() start_time = time.monotonic()
kwargs = kwargs or {} kwargs = kwargs or {}
# NOTE: If the args are heterogeneous, then we pack them into a list,
# and unpack them in the method of every worker, because every worker
# knows their own rank.
try: try:
if isinstance(method, str): if isinstance(method, str):
send_method = method send_method = method
......
...@@ -689,6 +689,9 @@ class GPUModelRunner: ...@@ -689,6 +689,9 @@ class GPUModelRunner:
encoder_outputs.append(encoder_output[start_idx:end_idx]) encoder_outputs.append(encoder_output[start_idx:end_idx])
return encoder_outputs return encoder_outputs
def get_model(self) -> nn.Module:
return self.model
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
......
...@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Optional ...@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Optional
import torch import torch
import torch.distributed import torch.distributed
import torch.nn as nn
import vllm.envs as envs import vllm.envs as envs
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig
...@@ -176,6 +177,9 @@ class Worker: ...@@ -176,6 +177,9 @@ class Worker:
# the model initialization and profiling. # the model initialization and profiling.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
......
...@@ -509,6 +509,9 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): ...@@ -509,6 +509,9 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
) )
self.model = self.lora_manager.create_lora_manager(self.model) self.model = self.lora_manager.create_lora_manager(self.model)
def get_model(self) -> nn.Module:
return self.model
def _prepare_model_input_tensors( def _prepare_model_input_tensors(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
......
...@@ -21,6 +21,7 @@ from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, ...@@ -21,6 +21,7 @@ from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple,
import habana_frameworks.torch as htorch import habana_frameworks.torch as htorch
import habana_frameworks.torch.internal.bridge_config as bc import habana_frameworks.torch.internal.bridge_config as bc
import torch import torch
import torch.nn as nn
from vllm_hpu_extension.ops import LoraMask as LoraMask from vllm_hpu_extension.ops import LoraMask as LoraMask
from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler, from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
HabanaMemoryProfiler, format_bytes) HabanaMemoryProfiler, format_bytes)
...@@ -676,6 +677,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -676,6 +677,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
msg = f"Loading model weights took in total {m.get_summary_string()}" msg = f"Loading model weights took in total {m.get_summary_string()}"
logger.info(msg) logger.info(msg)
def get_model(self) -> nn.Module:
return self.model
def _use_graphs(self, batch_size, seq_len, is_prompt): def _use_graphs(self, batch_size, seq_len, is_prompt):
if self.enforce_eager: if self.enforce_eager:
return False return False
......
...@@ -1176,6 +1176,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1176,6 +1176,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend=backend) backend=backend)
def get_model(self) -> nn.Module:
return self.model
def save_sharded_state( def save_sharded_state(
self, self,
path: str, path: str,
......
...@@ -7,6 +7,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, ...@@ -7,6 +7,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
Optional, Type, TypeVar) Optional, Type, TypeVar)
import torch import torch
import torch.nn as nn
from torch import is_tensor from torch import is_tensor
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -264,6 +265,10 @@ class ModelRunnerBase(ABC, Generic[T]): ...@@ -264,6 +265,10 @@ class ModelRunnerBase(ABC, Generic[T]):
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def get_model(self) -> nn.Module:
raise NotImplementedError
def execute_model( def execute_model(
self, self,
model_input: T, model_input: T,
...@@ -297,9 +302,9 @@ class ModelRunnerWrapperBase: ...@@ -297,9 +302,9 @@ class ModelRunnerWrapperBase:
def __init__( def __init__(
self, self,
moderl_runner: ModelRunnerBase, model_runner: ModelRunnerBase,
) -> None: ) -> None:
self.model_runner: ModelRunnerBase = moderl_runner self.model_runner: ModelRunnerBase = model_runner
def __getattr__(self, attr): def __getattr__(self, attr):
return getattr(self.model_runner, attr) return getattr(self.model_runner, attr)
...@@ -113,6 +113,9 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): ...@@ -113,6 +113,9 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
raise NotImplementedError( raise NotImplementedError(
"Supports only Transformer-NeuronX based models.") "Supports only Transformer-NeuronX based models.")
def get_model(self) -> nn.Module:
return self.model
def _prepare_prompt( def _prepare_prompt(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
......
...@@ -84,6 +84,9 @@ class OpenVINOModelRunner(ModelRunnerBase): ...@@ -84,6 +84,9 @@ class OpenVINOModelRunner(ModelRunnerBase):
kv_cache_dtype=self.kv_cache_dtype, kv_cache_dtype=self.kv_cache_dtype,
ov_core=self.ov_core) ov_core=self.ov_core)
def get_model(self) -> nn.Module:
return self.model
def _prepare_model_input( def _prepare_model_input(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
......
...@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple ...@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple
import openvino as ov import openvino as ov
import torch import torch
import torch.distributed import torch.distributed
import torch.nn as nn
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import get_attn_backend from vllm.attention import get_attn_backend
...@@ -362,6 +363,9 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase): ...@@ -362,6 +363,9 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
) -> None: ) -> None:
self.cache_engine.copy(blocks_to_copy) # type: ignore self.cache_engine.copy(blocks_to_copy) # type: ignore
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
......
...@@ -158,6 +158,9 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -158,6 +158,9 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
fullgraph=True, fullgraph=True,
dynamic=False) dynamic=False)
def get_model(self) -> nn.Module:
return self.model.model
def _dummy_run( def _dummy_run(
self, self,
batch_size: int, batch_size: int,
......
...@@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union ...@@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
import cloudpickle import cloudpickle
import torch import torch
import torch.nn as nn
from vllm.config import ObservabilityConfig, VllmConfig from vllm.config import ObservabilityConfig, VllmConfig
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
...@@ -90,6 +91,11 @@ class WorkerBase(ABC): ...@@ -90,6 +91,11 @@ class WorkerBase(ABC):
if output is None: if output is None:
return None return None
@abstractmethod
def get_model(self) -> nn.Module:
raise NotImplementedError
@abstractmethod
def execute_model( def execute_model(
self, self,
execute_model_req: Optional[ExecuteModelRequest] = None execute_model_req: Optional[ExecuteModelRequest] = None
...@@ -147,6 +153,9 @@ class DelegateWorkerBase(WorkerBase): ...@@ -147,6 +153,9 @@ class DelegateWorkerBase(WorkerBase):
num_cpu_blocks: int) -> None: num_cpu_blocks: int) -> None:
self.worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) self.worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def get_model(self) -> nn.Module:
return self.worker.get_model()
def execute_model( def execute_model(
self, self,
execute_model_req: Optional[ExecuteModelRequest] = None execute_model_req: Optional[ExecuteModelRequest] = None
...@@ -363,6 +372,9 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -363,6 +372,9 @@ class LocalOrDistributedWorkerBase(WorkerBase):
else: else:
return self._get_worker_input_from_broadcast() return self._get_worker_input_from_broadcast()
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
def execute_model( def execute_model(
self, self,
execute_model_req: Optional[ExecuteModelRequest] = None, execute_model_req: Optional[ExecuteModelRequest] = None,
......
...@@ -416,6 +416,9 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): ...@@ -416,6 +416,9 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
logger.info("Loading model weights took %.4f GB", logger.info("Loading model weights took %.4f GB",
self.model_memory_usage / float(2**30)) self.model_memory_usage / float(2**30))
def get_model(self) -> nn.Module:
return self.model
@property @property
def vocab_size(self) -> int: def vocab_size(self) -> int:
return self.model_config.get_vocab_size() return self.model_config.get_vocab_size()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment