Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
df29793d
Unverified
Commit
df29793d
authored
Apr 29, 2024
by
SangBin Cho
Committed by
GitHub
Apr 28, 2024
Browse files
[mypy][5/N] Support all typing on model executor (#4427)
parent
03dd7d52
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
61 additions
and
34 deletions
+61
-34
.github/workflows/mypy.yaml
.github/workflows/mypy.yaml
+1
-1
format.sh
format.sh
+1
-1
vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py
...l_executor/guided_decoding/lm_format_enforcer_decoding.py
+1
-0
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+11
-1
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+2
-2
vllm/model_executor/layers/quantization/base_config.py
vllm/model_executor/layers/quantization/base_config.py
+11
-3
vllm/model_executor/layers/quantization/squeezellm.py
vllm/model_executor/layers/quantization/squeezellm.py
+2
-3
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+2
-2
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+27
-20
vllm/model_executor/model_loader/tensorizer.py
vllm/model_executor/model_loader/tensorizer.py
+3
-1
No files found.
.github/workflows/mypy.yaml
View file @
df29793d
...
@@ -43,8 +43,8 @@ jobs:
...
@@ -43,8 +43,8 @@ jobs:
mypy vllm/worker --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
mypy vllm/model_executor --config-file pyproject.toml
# TODO(sang): Fix nested dir
# TODO(sang): Fix nested dir
mypy vllm/model_executor/*.py --config-file pyproject.toml
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
format.sh
View file @
df29793d
...
@@ -105,7 +105,7 @@ mypy vllm/transformers_utils --config-file pyproject.toml
...
@@ -105,7 +105,7 @@ mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/engine
--config-file
pyproject.toml
mypy vllm/engine
--config-file
pyproject.toml
mypy vllm/worker
--config-file
pyproject.toml
mypy vllm/worker
--config-file
pyproject.toml
mypy vllm/spec_decode
--config-file
pyproject.toml
mypy vllm/spec_decode
--config-file
pyproject.toml
mypy vllm/model_executor
/
*
.py
--config-file
pyproject.toml
mypy vllm/model_executor
--config-file
pyproject.toml
mypy vllm/lora
--config-file
pyproject.toml
mypy vllm/lora
--config-file
pyproject.toml
...
...
vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py
View file @
df29793d
...
@@ -61,6 +61,7 @@ def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
...
@@ -61,6 +61,7 @@ def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
return
schema
return
schema
if
isinstance
(
schema
,
BaseModel
):
if
isinstance
(
schema
,
BaseModel
):
return
schema
.
model_json_schema
()
return
schema
.
model_json_schema
()
raise
AssertionError
(
f
"Unsupported schema type
{
schema
}
"
)
@
lru_cache
@
lru_cache
...
...
vllm/model_executor/layers/linear.py
View file @
df29793d
...
@@ -128,7 +128,8 @@ class LinearBase(torch.nn.Module):
...
@@ -128,7 +128,8 @@ class LinearBase(torch.nn.Module):
params_dtype
=
torch
.
get_default_dtype
()
params_dtype
=
torch
.
get_default_dtype
()
self
.
params_dtype
=
params_dtype
self
.
params_dtype
=
params_dtype
if
quant_config
is
None
:
if
quant_config
is
None
:
self
.
quant_method
=
UnquantizedLinearMethod
()
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
UnquantizedLinearMethod
()
else
:
else
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
)
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
)
...
@@ -160,6 +161,8 @@ class ReplicatedLinear(LinearBase):
...
@@ -160,6 +161,8 @@ class ReplicatedLinear(LinearBase):
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
)
quant_config
)
# All the linear layer supports quant method.
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
self
,
self
.
input_size
,
self
.
quant_method
.
create_weights
(
self
,
self
.
input_size
,
[
self
.
output_size
],
self
.
input_size
,
[
self
.
output_size
],
self
.
input_size
,
self
.
output_size
,
self
.
params_dtype
)
self
.
output_size
,
self
.
params_dtype
)
...
@@ -173,6 +176,7 @@ class ReplicatedLinear(LinearBase):
...
@@ -173,6 +176,7 @@ class ReplicatedLinear(LinearBase):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
assert
self
.
quant_method
is
not
None
output
=
self
.
quant_method
.
apply
(
self
,
x
,
bias
)
output
=
self
.
quant_method
.
apply
(
self
,
x
,
bias
)
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
return
output
,
output_bias
return
output
,
output_bias
...
@@ -221,6 +225,8 @@ class ColumnParallelLinear(LinearBase):
...
@@ -221,6 +225,8 @@ class ColumnParallelLinear(LinearBase):
self
.
output_size_per_partition
=
divide
(
output_size
,
tp_size
)
self
.
output_size_per_partition
=
divide
(
output_size
,
tp_size
)
if
output_sizes
is
None
:
if
output_sizes
is
None
:
output_sizes
=
[
output_size
]
output_sizes
=
[
output_size
]
# All the linear layer supports quant method.
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
self
,
self
.
quant_method
.
create_weights
(
self
,
self
.
input_size
,
self
.
input_size
,
[
x
//
tp_size
for
x
in
output_sizes
],
[
x
//
tp_size
for
x
in
output_sizes
],
...
@@ -255,6 +261,7 @@ class ColumnParallelLinear(LinearBase):
...
@@ -255,6 +261,7 @@ class ColumnParallelLinear(LinearBase):
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
# Matrix multiply.
# Matrix multiply.
assert
self
.
quant_method
is
not
None
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
)
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
)
if
self
.
gather_output
:
if
self
.
gather_output
:
# All-gather across the partitions.
# All-gather across the partitions.
...
@@ -579,6 +586,8 @@ class RowParallelLinear(LinearBase):
...
@@ -579,6 +586,8 @@ class RowParallelLinear(LinearBase):
# Divide the weight matrix along the last dimension.
# Divide the weight matrix along the last dimension.
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
tp_size
)
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
tp_size
)
# All the linear layer supports quant method.
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
self
,
self
.
quant_method
.
create_weights
(
self
,
self
.
input_size_per_partition
,
self
.
input_size_per_partition
,
[
self
.
output_size
],
[
self
.
output_size
],
...
@@ -624,6 +633,7 @@ class RowParallelLinear(LinearBase):
...
@@ -624,6 +633,7 @@ class RowParallelLinear(LinearBase):
input_parallel
=
splitted_input
[
tp_rank
].
contiguous
()
input_parallel
=
splitted_input
[
tp_rank
].
contiguous
()
# Matrix multiply.
# Matrix multiply.
assert
self
.
quant_method
is
not
None
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
)
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
output_
=
tensor_model_parallel_all_reduce
(
output_parallel
)
output_
=
tensor_model_parallel_all_reduce
(
output_parallel
)
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
df29793d
from
typing
import
Type
from
typing
import
Dict
,
Type
from
vllm.model_executor.layers.quantization.aqlm
import
AQLMConfig
from
vllm.model_executor.layers.quantization.aqlm
import
AQLMConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
...
@@ -9,7 +9,7 @@ from vllm.model_executor.layers.quantization.gptq import GPTQConfig
...
@@ -9,7 +9,7 @@ from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
from
vllm.model_executor.layers.quantization.squeezellm
import
SqueezeLLMConfig
from
vllm.model_executor.layers.quantization.squeezellm
import
SqueezeLLMConfig
QUANTIZATION_METHODS
=
{
QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
"aqlm"
:
AQLMConfig
,
"aqlm"
:
AQLMConfig
,
"awq"
:
AWQConfig
,
"awq"
:
AWQConfig
,
"fp8"
:
Fp8Config
,
"fp8"
:
Fp8Config
,
...
...
vllm/model_executor/layers/quantization/base_config.py
View file @
df29793d
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Dict
,
List
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -76,8 +76,16 @@ class QuantizationConfig(ABC):
...
@@ -76,8 +76,16 @@ class QuantizationConfig(ABC):
"quantization config."
)
"quantization config."
)
@
abstractmethod
@
abstractmethod
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
QuantizeMethodBase
:
def
get_quant_method
(
"""Get the quantize method to use for the quantized layer."""
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
QuantizeMethodBase
]:
"""Get the quantize method to use for the quantized layer.
Args:
layer: The layer for the quant method.
Returns:
The quantize method. None if the given layer doesn't support quant
method.
"""
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
...
...
vllm/model_executor/layers/quantization/squeezellm.py
View file @
df29793d
...
@@ -52,11 +52,10 @@ class SqueezeLLMConfig(QuantizationConfig):
...
@@ -52,11 +52,10 @@ class SqueezeLLMConfig(QuantizationConfig):
return
cls
(
weight_bits
)
return
cls
(
weight_bits
)
def
get_quant_method
(
def
get_quant_method
(
self
,
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
QuantizeMethodBase
]:
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"SqueezeLLMLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
return
SqueezeLLMLinearMethod
(
self
)
return
SqueezeLLMLinearMethod
(
self
)
return
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
return
[]
...
...
vllm/model_executor/layers/rotary_embedding.py
View file @
df29793d
...
@@ -431,8 +431,8 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
...
@@ -431,8 +431,8 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
torch
.
full_like
(
positions
,
k
)).
long
()
torch
.
full_like
(
positions
,
k
)).
long
()
idx
=
(
torch
.
add
(
positions
,
long_prompt_offset
)
idx
=
(
torch
.
add
(
positions
,
long_prompt_offset
)
if
long_prompt_offset
is
not
None
else
positions
)
if
long_prompt_offset
is
not
None
else
positions
)
self
.
long_short_cos_sin_cache
=
self
.
long_short_cos_sin_cache
.
to
(
self
.
long_short_cos_sin_cache
:
torch
.
Tensor
=
(
idx
.
device
)
self
.
long_short_cos_sin_cache
.
to
(
idx
.
device
)
)
idx
=
torch
.
add
(
idx
,
offsets
)
if
offsets
is
not
None
else
idx
idx
=
torch
.
add
(
idx
,
offsets
)
if
offsets
is
not
None
else
idx
cos_sin
=
torch
.
index_select
(
self
.
long_short_cos_sin_cache
,
0
,
idx
)
cos_sin
=
torch
.
index_select
(
self
.
long_short_cos_sin_cache
,
0
,
idx
)
...
...
vllm/model_executor/layers/sampler.py
View file @
df29793d
...
@@ -13,6 +13,9 @@ from vllm.sampling_params import SamplingType
...
@@ -13,6 +13,9 @@ from vllm.sampling_params import SamplingType
from
vllm.sequence
import
(
Logprob
,
PromptLogprobs
,
SampleLogprobs
,
from
vllm.sequence
import
(
Logprob
,
PromptLogprobs
,
SampleLogprobs
,
SamplerOutput
,
SequenceGroupOutput
,
SequenceOutput
)
SamplerOutput
,
SequenceGroupOutput
,
SequenceOutput
)
# (num_token_ids, num_parent_ids) per sequence group.
SampleResultType
=
List
[
Tuple
[
List
[
int
],
List
[
int
]]]
class
Sampler
(
nn
.
Module
):
class
Sampler
(
nn
.
Module
):
"""Samples the next tokens from the model's outputs.
"""Samples the next tokens from the model's outputs.
...
@@ -155,7 +158,7 @@ def _apply_min_tokens_penalty(
...
@@ -155,7 +158,7 @@ def _apply_min_tokens_penalty(
have not been generated yet
have not been generated yet
"""
"""
# list of indices in logits that will be set to -inf
# list of indices in logits that will be set to -inf
logits_to_penalize
=
[]
logits_to_penalize
:
List
[
Tuple
[
int
,
int
]]
=
[]
logits_applied
=
0
logits_applied
=
0
for
seq_group
in
sampling_metadata
.
seq_groups
:
for
seq_group
in
sampling_metadata
.
seq_groups
:
seq_ids
=
seq_group
.
seq_ids
seq_ids
=
seq_group
.
seq_ids
...
@@ -269,7 +272,7 @@ def _apply_min_p(
...
@@ -269,7 +272,7 @@ def _apply_min_p(
def
_greedy_sample
(
def
_greedy_sample
(
selected_seq_groups
:
List
[
SequenceGroupToSample
],
selected_seq_groups
:
List
[
SequenceGroupToSample
],
samples
:
torch
.
Tensor
,
samples
:
torch
.
Tensor
,
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]
:
)
->
SampleResultType
:
"""Run greedy sampling on a given samples.
"""Run greedy sampling on a given samples.
Args:
Args:
...
@@ -284,7 +287,7 @@ def _greedy_sample(
...
@@ -284,7 +287,7 @@ def _greedy_sample(
"""
"""
samples
=
samples
.
tolist
()
samples
=
samples
.
tolist
()
sample_idx
=
0
sample_idx
=
0
results
=
[]
results
:
SampleResultType
=
[]
for
seq_group
in
selected_seq_groups
:
for
seq_group
in
selected_seq_groups
:
if
not
seq_group
.
do_sample
:
if
not
seq_group
.
do_sample
:
results
.
append
(([],
[]))
results
.
append
(([],
[]))
...
@@ -304,7 +307,7 @@ def _greedy_sample(
...
@@ -304,7 +307,7 @@ def _greedy_sample(
def
_random_sample
(
def
_random_sample
(
selected_seq_groups
:
List
[
SequenceGroupToSample
],
selected_seq_groups
:
List
[
SequenceGroupToSample
],
random_samples
:
torch
.
Tensor
,
random_samples
:
torch
.
Tensor
,
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]
:
)
->
SampleResultType
:
"""Run random sampling on a given samples.
"""Run random sampling on a given samples.
Args:
Args:
...
@@ -320,7 +323,7 @@ def _random_sample(
...
@@ -320,7 +323,7 @@ def _random_sample(
# Find the maximum best_of value of the prompt phase requests.
# Find the maximum best_of value of the prompt phase requests.
random_samples
=
random_samples
.
cpu
()
random_samples
=
random_samples
.
cpu
()
sample_idx
=
0
sample_idx
=
0
results
=
[]
results
:
SampleResultType
=
[]
for
seq_group
in
selected_seq_groups
:
for
seq_group
in
selected_seq_groups
:
if
not
seq_group
.
do_sample
:
if
not
seq_group
.
do_sample
:
results
.
append
(([],
[]))
results
.
append
(([],
[]))
...
@@ -348,7 +351,7 @@ def _random_sample(
...
@@ -348,7 +351,7 @@ def _random_sample(
def
_beam_search_sample
(
def
_beam_search_sample
(
selected_seq_groups
:
List
[
SequenceGroupToSample
],
selected_seq_groups
:
List
[
SequenceGroupToSample
],
logprobs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]
:
)
->
SampleResultType
:
"""Run beam sampling on a given samples.
"""Run beam sampling on a given samples.
Args:
Args:
...
@@ -370,7 +373,7 @@ def _beam_search_sample(
...
@@ -370,7 +373,7 @@ def _beam_search_sample(
# NOTE: Beam search is not vectorized, so its speed can be slower than
# NOTE: Beam search is not vectorized, so its speed can be slower than
# other sampling methods.
# other sampling methods.
sample_idx
=
0
sample_idx
=
0
results
=
[]
results
:
SampleResultType
=
[]
for
seq_group
in
selected_seq_groups
:
for
seq_group
in
selected_seq_groups
:
if
not
seq_group
.
do_sample
:
if
not
seq_group
.
do_sample
:
results
.
append
(([],
[]))
results
.
append
(([],
[]))
...
@@ -391,16 +394,16 @@ def _beam_search_sample(
...
@@ -391,16 +394,16 @@ def _beam_search_sample(
next_token_ids
=
next_token_ids
.
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
else
:
else
:
# Generation phase.
# Generation phase.
cumulative_logprobs
=
[
cumulative_logprobs
:
List
[
int
]
=
[
seq_group
.
seq_data
[
seq_id
].
cumulative_logprob
seq_group
.
seq_data
[
seq_id
].
cumulative_logprob
for
seq_id
in
seq_ids
for
seq_id
in
seq_ids
]
]
cumulative_logprobs
=
torch
.
tensor
(
cumulative_logprobs
_tensor
=
torch
.
tensor
(
cumulative_logprobs
,
cumulative_logprobs
,
dtype
=
torch
.
float
,
dtype
=
torch
.
float
,
device
=
seq_group_logprobs
.
device
)
device
=
seq_group_logprobs
.
device
)
seq_group_logprobs
=
(
seq_group_logprobs
+
seq_group_logprobs
=
(
seq_group_logprobs
+
cumulative_logprobs
.
unsqueeze
(
dim
=
1
))
cumulative_logprobs
_tensor
.
unsqueeze
(
dim
=
1
))
_
,
topk_ids
=
torch
.
topk
(
seq_group_logprobs
.
flatten
(),
_
,
topk_ids
=
torch
.
topk
(
seq_group_logprobs
.
flatten
(),
2
*
beam_width
)
2
*
beam_width
)
topk_ids
=
topk_ids
.
tolist
()
topk_ids
=
topk_ids
.
tolist
()
...
@@ -452,8 +455,10 @@ def _sample_with_torch(
...
@@ -452,8 +455,10 @@ def _sample_with_torch(
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
include_gpu_probs_tensor
:
bool
,
include_gpu_probs_tensor
:
bool
,
modify_greedy_probs
:
bool
,
modify_greedy_probs
:
bool
,
)
->
Tuple
[
List
[
Tuple
[
List
[
int
],
List
[
int
]]],
Optional
[
torch
.
Tensor
]]:
)
->
Tuple
[
SampleResultType
,
Optional
[
torch
.
Tensor
]]:
categorized_seq_group_ids
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_seq_group_ids
:
Dict
[
SamplingType
,
List
[
int
]]
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
sampling_params
=
seq_group
.
sampling_params
sampling_params
=
seq_group
.
sampling_params
...
@@ -555,8 +560,10 @@ def _sample_with_triton_kernel(
...
@@ -555,8 +560,10 @@ def _sample_with_triton_kernel(
logprobs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
sampling_tensors
:
SamplingTensors
,
sampling_tensors
:
SamplingTensors
,
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]:
)
->
SampleResultType
:
categorized_seq_group_ids
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_seq_group_ids
:
Dict
[
SamplingType
,
List
[
int
]]
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
sampling_params
=
seq_group
.
sampling_params
sampling_params
=
seq_group
.
sampling_params
...
@@ -632,7 +639,7 @@ def _sample(
...
@@ -632,7 +639,7 @@ def _sample(
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_tensors
:
SamplingTensors
,
sampling_metadata
:
SamplingMetadata
,
sampling_tensors
:
SamplingTensors
,
include_gpu_probs_tensor
:
bool
,
modify_greedy_probs
:
bool
include_gpu_probs_tensor
:
bool
,
modify_greedy_probs
:
bool
)
->
Tuple
[
List
[
Tuple
[
List
[
int
],
List
[
int
]]]
,
Optional
[
torch
.
Tensor
]]:
)
->
Tuple
[
SampleResultType
,
Optional
[
torch
.
Tensor
]]:
"""
"""
Args:
Args:
probs: (num_query_tokens_in_batch, num_vocab)
probs: (num_query_tokens_in_batch, num_vocab)
...
@@ -680,7 +687,7 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
...
@@ -680,7 +687,7 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
def
_get_logprobs
(
def
_get_logprobs
(
logprobs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
sample_results
:
List
[
Tuple
[
List
[
int
],
List
[
int
]]]
,
sample_results
:
SampleResultType
,
)
->
Tuple
[
List
[
Optional
[
PromptLogprobs
]],
List
[
SampleLogprobs
]]:
)
->
Tuple
[
List
[
Optional
[
PromptLogprobs
]],
List
[
SampleLogprobs
]]:
"""Return sample lobprobs and prompt logprobs.
"""Return sample lobprobs and prompt logprobs.
...
@@ -751,8 +758,8 @@ def _get_logprobs(
...
@@ -751,8 +758,8 @@ def _get_logprobs(
assert
len
(
next_token_ids
)
==
len
(
query_indices
)
assert
len
(
next_token_ids
)
==
len
(
query_indices
)
if
len
(
query_indices
)
==
0
:
if
len
(
query_indices
)
==
0
:
empty_sampled_logprob
=
[]
empty_sampled_logprob
:
SampleLogprobs
=
[]
empty_prompt_logprob
=
None
empty_prompt_logprob
:
Optional
[
PromptLogprobs
]
=
None
return
[
empty_prompt_logprob
],
[
empty_sampled_logprob
]
return
[
empty_prompt_logprob
],
[
empty_sampled_logprob
]
query_indices_gpu
=
torch
.
tensor
(
query_indices
,
device
=
logprobs
.
device
)
query_indices_gpu
=
torch
.
tensor
(
query_indices
,
device
=
logprobs
.
device
)
...
@@ -965,7 +972,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
...
@@ -965,7 +972,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
def
_build_sampler_output
(
def
_build_sampler_output
(
sample_results
:
List
[
Tuple
[
List
[
int
],
List
[
int
]]]
,
sample_results
:
SampleResultType
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
prompt_logprobs
:
List
[
Optional
[
PromptLogprobs
]],
prompt_logprobs
:
List
[
Optional
[
PromptLogprobs
]],
sample_logprobs
:
List
[
SampleLogprobs
],
sample_logprobs
:
List
[
SampleLogprobs
],
...
@@ -1009,7 +1016,7 @@ def _build_sampler_output(
...
@@ -1009,7 +1016,7 @@ def _build_sampler_output(
)
)
def
_get_next_prompt_tokens
(
seq_group
:
SequenceGroupToSample
)
->
List
[
str
]:
def
_get_next_prompt_tokens
(
seq_group
:
SequenceGroupToSample
)
->
List
[
int
]:
"""Get a list of next prompt tokens to compute logprob from a
"""Get a list of next prompt tokens to compute logprob from a
given sequence group.
given sequence group.
...
...
vllm/model_executor/model_loader/tensorizer.py
View file @
df29793d
...
@@ -64,7 +64,7 @@ class TensorizerConfig:
...
@@ -64,7 +64,7 @@ class TensorizerConfig:
"s3_secret_access_key"
:
self
.
s3_secret_access_key
,
"s3_secret_access_key"
:
self
.
s3_secret_access_key
,
"s3_endpoint"
:
self
.
s3_endpoint
,
"s3_endpoint"
:
self
.
s3_endpoint
,
}
}
return
TensorizerArgs
(
**
tensorizer_args
)
return
TensorizerArgs
(
**
tensorizer_args
)
# type: ignore
def
verify_with_parallel_config
(
def
verify_with_parallel_config
(
self
,
self
,
...
@@ -270,8 +270,10 @@ class TensorizerAgent:
...
@@ -270,8 +270,10 @@ class TensorizerAgent:
self
.
model
=
self
.
_init_model
()
self
.
model
=
self
.
_init_model
()
def
_init_model
(
self
):
def
_init_model
(
self
):
assert
self
.
tensorizer_config
.
hf_config
is
not
None
model_args
=
self
.
tensorizer_config
.
hf_config
model_args
=
self
.
tensorizer_config
.
hf_config
model_args
.
torch_dtype
=
self
.
tensorizer_config
.
dtype
model_args
.
torch_dtype
=
self
.
tensorizer_config
.
dtype
assert
self
.
tensorizer_config
.
model_class
is
not
None
with
no_init_or_tensor
():
with
no_init_or_tensor
():
return
self
.
tensorizer_config
.
model_class
(
return
self
.
tensorizer_config
.
model_class
(
config
=
model_args
,
config
=
model_args
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment