Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
da1f7cc1
Unverified
Commit
da1f7cc1
authored
Jul 31, 2024
by
Cyrus Leung
Committed by
GitHub
Jul 31, 2024
Browse files
[mypy] Enable following imports for some directories (#6681)
parent
c32ab8be
Changes
18
Show whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
185 additions
and
143 deletions
+185
-143
.github/workflows/mypy.yaml
.github/workflows/mypy.yaml
+13
-18
format.sh
format.sh
+13
-17
pyproject.toml
pyproject.toml
+16
-2
vllm/_custom_ops.py
vllm/_custom_ops.py
+1
-1
vllm/_ipex_ops.py
vllm/_ipex_ops.py
+49
-13
vllm/adapter_commons/models.py
vllm/adapter_commons/models.py
+15
-15
vllm/adapter_commons/request.py
vllm/adapter_commons/request.py
+5
-5
vllm/adapter_commons/worker_manager.py
vllm/adapter_commons/worker_manager.py
+7
-7
vllm/config.py
vllm/config.py
+2
-1
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+6
-8
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+1
-3
vllm/scripts.py
vllm/scripts.py
+5
-5
vllm/transformers_utils/detokenizer.py
vllm/transformers_utils/detokenizer.py
+2
-0
vllm/transformers_utils/tokenizer_group/__init__.py
vllm/transformers_utils/tokenizer_group/__init__.py
+4
-5
vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py
...ransformers_utils/tokenizer_group/base_tokenizer_group.py
+10
-8
vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
...transformers_utils/tokenizer_group/ray_tokenizer_group.py
+15
-15
vllm/transformers_utils/tokenizer_group/tokenizer_group.py
vllm/transformers_utils/tokenizer_group/tokenizer_group.py
+13
-15
vllm/utils.py
vllm/utils.py
+8
-5
No files found.
.github/workflows/mypy.yaml
View file @
da1f7cc1
...
@@ -32,22 +32,17 @@ jobs:
...
@@ -32,22 +32,17 @@ jobs:
pip install types-setuptools
pip install types-setuptools
-
name
:
Mypy
-
name
:
Mypy
run
:
|
run
:
|
mypy tests --config-file pyproject.toml
mypy tests --follow-imports skip
mypy vllm/*.py --config-file pyproject.toml
mypy vllm/attention --follow-imports skip
mypy vllm/attention --config-file pyproject.toml
mypy vllm/core --follow-imports skip
mypy vllm/core --config-file pyproject.toml
mypy vllm/distributed --follow-imports skip
mypy vllm/distributed --config-file pyproject.toml
mypy vllm/engine --follow-imports skip
mypy vllm/engine --config-file pyproject.toml
mypy vllm/entrypoints --follow-imports skip
mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --follow-imports skip
mypy vllm/executor --config-file pyproject.toml
mypy vllm/lora --follow-imports skip
mypy vllm/inputs --config-file pyproject.toml
mypy vllm/model_executor --follow-imports skip
mypy vllm/logging --config-file pyproject.toml
mypy vllm/prompt_adapter --follow-imports skip
mypy vllm/lora --config-file pyproject.toml
mypy vllm/spec_decode --follow-imports skip
mypy vllm/model_executor --config-file pyproject.toml
mypy vllm/worker --follow-imports skip
mypy vllm/multimodal --config-file pyproject.toml
mypy
mypy vllm/platforms --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/usage --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
format.sh
View file @
da1f7cc1
...
@@ -96,23 +96,19 @@ echo 'vLLM yapf: Done'
...
@@ -96,23 +96,19 @@ echo 'vLLM yapf: Done'
# Run mypy
# Run mypy
echo
'vLLM mypy:'
echo
'vLLM mypy:'
mypy tests
--config-file
pyproject.toml
mypy tests
--follow-imports
skip
mypy vllm/
*
.py
--config-file
pyproject.toml
mypy vllm/attention
--follow-imports
skip
mypy vllm/attention
--config-file
pyproject.toml
mypy vllm/core
--follow-imports
skip
mypy vllm/core
--config-file
pyproject.toml
mypy vllm/distributed
--follow-imports
skip
mypy vllm/distributed
--config-file
pyproject.toml
mypy vllm/engine
--follow-imports
skip
mypy vllm/engine
--config-file
pyproject.toml
mypy vllm/entrypoints
--follow-imports
skip
mypy vllm/entrypoints
--config-file
pyproject.toml
mypy vllm/executor
--follow-imports
skip
mypy vllm/executor
--config-file
pyproject.toml
mypy vllm/lora
--follow-imports
skip
mypy vllm/logging
--config-file
pyproject.toml
mypy vllm/model_executor
--follow-imports
skip
mypy vllm/lora
--config-file
pyproject.toml
mypy vllm/prompt_adapter
--follow-imports
skip
mypy vllm/model_executor
--config-file
pyproject.toml
mypy vllm/spec_decode
--follow-imports
skip
mypy vllm/multimodal
--config-file
pyproject.toml
mypy vllm/worker
--follow-imports
skip
mypy vllm/prompt_adapter
--config-file
pyproject.toml
mypy
mypy vllm/spec_decode
--config-file
pyproject.toml
mypy vllm/transformers_utils
--config-file
pyproject.toml
mypy vllm/usage
--config-file
pyproject.toml
mypy vllm/worker
--config-file
pyproject.toml
# If git diff returns a file that is in the skip list, the file may be checked anyway:
# If git diff returns a file that is in the skip list, the file may be checked anyway:
...
...
pyproject.toml
View file @
da1f7cc1
...
@@ -48,9 +48,23 @@ python_version = "3.8"
...
@@ -48,9 +48,23 @@ python_version = "3.8"
ignore_missing_imports
=
true
ignore_missing_imports
=
true
check_untyped_defs
=
true
check_untyped_defs
=
true
follow_imports
=
"s
kip
"
follow_imports
=
"s
ilent
"
files
=
"vllm"
# After fixing type errors resulting from follow_imports: "skip" -> "silent",
# move the directory here and remove it from format.sh and mypy.yaml
files
=
[
"vllm/*.py"
,
"vllm/adapter_commons"
,
"vllm/assets"
,
"vllm/inputs"
,
"vllm/logging"
,
"vllm/multimodal"
,
"vllm/platforms"
,
"vllm/server"
,
"vllm/transformers_utils"
,
"vllm/triton_utils"
,
"vllm/usage"
,
]
# TODO(woosuk): Include the code from Megatron and HuggingFace.
# TODO(woosuk): Include the code from Megatron and HuggingFace.
exclude
=
[
exclude
=
[
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/"
,
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/"
,
...
...
vllm/_custom_ops.py
View file @
da1f7cc1
...
@@ -239,7 +239,7 @@ def cutlass_scaled_mm(a: torch.Tensor,
...
@@ -239,7 +239,7 @@ def cutlass_scaled_mm(a: torch.Tensor,
b
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
Type
[
torch
.
dtype
]
,
out_dtype
:
torch
.
dtype
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
assert
(
b
.
shape
[
0
]
%
16
==
0
and
b
.
shape
[
1
]
%
16
==
0
)
assert
(
b
.
shape
[
0
]
%
16
==
0
and
b
.
shape
[
1
]
%
16
==
0
)
assert
(
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float16
)
assert
(
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float16
)
...
...
vllm/_ipex_ops.py
View file @
da1f7cc1
...
@@ -25,27 +25,33 @@ class ipex_ops:
...
@@ -25,27 +25,33 @@ class ipex_ops:
x2
=
x2
.
reshape
(
num
,
d
)
x2
=
x2
.
reshape
(
num
,
d
)
return
x1
,
x2
return
x1
,
x2
@
staticmethod
def
silu_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
def
silu_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
x1
,
x2
=
ipex_ops
.
_reshape_activation_tensor
(
x
)
x1
,
x2
=
ipex_ops
.
_reshape_activation_tensor
(
x
)
ipex
.
llm
.
functional
.
silu_mul
(
x1
,
x2
,
out
)
ipex
.
llm
.
functional
.
silu_mul
(
x1
,
x2
,
out
)
@
staticmethod
def
gelu_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
def
gelu_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
x1
,
x2
=
ipex_ops
.
_reshape_activation_tensor
(
x
)
x1
,
x2
=
ipex_ops
.
_reshape_activation_tensor
(
x
)
ipex
.
llm
.
functional
.
gelu_mul
(
x1
,
x2
,
out
,
"none"
)
ipex
.
llm
.
functional
.
gelu_mul
(
x1
,
x2
,
out
,
"none"
)
@
staticmethod
def
gelu_tanh_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
def
gelu_tanh_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
x1
,
x2
=
ipex_ops
.
_reshape_activation_tensor
(
x
)
x1
,
x2
=
ipex_ops
.
_reshape_activation_tensor
(
x
)
ipex
.
llm
.
functional
.
gelu_mul
(
x1
,
x2
,
out
,
"tanh"
)
ipex
.
llm
.
functional
.
gelu_mul
(
x1
,
x2
,
out
,
"tanh"
)
@
staticmethod
def
gelu_fast
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
def
gelu_fast
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
out
.
copy_
(
torch
.
nn
.
functional
.
gelu
(
x
))
out
.
copy_
(
torch
.
nn
.
functional
.
gelu
(
x
))
@
staticmethod
def
gelu_new
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
def
gelu_new
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
out
.
copy_
(
torch
.
nn
.
functional
.
gelu
(
x
))
out
.
copy_
(
torch
.
nn
.
functional
.
gelu
(
x
))
# TODO add implementation of gelu_quick here
# TODO add implementation of gelu_quick here
# def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
# def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
@
staticmethod
def
paged_attention_v1
(
def
paged_attention_v1
(
out
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
...
@@ -78,12 +84,21 @@ class ipex_ops:
...
@@ -78,12 +84,21 @@ class ipex_ops:
).
view
(
num_kv_heads
,
).
view
(
num_kv_heads
,
1
).
repeat_interleave
(
num_queries_per_tokens
).
flatten
()
1
).
repeat_interleave
(
num_queries_per_tokens
).
flatten
()
# todo: ipex will refactor namespace
# todo: ipex will refactor namespace
torch
.
xpu
.
paged_attention_v1
(
out
,
query
.
contiguous
(),
torch
.
xpu
.
paged_attention_v1
(
# type: ignore
out
,
query
.
contiguous
(),
key_cache
.
view_as
(
value_cache
),
key_cache
.
view_as
(
value_cache
),
value_cache
,
head_mapping
,
scale
,
value_cache
,
block_tables
,
context_lens
,
block_size
,
head_mapping
,
max_context_len
,
alibi_slopes
)
scale
,
block_tables
,
context_lens
,
block_size
,
max_context_len
,
alibi_slopes
,
)
@
staticmethod
def
paged_attention_v2
(
def
paged_attention_v2
(
out
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
exp_sum
:
torch
.
Tensor
,
exp_sum
:
torch
.
Tensor
,
...
@@ -119,13 +134,24 @@ class ipex_ops:
...
@@ -119,13 +134,24 @@ class ipex_ops:
).
view
(
num_kv_heads
,
).
view
(
num_kv_heads
,
1
).
repeat_interleave
(
num_queries_per_tokens
).
flatten
()
1
).
repeat_interleave
(
num_queries_per_tokens
).
flatten
()
# todo: ipex will refactor namespace
# todo: ipex will refactor namespace
torch
.
xpu
.
paged_attention_v2
(
out
,
exp_sum
,
max_logits
,
tmp_out
,
torch
.
xpu
.
paged_attention_v2
(
# type: ignore
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
.
contiguous
(),
query
.
contiguous
(),
key_cache
.
view_as
(
value_cache
),
key_cache
.
view_as
(
value_cache
),
value_cache
,
head_mapping
,
block_tables
,
value_cache
,
context_lens
,
scale
,
block_size
,
head_mapping
,
max_context_len
,
alibi_slopes
)
block_tables
,
context_lens
,
scale
,
block_size
,
max_context_len
,
alibi_slopes
,
)
@
staticmethod
def
rotary_embedding
(
def
rotary_embedding
(
positions
:
torch
.
Tensor
,
# [batch_size, seq_len]
positions
:
torch
.
Tensor
,
# [batch_size, seq_len]
query
:
torch
.
Tensor
,
# [batch_size, seq_len, num_heads*head_size]
query
:
torch
.
Tensor
,
# [batch_size, seq_len, num_heads*head_size]
...
@@ -158,6 +184,7 @@ class ipex_ops:
...
@@ -158,6 +184,7 @@ class ipex_ops:
ipex
.
llm
.
functional
.
rotary_embedding
(
query_rot
,
key_rot
,
sin
,
cos
,
ipex
.
llm
.
functional
.
rotary_embedding
(
query_rot
,
key_rot
,
sin
,
cos
,
rotary_dim
,
is_neox
,
positions
)
rotary_dim
,
is_neox
,
positions
)
@
staticmethod
def
batched_rotary_embedding
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
def
batched_rotary_embedding
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
head_size
:
int
,
key
:
torch
.
Tensor
,
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
,
...
@@ -189,17 +216,20 @@ class ipex_ops:
...
@@ -189,17 +216,20 @@ class ipex_ops:
ipex
.
llm
.
functional
.
rotary_embedding
(
query_rot
,
key_rot
,
sin
,
cos
,
ipex
.
llm
.
functional
.
rotary_embedding
(
query_rot
,
key_rot
,
sin
,
cos
,
rotary_dim
,
is_neox
,
positions
)
rotary_dim
,
is_neox
,
positions
)
@
staticmethod
def
rms_norm
(
out
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
def
rms_norm
(
out
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
epsilon
:
float
)
->
None
:
tmp
=
ipex
.
llm
.
functional
.
rms_norm
(
input
,
weight
,
epsilon
)
tmp
=
ipex
.
llm
.
functional
.
rms_norm
(
input
,
weight
,
epsilon
)
out
.
copy_
(
tmp
)
out
.
copy_
(
tmp
)
@
staticmethod
def
fused_add_rms_norm
(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
def
fused_add_rms_norm
(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
weight
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
tmp
=
ipex
.
llm
.
functional
.
add_rms_norm
(
residual
,
input
,
weight
,
None
,
tmp
=
ipex
.
llm
.
functional
.
add_rms_norm
(
residual
,
input
,
weight
,
None
,
epsilon
,
True
)
epsilon
,
True
)
input
.
copy_
(
tmp
)
input
.
copy_
(
tmp
)
@
staticmethod
def
varlen_attention
(
def
varlen_attention
(
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
...
@@ -222,6 +252,7 @@ class ipex_ops:
...
@@ -222,6 +252,7 @@ class ipex_ops:
softmax_scale
,
zero_tensors
,
softmax_scale
,
zero_tensors
,
is_causal
,
return_softmax
,
gen_
)
is_causal
,
return_softmax
,
gen_
)
@
staticmethod
def
reshape_and_cache
(
def
reshape_and_cache
(
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
...
@@ -240,8 +271,13 @@ class ipex_ops:
...
@@ -240,8 +271,13 @@ class ipex_ops:
def
copy_blocks
(
key_caches
:
List
[
torch
.
Tensor
],
def
copy_blocks
(
key_caches
:
List
[
torch
.
Tensor
],
value_caches
:
List
[
torch
.
Tensor
],
value_caches
:
List
[
torch
.
Tensor
],
block_mapping
:
torch
.
Tensor
)
->
None
:
block_mapping
:
torch
.
Tensor
)
->
None
:
torch
.
xpu
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping
)
torch
.
xpu
.
copy_blocks
(
# type: ignore
key_caches
,
value_caches
,
block_mapping
,
)
@
staticmethod
def
swap_blocks
(
src
:
torch
.
Tensor
,
dst
:
torch
.
Tensor
,
def
swap_blocks
(
src
:
torch
.
Tensor
,
dst
:
torch
.
Tensor
,
block_mapping
:
torch
.
Tensor
)
->
None
:
block_mapping
:
torch
.
Tensor
)
->
None
:
torch
.
xpu
.
swap_blocks
(
src
,
dst
,
block_mapping
)
torch
.
xpu
.
swap_blocks
(
src
,
dst
,
block_mapping
)
# type: ignore
vllm/adapter_commons/models.py
View file @
da1f7cc1
...
@@ -31,7 +31,7 @@ class AdapterLRUCache(LRUCache[T]):
...
@@ -31,7 +31,7 @@ class AdapterLRUCache(LRUCache[T]):
super
().
__init__
(
capacity
)
super
().
__init__
(
capacity
)
self
.
deactivate_fn
=
deactivate_fn
self
.
deactivate_fn
=
deactivate_fn
def
_on_remove
(
self
,
key
:
Hashable
,
value
:
T
):
def
_on_remove
(
self
,
key
:
Hashable
,
value
:
Optional
[
T
]
):
logger
.
debug
(
"Removing adapter int id: %d"
,
key
)
logger
.
debug
(
"Removing adapter int id: %d"
,
key
)
self
.
deactivate_fn
(
key
)
self
.
deactivate_fn
(
key
)
return
super
().
_on_remove
(
key
,
value
)
return
super
().
_on_remove
(
key
,
value
)
...
@@ -59,46 +59,46 @@ class AdapterModelManager(ABC):
...
@@ -59,46 +59,46 @@ class AdapterModelManager(ABC):
@
property
@
property
@
abstractmethod
@
abstractmethod
def
adapter_slots
(
self
):
def
adapter_slots
(
self
)
->
int
:
...
raise
NotImplementedError
@
property
@
property
@
abstractmethod
@
abstractmethod
def
capacity
(
self
):
def
capacity
(
self
)
->
int
:
...
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
activate_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
def
activate_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
...
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
deactivate_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
def
deactivate_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
...
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
add_adapter
(
self
,
adapter
:
Any
)
->
bool
:
def
add_adapter
(
self
,
adapter
:
Any
)
->
bool
:
...
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
set_adapter_mapping
(
self
,
mapping
:
Any
)
->
None
:
def
set_adapter_mapping
(
self
,
mapping
:
Any
)
->
None
:
...
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
remove_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
def
remove_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
...
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
remove_all_adapters
(
self
):
def
remove_all_adapters
(
self
)
->
None
:
...
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
get_adapter
(
self
,
adapter_id
:
int
)
->
Optional
[
Any
]:
def
get_adapter
(
self
,
adapter_id
:
int
)
->
Optional
[
Any
]:
...
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
list_adapters
(
self
)
->
Dict
[
int
,
Any
]:
def
list_adapters
(
self
)
->
Dict
[
int
,
Any
]:
...
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
pin_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
def
pin_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
...
raise
NotImplementedError
vllm/adapter_commons/request.py
View file @
da1f7cc1
from
abc
import
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
@
dataclass
@
dataclass
class
AdapterRequest
:
class
AdapterRequest
(
ABC
)
:
"""
"""
Base class for adapter requests.
Base class for adapter requests.
"""
"""
@
property
@
property
@
abstractmethod
@
abstractmethod
def
adapter_id
(
self
):
def
adapter_id
(
self
)
->
int
:
...
raise
NotImplementedError
def
__post_init__
(
self
):
def
__post_init__
(
self
)
->
None
:
if
self
.
adapter_id
<
1
:
if
self
.
adapter_id
<
1
:
raise
ValueError
(
f
"id must be > 0, got
{
self
.
adapter_id
}
"
)
raise
ValueError
(
f
"id must be > 0, got
{
self
.
adapter_id
}
"
)
...
...
vllm/adapter_commons/worker_manager.py
View file @
da1f7cc1
...
@@ -12,25 +12,25 @@ class AbstractWorkerManager(ABC):
...
@@ -12,25 +12,25 @@ class AbstractWorkerManager(ABC):
@
property
@
property
@
abstractmethod
@
abstractmethod
def
is_enabled
(
self
)
->
bool
:
def
is_enabled
(
self
)
->
bool
:
...
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
set_active_adapters
(
self
,
requests
:
Set
[
Any
],
def
set_active_adapters
(
self
,
requests
:
Set
[
Any
],
mapping
:
Optional
[
Any
])
->
None
:
mapping
:
Optional
[
Any
])
->
None
:
...
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
add_adapter
(
self
,
adapter_request
:
Any
)
->
bool
:
def
add_adapter
(
self
,
adapter_request
:
Any
)
->
bool
:
...
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
remove_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
def
remove_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
...
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
remove_all_adapters
(
self
):
def
remove_all_adapters
(
self
)
->
None
:
...
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
list_adapters
(
self
)
->
Set
[
int
]:
def
list_adapters
(
self
)
->
Set
[
int
]:
...
raise
NotImplementedError
vllm/config.py
View file @
da1f7cc1
...
@@ -724,7 +724,7 @@ class ParallelConfig:
...
@@ -724,7 +724,7 @@ class ParallelConfig:
backend
)
backend
)
self
.
_verify_args
()
self
.
_verify_args
()
self
.
rank
=
0
self
.
rank
:
int
=
0
@
property
@
property
def
use_ray
(
self
)
->
bool
:
def
use_ray
(
self
)
->
bool
:
...
@@ -850,6 +850,7 @@ class SchedulerConfig:
...
@@ -850,6 +850,7 @@ class SchedulerConfig:
class
DeviceConfig
:
class
DeviceConfig
:
device
:
Optional
[
torch
.
device
]
def
__init__
(
self
,
device
:
str
=
"auto"
)
->
None
:
def
__init__
(
self
,
device
:
str
=
"auto"
)
->
None
:
if
device
==
"auto"
:
if
device
==
"auto"
:
...
...
vllm/engine/llm_engine.py
View file @
da1f7cc1
...
@@ -5,8 +5,6 @@ from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
...
@@ -5,8 +5,6 @@ from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Type
,
TypeVar
,
Union
from
typing
import
Set
,
Type
,
TypeVar
,
Union
from
transformers
import
PreTrainedTokenizer
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
EngineConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
EngineConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
...
@@ -40,7 +38,8 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
...
@@ -40,7 +38,8 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer
)
init_tracer
)
from
vllm.transformers_utils.config
import
try_get_generation_config
from
vllm.transformers_utils.config
import
try_get_generation_config
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.tokenizer_group
import
(
BaseTokenizerGroup
,
from
vllm.transformers_utils.tokenizer_group
import
(
AnyTokenizer
,
BaseTokenizerGroup
,
get_tokenizer_group
)
get_tokenizer_group
)
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
usage_message
)
usage_message
)
...
@@ -478,12 +477,11 @@ class LLMEngine:
...
@@ -478,12 +477,11 @@ class LLMEngine:
def
get_tokenizer
(
def
get_tokenizer
(
self
,
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
"PreTrained
Tokenizer
"
:
)
->
Any
Tokenizer
:
return
self
.
get_tokenizer_group
().
get_lora_tokenizer
(
lora_request
)
return
self
.
get_tokenizer_group
().
get_lora_tokenizer
(
lora_request
)
def
get_tokenizer_for_seq
(
self
,
def
get_tokenizer_for_seq
(
self
,
sequence
:
Sequence
)
->
AnyTokenizer
:
sequence
:
Sequence
)
->
"PreTrainedTokenizer"
:
return
self
.
get_tokenizer_group
().
get_lora_tokenizer
(
return
self
.
get_tokenizer_group
().
get_lora_tokenizer
(
sequence
.
lora_request
)
sequence
.
lora_request
)
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
da1f7cc1
...
@@ -5,7 +5,6 @@ from http import HTTPStatus
...
@@ -5,7 +5,6 @@ from http import HTTPStatus
from
typing
import
Iterable
,
Iterator
,
List
,
Optional
,
Tuple
,
TypedDict
,
Union
from
typing
import
Iterable
,
Iterator
,
List
,
Optional
,
Tuple
,
TypedDict
,
Union
from
pydantic
import
Field
from
pydantic
import
Field
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
from
typing_extensions
import
Annotated
from
typing_extensions
import
Annotated
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
...
@@ -30,6 +29,7 @@ from vllm.pooling_params import PoolingParams
...
@@ -30,6 +29,7 @@ from vllm.pooling_params import PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
Logprob
from
vllm.sequence
import
Logprob
from
vllm.transformers_utils.tokenizer_group
import
AnyTokenizer
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -49,8 +49,6 @@ class LoRAModulePath:
...
@@ -49,8 +49,6 @@ class LoRAModulePath:
AnyRequest
=
Union
[
ChatCompletionRequest
,
CompletionRequest
,
DetokenizeRequest
,
AnyRequest
=
Union
[
ChatCompletionRequest
,
CompletionRequest
,
DetokenizeRequest
,
EmbeddingRequest
,
TokenizeRequest
]
EmbeddingRequest
,
TokenizeRequest
]
AnyTokenizer
=
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]
class
TextTokensPrompt
(
TypedDict
):
class
TextTokensPrompt
(
TypedDict
):
prompt
:
str
prompt
:
str
...
...
vllm/scripts.py
View file @
da1f7cc1
...
@@ -4,9 +4,10 @@ import asyncio
...
@@ -4,9 +4,10 @@ import asyncio
import
os
import
os
import
signal
import
signal
import
sys
import
sys
from
typing
import
Optional
from
typing
import
List
,
Optional
from
openai
import
OpenAI
from
openai
import
OpenAI
from
openai.types.chat
import
ChatCompletionMessageParam
from
vllm.entrypoints.openai.api_server
import
run_server
from
vllm.entrypoints.openai.api_server
import
run_server
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
...
@@ -63,15 +64,14 @@ def complete(model_name: str, client: OpenAI) -> None:
...
@@ -63,15 +64,14 @@ def complete(model_name: str, client: OpenAI) -> None:
def
chat
(
system_prompt
:
Optional
[
str
],
model_name
:
str
,
def
chat
(
system_prompt
:
Optional
[
str
],
model_name
:
str
,
client
:
OpenAI
)
->
None
:
client
:
OpenAI
)
->
None
:
conversation
=
[]
conversation
:
List
[
ChatCompletionMessageParam
]
=
[]
if
system_prompt
is
not
None
:
if
system_prompt
is
not
None
:
conversation
.
append
({
"role"
:
"system"
,
"content"
:
system_prompt
})
conversation
.
append
({
"role"
:
"system"
,
"content"
:
system_prompt
})
print
(
"Please enter a message for the chat model:"
)
print
(
"Please enter a message for the chat model:"
)
while
True
:
while
True
:
input_message
=
input
(
"> "
)
input_message
=
input
(
"> "
)
message
=
{
"role"
:
"user"
,
"content"
:
input_message
}
conversation
.
append
({
"role"
:
"user"
,
"content"
:
input_message
})
conversation
.
append
(
message
)
chat_completion
=
client
.
chat
.
completions
.
create
(
model
=
model_name
,
chat_completion
=
client
.
chat
.
completions
.
create
(
model
=
model_name
,
messages
=
conversation
)
messages
=
conversation
)
...
@@ -79,7 +79,7 @@ def chat(system_prompt: Optional[str], model_name: str,
...
@@ -79,7 +79,7 @@ def chat(system_prompt: Optional[str], model_name: str,
response_message
=
chat_completion
.
choices
[
0
].
message
response_message
=
chat_completion
.
choices
[
0
].
message
output
=
response_message
.
content
output
=
response_message
.
content
conversation
.
append
(
response_message
)
conversation
.
append
(
response_message
)
# type: ignore
print
(
output
)
print
(
output
)
...
...
vllm/transformers_utils/detokenizer.py
View file @
da1f7cc1
...
@@ -37,6 +37,8 @@ class Detokenizer:
...
@@ -37,6 +37,8 @@ class Detokenizer:
The prompt logprobs with the decoded tokens.
The prompt logprobs with the decoded tokens.
"""
"""
prms
=
seq_group
.
sampling_params
prms
=
seq_group
.
sampling_params
assert
prms
is
not
None
# We can pick any sequence for the prompt.
# We can pick any sequence for the prompt.
seq
=
next
(
iter
(
seq_group
.
seqs_dict
.
values
()))
seq
=
next
(
iter
(
seq_group
.
seqs_dict
.
values
()))
# Only prompt, without the generated token.
# Only prompt, without the generated token.
...
...
vllm/transformers_utils/tokenizer_group/__init__.py
View file @
da1f7cc1
...
@@ -2,10 +2,9 @@ from typing import Optional, Type
...
@@ -2,10 +2,9 @@ from typing import Optional, Type
from
vllm.config
import
TokenizerPoolConfig
from
vllm.config
import
TokenizerPoolConfig
from
vllm.executor.ray_utils
import
ray
from
vllm.executor.ray_utils
import
ray
from
vllm.transformers_utils.tokenizer_group.base_tokenizer_group
import
(
BaseTokenizerGroup
)
from
.base_tokenizer_group
import
AnyTokenizer
,
BaseTokenizerGroup
from
vllm.transformers_utils.tokenizer_group.tokenizer_group
import
(
from
.tokenizer_group
import
TokenizerGroup
TokenizerGroup
)
if
ray
:
if
ray
:
from
vllm.transformers_utils.tokenizer_group.ray_tokenizer_group
import
(
from
vllm.transformers_utils.tokenizer_group.ray_tokenizer_group
import
(
...
@@ -34,4 +33,4 @@ def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
...
@@ -34,4 +33,4 @@ def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
return
tokenizer_cls
.
from_config
(
tokenizer_pool_config
,
**
init_kwargs
)
return
tokenizer_cls
.
from_config
(
tokenizer_pool_config
,
**
init_kwargs
)
__all__
=
[
"get_tokenizer_group"
,
"BaseTokenizerGroup"
]
__all__
=
[
"AnyTokenizer"
,
"get_tokenizer_group"
,
"BaseTokenizerGroup"
]
vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py
View file @
da1f7cc1
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
,
Union
from
transformers
import
PreTrainedTokenizer
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
from
vllm.config
import
TokenizerPoolConfig
from
vllm.config
import
TokenizerPoolConfig
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
AnyTokenizer
=
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]
class
BaseTokenizerGroup
(
ABC
):
class
BaseTokenizerGroup
(
ABC
):
"""A group of tokenizers that can be used for LoRA adapters."""
"""A group of tokenizers that can be used for LoRA adapters."""
...
@@ -48,16 +50,16 @@ class BaseTokenizerGroup(ABC):
...
@@ -48,16 +50,16 @@ class BaseTokenizerGroup(ABC):
@
abstractmethod
@
abstractmethod
def
get_lora_tokenizer
(
def
get_lora_tokenizer
(
self
,
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
"PreTrained
Tokenizer
"
:
)
->
Any
Tokenizer
:
"""Get a tokenizer for a LoRA request."""
"""Get a tokenizer for a LoRA request."""
pass
pass
@
abstractmethod
@
abstractmethod
async
def
get_lora_tokenizer_async
(
async
def
get_lora_tokenizer_async
(
self
,
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
"PreTrained
Tokenizer
"
:
)
->
Any
Tokenizer
:
"""Get a tokenizer for a LoRA request."""
"""Get a tokenizer for a LoRA request."""
pass
pass
...
...
vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
View file @
da1f7cc1
...
@@ -6,18 +6,16 @@ try:
...
@@ -6,18 +6,16 @@ try:
from
ray.exceptions
import
ActorDiedError
from
ray.exceptions
import
ActorDiedError
except
ImportError
:
except
ImportError
:
# For older versions of Ray
# For older versions of Ray
from
ray.exceptions
import
RayActorError
as
ActorDiedError
from
ray.exceptions
import
RayActorError
as
ActorDiedError
# type: ignore
from
ray.util.scheduling_strategies
import
NodeAffinitySchedulingStrategy
from
ray.util.scheduling_strategies
import
NodeAffinitySchedulingStrategy
from
transformers
import
PreTrainedTokenizer
from
vllm.config
import
TokenizerPoolConfig
from
vllm.config
import
TokenizerPoolConfig
from
vllm.executor.ray_utils
import
ray
from
vllm.executor.ray_utils
import
ray
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.transformers_utils.tokenizer_group.base_tokenizer_group
import
(
BaseTokenizerGroup
)
from
.base_tokenizer_group
import
AnyTokenizer
,
BaseTokenizerGroup
from
vllm.transformers_utils.tokenizer_group.tokenizer_group
import
(
from
.tokenizer_group
import
TokenizerGroup
TokenizerGroup
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -67,7 +65,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
...
@@ -67,7 +65,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
**
self
.
_tokenizer_config
,
)
**
self
.
_tokenizer_config
,
)
self
.
_ray_tokenizer_group_cls
=
ray
.
remote
(
self
.
_ray_tokenizer_group_cls
=
ray
.
remote
(
self
.
_worker_cls
).
options
(
**
ray_actor_options
)
self
.
_worker_cls
).
options
(
**
ray_actor_options
)
# type: ignore
self
.
tokenizer_actors
=
[
self
.
_init_actor
()
for
_
in
range
(
num_actors
)]
self
.
tokenizer_actors
=
[
self
.
_init_actor
()
for
_
in
range
(
num_actors
)]
self
.
_idle_actors
:
Optional
[
asyncio
.
Queue
]
=
None
self
.
_idle_actors
:
Optional
[
asyncio
.
Queue
]
=
None
...
@@ -83,8 +81,10 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
...
@@ -83,8 +81,10 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
return
len
(
self
.
tokenizer_actors
)
return
len
(
self
.
tokenizer_actors
)
def
ping
(
self
):
def
ping
(
self
):
return
ray
.
get
(
return
ray
.
get
([
[
actor
.
ping
.
remote
()
for
actor
in
self
.
tokenizer_actors
])
actor
.
ping
.
remote
()
# type: ignore
for
actor
in
self
.
tokenizer_actors
])
def
_ensure_queue_initialized
(
self
):
def
_ensure_queue_initialized
(
self
):
if
self
.
_idle_actors
is
None
:
if
self
.
_idle_actors
is
None
:
...
@@ -209,14 +209,14 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
...
@@ -209,14 +209,14 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
def
get_lora_tokenizer
(
def
get_lora_tokenizer
(
self
,
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
"PreTrained
Tokenizer
"
:
)
->
Any
Tokenizer
:
return
self
.
_local_tokenizer_group
.
get_lora_tokenizer
(
lora_request
)
return
self
.
_local_tokenizer_group
.
get_lora_tokenizer
(
lora_request
)
async
def
get_lora_tokenizer_async
(
async
def
get_lora_tokenizer_async
(
self
,
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
"PreTrained
Tokenizer
"
:
)
->
Any
Tokenizer
:
return
await
self
.
_local_tokenizer_group
.
get_lora_tokenizer_async
(
return
await
self
.
_local_tokenizer_group
.
get_lora_tokenizer_async
(
lora_request
)
lora_request
)
...
...
vllm/transformers_utils/tokenizer_group/tokenizer_group.py
View file @
da1f7cc1
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
from
transformers
import
PreTrainedTokenizer
from
vllm.config
import
TokenizerPoolConfig
from
vllm.config
import
TokenizerPoolConfig
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.transformers_utils.tokenizer
import
(
get_lora_tokenizer
,
from
vllm.transformers_utils.tokenizer
import
(
get_lora_tokenizer
,
get_lora_tokenizer_async
,
get_lora_tokenizer_async
,
get_tokenizer
)
get_tokenizer
)
from
vllm.transformers_utils.tokenizer_group.base_tokenizer_group
import
(
BaseTokenizerGroup
)
from
vllm.utils
import
LRUCache
from
vllm.utils
import
LRUCache
from
.base_tokenizer_group
import
AnyTokenizer
,
BaseTokenizerGroup
class
TokenizerGroup
(
BaseTokenizerGroup
):
class
TokenizerGroup
(
BaseTokenizerGroup
):
"""A group of tokenizers that can be used for LoRA adapters."""
"""A group of tokenizers that can be used for LoRA adapters."""
...
@@ -22,8 +20,8 @@ class TokenizerGroup(BaseTokenizerGroup):
...
@@ -22,8 +20,8 @@ class TokenizerGroup(BaseTokenizerGroup):
self
.
enable_lora
=
enable_lora
self
.
enable_lora
=
enable_lora
self
.
max_input_length
=
max_input_length
self
.
max_input_length
=
max_input_length
self
.
tokenizer
=
get_tokenizer
(
self
.
tokenizer_id
,
**
tokenizer_config
)
self
.
tokenizer
=
get_tokenizer
(
self
.
tokenizer_id
,
**
tokenizer_config
)
self
.
lora_tokenizers
=
LRUCache
[
PreTrained
Tokenizer
](
self
.
lora_tokenizers
=
LRUCache
[
Any
Tokenizer
](
capacity
=
max_num_seqs
)
if
enable_lora
else
None
capacity
=
max_num_seqs
if
enable_lora
else
0
)
@
classmethod
@
classmethod
def
from_config
(
cls
,
tokenizer_pool_config
:
Optional
[
TokenizerPoolConfig
],
def
from_config
(
cls
,
tokenizer_pool_config
:
Optional
[
TokenizerPoolConfig
],
...
@@ -41,7 +39,7 @@ class TokenizerGroup(BaseTokenizerGroup):
...
@@ -41,7 +39,7 @@ class TokenizerGroup(BaseTokenizerGroup):
return
self
.
max_input_length
return
self
.
max_input_length
def
_raise_if_input_too_long
(
self
,
def
_raise_if_input_too_long
(
self
,
encoded_tokens
:
List
[
str
],
encoded_tokens
:
List
[
int
],
lora_request
:
Optional
[
LoRARequest
]
=
None
):
lora_request
:
Optional
[
LoRARequest
]
=
None
):
input_length
=
len
(
encoded_tokens
)
input_length
=
len
(
encoded_tokens
)
if
lora_request
:
if
lora_request
:
...
@@ -73,8 +71,8 @@ class TokenizerGroup(BaseTokenizerGroup):
...
@@ -73,8 +71,8 @@ class TokenizerGroup(BaseTokenizerGroup):
def
get_lora_tokenizer
(
def
get_lora_tokenizer
(
self
,
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
"PreTrained
Tokenizer
"
:
)
->
Any
Tokenizer
:
if
not
lora_request
or
not
self
.
enable_lora
:
if
not
lora_request
or
not
self
.
enable_lora
:
return
self
.
tokenizer
return
self
.
tokenizer
if
lora_request
.
lora_int_id
not
in
self
.
lora_tokenizers
:
if
lora_request
.
lora_int_id
not
in
self
.
lora_tokenizers
:
...
@@ -83,12 +81,12 @@ class TokenizerGroup(BaseTokenizerGroup):
...
@@ -83,12 +81,12 @@ class TokenizerGroup(BaseTokenizerGroup):
self
.
lora_tokenizers
.
put
(
lora_request
.
lora_int_id
,
tokenizer
)
self
.
lora_tokenizers
.
put
(
lora_request
.
lora_int_id
,
tokenizer
)
return
tokenizer
return
tokenizer
else
:
else
:
return
self
.
lora_tokenizers
.
get
(
lora_request
.
lora_int_id
)
return
self
.
lora_tokenizers
[
lora_request
.
lora_int_id
]
async
def
get_lora_tokenizer_async
(
async
def
get_lora_tokenizer_async
(
self
,
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
"PreTrained
Tokenizer
"
:
)
->
Any
Tokenizer
:
if
not
lora_request
or
not
self
.
enable_lora
:
if
not
lora_request
or
not
self
.
enable_lora
:
return
self
.
tokenizer
return
self
.
tokenizer
if
lora_request
.
lora_int_id
not
in
self
.
lora_tokenizers
:
if
lora_request
.
lora_int_id
not
in
self
.
lora_tokenizers
:
...
@@ -97,4 +95,4 @@ class TokenizerGroup(BaseTokenizerGroup):
...
@@ -97,4 +95,4 @@ class TokenizerGroup(BaseTokenizerGroup):
self
.
lora_tokenizers
.
put
(
lora_request
.
lora_int_id
,
tokenizer
)
self
.
lora_tokenizers
.
put
(
lora_request
.
lora_int_id
,
tokenizer
)
return
tokenizer
return
tokenizer
else
:
else
:
return
self
.
lora_tokenizers
.
get
(
lora_request
.
lora_int_id
)
return
self
.
lora_tokenizers
[
lora_request
.
lora_int_id
]
vllm/utils.py
View file @
da1f7cc1
...
@@ -94,8 +94,10 @@ class LRUCache(Generic[T]):
...
@@ -94,8 +94,10 @@ class LRUCache(Generic[T]):
def
__len__
(
self
)
->
int
:
def
__len__
(
self
)
->
int
:
return
len
(
self
.
cache
)
return
len
(
self
.
cache
)
def
__getitem__
(
self
,
key
:
Hashable
)
->
Optional
[
T
]:
def
__getitem__
(
self
,
key
:
Hashable
)
->
T
:
return
self
.
get
(
key
)
value
=
self
.
cache
[
key
]
# Raise KeyError if not exists
self
.
cache
.
move_to_end
(
key
)
return
value
def
__setitem__
(
self
,
key
:
Hashable
,
value
:
T
)
->
None
:
def
__setitem__
(
self
,
key
:
Hashable
,
value
:
T
)
->
None
:
self
.
put
(
key
,
value
)
self
.
put
(
key
,
value
)
...
@@ -109,8 +111,9 @@ class LRUCache(Generic[T]):
...
@@ -109,8 +111,9 @@ class LRUCache(Generic[T]):
def
get
(
self
,
def
get
(
self
,
key
:
Hashable
,
key
:
Hashable
,
default_value
:
Optional
[
T
]
=
None
)
->
Optional
[
T
]:
default_value
:
Optional
[
T
]
=
None
)
->
Optional
[
T
]:
value
:
Optional
[
T
]
if
key
in
self
.
cache
:
if
key
in
self
.
cache
:
value
:
Optional
[
T
]
=
self
.
cache
[
key
]
value
=
self
.
cache
[
key
]
self
.
cache
.
move_to_end
(
key
)
self
.
cache
.
move_to_end
(
key
)
else
:
else
:
value
=
default_value
value
=
default_value
...
@@ -590,8 +593,8 @@ class CudaMemoryProfiler:
...
@@ -590,8 +593,8 @@ class CudaMemoryProfiler:
torch
.
cuda
.
reset_peak_memory_stats
(
self
.
device
)
torch
.
cuda
.
reset_peak_memory_stats
(
self
.
device
)
mem
=
torch
.
cuda
.
max_memory_allocated
(
self
.
device
)
mem
=
torch
.
cuda
.
max_memory_allocated
(
self
.
device
)
elif
is_xpu
():
elif
is_xpu
():
torch
.
xpu
.
reset_peak_memory_stats
(
self
.
device
)
torch
.
xpu
.
reset_peak_memory_stats
(
self
.
device
)
# type: ignore
mem
=
torch
.
xpu
.
max_memory_allocated
(
self
.
device
)
mem
=
torch
.
xpu
.
max_memory_allocated
(
self
.
device
)
# type: ignore
return
mem
return
mem
def
__enter__
(
self
):
def
__enter__
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment