Unverified Commit e8937954 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[2/N] executor pass the complete config to worker/modelrunner (#9938)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
Co-authored-by: default avatarNick Hill <nhill@redhat.com>
parent 1d4cfe2b
...@@ -138,13 +138,7 @@ def test_rotary_emb_replaced(dist_init): ...@@ -138,13 +138,7 @@ def test_rotary_emb_replaced(dist_init):
enable_lora=True) enable_lora=True)
engine_config = engine_args.create_engine_config() engine_config = engine_args.create_engine_config()
model_runner = ModelRunner( model_runner = ModelRunner(
model_config=engine_config.model_config, vllm_config=engine_config,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
lora_config=engine_config.lora_config,
is_driver_worker=True, is_driver_worker=True,
) )
model_runner.load_model() model_runner.load_model()
......
...@@ -4,7 +4,8 @@ import tempfile ...@@ -4,7 +4,8 @@ import tempfile
from unittest.mock import patch from unittest.mock import patch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig) ModelConfig, ParallelConfig, SchedulerConfig,
VllmConfig)
from vllm.lora.models import LoRAMapping from vllm.lora.models import LoRAMapping
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
...@@ -12,7 +13,7 @@ from vllm.worker.worker import Worker ...@@ -12,7 +13,7 @@ from vllm.worker.worker import Worker
@patch.dict(os.environ, {"RANK": "0"}) @patch.dict(os.environ, {"RANK": "0"})
def test_worker_apply_lora(sql_lora_files): def test_worker_apply_lora(sql_lora_files):
worker = Worker( vllm_config = VllmConfig(
model_config=ModelConfig( model_config=ModelConfig(
"meta-llama/Llama-2-7b-hf", "meta-llama/Llama-2-7b-hf",
task="auto", task="auto",
...@@ -34,10 +35,13 @@ def test_worker_apply_lora(sql_lora_files): ...@@ -34,10 +35,13 @@ def test_worker_apply_lora(sql_lora_files):
gpu_memory_utilization=1., gpu_memory_utilization=1.,
swap_space=0, swap_space=0,
cache_dtype="auto"), cache_dtype="auto"),
local_rank=0,
rank=0,
lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32, lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32,
max_loras=32), max_loras=32),
)
worker = Worker(
vllm_config=vllm_config,
local_rank=0,
rank=0,
distributed_init_method=f"file://{tempfile.mkstemp()[1]}", distributed_init_method=f"file://{tempfile.mkstemp()[1]}",
) )
worker.init_device() worker.init_device()
......
...@@ -81,12 +81,7 @@ def create_worker(cls: Callable[..., T], ...@@ -81,12 +81,7 @@ def create_worker(cls: Callable[..., T],
get_ip(), get_open_port()) get_ip(), get_open_port())
worker = cls( worker = cls(
model_config=engine_config.model_config, vllm_config=engine_config,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
local_rank=0, local_rank=0,
rank=0, rank=0,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
......
...@@ -19,14 +19,7 @@ def _create_model_runner(model: str, *args, ...@@ -19,14 +19,7 @@ def _create_model_runner(model: str, *args,
engine_args = EngineArgs(model, *args, **kwargs) engine_args = EngineArgs(model, *args, **kwargs)
engine_config = engine_args.create_engine_config() engine_config = engine_args.create_engine_config()
model_runner = EncoderDecoderModelRunner( model_runner = EncoderDecoderModelRunner(
model_config=engine_config.model_config, vllm_config=engine_config,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
lora_config=engine_config.lora_config,
prompt_adapter_config=engine_config.prompt_adapter_config,
is_driver_worker=True, is_driver_worker=True,
) )
return model_runner return model_runner
......
...@@ -16,15 +16,7 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner: ...@@ -16,15 +16,7 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
engine_args = EngineArgs(model, *args, **kwargs) engine_args = EngineArgs(model, *args, **kwargs)
engine_config = engine_args.create_engine_config() engine_config = engine_args.create_engine_config()
model_runner = ModelRunner( model_runner = ModelRunner(
model_config=engine_config.model_config, vllm_config=engine_config,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
lora_config=engine_config.lora_config,
prompt_adapter_config=engine_config.prompt_adapter_config,
observability_config=engine_config.observability_config,
is_driver_worker=True, is_driver_worker=True,
) )
return model_runner return model_runner
......
...@@ -24,12 +24,7 @@ def test_gpu_memory_profiling(): ...@@ -24,12 +24,7 @@ def test_gpu_memory_profiling():
distributed_init_method = get_distributed_init_method( distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port()) get_ip(), get_open_port())
worker = Worker( worker = Worker(
model_config=engine_config.model_config, vllm_config=engine_config,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
local_rank=0, local_rank=0,
rank=0, rank=0,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
......
...@@ -19,12 +19,7 @@ def test_swap() -> None: ...@@ -19,12 +19,7 @@ def test_swap() -> None:
distributed_init_method = get_distributed_init_method( distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port()) get_ip(), get_open_port())
worker = Worker( worker = Worker(
model_config=engine_config.model_config, vllm_config=engine_config,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
local_rank=0, local_rank=0,
rank=0, rank=0,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
......
import enum import enum
import json import json
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal, from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal,
Mapping, Optional, Set, Tuple, Type, Union) Mapping, Optional, Set, Tuple, Type, Union)
...@@ -1941,9 +1941,9 @@ class ObservabilityConfig: ...@@ -1941,9 +1941,9 @@ class ObservabilityConfig:
f"installed. Original error:\n{otel_import_error_traceback}") f"installed. Original error:\n{otel_import_error_traceback}")
@dataclass(frozen=True) @dataclass
class EngineConfig: class VllmConfig:
"""Dataclass which contains all engine-related configuration. This """Dataclass which contains all vllm-related configuration. This
simplifies passing around the distinct configurations in the codebase. simplifies passing around the distinct configurations in the codebase.
""" """
...@@ -1953,11 +1953,11 @@ class EngineConfig: ...@@ -1953,11 +1953,11 @@ class EngineConfig:
scheduler_config: SchedulerConfig scheduler_config: SchedulerConfig
device_config: DeviceConfig device_config: DeviceConfig
load_config: LoadConfig load_config: LoadConfig
lora_config: Optional[LoRAConfig] lora_config: Optional[LoRAConfig] = None
speculative_config: Optional[SpeculativeConfig] speculative_config: Optional[SpeculativeConfig] = None
decoding_config: Optional[DecodingConfig] decoding_config: Optional[DecodingConfig] = None
observability_config: Optional[ObservabilityConfig] observability_config: Optional[ObservabilityConfig] = None
prompt_adapter_config: Optional[PromptAdapterConfig] prompt_adapter_config: Optional[PromptAdapterConfig] = None
def __post_init__(self): def __post_init__(self):
"""Verify configs are valid & consistent with each other. """Verify configs are valid & consistent with each other.
...@@ -1975,9 +1975,3 @@ class EngineConfig: ...@@ -1975,9 +1975,3 @@ class EngineConfig:
if self.prompt_adapter_config: if self.prompt_adapter_config:
self.prompt_adapter_config.verify_with_model_config( self.prompt_adapter_config.verify_with_model_config(
self.model_config) self.model_config)
def to_dict(self):
"""Return the configs as a dictionary, for use in **kwargs.
"""
return dict(
(field.name, getattr(self, field.name)) for field in fields(self))
...@@ -9,10 +9,11 @@ import torch ...@@ -9,10 +9,11 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig, from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
DeviceConfig, EngineConfig, LoadConfig, LoadFormat, DeviceConfig, LoadConfig, LoadFormat, LoRAConfig,
LoRAConfig, ModelConfig, ObservabilityConfig, ModelConfig, ObservabilityConfig, ParallelConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TaskOption, TokenizerPoolConfig) SpeculativeConfig, TaskOption, TokenizerPoolConfig,
VllmConfig)
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
...@@ -955,7 +956,7 @@ class EngineArgs: ...@@ -955,7 +956,7 @@ class EngineArgs:
ignore_patterns=self.ignore_patterns, ignore_patterns=self.ignore_patterns,
) )
def create_engine_config(self) -> EngineConfig: def create_engine_config(self) -> VllmConfig:
# gguf file needs a specific model loader and doesn't use hf_repo # gguf file needs a specific model loader and doesn't use hf_repo
if check_gguf_file(self.model): if check_gguf_file(self.model):
self.quantization = self.load_format = "gguf" self.quantization = self.load_format = "gguf"
...@@ -1167,7 +1168,7 @@ class EngineArgs: ...@@ -1167,7 +1168,7 @@ class EngineArgs:
or "all" in detailed_trace_modules, or "all" in detailed_trace_modules,
) )
return EngineConfig( return VllmConfig(
model_config=model_config, model_config=model_config,
cache_config=cache_config, cache_config=cache_config,
parallel_config=parallel_config, parallel_config=parallel_config,
......
...@@ -7,8 +7,8 @@ from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable, ...@@ -7,8 +7,8 @@ from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable,
from weakref import ReferenceType from weakref import ReferenceType
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig) ParallelConfig, SchedulerConfig, VllmConfig)
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout from vllm.engine.async_timeout import asyncio_timeout
...@@ -604,7 +604,7 @@ class AsyncLLMEngine(EngineClient): ...@@ -604,7 +604,7 @@ class AsyncLLMEngine(EngineClient):
@classmethod @classmethod
def _get_executor_cls( def _get_executor_cls(
cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]: cls, engine_config: VllmConfig) -> Type[ExecutorAsyncBase]:
distributed_executor_backend = ( distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend) engine_config.parallel_config.distributed_executor_backend)
if isinstance(distributed_executor_backend, type): if isinstance(distributed_executor_backend, type):
...@@ -663,7 +663,7 @@ class AsyncLLMEngine(EngineClient): ...@@ -663,7 +663,7 @@ class AsyncLLMEngine(EngineClient):
def from_engine_args( def from_engine_args(
cls, cls,
engine_args: AsyncEngineArgs, engine_args: AsyncEngineArgs,
engine_config: Optional[EngineConfig] = None, engine_config: Optional[VllmConfig] = None,
start_engine_loop: bool = True, start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
......
...@@ -13,8 +13,9 @@ import torch ...@@ -13,8 +13,9 @@ import torch
from typing_extensions import TypeIs, TypeVar from typing_extensions import TypeIs, TypeVar
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig, SchedulerConfig) ObservabilityConfig, ParallelConfig, SchedulerConfig,
VllmConfig)
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
SchedulerOutputs) SchedulerOutputs)
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
...@@ -219,7 +220,7 @@ class LLMEngine: ...@@ -219,7 +220,7 @@ class LLMEngine:
def __init__( def __init__(
self, self,
vllm_config: EngineConfig, vllm_config: VllmConfig,
executor_class: Type[ExecutorBase], executor_class: Type[ExecutorBase],
log_stats: bool, log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
...@@ -500,7 +501,7 @@ class LLMEngine: ...@@ -500,7 +501,7 @@ class LLMEngine:
@classmethod @classmethod
def _get_executor_cls(cls, def _get_executor_cls(cls,
engine_config: EngineConfig) -> Type[ExecutorBase]: engine_config: VllmConfig) -> Type[ExecutorBase]:
distributed_executor_backend = ( distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend) engine_config.parallel_config.distributed_executor_backend)
# Initialize the cluster and specify the executor class. # Initialize the cluster and specify the executor class.
......
...@@ -13,7 +13,7 @@ from zmq import Frame # type: ignore[attr-defined] ...@@ -13,7 +13,7 @@ from zmq import Frame # type: ignore[attr-defined]
from zmq.asyncio import Socket from zmq.asyncio import Socket
from vllm import PoolingParams from vllm import PoolingParams
from vllm.config import DecodingConfig, EngineConfig, ModelConfig from vllm.config import DecodingConfig, ModelConfig, VllmConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
...@@ -78,7 +78,7 @@ class MQLLMEngineClient(EngineClient): ...@@ -78,7 +78,7 @@ class MQLLMEngineClient(EngineClient):
every N seconds, confirming the engine is healthy every N seconds, confirming the engine is healthy
""" """
def __init__(self, ipc_path: str, engine_config: EngineConfig, def __init__(self, ipc_path: str, engine_config: VllmConfig,
engine_pid: int): engine_pid: int):
self.context = zmq.asyncio.Context() self.context = zmq.asyncio.Context()
self._errored_with: Optional[BaseException] = None self._errored_with: Optional[BaseException] = None
......
...@@ -138,18 +138,11 @@ class CPUExecutor(ExecutorBase): ...@@ -138,18 +138,11 @@ class CPUExecutor(ExecutorBase):
assert self.distributed_init_method is not None assert self.distributed_init_method is not None
kwargs = dict( kwargs = dict(
model_config=self.model_config, vllm_config=self.vllm_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
local_rank=local_rank, local_rank=local_rank,
rank=rank, rank=rank,
distributed_init_method=self.distributed_init_method, distributed_init_method=self.distributed_init_method,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype, kv_cache_dtype=self.cache_config.cache_dtype,
prompt_adapter_config=self.prompt_adapter_config,
is_driver_worker=rank == 0, is_driver_worker=rank == 0,
) )
wrapper.init_worker(**kwargs) wrapper.init_worker(**kwargs)
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional, Set, Tuple from typing import List, Optional, Set, Tuple
from vllm.config import EngineConfig from vllm.config import VllmConfig
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
...@@ -20,7 +20,7 @@ class ExecutorBase(ABC): ...@@ -20,7 +20,7 @@ class ExecutorBase(ABC):
def __init__( def __init__(
self, self,
vllm_config: EngineConfig, vllm_config: VllmConfig,
) -> None: ) -> None:
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
......
...@@ -49,21 +49,12 @@ class GPUExecutor(ExecutorBase): ...@@ -49,21 +49,12 @@ class GPUExecutor(ExecutorBase):
distributed_init_method = get_distributed_init_method( distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port()) get_ip(), get_open_port())
return dict( return dict(
model_config=self.model_config, vllm_config=self.vllm_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
local_rank=local_rank, local_rank=local_rank,
rank=rank, rank=rank,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
speculative_config=self.speculative_config,
prompt_adapter_config=self.prompt_adapter_config,
is_driver_worker=(not self.parallel_config) is_driver_worker=(not self.parallel_config)
or (rank % self.parallel_config.tensor_parallel_size == 0), or (rank % self.parallel_config.tensor_parallel_size == 0),
observability_config=self.observability_config,
) )
def _get_worker_module_and_class( def _get_worker_module_and_class(
......
...@@ -29,11 +29,7 @@ class NeuronExecutor(ExecutorBase): ...@@ -29,11 +29,7 @@ class NeuronExecutor(ExecutorBase):
distributed_init_method = get_distributed_init_method( distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port()) get_ip(), get_open_port())
self.driver_worker = NeuronWorker( self.driver_worker = NeuronWorker(
model_config=self.model_config, vllm_config=self.vllm_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
local_rank=0, local_rank=0,
rank=0, rank=0,
distributed_init_method=distributed_init_method) distributed_init_method=distributed_init_method)
......
...@@ -48,16 +48,10 @@ class OpenVINOExecutor(ExecutorBase): ...@@ -48,16 +48,10 @@ class OpenVINOExecutor(ExecutorBase):
get_ip(), get_open_port()) get_ip(), get_open_port())
self.driver_worker = OpenVINOWorker( self.driver_worker = OpenVINOWorker(
ov_core=self.ov_core, ov_core=self.ov_core,
model_config=self.model_config, vllm_config=self.vllm_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
local_rank=0, local_rank=0,
rank=0, rank=0,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype, kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True, is_driver_worker=True,
) )
......
...@@ -44,12 +44,7 @@ class TPUExecutor(ExecutorBase): ...@@ -44,12 +44,7 @@ class TPUExecutor(ExecutorBase):
distributed_init_method = get_distributed_init_method( distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port()) get_ip(), get_open_port())
return dict( return dict(
model_config=self.model_config, vllm_config=self.vllm_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
local_rank=local_rank, local_rank=local_rank,
rank=rank, rank=rank,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
......
...@@ -17,9 +17,6 @@ except (ModuleNotFoundError, ImportError) as err: ...@@ -17,9 +17,6 @@ except (ModuleNotFoundError, ImportError) as err:
"Draft model speculative decoding currently only supports" "Draft model speculative decoding currently only supports"
"CUDA and ROCm flash attention backend.") from err "CUDA and ROCm flash attention backend.") from err
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MultiModalInputs from vllm.multimodal import MultiModalInputs
from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.sequence import ExecuteModelRequest, IntermediateTensors
...@@ -49,40 +46,13 @@ class TP1DraftModelRunner(ModelRunner): ...@@ -49,40 +46,13 @@ class TP1DraftModelRunner(ModelRunner):
any broadcasting inside execute_model). any broadcasting inside execute_model).
""" """
def __init__( def __init__(self, *args, **kwargs):
self, if kwargs.get("return_hidden_states"):
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
return_hidden_states: bool = False,
observability_config: Optional[ObservabilityConfig] = None,
):
if return_hidden_states:
raise ValueError( raise ValueError(
"return_hidden_states is not supported for TP1DraftModelRunner." "return_hidden_states is not supported for TP1DraftModelRunner."
) )
super().__init__( super().__init__(*args, **kwargs)
model_config=model_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
cache_config=cache_config,
load_config=load_config,
lora_config=lora_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
prompt_adapter_config=prompt_adapter_config,
return_hidden_states=return_hidden_states,
observability_config=observability_config,
)
def _update_sampling_metadata(self, sampling_metadata, num_seqs, def _update_sampling_metadata(self, sampling_metadata, num_seqs,
num_queries): num_queries):
......
...@@ -21,7 +21,7 @@ class NGramWorker(NonLLMProposerWorkerBase): ...@@ -21,7 +21,7 @@ class NGramWorker(NonLLMProposerWorkerBase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
# Get local_rank/vocab_size from kwargs attribute # Get local_rank/vocab_size from kwargs attribute
self.local_rank = kwargs["local_rank"] self.local_rank = kwargs["local_rank"]
self.vocab_size = kwargs["model_config"].get_vocab_size() self.vocab_size = kwargs["vllm_config"].model_config.get_vocab_size()
# Lazy initialization list. # Lazy initialization list.
self._proposer: Top1Proposer self._proposer: Top1Proposer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment