Unverified Commit aafd4d23 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Chore] Try remove `init_cached_hf_modules` (#31786)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 0a2c2dc3
...@@ -29,11 +29,7 @@ class RunaiDummyExecutor(UniProcExecutor): ...@@ -29,11 +29,7 @@ class RunaiDummyExecutor(UniProcExecutor):
is_driver_worker=is_driver_worker, is_driver_worker=is_driver_worker,
) )
wrapper_kwargs = { self.driver_worker = WorkerWrapperBase()
"vllm_config": self.vllm_config,
}
self.driver_worker = WorkerWrapperBase(**wrapper_kwargs)
self.collective_rpc("init_worker", args=([worker_rpc_kwargs],)) self.collective_rpc("init_worker", args=([worker_rpc_kwargs],))
self.collective_rpc("init_device") self.collective_rpc("init_device")
...@@ -67,7 +67,7 @@ def assert_from_collective_rpc(engine: LLM, closure: Callable, closure_kwargs: d ...@@ -67,7 +67,7 @@ def assert_from_collective_rpc(engine: LLM, closure: Callable, closure_kwargs: d
class DummyExecutor(UniProcExecutor): class DummyExecutor(UniProcExecutor):
def _init_executor(self) -> None: def _init_executor(self) -> None:
"""Initialize the worker and load the model.""" """Initialize the worker and load the model."""
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, rpc_rank=0) self.driver_worker = WorkerWrapperBase(rpc_rank=0)
distributed_init_method = get_distributed_init_method(get_ip(), get_open_port()) distributed_init_method = get_distributed_init_method(get_ip(), get_open_port())
local_rank = 0 local_rank = 0
# set local rank as the device index if specified # set local rank as the device index if specified
......
...@@ -23,17 +23,6 @@ from vllm.logger import init_logger ...@@ -23,17 +23,6 @@ from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
# TODO: This function can be removed if transformer_modules classes are
# serialized by value when communicating between processes
def init_cached_hf_modules() -> None:
"""
Lazy initialization of the Hugging Face modules.
"""
from transformers.dynamic_module_utils import init_hf_modules
init_hf_modules()
def import_pynvml(): def import_pynvml():
""" """
Historical comments: Historical comments:
......
...@@ -519,9 +519,7 @@ class WorkerProc: ...@@ -519,9 +519,7 @@ class WorkerProc:
shared_worker_lock: LockType, shared_worker_lock: LockType,
): ):
self.rank = rank self.rank = rank
wrapper = WorkerWrapperBase( wrapper = WorkerWrapperBase(rpc_rank=local_rank, global_rank=rank)
vllm_config=vllm_config, rpc_rank=local_rank, global_rank=rank
)
# TODO: move `init_worker` to executor level as a collective rpc call # TODO: move `init_worker` to executor level as a collective rpc call
all_kwargs: list[dict] = [ all_kwargs: list[dict] = [
{} for _ in range(vllm_config.parallel_config.world_size) {} for _ in range(vllm_config.parallel_config.world_size)
......
...@@ -208,9 +208,7 @@ class RayDistributedExecutor(Executor): ...@@ -208,9 +208,7 @@ class RayDistributedExecutor(Executor):
num_gpus=num_gpus, num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy, scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs, **ray_remote_kwargs,
)(RayWorkerWrapper).remote( # type: ignore[attr-defined] )(RayWorkerWrapper).remote(rpc_rank=rank)
vllm_config=self.vllm_config, rpc_rank=rank
)
else: else:
worker = ray.remote( worker = ray.remote(
num_cpus=0, num_cpus=0,
...@@ -218,9 +216,8 @@ class RayDistributedExecutor(Executor): ...@@ -218,9 +216,8 @@ class RayDistributedExecutor(Executor):
resources={current_platform.ray_device_key: num_gpus}, resources={current_platform.ray_device_key: num_gpus},
scheduling_strategy=scheduling_strategy, scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs, **ray_remote_kwargs,
)(RayWorkerWrapper).remote( # type: ignore[attr-defined] )(RayWorkerWrapper).remote(rpc_rank=rank)
vllm_config=self.vllm_config, rpc_rank=rank
)
worker_metadata.append(RayWorkerMetaData(worker=worker, created_rank=rank)) worker_metadata.append(RayWorkerMetaData(worker=worker, created_rank=rank))
worker_ips = ray.get( worker_ips = ray.get(
......
...@@ -26,7 +26,7 @@ logger = init_logger(__name__) ...@@ -26,7 +26,7 @@ logger = init_logger(__name__)
class UniProcExecutor(Executor): class UniProcExecutor(Executor):
def _init_executor(self) -> None: def _init_executor(self) -> None:
"""Initialize the worker and load the model.""" """Initialize the worker and load the model."""
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, rpc_rank=0) self.driver_worker = WorkerWrapperBase(rpc_rank=0)
distributed_init_method, rank, local_rank = self._distributed_args() distributed_init_method, rank, local_rank = self._distributed_args()
kwargs = dict( kwargs = dict(
vllm_config=self.vllm_config, vllm_config=self.vllm_config,
......
...@@ -85,12 +85,6 @@ class Worker(WorkerBase): ...@@ -85,12 +85,6 @@ class Worker(WorkerBase):
precision = envs.VLLM_FLOAT32_MATMUL_PRECISION precision = envs.VLLM_FLOAT32_MATMUL_PRECISION
torch.set_float32_matmul_precision(precision) torch.set_float32_matmul_precision(precision)
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils.import_utils import init_cached_hf_modules
init_cached_hf_modules()
# Buffers saved before sleep # Buffers saved before sleep
self._sleep_saved_buffers: dict[str, torch.Tensor] = {} self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
......
...@@ -85,12 +85,6 @@ class TPUWorker: ...@@ -85,12 +85,6 @@ class TPUWorker:
else: else:
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[self.cache_config.cache_dtype] self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[self.cache_config.cache_dtype]
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils.import_utils import init_cached_hf_modules
init_cached_hf_modules()
# Delay profiler initialization to the start of the profiling. # Delay profiler initialization to the start of the profiling.
# This is because in vLLM V1, MP runtime is initialized before the # This is because in vLLM V1, MP runtime is initialized before the
# TPU Worker is initialized. The profiler server needs to start after # TPU Worker is initialized. The profiler server needs to start after
......
...@@ -178,7 +178,6 @@ class WorkerWrapperBase: ...@@ -178,7 +178,6 @@ class WorkerWrapperBase:
def __init__( def __init__(
self, self,
vllm_config: VllmConfig,
rpc_rank: int = 0, rpc_rank: int = 0,
global_rank: int | None = None, global_rank: int | None = None,
) -> None: ) -> None:
...@@ -194,21 +193,10 @@ class WorkerWrapperBase: ...@@ -194,21 +193,10 @@ class WorkerWrapperBase:
""" """
self.rpc_rank = rpc_rank self.rpc_rank = rpc_rank
self.global_rank = self.rpc_rank if global_rank is None else global_rank self.global_rank = self.rpc_rank if global_rank is None else global_rank
self.worker: WorkerBase | None = None
# do not store this `vllm_config`, `init_worker` will set the final
# one.
# TODO: investigate if we can remove this field in `WorkerWrapperBase`,
# `init_cached_hf_modules` should be unnecessary now.
self.vllm_config: VllmConfig | None = None
# `model_config` can be None in tests
model_config = vllm_config.model_config
if model_config and model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils.import_utils import init_cached_hf_modules
init_cached_hf_modules() # Initialized after init_worker is called
self.worker: WorkerBase
self.vllm_config: VllmConfig
def shutdown(self) -> None: def shutdown(self) -> None:
if self.worker is not None: if self.worker is not None:
...@@ -241,27 +229,34 @@ class WorkerWrapperBase: ...@@ -241,27 +229,34 @@ class WorkerWrapperBase:
Arguments are passed to the worker class constructor. Arguments are passed to the worker class constructor.
""" """
kwargs = all_kwargs[self.rpc_rank] kwargs = all_kwargs[self.rpc_rank]
self.vllm_config = kwargs.get("vllm_config")
assert self.vllm_config is not None, ( vllm_config: VllmConfig | None = kwargs.get("vllm_config")
assert vllm_config is not None, (
"vllm_config is required to initialize the worker" "vllm_config is required to initialize the worker"
) )
self.vllm_config.enable_trace_function_call_for_thread() self.vllm_config = vllm_config
vllm_config.enable_trace_function_call_for_thread()
from vllm.plugins import load_general_plugins from vllm.plugins import load_general_plugins
load_general_plugins() load_general_plugins()
if isinstance(self.vllm_config.parallel_config.worker_cls, str): parallel_config = vllm_config.parallel_config
worker_class = resolve_obj_by_qualname( if isinstance(parallel_config.worker_cls, str):
self.vllm_config.parallel_config.worker_cls worker_class: type[WorkerBase] = resolve_obj_by_qualname(
parallel_config.worker_cls
) )
else: else:
raise ValueError( raise ValueError(
"passing worker_cls is no longer supported. Please pass keep the class in a separate module and pass the qualified name of the class as a string." # noqa: E501 "passing worker_cls is no longer supported. "
"Please pass keep the class in a separate module "
"and pass the qualified name of the class as a string."
) )
if self.vllm_config.parallel_config.worker_extension_cls:
if parallel_config.worker_extension_cls:
worker_extension_cls = resolve_obj_by_qualname( worker_extension_cls = resolve_obj_by_qualname(
self.vllm_config.parallel_config.worker_extension_cls parallel_config.worker_extension_cls
) )
extended_calls = [] extended_calls = []
if worker_extension_cls not in worker_class.__bases__: if worker_extension_cls not in worker_class.__bases__:
...@@ -294,7 +289,7 @@ class WorkerWrapperBase: ...@@ -294,7 +289,7 @@ class WorkerWrapperBase:
"This argument is needed for mm_processor_cache_type='shm'." "This argument is needed for mm_processor_cache_type='shm'."
) )
mm_config = self.vllm_config.model_config.multimodal_config mm_config = vllm_config.model_config.multimodal_config
if mm_config and mm_config.mm_processor_cache_type == "shm": if mm_config and mm_config.mm_processor_cache_type == "shm":
raise ValueError(msg) raise ValueError(msg)
else: else:
...@@ -303,7 +298,7 @@ class WorkerWrapperBase: ...@@ -303,7 +298,7 @@ class WorkerWrapperBase:
self.mm_receiver_cache = None self.mm_receiver_cache = None
else: else:
self.mm_receiver_cache = worker_receiver_cache_from_config( self.mm_receiver_cache = worker_receiver_cache_from_config(
self.vllm_config, vllm_config,
MULTIMODAL_REGISTRY, MULTIMODAL_REGISTRY,
shared_worker_lock, shared_worker_lock,
) )
...@@ -311,7 +306,6 @@ class WorkerWrapperBase: ...@@ -311,7 +306,6 @@ class WorkerWrapperBase:
with set_current_vllm_config(self.vllm_config): with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during worker initialization # To make vLLM config available during worker initialization
self.worker = worker_class(**kwargs) self.worker = worker_class(**kwargs)
assert self.worker is not None
def initialize_from_config(self, kv_cache_configs: list[Any]) -> None: def initialize_from_config(self, kv_cache_configs: list[Any]) -> None:
kv_cache_config = kv_cache_configs[self.global_rank] kv_cache_config = kv_cache_configs[self.global_rank]
...@@ -358,20 +352,15 @@ class WorkerWrapperBase: ...@@ -358,20 +352,15 @@ class WorkerWrapperBase:
) )
def execute_model( def execute_model(
self, self, scheduler_output: SchedulerOutput
scheduler_output: SchedulerOutput,
*args,
**kwargs,
) -> ModelRunnerOutput | AsyncModelRunnerOutput | None: ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
self._apply_mm_cache(scheduler_output) self._apply_mm_cache(scheduler_output)
assert self.worker is not None return self.worker.execute_model(scheduler_output)
return self.worker.execute_model(scheduler_output, *args, **kwargs)
def reset_mm_cache(self) -> None: def reset_mm_cache(self) -> None:
mm_receiver_cache = self.mm_receiver_cache mm_receiver_cache = self.mm_receiver_cache
if mm_receiver_cache is not None: if mm_receiver_cache is not None:
mm_receiver_cache.clear_cache() mm_receiver_cache.clear_cache()
assert self.worker is not None
self.worker.reset_mm_cache() self.worker.reset_mm_cache()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment