Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9e8744a5
"vscode:/vscode.git/clone" did not exist on "7bdb03ea31105a087cff1d7db0431a7f49fe4f57"
Unverified
Commit
9e8744a5
authored
Mar 11, 2024
by
Roy
Committed by
GitHub
Mar 10, 2024
Browse files
[BugFix] Fix get tokenizer when using ray (#3301)
parent
e4a28e53
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
23 additions
and
7 deletions
+23
-7
tests/async_engine/test_async_llm_engine.py
tests/async_engine/test_async_llm_engine.py
+3
-0
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+7
-2
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+7
-1
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+1
-1
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+1
-1
vllm/transformers_utils/tokenizer.py
vllm/transformers_utils/tokenizer.py
+4
-2
No files found.
tests/async_engine/test_async_llm_engine.py
View file @
9e8744a5
...
@@ -89,3 +89,6 @@ async def test_new_requests_event():
...
@@ -89,3 +89,6 @@ async def test_new_requests_event():
await
asyncio
.
sleep
(
0.01
)
await
asyncio
.
sleep
(
0.01
)
assert
engine
.
engine
.
add_request_calls
==
3
assert
engine
.
engine
.
add_request_calls
==
3
assert
engine
.
engine
.
step_calls
==
old_step_calls
+
1
assert
engine
.
engine
.
step_calls
==
old_step_calls
+
1
engine
=
MockAsyncLLMEngine
(
worker_use_ray
=
True
,
engine_use_ray
=
True
)
assert
engine
.
get_tokenizer
()
is
not
None
vllm/engine/async_llm_engine.py
View file @
9e8744a5
...
@@ -5,6 +5,8 @@ from functools import partial
...
@@ -5,6 +5,8 @@ from functools import partial
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
Union
,
AsyncIterator
,
Callable
)
Union
,
AsyncIterator
,
Callable
)
from
transformers
import
PreTrainedTokenizer
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
...
@@ -372,8 +374,11 @@ class AsyncLLMEngine:
...
@@ -372,8 +374,11 @@ class AsyncLLMEngine:
self
.
set_errored
(
exc
)
self
.
set_errored
(
exc
)
self
.
_request_tracker
.
propagate_exception
(
exc
)
self
.
_request_tracker
.
propagate_exception
(
exc
)
def
get_tokenizer
(
self
):
async
def
get_tokenizer
(
self
)
->
"PreTrainedTokenizer"
:
return
self
.
engine
.
tokenizer
.
tokenizer
if
self
.
engine_use_ray
:
return
await
self
.
engine
.
get_tokenizer
.
remote
()
else
:
return
self
.
engine
.
get_tokenizer
()
def
start_background_loop
(
self
)
->
None
:
def
start_background_loop
(
self
)
->
None
:
"""Start the background loop."""
"""Start the background loop."""
...
...
vllm/engine/llm_engine.py
View file @
9e8744a5
...
@@ -7,6 +7,8 @@ import importlib
...
@@ -7,6 +7,8 @@ import importlib
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
)
Union
)
from
transformers
import
PreTrainedTokenizer
import
vllm
import
vllm
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
ModelConfig
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
ModelConfig
,
...
@@ -163,7 +165,11 @@ class LLMEngine:
...
@@ -163,7 +165,11 @@ class LLMEngine:
# the closure used to initialize Ray worker actors
# the closure used to initialize Ray worker actors
raise
RuntimeError
(
"LLMEngine should not be pickled!"
)
raise
RuntimeError
(
"LLMEngine should not be pickled!"
)
def
get_tokenizer_for_seq
(
self
,
sequence
:
Sequence
):
def
get_tokenizer
(
self
)
->
"PreTrainedTokenizer"
:
return
self
.
tokenizer
.
get_lora_tokenizer
()
def
get_tokenizer_for_seq
(
self
,
sequence
:
Sequence
)
->
"PreTrainedTokenizer"
:
return
self
.
tokenizer
.
get_lora_tokenizer
(
sequence
.
lora_request
)
return
self
.
tokenizer
.
get_lora_tokenizer
(
sequence
.
lora_request
)
def
_dispatch_worker
(
self
):
def
_dispatch_worker
(
self
):
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
9e8744a5
...
@@ -65,7 +65,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -65,7 +65,7 @@ class OpenAIServingChat(OpenAIServing):
lora_request
=
self
.
_maybe_get_lora
(
request
)
lora_request
=
self
.
_maybe_get_lora
(
request
)
guided_decode_logits_processor
=
(
guided_decode_logits_processor
=
(
await
get_guided_decoding_logits_processor
(
await
get_guided_decoding_logits_processor
(
request
,
self
.
engine
.
get_tokenizer
()))
request
,
await
self
.
engine
.
get_tokenizer
()))
if
guided_decode_logits_processor
:
if
guided_decode_logits_processor
:
if
sampling_params
.
logits_processors
is
None
:
if
sampling_params
.
logits_processors
is
None
:
sampling_params
.
logits_processors
=
[]
sampling_params
.
logits_processors
=
[]
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
9e8744a5
...
@@ -126,7 +126,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -126,7 +126,7 @@ class OpenAIServingCompletion(OpenAIServing):
lora_request
=
self
.
_maybe_get_lora
(
request
)
lora_request
=
self
.
_maybe_get_lora
(
request
)
guided_decode_logit_processor
=
(
guided_decode_logit_processor
=
(
await
get_guided_decoding_logits_processor
(
await
get_guided_decoding_logits_processor
(
request
,
self
.
engine
.
get_tokenizer
()))
request
,
await
self
.
engine
.
get_tokenizer
()))
if
guided_decode_logit_processor
is
not
None
:
if
guided_decode_logit_processor
is
not
None
:
if
sampling_params
.
logits_processors
is
None
:
if
sampling_params
.
logits_processors
is
None
:
sampling_params
.
logits_processors
=
[]
sampling_params
.
logits_processors
=
[]
...
...
vllm/transformers_utils/tokenizer.py
View file @
9e8744a5
...
@@ -120,7 +120,8 @@ class TokenizerGroup:
...
@@ -120,7 +120,8 @@ class TokenizerGroup:
def
get_lora_tokenizer
(
def
get_lora_tokenizer
(
self
,
self
,
lora_request
:
Optional
[
LoRARequest
])
->
"PreTrainedTokenizer"
:
lora_request
:
Optional
[
LoRARequest
]
=
None
)
->
"PreTrainedTokenizer"
:
if
not
lora_request
or
not
self
.
enable_lora
:
if
not
lora_request
or
not
self
.
enable_lora
:
return
self
.
tokenizer
return
self
.
tokenizer
if
lora_request
.
lora_int_id
not
in
self
.
lora_tokenizers
:
if
lora_request
.
lora_int_id
not
in
self
.
lora_tokenizers
:
...
@@ -133,7 +134,8 @@ class TokenizerGroup:
...
@@ -133,7 +134,8 @@ class TokenizerGroup:
async
def
get_lora_tokenizer_async
(
async
def
get_lora_tokenizer_async
(
self
,
self
,
lora_request
:
Optional
[
LoRARequest
])
->
"PreTrainedTokenizer"
:
lora_request
:
Optional
[
LoRARequest
]
=
None
)
->
"PreTrainedTokenizer"
:
if
not
lora_request
or
not
self
.
enable_lora
:
if
not
lora_request
or
not
self
.
enable_lora
:
return
self
.
tokenizer
return
self
.
tokenizer
if
lora_request
.
lora_int_id
not
in
self
.
lora_tokenizers
:
if
lora_request
.
lora_int_id
not
in
self
.
lora_tokenizers
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment