"vscode:/vscode.git/clone" did not exist on "7bdb03ea31105a087cff1d7db0431a7f49fe4f57"
Unverified Commit 9e8744a5 authored by Roy's avatar Roy Committed by GitHub
Browse files

[BugFix] Fix get tokenizer when using ray (#3301)

parent e4a28e53
...@@ -89,3 +89,6 @@ async def test_new_requests_event(): ...@@ -89,3 +89,6 @@ async def test_new_requests_event():
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3 assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == old_step_calls + 1 assert engine.engine.step_calls == old_step_calls + 1
engine = MockAsyncLLMEngine(worker_use_ray=True, engine_use_ray=True)
assert engine.get_tokenizer() is not None
...@@ -5,6 +5,8 @@ from functools import partial ...@@ -5,6 +5,8 @@ from functools import partial
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type, from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
Union, AsyncIterator, Callable) Union, AsyncIterator, Callable)
from transformers import PreTrainedTokenizer
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
...@@ -372,8 +374,11 @@ class AsyncLLMEngine: ...@@ -372,8 +374,11 @@ class AsyncLLMEngine:
self.set_errored(exc) self.set_errored(exc)
self._request_tracker.propagate_exception(exc) self._request_tracker.propagate_exception(exc)
def get_tokenizer(self): async def get_tokenizer(self) -> "PreTrainedTokenizer":
return self.engine.tokenizer.tokenizer if self.engine_use_ray:
return await self.engine.get_tokenizer.remote()
else:
return self.engine.get_tokenizer()
def start_background_loop(self) -> None: def start_background_loop(self) -> None:
"""Start the background loop.""" """Start the background loop."""
......
...@@ -7,6 +7,8 @@ import importlib ...@@ -7,6 +7,8 @@ import importlib
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
Union) Union)
from transformers import PreTrainedTokenizer
import vllm import vllm
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
...@@ -163,7 +165,11 @@ class LLMEngine: ...@@ -163,7 +165,11 @@ class LLMEngine:
# the closure used to initialize Ray worker actors # the closure used to initialize Ray worker actors
raise RuntimeError("LLMEngine should not be pickled!") raise RuntimeError("LLMEngine should not be pickled!")
def get_tokenizer_for_seq(self, sequence: Sequence): def get_tokenizer(self) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer()
def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer(sequence.lora_request) return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
def _dispatch_worker(self): def _dispatch_worker(self):
......
...@@ -65,7 +65,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -65,7 +65,7 @@ class OpenAIServingChat(OpenAIServing):
lora_request = self._maybe_get_lora(request) lora_request = self._maybe_get_lora(request)
guided_decode_logits_processor = ( guided_decode_logits_processor = (
await get_guided_decoding_logits_processor( await get_guided_decoding_logits_processor(
request, self.engine.get_tokenizer())) request, await self.engine.get_tokenizer()))
if guided_decode_logits_processor: if guided_decode_logits_processor:
if sampling_params.logits_processors is None: if sampling_params.logits_processors is None:
sampling_params.logits_processors = [] sampling_params.logits_processors = []
......
...@@ -126,7 +126,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -126,7 +126,7 @@ class OpenAIServingCompletion(OpenAIServing):
lora_request = self._maybe_get_lora(request) lora_request = self._maybe_get_lora(request)
guided_decode_logit_processor = ( guided_decode_logit_processor = (
await get_guided_decoding_logits_processor( await get_guided_decoding_logits_processor(
request, self.engine.get_tokenizer())) request, await self.engine.get_tokenizer()))
if guided_decode_logit_processor is not None: if guided_decode_logit_processor is not None:
if sampling_params.logits_processors is None: if sampling_params.logits_processors is None:
sampling_params.logits_processors = [] sampling_params.logits_processors = []
......
...@@ -120,7 +120,8 @@ class TokenizerGroup: ...@@ -120,7 +120,8 @@ class TokenizerGroup:
def get_lora_tokenizer( def get_lora_tokenizer(
self, self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
if not lora_request or not self.enable_lora: if not lora_request or not self.enable_lora:
return self.tokenizer return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers: if lora_request.lora_int_id not in self.lora_tokenizers:
...@@ -133,7 +134,8 @@ class TokenizerGroup: ...@@ -133,7 +134,8 @@ class TokenizerGroup:
async def get_lora_tokenizer_async( async def get_lora_tokenizer_async(
self, self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
if not lora_request or not self.enable_lora: if not lora_request or not self.enable_lora:
return self.tokenizer return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers: if lora_request.lora_int_id not in self.lora_tokenizers:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment