Unverified Commit a111d015 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[platforms] absorb worker cls difference into platforms folder (#10555)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
Co-authored-by: default avatarNick Hill <nhill@redhat.com>
parent 446c7806
import dataclasses
import importlib
import os
import time
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
import torch
......@@ -15,7 +14,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import (enable_trace_function_call_for_thread,
update_environment_variables)
resolve_obj_by_qualname, update_environment_variables)
from vllm.worker.model_runner_base import (BroadcastableModelInput,
ModelRunnerBase,
ModelRunnerInputBase)
......@@ -411,23 +410,14 @@ class WorkerWrapperBase:
We first instantiate the WorkerWrapper, which remembers the worker module
and class name. Then, when we call `update_environment_variables`, and the
real initialization happens in `init_worker`.
If worker_class_fn is specified, it will be executed to get the worker
class.
Otherwise, the worker class will be obtained by dynamically importing it
using worker_module_name and worker_class_name.
"""
def __init__(
self,
worker_module_name: str,
worker_class_name: str,
trust_remote_code: bool = False,
worker_class_fn: Optional[Callable[[],
Type[WorkerBase]]] = None) -> None:
self.worker_module_name = worker_module_name
self.worker_class_name = worker_class_name
self.worker_class_fn = worker_class_fn
vllm_config: VllmConfig,
) -> None:
self.vllm_config = vllm_config
trust_remote_code = vllm_config.model_config.trust_remote_code
self.worker: Optional[WorkerBase] = None
if trust_remote_code:
# note: lazy import to avoid importing torch before initializing
......@@ -456,12 +446,8 @@ class WorkerWrapperBase:
from vllm.plugins import load_general_plugins
load_general_plugins()
if self.worker_class_fn:
worker_class = self.worker_class_fn()
else:
mod = importlib.import_module(self.worker_module_name)
worker_class = getattr(mod, self.worker_class_name)
worker_class = resolve_obj_by_qualname(
self.vllm_config.parallel_config.worker_cls)
self.worker = worker_class(*args, **kwargs)
assert self.worker is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment