Unverified Commit e8937954 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[2/N] executor pass the complete config to worker/modelrunner (#9938)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
Co-authored-by: default avatarNick Hill <nhill@redhat.com>
parent 1d4cfe2b
import copy
from collections import defaultdict
from functools import cached_property
from typing import Any, Dict, List, Optional, Set, Tuple, Type
import torch
from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
from vllm.distributed.communication_op import broadcast_tensor_dict
from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
......@@ -45,8 +46,8 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
"""Helper method that is the entrypoint for Executors which use
WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
"""
assert "speculative_config" in kwargs
speculative_config: SpeculativeConfig = kwargs.get("speculative_config")
vllm_config: VllmConfig = kwargs.get("vllm_config")
speculative_config: SpeculativeConfig = vllm_config.speculative_config
assert speculative_config is not None
draft_worker_kwargs = kwargs.copy()
......@@ -58,14 +59,16 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
target_worker.model_runner.disable_logprobs =\
speculative_config.disable_logprobs
draft_worker_config = copy.deepcopy(vllm_config)
draft_worker_config.model_config = speculative_config.draft_model_config
draft_worker_config.parallel_config = speculative_config.draft_parallel_config # noqa
# TODO allow draft-model specific load config.
# Override draft-model specific worker args.
draft_worker_kwargs.update(
model_config=speculative_config.draft_model_config,
parallel_config=speculative_config.draft_parallel_config,
vllm_config=draft_worker_config,
ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max,
ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min,
# TODO allow draft-model specific load config.
#load_config=load_config,
)
spec_decode_worker = SpecDecodeWorker.create_worker(
......@@ -134,29 +137,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
ngram_prompt_lookup_min = (
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
draft_model_config = draft_worker_kwargs["vllm_config"].model_config
draft_parallel_config: ParallelConfig = draft_worker_kwargs[
'vllm_config'].parallel_config
if ngram_prompt_lookup_max > 0:
proposer_worker = NGramWorker(**draft_worker_kwargs)
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
ngram_prompt_lookup_max)
else:
draft_parallel_config: ParallelConfig = draft_worker_kwargs[
'parallel_config']
draft_tp = draft_parallel_config.tensor_parallel_size
target_tp = scorer_worker.parallel_config.tensor_parallel_size
if draft_worker_kwargs[
"model_config"].hf_config.model_type == "mlp_speculator":
if draft_model_config.hf_config.model_type == "mlp_speculator":
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
elif draft_worker_kwargs[
"model_config"].hf_config.model_type == "medusa":
elif draft_model_config.hf_config.model_type == "medusa":
proposer_worker = MedusaWorker(**draft_worker_kwargs)
else:
if draft_tp == 1:
draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner
else:
if draft_worker_kwargs[
"model_config"].hf_config.model_type == "eagle":
if draft_model_config.hf_config.model_type == "eagle":
raise NotImplementedError(
"EAGLE does not support TP > 1 yet")
......@@ -190,8 +191,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"[Speculative Decoding] Disabling MQA scorer as the "
"MQA is only available with flash attn backend.")
if "model_config" in draft_worker_kwargs and \
draft_worker_kwargs["model_config"].max_model_len < \
if draft_model_config and \
draft_model_config.max_model_len < \
scorer_worker.model_config.max_model_len:
disable_mqa_scorer = True
logger.info(
......
from typing import List, Optional
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.config import VllmConfig
from vllm.sequence import SequenceGroupMetadata
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner)
......@@ -20,35 +18,21 @@ class TargetModelRunner(ModelRunner):
requested or not.
"""
def __init__(self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
return_hidden_states: bool = False,
observability_config: Optional[ObservabilityConfig] = None):
def __init__(
self,
vllm_config: VllmConfig,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
return_hidden_states: bool = False,
):
# An internal boolean member variable to indicate if token log
# probabilities are needed or not.
self.disable_logprobs = True
super().__init__(
model_config=model_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
cache_config=cache_config,
load_config=load_config,
lora_config=lora_config,
vllm_config=vllm_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
prompt_adapter_config=prompt_adapter_config,
return_hidden_states=return_hidden_states,
observability_config=observability_config,
)
def prepare_model_input(
......
......@@ -2,8 +2,9 @@ import time
from typing import (Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type,
Union)
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig, SchedulerConfig)
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig, SchedulerConfig,
VllmConfig)
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics_types import StatLoggerBase
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
......@@ -32,7 +33,7 @@ class LLMEngine:
def __init__(
self,
vllm_config: EngineConfig,
vllm_config: VllmConfig,
executor_class: Type[GPUExecutor],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
......@@ -477,7 +478,7 @@ class LLMEngine:
return self.lora_config
@classmethod
def _get_executor_cls(cls, engine_config: EngineConfig):
def _get_executor_cls(cls, engine_config: VllmConfig):
return GPUExecutor
def is_tracing_enabled(self) -> bool:
......
......@@ -56,19 +56,10 @@ class GPUExecutor:
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
return Worker(
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
speculative_config=self.speculative_config,
prompt_adapter_config=self.prompt_adapter_config,
observability_config=self.observability_config,
)
def determine_num_available_blocks(self) -> Tuple[int, int]:
......
......@@ -7,9 +7,7 @@ import torch
import torch.distributed
import torch.nn as nn
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
......@@ -33,26 +31,25 @@ class GPUModelRunner:
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
observability_config: Optional[ObservabilityConfig] = None,
vllm_config: VllmConfig,
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.lora_config = lora_config
self.load_config = load_config
self.prompt_adapter_config = prompt_adapter_config
self.observability_config = observability_config
# TODO: use ModelRunnerBase.__init__(self, vllm_config=vllm_config)
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
model_config = self.model_config
cache_config = self.cache_config
scheduler_config = self.scheduler_config
parallel_config = self.parallel_config
self.device = self.device_config.device
self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype
......
......@@ -6,10 +6,7 @@ from typing import TYPE_CHECKING, Optional, Tuple
import torch
import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
......@@ -30,48 +27,35 @@ class Worker:
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
speculative_config: Optional[SpeculativeConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
observability_config: Optional[ObservabilityConfig] = None,
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.load_config = load_config
# TODO: use WorkerBase.__init__(self, vllm_config=vllm_config)
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.speculative_config = speculative_config
self.prompt_adapter_config = prompt_adapter_config
self.observability_config = observability_config
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.model_runner = GPUModelRunner(
model_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config,
lora_config=lora_config,
)
self.model_runner = GPUModelRunner(vllm_config)
def initialize(self):
if self.device_config.device.type == "cuda":
......
......@@ -8,9 +8,7 @@ import torch
from torch import nn
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, PromptAdapterConfig,
SchedulerConfig)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
......@@ -412,29 +410,18 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vllm_config: VllmConfig,
kv_cache_dtype: Optional[str] = "auto",
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False,
*args,
**kwargs,
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
ModelRunnerBase.__init__(self, vllm_config)
# Currently, CPU worker doesn't support chunked prefill.
assert self.scheduler_config.chunked_prefill_enabled is False
self.device_config = device_config
self.cache_config = cache_config
self.lora_config = lora_config
self.prompt_adapter_config = prompt_adapter_config
self.load_config = load_config
model_config = self.model_config
cache_config = self.cache_config
self.is_driver_worker = is_driver_worker
self.device = self.device_config.device
......
......@@ -6,9 +6,8 @@ import torch.distributed
import vllm.envs as envs
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, PromptAdapterConfig,
SchedulerConfig)
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, VllmConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
......@@ -18,7 +17,8 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
from vllm.worker.cpu_model_runner import CPUModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerInput)
LoraNotSupportedWorkerBase, WorkerBase,
WorkerInput)
logger = init_logger(__name__)
......@@ -121,31 +121,19 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
kv_cache_dtype: Optional[str] = "auto",
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False,
) -> None:
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.load_config = load_config
WorkerBase.__init__(self, vllm_config=vllm_config)
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.prompt_adapter_config = prompt_adapter_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
......@@ -166,15 +154,8 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
if self._is_encoder_decoder_model():
ModelRunnerClass = CPUEncoderDecoderModelRunner
self.model_runner: CPUModelRunner = ModelRunnerClass(
model_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config=self.load_config,
lora_config=self.lora_config,
vllm_config=vllm_config,
kv_cache_dtype=kv_cache_dtype,
prompt_adapter_config=self.prompt_adapter_config,
is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
......
......@@ -3,9 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
import torch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
......@@ -36,29 +34,13 @@ class EmbeddingModelRunner(
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vllm_config: VllmConfig,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
observability_config: Optional[ObservabilityConfig] = None,
):
super().__init__(model_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config,
lora_config=lora_config,
super().__init__(vllm_config=vllm_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
prompt_adapter_config=prompt_adapter_config,
observability_config=observability_config)
is_driver_worker=is_driver_worker)
@torch.inference_mode()
def execute_model(
......
......@@ -11,9 +11,7 @@ from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
get_global_forced_attn_backend,
global_force_attn_backend)
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.config import ModelConfig, VllmConfig
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
......@@ -85,17 +83,9 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vllm_config: VllmConfig,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
observability_config: Optional[ObservabilityConfig] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
......@@ -107,15 +97,10 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
models) but these arguments are present here for compatibility with
the base-class constructor.
'''
self._maybe_force_supported_attention_backend(model_config)
self._maybe_force_supported_attention_backend(vllm_config.model_config)
super().__init__(
model_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config,
lora_config=None,
vllm_config=vllm_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
)
......
......@@ -20,9 +20,7 @@ from vllm.attention.backends.abstract import AttentionState
from vllm.attention.backends.utils import CommonAttentionState
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.levels import CompilationLevel
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.config import VllmConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_pp_group
from vllm.distributed.parallel_state import graph_capture
......@@ -955,32 +953,20 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vllm_config: VllmConfig,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
return_hidden_states: bool = False,
observability_config: Optional[ObservabilityConfig] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.lora_config = lora_config
self.load_config = load_config
ModelRunnerBase.__init__(self, vllm_config)
model_config = self.model_config
cache_config = self.cache_config
self.is_driver_worker = is_driver_worker
self.prompt_adapter_config = prompt_adapter_config
self.return_hidden_states = return_hidden_states
self.observability_config = observability_config
self.device = self.device_config.device
self.pin_memory = is_pin_memory_available()
......
......@@ -9,6 +9,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
import torch
from torch import is_tensor
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
......@@ -220,6 +221,22 @@ class ModelRunnerBase(ABC, Generic[T]):
ModelRunnerInputBase subclass.
"""
def __init__(
self,
vllm_config: VllmConfig,
) -> None:
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
# Map of request_id -> generator used for seeded random sampling
generators: Dict[str, torch.Generator] = {}
......
......@@ -304,6 +304,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
# mypy: enable-error-code=type-var
def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs):
super().__init__(*args, **kwargs)
# Check attention backend support.
......
......@@ -27,17 +27,9 @@ class MultiStepWorker(Worker):
# for multi-step model, wrap the model runner with MultiStepModelRunner
self.model_runner = MultiStepModelRunner(
base_model_runner,
base_model_runner.model_config,
base_model_runner.parallel_config,
base_model_runner.scheduler_config,
base_model_runner.device_config,
base_model_runner.cache_config,
load_config=base_model_runner.load_config,
lora_config=self.lora_config,
vllm_config=base_model_runner.vllm_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=base_model_runner.is_driver_worker,
prompt_adapter_config=base_model_runner.prompt_adapter_config,
observability_config=base_model_runner.observability_config,
)
pipeline_parallel_size = self.parallel_config.pipeline_parallel_size
......
......@@ -7,8 +7,7 @@ import torch
from torch import nn
from transformers_neuronx.config import GenerationConfig
from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
......@@ -57,20 +56,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
vllm_config: VllmConfig,
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
ModelRunnerBase.__init__(self, vllm_config)
model_config = self.model_config
if model_config is not None and model_config.get_sliding_window():
logger.warning("Sliding window is not supported on Neuron. "
"The model will run without sliding window.")
self.device_config = (device_config
if device_config is not None else DeviceConfig())
self.device = self.device_config.device
self.pin_memory = is_pin_memory_available()
......
......@@ -4,15 +4,15 @@ from typing import List, Optional, Tuple
import torch
import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest
from vllm.worker.neuron_model_runner import NeuronModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerInput)
LoraNotSupportedWorkerBase, WorkerBase,
WorkerInput)
class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
......@@ -21,20 +21,12 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
) -> None:
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
WorkerBase.__init__(self, vllm_config=vllm_config)
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
......@@ -44,7 +36,7 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
init_cached_hf_modules()
self.model_runner: NeuronModelRunner = NeuronModelRunner(
model_config, parallel_config, scheduler_config, device_config)
vllm_config=vllm_config)
self.is_driver_worker = True
def init_device(self) -> None:
......
......@@ -7,9 +7,7 @@ from torch import nn
from vllm.attention import get_attn_backend
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
......@@ -17,6 +15,7 @@ from vllm.model_executor.model_loader.openvino import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs, MultiModalPlaceholderMap)
from vllm.sequence import SequenceGroupMetadata
from vllm.worker.model_runner_base import ModelRunnerBase
logger = init_logger(__name__)
......@@ -39,33 +38,21 @@ class ModelInput(NamedTuple):
multi_modal_kwargs={})
class OpenVINOModelRunner:
class OpenVINOModelRunner(ModelRunnerBase):
def __init__(
self,
ov_core: ov.Core,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
vllm_config: VllmConfig,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
*args,
**kwargs,
):
self.ov_core = ov_core
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.lora_config = lora_config
self.multimodal_config = multimodal_config
self.load_config = load_config
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
cache_config = self.cache_config
model_config = self.model_config
self.is_driver_worker = is_driver_worker
self.device = self.device_config.device
......@@ -369,3 +356,9 @@ class OpenVINOModelRunner:
sampling_metadata=sampling_metadata,
)
return output
def prepare_model_input(self, *args, **kwargs):
raise NotImplementedError
def make_model_input_from_broadcasted_tensor_dict(self, *args, **kwargs):
raise NotImplementedError
......@@ -7,9 +7,8 @@ import torch.distributed
import vllm.envs as envs
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig)
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, VllmConfig)
from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized,
init_distributed_environment)
......@@ -22,7 +21,7 @@ from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.worker.openvino_model_runner import OpenVINOModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
logger = init_logger(__name__)
......@@ -212,33 +211,19 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
def __init__(
self,
ov_core: ov.Core,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
kv_cache_dtype: Optional[ov.Type] = ov.Type.undefined,
is_driver_worker: bool = False,
) -> None:
self.ov_core = ov_core
self.model_config = model_config
self.parallel_config = parallel_config
WorkerBase.__init__(self, vllm_config)
self.parallel_config.rank = rank
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.load_config = load_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.multimodal_config = multimodal_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
......@@ -250,14 +235,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
init_cached_hf_modules()
self.model_runner = OpenVINOModelRunner(
self.ov_core,
model_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config=self.load_config,
lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
vllm_config=self.vllm_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
)
......
......@@ -12,8 +12,7 @@ import torch_xla.runtime as xr
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model
......@@ -90,20 +89,10 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
vllm_config: VllmConfig,
is_driver_worker: bool = False,
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.load_config = load_config
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
self.is_driver_worker = is_driver_worker
self.block_size = self.cache_config.block_size
......
......@@ -6,8 +6,7 @@ import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import vllm.envs as envs
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
......@@ -16,7 +15,8 @@ from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
from vllm.worker.tpu_model_runner import TPUModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerInput)
LoraNotSupportedWorkerBase, WorkerBase,
WorkerInput)
logger = init_logger(__name__)
......@@ -25,24 +25,14 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool,
) -> None:
self.model_config = model_config
self.parallel_config = parallel_config
WorkerBase.__init__(self, vllm_config=vllm_config)
self.parallel_config.rank = rank
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.load_config = load_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
......@@ -56,13 +46,7 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self.cache_config.cache_dtype]
self.model_runner: TPUModelRunner = TPUModelRunner(
model_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config,
is_driver_worker=is_driver_worker)
vllm_config=vllm_config, is_driver_worker=is_driver_worker)
def init_device(self) -> None:
os.environ["PJRT_DEVICE"] = "TPU"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment