Unverified Commit 9e8744a5 authored by Roy's avatar Roy Committed by GitHub
Browse files

[BugFix] Fix get tokenizer when using ray (#3301)

parent e4a28e53
......@@ -89,3 +89,6 @@ async def test_new_requests_event():
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == old_step_calls + 1
engine = MockAsyncLLMEngine(worker_use_ray=True, engine_use_ray=True)
assert engine.get_tokenizer() is not None
......@@ -5,6 +5,8 @@ from functools import partial
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
Union, AsyncIterator, Callable)
from transformers import PreTrainedTokenizer
from vllm.lora.request import LoRARequest
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
......@@ -372,8 +374,11 @@ class AsyncLLMEngine:
self.set_errored(exc)
self._request_tracker.propagate_exception(exc)
def get_tokenizer(self):
return self.engine.tokenizer.tokenizer
async def get_tokenizer(self) -> "PreTrainedTokenizer":
if self.engine_use_ray:
return await self.engine.get_tokenizer.remote()
else:
return self.engine.get_tokenizer()
def start_background_loop(self) -> None:
"""Start the background loop."""
......
......@@ -7,6 +7,8 @@ import importlib
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
Union)
from transformers import PreTrainedTokenizer
import vllm
from vllm.lora.request import LoRARequest
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
......@@ -163,7 +165,11 @@ class LLMEngine:
# the closure used to initialize Ray worker actors
raise RuntimeError("LLMEngine should not be pickled!")
def get_tokenizer_for_seq(self, sequence: Sequence):
def get_tokenizer(self) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer()
def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
def _dispatch_worker(self):
......
......@@ -65,7 +65,7 @@ class OpenAIServingChat(OpenAIServing):
lora_request = self._maybe_get_lora(request)
guided_decode_logits_processor = (
await get_guided_decoding_logits_processor(
request, self.engine.get_tokenizer()))
request, await self.engine.get_tokenizer()))
if guided_decode_logits_processor:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
......
......@@ -126,7 +126,7 @@ class OpenAIServingCompletion(OpenAIServing):
lora_request = self._maybe_get_lora(request)
guided_decode_logit_processor = (
await get_guided_decoding_logits_processor(
request, self.engine.get_tokenizer()))
request, await self.engine.get_tokenizer()))
if guided_decode_logit_processor is not None:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
......
......@@ -120,7 +120,8 @@ class TokenizerGroup:
def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
......@@ -133,7 +134,8 @@ class TokenizerGroup:
async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment