Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0e9164b4
Unverified
Commit
0e9164b4
authored
Jun 15, 2024
by
Cyrus Leung
Committed by
GitHub
Jun 15, 2024
Browse files
[mypy] Enable type checking for test directory (#5017)
parent
1b8a0d71
Changes
92
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
49 additions
and
41 deletions
+49
-41
vllm/model_executor/models/arctic.py
vllm/model_executor/models/arctic.py
+2
-2
vllm/model_executor/models/commandr.py
vllm/model_executor/models/commandr.py
+2
-2
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+2
-2
vllm/sequence.py
vllm/sequence.py
+1
-1
vllm/spec_decode/multi_step_worker.py
vllm/spec_decode/multi_step_worker.py
+5
-5
vllm/spec_decode/ngram_worker.py
vllm/spec_decode/ngram_worker.py
+3
-3
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+4
-4
vllm/spec_decode/util.py
vllm/spec_decode/util.py
+2
-2
vllm/transformers_utils/detokenizer.py
vllm/transformers_utils/detokenizer.py
+1
-1
vllm/utils.py
vllm/utils.py
+23
-15
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+2
-2
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+2
-2
No files found.
vllm/model_executor/models/arctic.py
View file @
0e9164b4
...
@@ -453,8 +453,8 @@ class ArcticForCausalLM(nn.Module):
...
@@ -453,8 +453,8 @@ class ArcticForCausalLM(nn.Module):
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
]
mlp_params_mapping
=
[]
mlp_params_mapping
:
List
[
Tuple
[
str
,
str
,
int
]]
=
[]
expert_params_mapping
=
[]
expert_params_mapping
:
List
[
Tuple
[
str
,
str
,
int
]]
=
[]
num_layers
=
self
.
config
.
num_hidden_layers
num_layers
=
self
.
config
.
num_hidden_layers
for
layer
in
range
(
num_layers
):
for
layer
in
range
(
num_layers
):
...
...
vllm/model_executor/models/commandr.py
View file @
0e9164b4
...
@@ -20,7 +20,7 @@
...
@@ -20,7 +20,7 @@
# This file is based on the LLama model definition file in transformers
# This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model."""
"""PyTorch Cohere model."""
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
import
torch
import
torch
import
torch.utils.checkpoint
import
torch.utils.checkpoint
...
@@ -352,7 +352,7 @@ class CohereForCausalLM(nn.Module):
...
@@ -352,7 +352,7 @@ class CohereForCausalLM(nn.Module):
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
=
set
()
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
for
param_name
,
shard_name
,
shard_id
in
stacked_params_mapping
:
for
param_name
,
shard_name
,
shard_id
in
stacked_params_mapping
:
if
shard_name
not
in
name
:
if
shard_name
not
in
name
:
...
...
vllm/model_executor/models/gemma.py
View file @
0e9164b4
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
# limitations under the License.
# limitations under the License.
"""Inference-only Gemma model compatible with HuggingFace weights."""
"""Inference-only Gemma model compatible with HuggingFace weights."""
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -363,7 +363,7 @@ class GemmaForCausalLM(nn.Module):
...
@@ -363,7 +363,7 @@ class GemmaForCausalLM(nn.Module):
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
=
set
()
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
for
(
param_name
,
shard_name
,
shard_id
)
in
stacked_params_mapping
:
for
(
param_name
,
shard_name
,
shard_id
)
in
stacked_params_mapping
:
if
shard_name
not
in
name
:
if
shard_name
not
in
name
:
...
...
vllm/sequence.py
View file @
0e9164b4
...
@@ -123,7 +123,7 @@ class SequenceData:
...
@@ -123,7 +123,7 @@ class SequenceData:
output_token_ids
=
[]
output_token_ids
=
[]
self
.
prompt_token_ids
=
prompt_token_ids
self
.
prompt_token_ids
=
prompt_token_ids
self
.
_prompt_token_ids_tuple
:
Tuple
[
int
,
...]
=
tuple
(
prompt_token_ids
)
self
.
_prompt_token_ids_tuple
=
tuple
(
prompt_token_ids
)
self
.
output_token_ids
=
output_token_ids
self
.
output_token_ids
=
output_token_ids
self
.
cumulative_logprob
=
0.0
self
.
cumulative_logprob
=
0.0
# The number of tokens that are computed (that run against the model).
# The number of tokens that are computed (that run against the model).
...
...
vllm/spec_decode/multi_step_worker.py
View file @
0e9164b4
import
copy
import
copy
import
weakref
import
weakref
from
typing
import
List
,
Tuple
from
typing
import
Dict
,
List
,
Tuple
import
torch
import
torch
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
)
SequenceGroupMetadata
)
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
...
@@ -71,7 +71,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
...
@@ -71,7 +71,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
sample_len
)
sample_len
)
# Run model sample_len times.
# Run model sample_len times.
model_outputs
=
[]
model_outputs
:
List
[
SamplerOutput
]
=
[]
for
_
in
range
(
sample_len
):
for
_
in
range
(
sample_len
):
model_output
=
super
().
execute_model
(
model_output
=
super
().
execute_model
(
execute_model_req
=
copied_execute_model_req
)
execute_model_req
=
copied_execute_model_req
)
...
@@ -132,7 +132,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
...
@@ -132,7 +132,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
# Shallow-copy the list of SequenceGroupMetadata. This allows us to
# Shallow-copy the list of SequenceGroupMetadata. This allows us to
# append tokens and change is_prompt without external side-effects.
# append tokens and change is_prompt without external side-effects.
new_seq_group_metadata_list
=
[]
new_seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
for
old_seq_group_metadata
in
seq_group_metadata_list
:
for
old_seq_group_metadata
in
seq_group_metadata_list
:
# We must shallow-copy seq_group_metadata as is_prompt could change.
# We must shallow-copy seq_group_metadata as is_prompt could change.
...
@@ -140,7 +140,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
...
@@ -140,7 +140,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
new_seq_group_metadata_list
.
append
(
seq_group_metadata
)
new_seq_group_metadata_list
.
append
(
seq_group_metadata
)
# We must shallow-copy seq_data as we will append token ids
# We must shallow-copy seq_data as we will append token ids
new_seq_data
=
{}
new_seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
for
seq_id
,
old_seq_data
in
seq_group_metadata
.
seq_data
.
items
():
for
seq_id
,
old_seq_data
in
seq_group_metadata
.
seq_data
.
items
():
new_seq_data
[
seq_id
]
=
copy
.
copy
(
old_seq_data
)
new_seq_data
[
seq_id
]
=
copy
.
copy
(
old_seq_data
)
new_seq_data
[
new_seq_data
[
...
...
vllm/spec_decode/ngram_worker.py
View file @
0e9164b4
...
@@ -48,7 +48,7 @@ class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
...
@@ -48,7 +48,7 @@ class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
self
,
self
,
execute_model_req
:
ExecuteModelRequest
,
execute_model_req
:
ExecuteModelRequest
,
sample_len
:
int
,
sample_len
:
int
,
)
->
Tuple
[
Optional
[
List
[
SamplerOutput
]],
bool
]:
)
->
Tuple
[
Optional
[
List
[
Optional
[
SamplerOutput
]]
]
,
bool
]:
"""NGram match algo to pick proposal candidate. Returns the list of
"""NGram match algo to pick proposal candidate. Returns the list of
sampler output, one per SequenceGroupMetadata.
sampler output, one per SequenceGroupMetadata.
...
@@ -58,8 +58,8 @@ class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
...
@@ -58,8 +58,8 @@ class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
self
.
_raise_if_unsupported
(
execute_model_req
)
self
.
_raise_if_unsupported
(
execute_model_req
)
has_spec_out
=
False
has_spec_out
=
False
token_id_list
=
[]
token_id_list
:
List
[
Optional
[
torch
.
Tensor
]]
=
[]
token_prob_list
=
[]
token_prob_list
:
List
[
Optional
[
torch
.
Tensor
]]
=
[]
for
idx
,
seq_group_metadata
in
enumerate
(
for
idx
,
seq_group_metadata
in
enumerate
(
execute_model_req
.
seq_group_metadata_list
):
execute_model_req
.
seq_group_metadata_list
):
seq_data
=
next
(
iter
(
seq_group_metadata
.
seq_data
.
values
()))
seq_data
=
next
(
iter
(
seq_group_metadata
.
seq_data
.
values
()))
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
0e9164b4
...
@@ -7,8 +7,8 @@ from vllm.config import SpeculativeConfig
...
@@ -7,8 +7,8 @@ from vllm.config import SpeculativeConfig
from
vllm.distributed.communication_op
import
broadcast_tensor_dict
from
vllm.distributed.communication_op
import
broadcast_tensor_dict
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
ExecuteModelRequest
,
SequenceGroupMetadata
)
SamplerOutput
,
SequenceGroupMetadata
)
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
SpeculativeScorer
,
SpeculativeScores
)
...
@@ -516,13 +516,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -516,13 +516,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
topk_indices_by_step
=
topk_indices_by_step
.
tolist
()
topk_indices_by_step
=
topk_indices_by_step
.
tolist
()
# Construct the output on a per-step, per-sequence basis.
# Construct the output on a per-step, per-sequence basis.
sampler_output_list
=
[]
sampler_output_list
:
List
[
SamplerOutput
]
=
[]
for
step_index
in
range
(
num_steps
):
for
step_index
in
range
(
num_steps
):
if
all
(
token_id
==
-
1
if
all
(
token_id
==
-
1
for
token_id
in
accepted_token_ids_by_step
[
step_index
]):
for
token_id
in
accepted_token_ids_by_step
[
step_index
]):
break
break
step_output_token_ids
=
[]
step_output_token_ids
:
List
[
CompletionSequenceGroupOutput
]
=
[]
for
sequence_index
in
range
(
batch_size
):
for
sequence_index
in
range
(
batch_size
):
# Each sequence may have a different num_logprobs; retrieve it.
# Each sequence may have a different num_logprobs; retrieve it.
num_logprobs
=
num_logprobs_per_seq
[
sequence_index
]
num_logprobs
=
num_logprobs_per_seq
[
sequence_index
]
...
...
vllm/spec_decode/util.py
View file @
0e9164b4
...
@@ -26,10 +26,10 @@ def get_all_num_logprobs(
...
@@ -26,10 +26,10 @@ def get_all_num_logprobs(
sequence.
sequence.
"""
"""
all_num_logprobs
=
[]
all_num_logprobs
:
List
[
int
]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_group_metadata
in
seq_group_metadata_list
:
num_logprobs
=
seq_group_metadata
.
sampling_params
.
logprobs
num_logprobs
=
seq_group_metadata
.
sampling_params
.
logprobs
if
seq_group_metadata
.
sampling_params
.
logprobs
is
None
:
if
num_
logprobs
is
None
:
num_logprobs
=
0
num_logprobs
=
0
all_num_logprobs
.
append
(
num_logprobs
)
all_num_logprobs
.
append
(
num_logprobs
)
...
...
vllm/transformers_utils/detokenizer.py
View file @
0e9164b4
...
@@ -44,7 +44,7 @@ class Detokenizer:
...
@@ -44,7 +44,7 @@ class Detokenizer:
read_offset
=
0
read_offset
=
0
next_iter_prefix_offset
=
0
next_iter_prefix_offset
=
0
next_iter_read_offset
=
0
next_iter_read_offset
=
0
next_iter_tokens
=
[]
next_iter_tokens
:
List
[
str
]
=
[]
prev_tokens
=
None
prev_tokens
=
None
for
token_position
,
prompt_logprobs_for_token
in
enumerate
(
for
token_position
,
prompt_logprobs_for_token
in
enumerate
(
...
...
vllm/utils.py
View file @
0e9164b4
...
@@ -20,12 +20,13 @@ from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
...
@@ -20,12 +20,13 @@ from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
import
numpy
as
np
import
numpy
as
np
import
psutil
import
psutil
import
torch
import
torch
import
torch.types
from
typing_extensions
import
ParamSpec
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
enable_trace_function_call
,
init_logger
from
vllm.logger
import
enable_trace_function_call
,
init_logger
T
=
TypeVar
(
"T"
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
STR_DTYPE_TO_TORCH_DTYPE
=
{
STR_DTYPE_TO_TORCH_DTYPE
=
{
...
@@ -37,6 +38,10 @@ STR_DTYPE_TO_TORCH_DTYPE = {
...
@@ -37,6 +38,10 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"fp8_e5m2"
:
torch
.
uint8
,
"fp8_e5m2"
:
torch
.
uint8
,
}
}
P
=
ParamSpec
(
'P'
)
K
=
TypeVar
(
"K"
)
T
=
TypeVar
(
"T"
)
class
Device
(
enum
.
Enum
):
class
Device
(
enum
.
Enum
):
GPU
=
enum
.
auto
()
GPU
=
enum
.
auto
()
...
@@ -176,7 +181,7 @@ def random_uuid() -> str:
...
@@ -176,7 +181,7 @@ def random_uuid() -> str:
@
lru_cache
(
maxsize
=
None
)
@
lru_cache
(
maxsize
=
None
)
def
get_vllm_instance_id
():
def
get_vllm_instance_id
()
->
str
:
"""
"""
If the environment variable VLLM_INSTANCE_ID is set, return it.
If the environment variable VLLM_INSTANCE_ID is set, return it.
Otherwise, return a random UUID.
Otherwise, return a random UUID.
...
@@ -192,7 +197,7 @@ def in_wsl() -> bool:
...
@@ -192,7 +197,7 @@ def in_wsl() -> bool:
return
"microsoft"
in
" "
.
join
(
uname
()).
lower
()
return
"microsoft"
in
" "
.
join
(
uname
()).
lower
()
def
make_async
(
func
:
Callable
[
...
,
T
])
->
Callable
[
...
,
Awaitable
[
T
]]:
def
make_async
(
func
:
Callable
[
P
,
T
])
->
Callable
[
P
,
Awaitable
[
T
]]:
"""Take a blocking function, and run it on in an executor thread.
"""Take a blocking function, and run it on in an executor thread.
This function prevents the blocking function from blocking the
This function prevents the blocking function from blocking the
...
@@ -200,7 +205,7 @@ def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
...
@@ -200,7 +205,7 @@ def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
The code in this function needs to be thread safe.
The code in this function needs to be thread safe.
"""
"""
def
_async_wrapper
(
*
args
,
**
kwargs
)
->
asyncio
.
Future
:
def
_async_wrapper
(
*
args
:
P
.
args
,
**
kwargs
:
P
.
kwargs
)
->
asyncio
.
Future
:
loop
=
asyncio
.
get_event_loop
()
loop
=
asyncio
.
get_event_loop
()
p_func
=
partial
(
func
,
*
args
,
**
kwargs
)
p_func
=
partial
(
func
,
*
args
,
**
kwargs
)
return
loop
.
run_in_executor
(
executor
=
None
,
func
=
p_func
)
return
loop
.
run_in_executor
(
executor
=
None
,
func
=
p_func
)
...
@@ -325,7 +330,7 @@ def update_environment_variables(envs: Dict[str, str]):
...
@@ -325,7 +330,7 @@ def update_environment_variables(envs: Dict[str, str]):
os
.
environ
[
k
]
=
v
os
.
environ
[
k
]
=
v
def
chunk_list
(
lst
,
chunk_size
)
:
def
chunk_list
(
lst
:
List
[
T
]
,
chunk_size
:
int
)
->
List
[
List
[
T
]]
:
"""Yield successive chunk_size chunks from lst."""
"""Yield successive chunk_size chunks from lst."""
return
[
lst
[
i
:
i
+
chunk_size
]
for
i
in
range
(
0
,
len
(
lst
),
chunk_size
)]
return
[
lst
[
i
:
i
+
chunk_size
]
for
i
in
range
(
0
,
len
(
lst
),
chunk_size
)]
...
@@ -336,7 +341,7 @@ def cdiv(a: int, b: int) -> int:
...
@@ -336,7 +341,7 @@ def cdiv(a: int, b: int) -> int:
def
_generate_random_fp8
(
def
_generate_random_fp8
(
tensor
:
torch
.
t
ensor
,
tensor
:
torch
.
T
ensor
,
low
:
float
,
low
:
float
,
high
:
float
,
high
:
float
,
)
->
None
:
)
->
None
:
...
@@ -398,7 +403,10 @@ def create_kv_caches_with_random_flash(
...
@@ -398,7 +403,10 @@ def create_kv_caches_with_random_flash(
torch_dtype
=
get_kv_cache_torch_dtype
(
cache_dtype
,
model_dtype
)
torch_dtype
=
get_kv_cache_torch_dtype
(
cache_dtype
,
model_dtype
)
key_value_cache_shape
=
(
num_blocks
,
2
,
block_size
,
num_heads
,
head_size
)
key_value_cache_shape
=
(
num_blocks
,
2
,
block_size
,
num_heads
,
head_size
)
scale
=
head_size
**-
0.5
scale
=
head_size
**-
0.5
key_caches
,
value_caches
=
[],
[]
key_caches
:
List
[
torch
.
Tensor
]
=
[]
value_caches
:
List
[
torch
.
Tensor
]
=
[]
for
_
in
range
(
num_layers
):
for
_
in
range
(
num_layers
):
key_value_cache
=
torch
.
empty
(
size
=
key_value_cache_shape
,
key_value_cache
=
torch
.
empty
(
size
=
key_value_cache_shape
,
dtype
=
torch_dtype
,
dtype
=
torch_dtype
,
...
@@ -429,7 +437,7 @@ def create_kv_caches_with_random(
...
@@ -429,7 +437,7 @@ def create_kv_caches_with_random(
scale
=
head_size
**-
0.5
scale
=
head_size
**-
0.5
x
=
16
//
torch
.
tensor
([],
dtype
=
torch_dtype
).
element_size
()
x
=
16
//
torch
.
tensor
([],
dtype
=
torch_dtype
).
element_size
()
key_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
//
x
,
block_size
,
x
)
key_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
//
x
,
block_size
,
x
)
key_caches
=
[]
key_caches
:
List
[
torch
.
Tensor
]
=
[]
for
_
in
range
(
num_layers
):
for
_
in
range
(
num_layers
):
key_cache
=
torch
.
empty
(
size
=
key_cache_shape
,
key_cache
=
torch
.
empty
(
size
=
key_cache_shape
,
dtype
=
torch_dtype
,
dtype
=
torch_dtype
,
...
@@ -444,7 +452,7 @@ def create_kv_caches_with_random(
...
@@ -444,7 +452,7 @@ def create_kv_caches_with_random(
key_caches
.
append
(
key_cache
)
key_caches
.
append
(
key_cache
)
value_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
,
block_size
)
value_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
,
block_size
)
value_caches
=
[]
value_caches
:
List
[
torch
.
Tensor
]
=
[]
for
_
in
range
(
num_layers
):
for
_
in
range
(
num_layers
):
value_cache
=
torch
.
empty
(
size
=
value_cache_shape
,
value_cache
=
torch
.
empty
(
size
=
value_cache_shape
,
dtype
=
torch_dtype
,
dtype
=
torch_dtype
,
...
@@ -484,7 +492,7 @@ def is_pin_memory_available() -> bool:
...
@@ -484,7 +492,7 @@ def is_pin_memory_available() -> bool:
class
CudaMemoryProfiler
:
class
CudaMemoryProfiler
:
def
__init__
(
self
,
device
=
None
):
def
__init__
(
self
,
device
:
Optional
[
torch
.
types
.
Device
]
=
None
):
self
.
device
=
device
self
.
device
=
device
def
current_memory_usage
(
self
)
->
float
:
def
current_memory_usage
(
self
)
->
float
:
...
@@ -560,13 +568,13 @@ def get_dtype_size(dtype: torch.dtype) -> int:
...
@@ -560,13 +568,13 @@ def get_dtype_size(dtype: torch.dtype) -> int:
return
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
return
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
def
merge_dicts
(
dict1
:
Dict
[
Any
,
List
[
Any
]],
def
merge_dicts
(
dict1
:
Dict
[
K
,
List
[
T
]],
dict2
:
Dict
[
Any
,
List
[
Any
]])
->
Dict
[
Any
,
List
[
Any
]]:
dict2
:
Dict
[
K
,
List
[
T
]])
->
Dict
[
K
,
List
[
T
]]:
"""Merge 2 dicts that have key -> List of items.
"""Merge 2 dicts that have key -> List of items.
When a key conflicts, the values in dict1 is prioritized.
When a key conflicts, the values in dict1 is prioritized.
"""
"""
merged_dict
=
defaultdict
(
list
)
merged_dict
:
Dict
[
K
,
List
[
T
]]
=
defaultdict
(
list
)
for
key
,
value
in
dict1
.
items
():
for
key
,
value
in
dict1
.
items
():
merged_dict
[
key
].
extend
(
value
)
merged_dict
[
key
].
extend
(
value
)
...
@@ -577,7 +585,7 @@ def merge_dicts(dict1: Dict[Any, List[Any]],
...
@@ -577,7 +585,7 @@ def merge_dicts(dict1: Dict[Any, List[Any]],
return
dict
(
merged_dict
)
return
dict
(
merged_dict
)
def
init_cached_hf_modules
():
def
init_cached_hf_modules
()
->
None
:
"""
"""
Lazy initialization of the Hugging Face modules.
Lazy initialization of the Hugging Face modules.
"""
"""
...
@@ -613,7 +621,7 @@ def find_library(lib_name: str) -> str:
...
@@ -613,7 +621,7 @@ def find_library(lib_name: str) -> str:
return
locs
[
0
]
return
locs
[
0
]
def
find_nccl_library
():
def
find_nccl_library
()
->
str
:
"""
"""
We either use the library file specified by the `VLLM_NCCL_SO_PATH`
We either use the library file specified by the `VLLM_NCCL_SO_PATH`
environment variable, or we find the library file brought by PyTorch.
environment variable, or we find the library file brought by PyTorch.
...
...
vllm/worker/model_runner.py
View file @
0e9164b4
...
@@ -779,8 +779,8 @@ class ModelRunner:
...
@@ -779,8 +779,8 @@ class ModelRunner:
# that will have unique loras, an therefore the max amount of memory
# that will have unique loras, an therefore the max amount of memory
# consumption create dummy lora request copies from the lora request
# consumption create dummy lora request copies from the lora request
# passed in, which contains a lora from the lora warmup path.
# passed in, which contains a lora from the lora warmup path.
dummy_lora_requests
=
[]
dummy_lora_requests
:
List
[
LoRARequest
]
=
[]
dummy_lora_requests_per_seq
=
[]
dummy_lora_requests_per_seq
:
List
[
LoRARequest
]
=
[]
if
self
.
lora_config
:
if
self
.
lora_config
:
assert
self
.
lora_manager
is
not
None
assert
self
.
lora_manager
is
not
None
with
self
.
lora_manager
.
dummy_lora_cache
():
with
self
.
lora_manager
.
dummy_lora_cache
():
...
...
vllm/worker/worker_base.py
View file @
0e9164b4
...
@@ -99,8 +99,8 @@ class WorkerWrapperBase:
...
@@ -99,8 +99,8 @@ class WorkerWrapperBase:
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
worker_module_name
=
None
,
worker_module_name
:
str
,
worker_class_name
=
None
,
worker_class_name
:
str
,
trust_remote_code
:
bool
=
False
)
->
None
:
trust_remote_code
:
bool
=
False
)
->
None
:
self
.
worker_module_name
=
worker_module_name
self
.
worker_module_name
=
worker_module_name
self
.
worker_class_name
=
worker_class_name
self
.
worker_class_name
=
worker_class_name
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment