Unverified Commit 7bd82002 authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

[Core] Allow specifying custom Executor (#6557)

parent 2e265642
......@@ -2,6 +2,7 @@ from typing import List, Optional
from transformers import PreTrainedTokenizer
from vllm.config import TokenizerPoolConfig
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import (get_lora_tokenizer,
get_lora_tokenizer_async,
......@@ -24,6 +25,11 @@ class TokenizerGroup(BaseTokenizerGroup):
self.lora_tokenizers = LRUCache[PreTrainedTokenizer](
capacity=max_num_seqs) if enable_lora else None
@classmethod
def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],
**init_kwargs) -> "TokenizerGroup":
return cls(**init_kwargs)
def ping(self) -> bool:
"""Check if the tokenizer group is alive."""
return True
......
......@@ -2,7 +2,7 @@ import dataclasses
import importlib
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
import torch
......@@ -315,14 +315,23 @@ class WorkerWrapperBase:
We first instantiate the WorkerWrapper, which remembers the worker module
and class name. Then, when we call `update_environment_variables`, and the
real initialization happens in `init_worker`.
If worker_class_fn is specified, it will be executed to get the worker
class.
Otherwise, the worker class will be obtained by dynamically importing it
using worker_module_name and worker_class_name.
"""
def __init__(self,
worker_module_name: str,
worker_class_name: str,
trust_remote_code: bool = False) -> None:
def __init__(
self,
worker_module_name: str,
worker_class_name: str,
trust_remote_code: bool = False,
worker_class_fn: Optional[Callable[[],
Type[WorkerBase]]] = None) -> None:
self.worker_module_name = worker_module_name
self.worker_class_name = worker_class_name
self.worker_class_fn = worker_class_fn
self.worker: Optional[WorkerBase] = None
if trust_remote_code:
# note: lazy import to avoid importing torch before initializing
......@@ -348,8 +357,11 @@ class WorkerWrapperBase:
# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'
mod = importlib.import_module(self.worker_module_name)
worker_class = getattr(mod, self.worker_class_name)
if self.worker_class_fn:
worker_class = self.worker_class_fn()
else:
mod = importlib.import_module(self.worker_module_name)
worker_class = getattr(mod, self.worker_class_name)
self.worker = worker_class(*args, **kwargs)
assert self.worker is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment