Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
01d079fd
Unverified
Commit
01d079fd
authored
Dec 04, 2024
by
Xin Yang
Committed by
GitHub
Dec 04, 2024
Browse files
[LoRA] Change lora_tokenizers capacity (#10796)
Signed-off-by:
Xin Yang
<
xyang19@gmail.com
>
parent
c92acb96
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
31 additions
and
10 deletions
+31
-10
tests/lora/test_tokenizer_group.py
tests/lora/test_tokenizer_group.py
+20
-0
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+1
-1
vllm/engine/multiprocessing/client.py
vllm/engine/multiprocessing/client.py
+1
-2
vllm/transformers_utils/tokenizer_group/__init__.py
vllm/transformers_utils/tokenizer_group/__init__.py
+5
-4
vllm/transformers_utils/tokenizer_group/tokenizer_group.py
vllm/transformers_utils/tokenizer_group/tokenizer_group.py
+2
-1
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+1
-1
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+1
-1
No files found.
tests/lora/test_tokenizer_group.py
View file @
01d079fd
...
...
@@ -17,6 +17,7 @@ async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type):
tokenizer_id
=
"gpt2"
,
enable_lora
=
True
,
max_num_seqs
=
1
,
max_loras
=
1
,
max_input_length
=
None
,
)
lora_request
=
LoRARequest
(
"1"
,
1
,
sql_lora_files
)
...
...
@@ -53,3 +54,22 @@ def test_get_lora_tokenizer(sql_lora_files, tmp_path):
lora_request
=
LoRARequest
(
"1"
,
1
,
str
(
tmp_path
))
tokenizer
=
get_lora_tokenizer
(
lora_request
)
assert
not
tokenizer
@
pytest
.
mark
.
parametrize
(
"enable_lora"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"max_num_seqs"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"max_loras"
,
[
1
,
2
])
def
test_lora_tokenizers
(
enable_lora
,
max_num_seqs
,
max_loras
):
tokenizer_group
=
get_tokenizer_group
(
get_tokenizer_pool_config
(
None
),
tokenizer_id
=
"gpt2"
,
enable_lora
=
enable_lora
,
max_num_seqs
=
max_num_seqs
,
max_loras
=
max_loras
,
max_input_length
=
None
,
)
if
enable_lora
:
assert
tokenizer_group
.
lora_tokenizers
.
capacity
==
max
(
max_num_seqs
,
max_loras
)
else
:
assert
tokenizer_group
.
lora_tokenizers
.
capacity
==
0
vllm/engine/llm_engine.py
View file @
01d079fd
...
...
@@ -620,7 +620,7 @@ class LLMEngine:
model_config
=
self
.
model_config
,
scheduler_config
=
self
.
scheduler_config
,
parallel_config
=
self
.
parallel_config
,
enable_lora
=
bool
(
self
.
lora_config
)
)
lora_config
=
self
.
lora_config
)
def
_verify_args
(
self
)
->
None
:
self
.
model_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
...
...
vllm/engine/multiprocessing/client.py
View file @
01d079fd
...
...
@@ -94,8 +94,7 @@ class MQLLMEngineClient(EngineClient):
model_config
=
self
.
model_config
,
scheduler_config
=
engine_config
.
scheduler_config
,
parallel_config
=
engine_config
.
parallel_config
,
enable_lora
=
bool
(
engine_config
.
lora_config
),
)
lora_config
=
engine_config
.
lora_config
)
self
.
input_preprocessor
=
InputPreprocessor
(
self
.
model_config
,
self
.
tokenizer
)
...
...
vllm/transformers_utils/tokenizer_group/__init__.py
View file @
01d079fd
from
typing
import
Optional
,
Type
from
vllm.config
import
(
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
TokenizerPoolConfig
)
from
vllm.config
import
(
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
TokenizerPoolConfig
)
from
vllm.executor.ray_utils
import
ray
from
.base_tokenizer_group
import
AnyTokenizer
,
BaseTokenizerGroup
...
...
@@ -16,10 +16,11 @@ else:
def
init_tokenizer_from_configs
(
model_config
:
ModelConfig
,
scheduler_config
:
SchedulerConfig
,
parallel_config
:
ParallelConfig
,
enable_lora
:
bool
):
lora_config
:
LoRAConfig
):
init_kwargs
=
dict
(
tokenizer_id
=
model_config
.
tokenizer
,
enable_lora
=
enable_lora
,
enable_lora
=
bool
(
lora_config
)
,
max_num_seqs
=
scheduler_config
.
max_num_seqs
,
max_loras
=
lora_config
.
max_loras
if
lora_config
else
0
,
max_input_length
=
None
,
tokenizer_mode
=
model_config
.
tokenizer_mode
,
trust_remote_code
=
model_config
.
trust_remote_code
,
...
...
vllm/transformers_utils/tokenizer_group/tokenizer_group.py
View file @
01d079fd
...
...
@@ -21,8 +21,9 @@ class TokenizerGroup(BaseTokenizerGroup):
self
.
enable_lora
=
enable_lora
self
.
max_input_length
=
max_input_length
self
.
tokenizer
=
get_tokenizer
(
self
.
tokenizer_id
,
**
tokenizer_config
)
max_loras
=
tokenizer_config
.
get
(
"max_loras"
,
0
)
self
.
lora_tokenizers
=
LRUCache
[
AnyTokenizer
](
capacity
=
max_num_seqs
if
enable_lora
else
0
)
capacity
=
max
(
max_loras
,
max_num_seqs
)
if
enable_lora
else
0
)
@
classmethod
def
from_config
(
cls
,
tokenizer_pool_config
:
Optional
[
TokenizerPoolConfig
],
...
...
vllm/v1/engine/async_llm.py
View file @
01d079fd
...
...
@@ -51,7 +51,7 @@ class AsyncLLM(EngineClient):
model_config
=
vllm_config
.
model_config
,
scheduler_config
=
vllm_config
.
scheduler_config
,
parallel_config
=
vllm_config
.
parallel_config
,
enable_lora
=
bool
(
vllm_config
.
lora_config
)
)
lora_config
=
vllm_config
.
lora_config
)
self
.
tokenizer
.
ping
()
# Request streams (map of request_id -> AsyncStream).
...
...
vllm/v1/engine/llm_engine.py
View file @
01d079fd
...
...
@@ -46,7 +46,7 @@ class LLMEngine:
model_config
=
vllm_config
.
model_config
,
scheduler_config
=
vllm_config
.
scheduler_config
,
parallel_config
=
vllm_config
.
parallel_config
,
enable_lora
=
bool
(
vllm_config
.
lora_config
)
)
lora_config
=
vllm_config
.
lora_config
)
self
.
tokenizer
.
ping
()
# Processor (convert Inputs --> EngineCoreRequests)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment