Unverified Commit e8937954 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[2/N] executor pass the complete config to worker/modelrunner (#9938)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
Co-authored-by: default avatarNick Hill <nhill@redhat.com>
parent 1d4cfe2b
import copy
from collections import defaultdict from collections import defaultdict
from functools import cached_property from functools import cached_property
from typing import Any, Dict, List, Optional, Set, Tuple, Type from typing import Any, Dict, List, Optional, Set, Tuple, Type
import torch import torch
from vllm.config import ParallelConfig, SpeculativeConfig from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
from vllm.distributed.communication_op import broadcast_tensor_dict from vllm.distributed.communication_op import broadcast_tensor_dict
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.rejection_sampler import RejectionSampler
...@@ -45,8 +46,8 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": ...@@ -45,8 +46,8 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
"""Helper method that is the entrypoint for Executors which use """Helper method that is the entrypoint for Executors which use
WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config. WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
""" """
assert "speculative_config" in kwargs vllm_config: VllmConfig = kwargs.get("vllm_config")
speculative_config: SpeculativeConfig = kwargs.get("speculative_config") speculative_config: SpeculativeConfig = vllm_config.speculative_config
assert speculative_config is not None assert speculative_config is not None
draft_worker_kwargs = kwargs.copy() draft_worker_kwargs = kwargs.copy()
...@@ -58,14 +59,16 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": ...@@ -58,14 +59,16 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
target_worker.model_runner.disable_logprobs =\ target_worker.model_runner.disable_logprobs =\
speculative_config.disable_logprobs speculative_config.disable_logprobs
draft_worker_config = copy.deepcopy(vllm_config)
draft_worker_config.model_config = speculative_config.draft_model_config
draft_worker_config.parallel_config = speculative_config.draft_parallel_config # noqa
# TODO allow draft-model specific load config.
# Override draft-model specific worker args. # Override draft-model specific worker args.
draft_worker_kwargs.update( draft_worker_kwargs.update(
model_config=speculative_config.draft_model_config, vllm_config=draft_worker_config,
parallel_config=speculative_config.draft_parallel_config,
ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max, ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max,
ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min, ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min,
# TODO allow draft-model specific load config.
#load_config=load_config,
) )
spec_decode_worker = SpecDecodeWorker.create_worker( spec_decode_worker = SpecDecodeWorker.create_worker(
...@@ -134,29 +137,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -134,29 +137,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
draft_worker_kwargs.pop("ngram_prompt_lookup_max")) draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
ngram_prompt_lookup_min = ( ngram_prompt_lookup_min = (
draft_worker_kwargs.pop("ngram_prompt_lookup_min")) draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
draft_model_config = draft_worker_kwargs["vllm_config"].model_config
draft_parallel_config: ParallelConfig = draft_worker_kwargs[
'vllm_config'].parallel_config
if ngram_prompt_lookup_max > 0: if ngram_prompt_lookup_max > 0:
proposer_worker = NGramWorker(**draft_worker_kwargs) proposer_worker = NGramWorker(**draft_worker_kwargs)
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min, proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
ngram_prompt_lookup_max) ngram_prompt_lookup_max)
else: else:
draft_parallel_config: ParallelConfig = draft_worker_kwargs[
'parallel_config']
draft_tp = draft_parallel_config.tensor_parallel_size draft_tp = draft_parallel_config.tensor_parallel_size
target_tp = scorer_worker.parallel_config.tensor_parallel_size target_tp = scorer_worker.parallel_config.tensor_parallel_size
if draft_worker_kwargs[ if draft_model_config.hf_config.model_type == "mlp_speculator":
"model_config"].hf_config.model_type == "mlp_speculator":
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs) proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
elif draft_worker_kwargs[ elif draft_model_config.hf_config.model_type == "medusa":
"model_config"].hf_config.model_type == "medusa":
proposer_worker = MedusaWorker(**draft_worker_kwargs) proposer_worker = MedusaWorker(**draft_worker_kwargs)
else: else:
if draft_tp == 1: if draft_tp == 1:
draft_worker_kwargs[ draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner "model_runner_cls"] = TP1DraftModelRunner
else: else:
if draft_worker_kwargs[ if draft_model_config.hf_config.model_type == "eagle":
"model_config"].hf_config.model_type == "eagle":
raise NotImplementedError( raise NotImplementedError(
"EAGLE does not support TP > 1 yet") "EAGLE does not support TP > 1 yet")
...@@ -190,8 +191,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -190,8 +191,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"[Speculative Decoding] Disabling MQA scorer as the " "[Speculative Decoding] Disabling MQA scorer as the "
"MQA is only available with flash attn backend.") "MQA is only available with flash attn backend.")
if "model_config" in draft_worker_kwargs and \ if draft_model_config and \
draft_worker_kwargs["model_config"].max_model_len < \ draft_model_config.max_model_len < \
scorer_worker.model_config.max_model_len: scorer_worker.model_config.max_model_len:
disable_mqa_scorer = True disable_mqa_scorer = True
logger.info( logger.info(
......
from typing import List, Optional from typing import List, Optional
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import VllmConfig
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.sequence import SequenceGroupMetadata from vllm.sequence import SequenceGroupMetadata
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner) ModelRunner)
...@@ -20,35 +18,21 @@ class TargetModelRunner(ModelRunner): ...@@ -20,35 +18,21 @@ class TargetModelRunner(ModelRunner):
requested or not. requested or not.
""" """
def __init__(self, def __init__(
model_config: ModelConfig, self,
parallel_config: ParallelConfig, vllm_config: VllmConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
return_hidden_states: bool = False, return_hidden_states: bool = False,
observability_config: Optional[ObservabilityConfig] = None): ):
# An internal boolean member variable to indicate if token log # An internal boolean member variable to indicate if token log
# probabilities are needed or not. # probabilities are needed or not.
self.disable_logprobs = True self.disable_logprobs = True
super().__init__( super().__init__(
model_config=model_config, vllm_config=vllm_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
cache_config=cache_config,
load_config=load_config,
lora_config=lora_config,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker, is_driver_worker=is_driver_worker,
prompt_adapter_config=prompt_adapter_config,
return_hidden_states=return_hidden_states, return_hidden_states=return_hidden_states,
observability_config=observability_config,
) )
def prepare_model_input( def prepare_model_input(
......
...@@ -2,8 +2,9 @@ import time ...@@ -2,8 +2,9 @@ import time
from typing import (Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type, from typing import (Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type,
Union) Union)
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig, SchedulerConfig) ObservabilityConfig, ParallelConfig, SchedulerConfig,
VllmConfig)
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics_types import StatLoggerBase from vllm.engine.metrics_types import StatLoggerBase
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
...@@ -32,7 +33,7 @@ class LLMEngine: ...@@ -32,7 +33,7 @@ class LLMEngine:
def __init__( def __init__(
self, self,
vllm_config: EngineConfig, vllm_config: VllmConfig,
executor_class: Type[GPUExecutor], executor_class: Type[GPUExecutor],
log_stats: bool, log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
...@@ -477,7 +478,7 @@ class LLMEngine: ...@@ -477,7 +478,7 @@ class LLMEngine:
return self.lora_config return self.lora_config
@classmethod @classmethod
def _get_executor_cls(cls, engine_config: EngineConfig): def _get_executor_cls(cls, engine_config: VllmConfig):
return GPUExecutor return GPUExecutor
def is_tracing_enabled(self) -> bool: def is_tracing_enabled(self) -> bool:
......
...@@ -56,19 +56,10 @@ class GPUExecutor: ...@@ -56,19 +56,10 @@ class GPUExecutor:
distributed_init_method = get_distributed_init_method( distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port()) get_ip(), get_open_port())
return Worker( return Worker(
model_config=self.model_config, vllm_config=self.vllm_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
local_rank=local_rank, local_rank=local_rank,
rank=rank, rank=rank,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
speculative_config=self.speculative_config,
prompt_adapter_config=self.prompt_adapter_config,
observability_config=self.observability_config,
) )
def determine_num_available_blocks(self) -> Tuple[int, int]: def determine_num_available_blocks(self) -> Tuple[int, int]:
......
...@@ -7,9 +7,7 @@ import torch ...@@ -7,9 +7,7 @@ import torch
import torch.distributed import torch.distributed
import torch.nn as nn import torch.nn as nn
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import VllmConfig
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
...@@ -33,26 +31,25 @@ class GPUModelRunner: ...@@ -33,26 +31,25 @@ class GPUModelRunner:
def __init__( def __init__(
self, self,
model_config: ModelConfig, vllm_config: VllmConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
observability_config: Optional[ObservabilityConfig] = None,
): ):
self.model_config = model_config # TODO: use ModelRunnerBase.__init__(self, vllm_config=vllm_config)
self.parallel_config = parallel_config self.vllm_config = vllm_config
self.scheduler_config = scheduler_config self.model_config = vllm_config.model_config
self.device_config = device_config self.cache_config = vllm_config.cache_config
self.cache_config = cache_config self.lora_config = vllm_config.lora_config
self.lora_config = lora_config self.load_config = vllm_config.load_config
self.load_config = load_config self.parallel_config = vllm_config.parallel_config
self.prompt_adapter_config = prompt_adapter_config self.scheduler_config = vllm_config.scheduler_config
self.observability_config = observability_config self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
model_config = self.model_config
cache_config = self.cache_config
scheduler_config = self.scheduler_config
parallel_config = self.parallel_config
self.device = self.device_config.device self.device = self.device_config.device
self.pin_memory = is_pin_memory_available() self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype self.dtype = self.model_config.dtype
......
...@@ -6,10 +6,7 @@ from typing import TYPE_CHECKING, Optional, Tuple ...@@ -6,10 +6,7 @@ from typing import TYPE_CHECKING, Optional, Tuple
import torch import torch
import torch.distributed import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment, init_distributed_environment,
set_custom_all_reduce) set_custom_all_reduce)
...@@ -30,48 +27,35 @@ class Worker: ...@@ -30,48 +27,35 @@ class Worker:
def __init__( def __init__(
self, self,
model_config: ModelConfig, vllm_config: VllmConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
local_rank: int, local_rank: int,
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: str,
speculative_config: Optional[SpeculativeConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
observability_config: Optional[ObservabilityConfig] = None,
): ):
self.model_config = model_config
self.parallel_config = parallel_config # TODO: use WorkerBase.__init__(self, vllm_config=vllm_config)
self.scheduler_config = scheduler_config self.vllm_config = vllm_config
self.device_config = device_config self.model_config = vllm_config.model_config
self.cache_config = cache_config self.cache_config = vllm_config.cache_config
self.load_config = load_config self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
self.local_rank = local_rank self.local_rank = local_rank
self.rank = rank self.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.speculative_config = speculative_config
self.prompt_adapter_config = prompt_adapter_config
self.observability_config = observability_config
if self.model_config.trust_remote_code: if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing # note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules from vllm.utils import init_cached_hf_modules
init_cached_hf_modules() init_cached_hf_modules()
self.model_runner = GPUModelRunner( self.model_runner = GPUModelRunner(vllm_config)
model_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config,
lora_config=lora_config,
)
def initialize(self): def initialize(self):
if self.device_config.device.type == "cuda": if self.device_config.device.type == "cuda":
......
...@@ -8,9 +8,7 @@ import torch ...@@ -8,9 +8,7 @@ import torch
from torch import nn from torch import nn
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import VllmConfig
ModelConfig, ParallelConfig, PromptAdapterConfig,
SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
...@@ -412,29 +410,18 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]): ...@@ -412,29 +410,18 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
def __init__( def __init__(
self, self,
model_config: ModelConfig, vllm_config: VllmConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False, is_driver_worker: bool = False,
*args, *args,
**kwargs, **kwargs,
): ):
self.model_config = model_config ModelRunnerBase.__init__(self, vllm_config)
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
# Currently, CPU worker doesn't support chunked prefill. # Currently, CPU worker doesn't support chunked prefill.
assert self.scheduler_config.chunked_prefill_enabled is False assert self.scheduler_config.chunked_prefill_enabled is False
self.device_config = device_config model_config = self.model_config
self.cache_config = cache_config cache_config = self.cache_config
self.lora_config = lora_config
self.prompt_adapter_config = prompt_adapter_config
self.load_config = load_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
self.device = self.device_config.device self.device = self.device_config.device
......
...@@ -6,9 +6,8 @@ import torch.distributed ...@@ -6,9 +6,8 @@ import torch.distributed
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import get_attn_backend from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ModelConfig, ParallelConfig, PromptAdapterConfig, ParallelConfig, VllmConfig)
SchedulerConfig)
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -18,7 +17,8 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE ...@@ -18,7 +17,8 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
from vllm.worker.cpu_model_runner import CPUModelRunner from vllm.worker.cpu_model_runner import CPUModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerInput) LoraNotSupportedWorkerBase, WorkerBase,
WorkerInput)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -121,31 +121,19 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -121,31 +121,19 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
def __init__( def __init__(
self, self,
model_config: ModelConfig, vllm_config: VllmConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
local_rank: int, local_rank: int,
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False, is_driver_worker: bool = False,
) -> None: ) -> None:
self.model_config = model_config WorkerBase.__init__(self, vllm_config=vllm_config)
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.load_config = load_config
self.local_rank = local_rank self.local_rank = local_rank
self.rank = rank self.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.prompt_adapter_config = prompt_adapter_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
if self.is_driver_worker: if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0." assert self.rank == 0, "The driver worker must have rank 0."
...@@ -166,15 +154,8 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -166,15 +154,8 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
if self._is_encoder_decoder_model(): if self._is_encoder_decoder_model():
ModelRunnerClass = CPUEncoderDecoderModelRunner ModelRunnerClass = CPUEncoderDecoderModelRunner
self.model_runner: CPUModelRunner = ModelRunnerClass( self.model_runner: CPUModelRunner = ModelRunnerClass(
model_config, vllm_config=vllm_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config=self.load_config,
lora_config=self.lora_config,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
prompt_adapter_config=self.prompt_adapter_config,
is_driver_worker=is_driver_worker) is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
# initialize_cache. # initialize_cache.
......
...@@ -3,9 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union ...@@ -3,9 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
import torch import torch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import VllmConfig
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -36,29 +34,13 @@ class EmbeddingModelRunner( ...@@ -36,29 +34,13 @@ class EmbeddingModelRunner(
def __init__( def __init__(
self, self,
model_config: ModelConfig, vllm_config: VllmConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
observability_config: Optional[ObservabilityConfig] = None,
): ):
super().__init__(model_config, super().__init__(vllm_config=vllm_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config,
lora_config=lora_config,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker, is_driver_worker=is_driver_worker)
prompt_adapter_config=prompt_adapter_config,
observability_config=observability_config)
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
......
...@@ -11,9 +11,7 @@ from vllm.attention.backends.utils import PAD_SLOT_ID ...@@ -11,9 +11,7 @@ from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.attention.selector import (_Backend, get_env_variable_attn_backend, from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
get_global_forced_attn_backend, get_global_forced_attn_backend,
global_force_attn_backend) global_force_attn_backend)
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import ModelConfig, VllmConfig
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -85,17 +83,9 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): ...@@ -85,17 +83,9 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
def __init__( def __init__(
self, self,
model_config: ModelConfig, vllm_config: VllmConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
observability_config: Optional[ObservabilityConfig] = None,
input_registry: InputRegistry = INPUT_REGISTRY, input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
): ):
...@@ -107,15 +97,10 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): ...@@ -107,15 +97,10 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
models) but these arguments are present here for compatibility with models) but these arguments are present here for compatibility with
the base-class constructor. the base-class constructor.
''' '''
self._maybe_force_supported_attention_backend(model_config) self._maybe_force_supported_attention_backend(vllm_config.model_config)
super().__init__( super().__init__(
model_config, vllm_config=vllm_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config,
lora_config=None,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker, is_driver_worker=is_driver_worker,
) )
......
...@@ -20,9 +20,7 @@ from vllm.attention.backends.abstract import AttentionState ...@@ -20,9 +20,7 @@ from vllm.attention.backends.abstract import AttentionState
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
from vllm.compilation.compile_context import set_compile_context from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.levels import CompilationLevel from vllm.compilation.levels import CompilationLevel
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import VllmConfig
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.distributed.parallel_state import graph_capture from vllm.distributed.parallel_state import graph_capture
...@@ -955,32 +953,20 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -955,32 +953,20 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
def __init__( def __init__(
self, self,
model_config: ModelConfig, vllm_config: VllmConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
return_hidden_states: bool = False, return_hidden_states: bool = False,
observability_config: Optional[ObservabilityConfig] = None,
input_registry: InputRegistry = INPUT_REGISTRY, input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
): ):
self.model_config = model_config
self.parallel_config = parallel_config ModelRunnerBase.__init__(self, vllm_config)
self.scheduler_config = scheduler_config model_config = self.model_config
self.device_config = device_config cache_config = self.cache_config
self.cache_config = cache_config
self.lora_config = lora_config
self.load_config = load_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
self.prompt_adapter_config = prompt_adapter_config
self.return_hidden_states = return_hidden_states self.return_hidden_states = return_hidden_states
self.observability_config = observability_config
self.device = self.device_config.device self.device = self.device_config.device
self.pin_memory = is_pin_memory_available() self.pin_memory = is_pin_memory_available()
......
...@@ -9,6 +9,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, ...@@ -9,6 +9,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
import torch import torch
from torch import is_tensor from torch import is_tensor
from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -220,6 +221,22 @@ class ModelRunnerBase(ABC, Generic[T]): ...@@ -220,6 +221,22 @@ class ModelRunnerBase(ABC, Generic[T]):
ModelRunnerInputBase subclass. ModelRunnerInputBase subclass.
""" """
def __init__(
self,
vllm_config: VllmConfig,
) -> None:
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
# Map of request_id -> generator used for seeded random sampling # Map of request_id -> generator used for seeded random sampling
generators: Dict[str, torch.Generator] = {} generators: Dict[str, torch.Generator] = {}
......
...@@ -304,6 +304,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): ...@@ -304,6 +304,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
# mypy: enable-error-code=type-var # mypy: enable-error-code=type-var
def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs): def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# Check attention backend support. # Check attention backend support.
......
...@@ -27,17 +27,9 @@ class MultiStepWorker(Worker): ...@@ -27,17 +27,9 @@ class MultiStepWorker(Worker):
# for multi-step model, wrap the model runner with MultiStepModelRunner # for multi-step model, wrap the model runner with MultiStepModelRunner
self.model_runner = MultiStepModelRunner( self.model_runner = MultiStepModelRunner(
base_model_runner, base_model_runner,
base_model_runner.model_config, vllm_config=base_model_runner.vllm_config,
base_model_runner.parallel_config,
base_model_runner.scheduler_config,
base_model_runner.device_config,
base_model_runner.cache_config,
load_config=base_model_runner.load_config,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype, kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=base_model_runner.is_driver_worker, is_driver_worker=base_model_runner.is_driver_worker,
prompt_adapter_config=base_model_runner.prompt_adapter_config,
observability_config=base_model_runner.observability_config,
) )
pipeline_parallel_size = self.parallel_config.pipeline_parallel_size pipeline_parallel_size = self.parallel_config.pipeline_parallel_size
......
...@@ -7,8 +7,7 @@ import torch ...@@ -7,8 +7,7 @@ import torch
from torch import nn from torch import nn
from transformers_neuronx.config import GenerationConfig from transformers_neuronx.config import GenerationConfig
from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, from vllm.config import VllmConfig
SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
...@@ -57,20 +56,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): ...@@ -57,20 +56,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
def __init__( def __init__(
self, self,
model_config: ModelConfig, vllm_config: VllmConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
): ):
self.model_config = model_config ModelRunnerBase.__init__(self, vllm_config)
self.parallel_config = parallel_config model_config = self.model_config
self.scheduler_config = scheduler_config
if model_config is not None and model_config.get_sliding_window(): if model_config is not None and model_config.get_sliding_window():
logger.warning("Sliding window is not supported on Neuron. " logger.warning("Sliding window is not supported on Neuron. "
"The model will run without sliding window.") "The model will run without sliding window.")
self.device_config = (device_config
if device_config is not None else DeviceConfig())
self.device = self.device_config.device self.device = self.device_config.device
self.pin_memory = is_pin_memory_available() self.pin_memory = is_pin_memory_available()
......
...@@ -4,15 +4,15 @@ from typing import List, Optional, Tuple ...@@ -4,15 +4,15 @@ from typing import List, Optional, Tuple
import torch import torch
import torch.distributed import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, from vllm.config import VllmConfig
ParallelConfig, SchedulerConfig)
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.worker.neuron_model_runner import NeuronModelRunner from vllm.worker.neuron_model_runner import NeuronModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerInput) LoraNotSupportedWorkerBase, WorkerBase,
WorkerInput)
class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
...@@ -21,20 +21,12 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -21,20 +21,12 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
def __init__( def __init__(
self, self,
model_config: ModelConfig, vllm_config: VllmConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
local_rank: int, local_rank: int,
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: str,
) -> None: ) -> None:
self.model_config = model_config WorkerBase.__init__(self, vllm_config=vllm_config)
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.local_rank = local_rank self.local_rank = local_rank
self.rank = rank self.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
...@@ -44,7 +36,7 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -44,7 +36,7 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
init_cached_hf_modules() init_cached_hf_modules()
self.model_runner: NeuronModelRunner = NeuronModelRunner( self.model_runner: NeuronModelRunner = NeuronModelRunner(
model_config, parallel_config, scheduler_config, device_config) vllm_config=vllm_config)
self.is_driver_worker = True self.is_driver_worker = True
def init_device(self) -> None: def init_device(self) -> None:
......
...@@ -7,9 +7,7 @@ from torch import nn ...@@ -7,9 +7,7 @@ from torch import nn
from vllm.attention import get_attn_backend from vllm.attention import get_attn_backend
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import VllmConfig
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
...@@ -17,6 +15,7 @@ from vllm.model_executor.model_loader.openvino import get_model ...@@ -17,6 +15,7 @@ from vllm.model_executor.model_loader.openvino import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs, MultiModalPlaceholderMap) MultiModalInputs, MultiModalPlaceholderMap)
from vllm.sequence import SequenceGroupMetadata from vllm.sequence import SequenceGroupMetadata
from vllm.worker.model_runner_base import ModelRunnerBase
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -39,33 +38,21 @@ class ModelInput(NamedTuple): ...@@ -39,33 +38,21 @@ class ModelInput(NamedTuple):
multi_modal_kwargs={}) multi_modal_kwargs={})
class OpenVINOModelRunner: class OpenVINOModelRunner(ModelRunnerBase):
def __init__( def __init__(
self, self,
ov_core: ov.Core, ov_core: ov.Core,
model_config: ModelConfig, vllm_config: VllmConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
*args, *args,
**kwargs, **kwargs,
): ):
self.ov_core = ov_core self.ov_core = ov_core
self.model_config = model_config ModelRunnerBase.__init__(self, vllm_config=vllm_config)
self.parallel_config = parallel_config cache_config = self.cache_config
self.scheduler_config = scheduler_config model_config = self.model_config
self.device_config = device_config
self.cache_config = cache_config
self.lora_config = lora_config
self.multimodal_config = multimodal_config
self.load_config = load_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
self.device = self.device_config.device self.device = self.device_config.device
...@@ -369,3 +356,9 @@ class OpenVINOModelRunner: ...@@ -369,3 +356,9 @@ class OpenVINOModelRunner:
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
) )
return output return output
def prepare_model_input(self, *args, **kwargs):
raise NotImplementedError
def make_model_input_from_broadcasted_tensor_dict(self, *args, **kwargs):
raise NotImplementedError
...@@ -7,9 +7,8 @@ import torch.distributed ...@@ -7,9 +7,8 @@ import torch.distributed
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import get_attn_backend from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ModelConfig, MultiModalConfig, ParallelConfig, ParallelConfig, VllmConfig)
SchedulerConfig)
from vllm.distributed import (broadcast_tensor_dict, from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized, ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
...@@ -22,7 +21,7 @@ from vllm.platforms import current_platform ...@@ -22,7 +21,7 @@ from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.worker.openvino_model_runner import OpenVINOModelRunner from vllm.worker.openvino_model_runner import OpenVINOModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -212,33 +211,19 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase): ...@@ -212,33 +211,19 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
def __init__( def __init__(
self, self,
ov_core: ov.Core, ov_core: ov.Core,
model_config: ModelConfig, vllm_config: VllmConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
local_rank: int, local_rank: int,
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
kv_cache_dtype: Optional[ov.Type] = ov.Type.undefined, kv_cache_dtype: Optional[ov.Type] = ov.Type.undefined,
is_driver_worker: bool = False, is_driver_worker: bool = False,
) -> None: ) -> None:
self.ov_core = ov_core self.ov_core = ov_core
self.model_config = model_config WorkerBase.__init__(self, vllm_config)
self.parallel_config = parallel_config
self.parallel_config.rank = rank self.parallel_config.rank = rank
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.load_config = load_config
self.local_rank = local_rank self.local_rank = local_rank
self.rank = rank self.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.multimodal_config = multimodal_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
if self.is_driver_worker: if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0." assert self.rank == 0, "The driver worker must have rank 0."
...@@ -250,14 +235,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase): ...@@ -250,14 +235,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
init_cached_hf_modules() init_cached_hf_modules()
self.model_runner = OpenVINOModelRunner( self.model_runner = OpenVINOModelRunner(
self.ov_core, self.ov_core,
model_config, vllm_config=self.vllm_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config=self.load_config,
lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker, is_driver_worker=is_driver_worker,
) )
......
...@@ -12,8 +12,7 @@ import torch_xla.runtime as xr ...@@ -12,8 +12,7 @@ import torch_xla.runtime as xr
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, from vllm.config import VllmConfig
ParallelConfig, SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
...@@ -90,20 +89,10 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -90,20 +89,10 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
def __init__( def __init__(
self, self,
model_config: ModelConfig, vllm_config: VllmConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
is_driver_worker: bool = False, is_driver_worker: bool = False,
): ):
self.model_config = model_config ModelRunnerBase.__init__(self, vllm_config=vllm_config)
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.load_config = load_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
self.block_size = self.cache_config.block_size self.block_size = self.cache_config.block_size
......
...@@ -6,8 +6,7 @@ import torch_xla.core.xla_model as xm ...@@ -6,8 +6,7 @@ import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr import torch_xla.runtime as xr
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, from vllm.config import VllmConfig
ParallelConfig, SchedulerConfig)
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -16,7 +15,8 @@ from vllm.sequence import ExecuteModelRequest ...@@ -16,7 +15,8 @@ from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
from vllm.worker.tpu_model_runner import TPUModelRunner from vllm.worker.tpu_model_runner import TPUModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerInput) LoraNotSupportedWorkerBase, WorkerBase,
WorkerInput)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -25,24 +25,14 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -25,24 +25,14 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
def __init__( def __init__(
self, self,
model_config: ModelConfig, vllm_config: VllmConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
local_rank: int, local_rank: int,
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: str,
is_driver_worker: bool, is_driver_worker: bool,
) -> None: ) -> None:
self.model_config = model_config WorkerBase.__init__(self, vllm_config=vllm_config)
self.parallel_config = parallel_config
self.parallel_config.rank = rank self.parallel_config.rank = rank
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.load_config = load_config
self.local_rank = local_rank self.local_rank = local_rank
self.rank = rank self.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
...@@ -56,13 +46,7 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -56,13 +46,7 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self.cache_config.cache_dtype] self.cache_config.cache_dtype]
self.model_runner: TPUModelRunner = TPUModelRunner( self.model_runner: TPUModelRunner = TPUModelRunner(
model_config, vllm_config=vllm_config, is_driver_worker=is_driver_worker)
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config,
is_driver_worker=is_driver_worker)
def init_device(self) -> None: def init_device(self) -> None:
os.environ["PJRT_DEVICE"] = "TPU" os.environ["PJRT_DEVICE"] = "TPU"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment