Unverified Commit a111d015 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[platforms] absorb worker cls difference into platforms folder (#10555)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
Co-authored-by: default avatarNick Hill <nhill@redhat.com>
parent 446c7806
import dataclasses import dataclasses
import importlib
import os import os
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
import torch import torch
...@@ -15,7 +14,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput ...@@ -15,7 +14,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import (enable_trace_function_call_for_thread, from vllm.utils import (enable_trace_function_call_for_thread,
update_environment_variables) resolve_obj_by_qualname, update_environment_variables)
from vllm.worker.model_runner_base import (BroadcastableModelInput, from vllm.worker.model_runner_base import (BroadcastableModelInput,
ModelRunnerBase, ModelRunnerBase,
ModelRunnerInputBase) ModelRunnerInputBase)
...@@ -411,23 +410,14 @@ class WorkerWrapperBase: ...@@ -411,23 +410,14 @@ class WorkerWrapperBase:
We first instantiate the WorkerWrapper, which remembers the worker module We first instantiate the WorkerWrapper, which remembers the worker module
and class name. Then, when we call `update_environment_variables`, and the and class name. Then, when we call `update_environment_variables`, and the
real initialization happens in `init_worker`. real initialization happens in `init_worker`.
If worker_class_fn is specified, it will be executed to get the worker
class.
Otherwise, the worker class will be obtained by dynamically importing it
using worker_module_name and worker_class_name.
""" """
def __init__( def __init__(
self, self,
worker_module_name: str, vllm_config: VllmConfig,
worker_class_name: str, ) -> None:
trust_remote_code: bool = False, self.vllm_config = vllm_config
worker_class_fn: Optional[Callable[[], trust_remote_code = vllm_config.model_config.trust_remote_code
Type[WorkerBase]]] = None) -> None:
self.worker_module_name = worker_module_name
self.worker_class_name = worker_class_name
self.worker_class_fn = worker_class_fn
self.worker: Optional[WorkerBase] = None self.worker: Optional[WorkerBase] = None
if trust_remote_code: if trust_remote_code:
# note: lazy import to avoid importing torch before initializing # note: lazy import to avoid importing torch before initializing
...@@ -456,12 +446,8 @@ class WorkerWrapperBase: ...@@ -456,12 +446,8 @@ class WorkerWrapperBase:
from vllm.plugins import load_general_plugins from vllm.plugins import load_general_plugins
load_general_plugins() load_general_plugins()
if self.worker_class_fn: worker_class = resolve_obj_by_qualname(
worker_class = self.worker_class_fn() self.vllm_config.parallel_config.worker_cls)
else:
mod = importlib.import_module(self.worker_module_name)
worker_class = getattr(mod, self.worker_class_name)
self.worker = worker_class(*args, **kwargs) self.worker = worker_class(*args, **kwargs)
assert self.worker is not None assert self.worker is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment