Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9ba85bc1
Unverified
Commit
9ba85bc1
authored
Aug 13, 2024
by
Cyrus Leung
Committed by
GitHub
Aug 13, 2024
Browse files
[mypy] Misc. typing improvements (#7417)
parent
198d6a28
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
74 additions
and
75 deletions
+74
-75
tests/tensorizer_loader/conftest.py
tests/tensorizer_loader/conftest.py
+12
-4
tests/test_utils.py
tests/test_utils.py
+8
-22
tests/utils.py
tests/utils.py
+8
-3
vllm/inputs/registry.py
vllm/inputs/registry.py
+4
-4
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+3
-3
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+3
-3
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+2
-2
vllm/multimodal/image.py
vllm/multimodal/image.py
+2
-3
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+7
-3
vllm/transformers_utils/detokenizer.py
vllm/transformers_utils/detokenizer.py
+8
-10
vllm/transformers_utils/tokenizer.py
vllm/transformers_utils/tokenizer.py
+3
-3
vllm/transformers_utils/tokenizer_group/__init__.py
vllm/transformers_utils/tokenizer_group/__init__.py
+1
-2
vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py
...ransformers_utils/tokenizer_group/base_tokenizer_group.py
+6
-8
vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
...transformers_utils/tokenizer_group/ray_tokenizer_group.py
+2
-1
vllm/transformers_utils/tokenizer_group/tokenizer_group.py
vllm/transformers_utils/tokenizer_group/tokenizer_group.py
+3
-2
vllm/utils.py
vllm/utils.py
+2
-2
No files found.
tests/tensorizer_loader/conftest.py
View file @
9ba85bc1
import
contextlib
import
contextlib
import
functools
import
functools
import
gc
import
gc
from
typing
import
Callable
,
TypeVar
import
pytest
import
pytest
import
ray
import
ray
import
torch
import
torch
from
typing_extensions
import
ParamSpec
from
vllm.distributed
import
(
destroy_distributed_environment
,
from
vllm.distributed
import
(
destroy_distributed_environment
,
destroy_model_parallel
)
destroy_model_parallel
)
...
@@ -22,12 +24,16 @@ def cleanup():
...
@@ -22,12 +24,16 @@ def cleanup():
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
def
retry_until_skip
(
n
):
_P
=
ParamSpec
(
"_P"
)
_R
=
TypeVar
(
"_R"
)
def
decorator_retry
(
func
):
def
retry_until_skip
(
n
:
int
):
def
decorator_retry
(
func
:
Callable
[
_P
,
_R
])
->
Callable
[
_P
,
_R
]:
@
functools
.
wraps
(
func
)
@
functools
.
wraps
(
func
)
def
wrapper_retry
(
*
args
,
**
kwargs
)
:
def
wrapper_retry
(
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
_R
:
for
i
in
range
(
n
):
for
i
in
range
(
n
):
try
:
try
:
return
func
(
*
args
,
**
kwargs
)
return
func
(
*
args
,
**
kwargs
)
...
@@ -35,7 +41,9 @@ def retry_until_skip(n):
...
@@ -35,7 +41,9 @@ def retry_until_skip(n):
gc
.
collect
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
if
i
==
n
-
1
:
if
i
==
n
-
1
:
pytest
.
skip
(
"Skipping test after attempts.."
)
pytest
.
skip
(
f
"Skipping test after
{
n
}
attempts."
)
raise
AssertionError
(
"Code should not be reached"
)
return
wrapper_retry
return
wrapper_retry
...
...
tests/test_utils.py
View file @
9ba85bc1
import
asyncio
import
asyncio
import
os
import
os
import
socket
import
socket
import
sys
from
functools
import
partial
from
functools
import
partial
from
typing
import
(
TYPE_CHECKING
,
Any
,
AsyncIterator
,
Awaitable
,
Protocol
,
from
typing
import
AsyncIterator
,
Tuple
Tuple
,
TypeVar
)
import
pytest
import
pytest
...
@@ -13,26 +11,11 @@ from vllm.utils import (FlexibleArgumentParser, deprecate_kwargs,
...
@@ -13,26 +11,11 @@ from vllm.utils import (FlexibleArgumentParser, deprecate_kwargs,
from
.utils
import
error_on_warning
from
.utils
import
error_on_warning
if
sys
.
version_info
<
(
3
,
10
):
if
TYPE_CHECKING
:
_AwaitableT
=
TypeVar
(
"_AwaitableT"
,
bound
=
Awaitable
[
Any
])
_AwaitableT_co
=
TypeVar
(
"_AwaitableT_co"
,
bound
=
Awaitable
[
Any
],
covariant
=
True
)
class
_SupportsSynchronousAnext
(
Protocol
[
_AwaitableT_co
]):
def
__anext__
(
self
)
->
_AwaitableT_co
:
...
def
anext
(
i
:
"_SupportsSynchronousAnext[_AwaitableT]"
,
/
)
->
"_AwaitableT"
:
return
i
.
__anext__
()
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
async
def
test_merge_async_iterators
():
async
def
test_merge_async_iterators
():
async
def
mock_async_iterator
(
idx
:
int
)
->
AsyncIterator
[
str
]
:
async
def
mock_async_iterator
(
idx
:
int
):
try
:
try
:
while
True
:
while
True
:
yield
f
"item from iterator
{
idx
}
"
yield
f
"item from iterator
{
idx
}
"
...
@@ -41,8 +24,10 @@ async def test_merge_async_iterators():
...
@@ -41,8 +24,10 @@ async def test_merge_async_iterators():
print
(
f
"iterator
{
idx
}
cancelled"
)
print
(
f
"iterator
{
idx
}
cancelled"
)
iterators
=
[
mock_async_iterator
(
i
)
for
i
in
range
(
3
)]
iterators
=
[
mock_async_iterator
(
i
)
for
i
in
range
(
3
)]
merged_iterator
:
AsyncIterator
[
Tuple
[
int
,
str
]]
=
merge_async_iterators
(
merged_iterator
=
merge_async_iterators
(
*
iterators
,
*
iterators
,
is_cancelled
=
partial
(
asyncio
.
sleep
,
0
,
result
=
False
))
is_cancelled
=
partial
(
asyncio
.
sleep
,
0
,
result
=
False
))
async
def
stream_output
(
generator
:
AsyncIterator
[
Tuple
[
int
,
str
]]):
async
def
stream_output
(
generator
:
AsyncIterator
[
Tuple
[
int
,
str
]]):
async
for
idx
,
output
in
generator
:
async
for
idx
,
output
in
generator
:
...
@@ -56,7 +41,8 @@ async def test_merge_async_iterators():
...
@@ -56,7 +41,8 @@ async def test_merge_async_iterators():
for
iterator
in
iterators
:
for
iterator
in
iterators
:
try
:
try
:
await
asyncio
.
wait_for
(
anext
(
iterator
),
1
)
# Can use anext() in python >= 3.10
await
asyncio
.
wait_for
(
iterator
.
__anext__
(),
1
)
except
StopAsyncIteration
:
except
StopAsyncIteration
:
# All iterators should be cancelled and print this message.
# All iterators should be cancelled and print this message.
print
(
"Iterator was cancelled normally"
)
print
(
"Iterator was cancelled normally"
)
...
...
tests/utils.py
View file @
9ba85bc1
...
@@ -7,12 +7,13 @@ import time
...
@@ -7,12 +7,13 @@ import time
import
warnings
import
warnings
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
openai
import
openai
import
ray
import
ray
import
requests
import
requests
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
typing_extensions
import
ParamSpec
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
)
init_distributed_environment
)
...
@@ -360,13 +361,17 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
...
@@ -360,13 +361,17 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
time
.
sleep
(
5
)
time
.
sleep
(
5
)
def
fork_new_process_for_each_test
(
f
):
_P
=
ParamSpec
(
"_P"
)
def
fork_new_process_for_each_test
(
f
:
Callable
[
_P
,
None
])
->
Callable
[
_P
,
None
]:
"""Decorator to fork a new process for each test function.
"""Decorator to fork a new process for each test function.
See https://github.com/vllm-project/vllm/issues/7053 for more details.
See https://github.com/vllm-project/vllm/issues/7053 for more details.
"""
"""
@
functools
.
wraps
(
f
)
@
functools
.
wraps
(
f
)
def
wrapper
(
*
args
,
**
kwargs
)
:
def
wrapper
(
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
None
:
# Make the process the leader of its own process group
# Make the process the leader of its own process group
# to avoid sending SIGTERM to the parent process
# to avoid sending SIGTERM to the parent process
os
.
setpgrp
()
os
.
setpgrp
()
...
...
vllm/inputs/registry.py
View file @
9ba85bc1
import
functools
import
functools
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Callable
,
Dict
,
Optional
,
Tuple
,
Type
,
from
typing
import
TYPE_CHECKING
,
Callable
,
Dict
,
Optional
,
Tuple
,
Type
TypeVar
)
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
typing_extensions
import
TypeVar
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -17,7 +17,7 @@ if TYPE_CHECKING:
...
@@ -17,7 +17,7 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
C
=
TypeVar
(
"C"
,
bound
=
PretrainedConfig
)
C
=
TypeVar
(
"C"
,
bound
=
PretrainedConfig
,
default
=
PretrainedConfig
)
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
...
@@ -44,7 +44,7 @@ class InputContext:
...
@@ -44,7 +44,7 @@ class InputContext:
return
multimodal_config
return
multimodal_config
def
get_hf_config
(
self
,
hf_config_type
:
Type
[
C
])
->
C
:
def
get_hf_config
(
self
,
hf_config_type
:
Type
[
C
]
=
PretrainedConfig
)
->
C
:
"""
"""
Get the HuggingFace configuration
Get the HuggingFace configuration
(:class:`transformers.PretrainedConfig`) of the model,
(:class:`transformers.PretrainedConfig`) of the model,
...
...
vllm/model_executor/models/internvl.py
View file @
9ba85bc1
...
@@ -165,7 +165,7 @@ def get_internvl_num_patches(image_size: int, patch_size: int,
...
@@ -165,7 +165,7 @@ def get_internvl_num_patches(image_size: int, patch_size: int,
def
get_max_internvl_image_tokens
(
ctx
:
InputContext
):
def
get_max_internvl_image_tokens
(
ctx
:
InputContext
):
hf_config
=
ctx
.
get_hf_config
(
PretrainedConfig
)
hf_config
=
ctx
.
get_hf_config
()
vision_config
=
hf_config
.
vision_config
vision_config
=
hf_config
.
vision_config
use_thumbnail
=
hf_config
.
use_thumbnail
use_thumbnail
=
hf_config
.
use_thumbnail
...
@@ -187,7 +187,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
...
@@ -187,7 +187,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
return
llm_inputs
return
llm_inputs
model_config
=
ctx
.
model_config
model_config
=
ctx
.
model_config
hf_config
=
ctx
.
get_hf_config
(
PretrainedConfig
)
hf_config
=
ctx
.
get_hf_config
()
vision_config
=
hf_config
.
vision_config
vision_config
=
hf_config
.
vision_config
image_size
=
vision_config
.
image_size
image_size
=
vision_config
.
image_size
...
@@ -260,7 +260,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
...
@@ -260,7 +260,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
image_feature_size
=
get_max_internvl_image_tokens
(
ctx
)
image_feature_size
=
get_max_internvl_image_tokens
(
ctx
)
model_config
=
ctx
.
model_config
model_config
=
ctx
.
model_config
hf_config
=
ctx
.
get_hf_config
(
PretrainedConfig
)
hf_config
=
ctx
.
get_hf_config
()
vision_config
=
hf_config
.
vision_config
vision_config
=
hf_config
.
vision_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
True
)
trust_remote_code
=
True
)
...
...
vllm/model_executor/models/minicpmv.py
View file @
9ba85bc1
...
@@ -34,7 +34,7 @@ import torch.types
...
@@ -34,7 +34,7 @@ import torch.types
from
PIL
import
Image
from
PIL
import
Image
from
torch
import
nn
from
torch
import
nn
from
torch.nn.init
import
trunc_normal_
from
torch.nn.init
import
trunc_normal_
from
transformers
.configuration_utils
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.attention
import
AttentionMetadata
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.config
import
CacheConfig
,
MultiModalConfig
...
@@ -404,7 +404,7 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
...
@@ -404,7 +404,7 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
def
get_max_minicpmv_image_tokens
(
ctx
:
InputContext
):
def
get_max_minicpmv_image_tokens
(
ctx
:
InputContext
):
hf_config
=
ctx
.
get_hf_config
(
PretrainedConfig
)
hf_config
=
ctx
.
get_hf_config
()
return
getattr
(
hf_config
,
"query_num"
,
64
)
return
getattr
(
hf_config
,
"query_num"
,
64
)
...
@@ -420,7 +420,7 @@ def dummy_image_for_minicpmv(hf_config: PretrainedConfig):
...
@@ -420,7 +420,7 @@ def dummy_image_for_minicpmv(hf_config: PretrainedConfig):
def
dummy_data_for_minicpmv
(
ctx
:
InputContext
,
seq_len
:
int
):
def
dummy_data_for_minicpmv
(
ctx
:
InputContext
,
seq_len
:
int
):
hf_config
=
ctx
.
get_hf_config
(
PretrainedConfig
)
hf_config
=
ctx
.
get_hf_config
()
seq_data
=
dummy_seq_data_for_minicpmv
(
seq_len
)
seq_data
=
dummy_seq_data_for_minicpmv
(
seq_len
)
mm_data
=
dummy_image_for_minicpmv
(
hf_config
)
mm_data
=
dummy_image_for_minicpmv
(
hf_config
)
...
...
vllm/model_executor/models/phi3v.py
View file @
9ba85bc1
...
@@ -341,7 +341,7 @@ def get_phi3v_image_feature_size(
...
@@ -341,7 +341,7 @@ def get_phi3v_image_feature_size(
def
get_max_phi3v_image_tokens
(
ctx
:
InputContext
):
def
get_max_phi3v_image_tokens
(
ctx
:
InputContext
):
return
get_phi3v_image_feature_size
(
return
get_phi3v_image_feature_size
(
ctx
.
get_hf_config
(
PretrainedConfig
),
ctx
.
get_hf_config
(),
input_height
=
MAX_IMAGE_FEATURE_SIZE_HEIGHT
,
input_height
=
MAX_IMAGE_FEATURE_SIZE_HEIGHT
,
input_width
=
MAX_IMAGE_FEATURE_SIZE_WIDTH
,
input_width
=
MAX_IMAGE_FEATURE_SIZE_WIDTH
,
)
)
...
@@ -391,7 +391,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
...
@@ -391,7 +391,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
return
llm_inputs
return
llm_inputs
model_config
=
ctx
.
model_config
model_config
=
ctx
.
model_config
hf_config
=
ctx
.
get_hf_config
(
PretrainedConfig
)
hf_config
=
ctx
.
get_hf_config
()
image_data
=
multi_modal_data
[
"image"
]
image_data
=
multi_modal_data
[
"image"
]
if
isinstance
(
image_data
,
Image
.
Image
):
if
isinstance
(
image_data
,
Image
.
Image
):
...
...
vllm/multimodal/image.py
View file @
9ba85bc1
...
@@ -3,13 +3,12 @@ from typing import List, Optional, Tuple, TypeVar
...
@@ -3,13 +3,12 @@ from typing import List, Optional, Tuple, TypeVar
import
torch
import
torch
from
PIL
import
Image
from
PIL
import
Image
from
transformers
import
PreTrainedTokenizerBase
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.inputs.registry
import
InputContext
from
vllm.inputs.registry
import
InputContext
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.image_processor
import
get_image_processor
from
vllm.transformers_utils.image_processor
import
get_image_processor
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
get_tokenizer
from
vllm.utils
import
is_list_of
from
vllm.utils
import
is_list_of
from
.base
import
MultiModalInputs
,
MultiModalPlugin
from
.base
import
MultiModalInputs
,
MultiModalPlugin
...
@@ -40,7 +39,7 @@ def repeat_and_pad_token(
...
@@ -40,7 +39,7 @@ def repeat_and_pad_token(
def
repeat_and_pad_image_tokens
(
def
repeat_and_pad_image_tokens
(
tokenizer
:
PreTrained
Tokenizer
Base
,
tokenizer
:
Any
Tokenizer
,
prompt
:
Optional
[
str
],
prompt
:
Optional
[
str
],
prompt_token_ids
:
List
[
int
],
prompt_token_ids
:
List
[
int
],
*
,
*
,
...
...
vllm/platforms/cuda.py
View file @
9ba85bc1
...
@@ -4,9 +4,10 @@ pynvml. However, it should not initialize cuda context.
...
@@ -4,9 +4,10 @@ pynvml. However, it should not initialize cuda context.
import
os
import
os
from
functools
import
lru_cache
,
wraps
from
functools
import
lru_cache
,
wraps
from
typing
import
List
,
Tuple
from
typing
import
Callable
,
List
,
Tuple
,
TypeVar
import
pynvml
import
pynvml
from
typing_extensions
import
ParamSpec
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -14,16 +15,19 @@ from .interface import Platform, PlatformEnum
...
@@ -14,16 +15,19 @@ from .interface import Platform, PlatformEnum
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_P
=
ParamSpec
(
"_P"
)
_R
=
TypeVar
(
"_R"
)
# NVML utils
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
# the major benefit of using NVML is that it will not initialize CUDA
def
with_nvml_context
(
fn
)
:
def
with_nvml_context
(
fn
:
Callable
[
_P
,
_R
])
->
Callable
[
_P
,
_R
]
:
@
wraps
(
fn
)
@
wraps
(
fn
)
def
wrapper
(
*
args
,
**
kwargs
)
:
def
wrapper
(
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
_R
:
pynvml
.
nvmlInit
()
pynvml
.
nvmlInit
()
try
:
try
:
return
fn
(
*
args
,
**
kwargs
)
return
fn
(
*
args
,
**
kwargs
)
...
...
vllm/transformers_utils/detokenizer.py
View file @
9ba85bc1
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
from
vllm.sequence
import
Logprob
,
SamplingParams
,
Sequence
,
SequenceGroup
from
vllm.sequence
import
Logprob
,
SamplingParams
,
Sequence
,
SequenceGroup
from
vllm.transformers_utils.tokenizer_group.base_tokenizer_group
import
(
BaseTokenizerGroup
)
from
.tokenizer
import
AnyTokenizer
from
.tokenizer_group
import
BaseTokenizerGroup
# Used eg. for marking rejected tokens in spec decoding.
# Used eg. for marking rejected tokens in spec decoding.
INVALID_TOKEN_ID
=
-
1
INVALID_TOKEN_ID
=
-
1
...
@@ -16,8 +15,7 @@ class Detokenizer:
...
@@ -16,8 +15,7 @@ class Detokenizer:
def
__init__
(
self
,
tokenizer_group
:
BaseTokenizerGroup
):
def
__init__
(
self
,
tokenizer_group
:
BaseTokenizerGroup
):
self
.
tokenizer_group
=
tokenizer_group
self
.
tokenizer_group
=
tokenizer_group
def
get_tokenizer_for_seq
(
self
,
def
get_tokenizer_for_seq
(
self
,
sequence
:
Sequence
)
->
AnyTokenizer
:
sequence
:
Sequence
)
->
"PreTrainedTokenizer"
:
"""Returns the HF tokenizer to use for a given sequence."""
"""Returns the HF tokenizer to use for a given sequence."""
return
self
.
tokenizer_group
.
get_lora_tokenizer
(
sequence
.
lora_request
)
return
self
.
tokenizer_group
.
get_lora_tokenizer
(
sequence
.
lora_request
)
...
@@ -174,7 +172,7 @@ def _replace_none_with_empty(tokens: List[Optional[str]]):
...
@@ -174,7 +172,7 @@ def _replace_none_with_empty(tokens: List[Optional[str]]):
def
_convert_tokens_to_string_with_added_encoders
(
def
_convert_tokens_to_string_with_added_encoders
(
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]
,
tokenizer
:
AnyTokenizer
,
output_tokens
:
List
[
str
],
output_tokens
:
List
[
str
],
skip_special_tokens
:
bool
,
skip_special_tokens
:
bool
,
spaces_between_special_tokens
:
bool
,
spaces_between_special_tokens
:
bool
,
...
@@ -213,7 +211,7 @@ INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5
...
@@ -213,7 +211,7 @@ INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5
def
convert_prompt_ids_to_tokens
(
def
convert_prompt_ids_to_tokens
(
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]
,
tokenizer
:
AnyTokenizer
,
prompt_ids
:
List
[
int
],
prompt_ids
:
List
[
int
],
skip_special_tokens
:
bool
=
False
,
skip_special_tokens
:
bool
=
False
,
)
->
Tuple
[
List
[
str
],
int
,
int
]:
)
->
Tuple
[
List
[
str
],
int
,
int
]:
...
@@ -240,7 +238,7 @@ def convert_prompt_ids_to_tokens(
...
@@ -240,7 +238,7 @@ def convert_prompt_ids_to_tokens(
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
# under Apache 2.0 license
# under Apache 2.0 license
def
detokenize_incrementally
(
def
detokenize_incrementally
(
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]
,
tokenizer
:
AnyTokenizer
,
all_input_ids
:
List
[
int
],
all_input_ids
:
List
[
int
],
prev_tokens
:
Optional
[
List
[
str
]],
prev_tokens
:
Optional
[
List
[
str
]],
prefix_offset
:
int
,
prefix_offset
:
int
,
...
...
vllm/transformers_utils/tokenizer.py
View file @
9ba85bc1
...
@@ -12,10 +12,10 @@ from vllm.lora.request import LoRARequest
...
@@ -12,10 +12,10 @@ from vllm.lora.request import LoRARequest
from
vllm.transformers_utils.tokenizers
import
BaichuanTokenizer
from
vllm.transformers_utils.tokenizers
import
BaichuanTokenizer
from
vllm.utils
import
make_async
from
vllm.utils
import
make_async
from
.tokenizer_group
import
AnyTokenizer
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
AnyTokenizer
=
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]
def
get_cached_tokenizer
(
tokenizer
:
AnyTokenizer
)
->
AnyTokenizer
:
def
get_cached_tokenizer
(
tokenizer
:
AnyTokenizer
)
->
AnyTokenizer
:
"""Get tokenizer with cached properties.
"""Get tokenizer with cached properties.
...
@@ -141,7 +141,7 @@ def get_tokenizer(
...
@@ -141,7 +141,7 @@ def get_tokenizer(
def
get_lora_tokenizer
(
lora_request
:
LoRARequest
,
*
args
,
def
get_lora_tokenizer
(
lora_request
:
LoRARequest
,
*
args
,
**
kwargs
)
->
Optional
[
PreTrained
Tokenizer
]:
**
kwargs
)
->
Optional
[
Any
Tokenizer
]:
if
lora_request
is
None
:
if
lora_request
is
None
:
return
None
return
None
try
:
try
:
...
...
vllm/transformers_utils/tokenizer_group/__init__.py
View file @
9ba85bc1
...
@@ -8,8 +8,7 @@ from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
...
@@ -8,8 +8,7 @@ from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
from
.tokenizer_group
import
TokenizerGroup
from
.tokenizer_group
import
TokenizerGroup
if
ray
:
if
ray
:
from
vllm.transformers_utils.tokenizer_group.ray_tokenizer_group
import
(
from
.ray_tokenizer_group
import
RayTokenizerGroupPool
RayTokenizerGroupPool
)
else
:
else
:
RayTokenizerGroupPool
=
None
# type: ignore
RayTokenizerGroupPool
=
None
# type: ignore
...
...
vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py
View file @
9ba85bc1
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
List
,
Optional
,
Union
from
typing
import
List
,
Optional
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
from
vllm.config
import
TokenizerPoolConfig
from
vllm.config
import
TokenizerPoolConfig
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
AnyTokenizer
=
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]
class
BaseTokenizerGroup
(
ABC
):
class
BaseTokenizerGroup
(
ABC
):
...
@@ -24,8 +21,9 @@ class BaseTokenizerGroup(ABC):
...
@@ -24,8 +21,9 @@ class BaseTokenizerGroup(ABC):
pass
pass
@
abstractmethod
@
abstractmethod
def
get_max_input_len
(
self
,
def
get_max_input_len
(
lora_request
:
Optional
[
LoRARequest
]
=
None
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
Optional
[
int
]:
)
->
Optional
[
int
]:
"""Get the maximum input length for the LoRA request."""
"""Get the maximum input length for the LoRA request."""
pass
pass
...
...
vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
View file @
9ba85bc1
...
@@ -13,8 +13,9 @@ from vllm.config import TokenizerPoolConfig
...
@@ -13,8 +13,9 @@ from vllm.config import TokenizerPoolConfig
from
vllm.executor.ray_utils
import
ray
from
vllm.executor.ray_utils
import
ray
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
.base_tokenizer_group
import
AnyTokenizer
,
BaseTokenizerGroup
from
.base_tokenizer_group
import
BaseTokenizerGroup
from
.tokenizer_group
import
TokenizerGroup
from
.tokenizer_group
import
TokenizerGroup
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
...
vllm/transformers_utils/tokenizer_group/tokenizer_group.py
View file @
9ba85bc1
...
@@ -2,12 +2,13 @@ from typing import List, Optional
...
@@ -2,12 +2,13 @@ from typing import List, Optional
from
vllm.config
import
TokenizerPoolConfig
from
vllm.config
import
TokenizerPoolConfig
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.transformers_utils.tokenizer
import
(
get_lora_tokenizer
,
from
vllm.transformers_utils.tokenizer
import
(
AnyTokenizer
,
get_lora_tokenizer
,
get_lora_tokenizer_async
,
get_lora_tokenizer_async
,
get_tokenizer
)
get_tokenizer
)
from
vllm.utils
import
LRUCache
from
vllm.utils
import
LRUCache
from
.base_tokenizer_group
import
AnyTokenizer
,
BaseTokenizerGroup
from
.base_tokenizer_group
import
BaseTokenizerGroup
class
TokenizerGroup
(
BaseTokenizerGroup
):
class
TokenizerGroup
(
BaseTokenizerGroup
):
...
...
vllm/utils.py
View file @
9ba85bc1
...
@@ -1101,9 +1101,9 @@ def cuda_device_count_stateless() -> int:
...
@@ -1101,9 +1101,9 @@ def cuda_device_count_stateless() -> int:
#From: https://stackoverflow.com/a/4104188/2749989
#From: https://stackoverflow.com/a/4104188/2749989
def
run_once
(
f
)
:
def
run_once
(
f
:
Callable
[
P
,
None
])
->
Callable
[
P
,
None
]
:
def
wrapper
(
*
args
,
**
kwargs
)
->
Any
:
def
wrapper
(
*
args
:
P
.
args
,
**
kwargs
:
P
.
kwargs
)
->
None
:
if
not
wrapper
.
has_run
:
# type: ignore[attr-defined]
if
not
wrapper
.
has_run
:
# type: ignore[attr-defined]
wrapper
.
has_run
=
True
# type: ignore[attr-defined]
wrapper
.
has_run
=
True
# type: ignore[attr-defined]
return
f
(
*
args
,
**
kwargs
)
return
f
(
*
args
,
**
kwargs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment