Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
ead94d93
Commit
ead94d93
authored
Jan 16, 2024
by
zhuwenwen
Browse files
Merge remote-tracking branch 'mirror/main'
parents
fcffb7c8
f780504d
Changes
40
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1220 additions
and
145 deletions
+1220
-145
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+2
-2
tests/samplers/test_logprobs.py
tests/samplers/test_logprobs.py
+1
-0
tests/samplers/test_rejection_sampler.py
tests/samplers/test_rejection_sampler.py
+392
-0
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+74
-0
vllm/core/policy.py
vllm/core/policy.py
+10
-8
vllm/core/scheduler.py
vllm/core/scheduler.py
+42
-27
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+51
-1
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+109
-8
vllm/engine/ray_utils.py
vllm/engine/ray_utils.py
+1
-1
vllm/entrypoints/api_server.py
vllm/entrypoints/api_server.py
+6
-0
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+25
-1
vllm/model_executor/layers/attention.py
vllm/model_executor/layers/attention.py
+9
-14
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+4
-1
vllm/model_executor/layers/rejection_sampler.py
vllm/model_executor/layers/rejection_sampler.py
+392
-0
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+15
-15
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+1
-1
vllm/model_executor/models/phi.py
vllm/model_executor/models/phi.py
+59
-61
vllm/utils.py
vllm/utils.py
+3
-1
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+13
-3
vllm/worker/worker.py
vllm/worker/worker.py
+11
-1
No files found.
tests/kernels/test_cache.py
View file @
ead94d93
...
@@ -6,12 +6,12 @@ import torch
...
@@ -6,12 +6,12 @@ import torch
from
vllm._C
import
cache_ops
from
vllm._C
import
cache_ops
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
NUM_TOKENS
=
[
83
]
# Arbitrary values for testing
NUM_TOKENS
=
[
42
]
# Arbitrary values for testing
NUM_LAYERS
=
[
1
]
# Arbitrary values for testing
NUM_LAYERS
=
[
1
]
# Arbitrary values for testing
NUM_HEADS
=
[
8
]
# Arbitrary values for testing
NUM_HEADS
=
[
8
]
# Arbitrary values for testing
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
BLOCK_SIZES
=
[
8
,
16
,
32
]
BLOCK_SIZES
=
[
8
,
16
,
32
]
NUM_BLOCKS
=
[
1024
,
3600
0
]
# Arbitrary values for testing
NUM_BLOCKS
=
[
1024
,
3600
]
# Arbitrary values for testing
NUM_MAPPINGS
=
[
256
]
# Arbitrary values for testing
NUM_MAPPINGS
=
[
256
]
# Arbitrary values for testing
SEEDS
=
[
0
]
SEEDS
=
[
0
]
DEVICES
=
[
i
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)]
DEVICES
=
[
i
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)]
...
...
tests/samplers/test_logprobs.py
View file @
ead94d93
...
@@ -30,6 +30,7 @@ def test_get_prompt_logprobs(
...
@@ -30,6 +30,7 @@ def test_get_prompt_logprobs(
temperature
=
0.0
)
temperature
=
0.0
)
vllm_results
=
vllm_model
.
model
.
generate
(
vllm_results
=
vllm_model
.
model
.
generate
(
example_prompts
,
sampling_params
=
vllm_sampling_params
)
example_prompts
,
sampling_params
=
vllm_sampling_params
)
del
vllm_model
# Test whether logprobs are included in the results.
# Test whether logprobs are included in the results.
for
result
in
vllm_results
:
for
result
in
vllm_results
:
...
...
tests/samplers/test_rejection_sampler.py
0 → 100644
View file @
ead94d93
"""Tests for rejection sampling."""
import
pytest
from
typing
import
List
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
def
mock_causal_accepted_tensor
(
k
:
int
,
last_accepted_indices
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Generate an "accepted" tensor which should yield causally-accepted tokens
up to last accepted indices.
Tokens after last_accepted_indices+1 may also be accepted, although they
will not be causally accepted.
"""
batch_size
=
last_accepted_indices
.
shape
[
0
]
accepted
=
(
torch
.
arange
(
k
).
expand
(
batch_size
,
k
)
<=
last_accepted_indices
.
unsqueeze
(
-
1
).
broadcast_to
(
batch_size
,
k
)).
to
(
device
=
"cuda"
)
# Sprinkle accepted values after the contiguous initial accepted values.
# This replicates the behavior of rejection sampling, which may "accept"
# a token that cannot be accepted because of causality.
sprinkle_candidates
=
(
torch
.
arange
(
k
).
expand
(
batch_size
,
k
)
>
last_accepted_indices
.
unsqueeze
(
-
1
).
broadcast_to
(
batch_size
,
k
)
+
1
)
sprinkle
=
torch
.
rand
(
batch_size
,
k
,
device
=
"cuda"
)
>
0.5
accepted
[
sprinkle_candidates
]
=
sprinkle
[
sprinkle_candidates
]
return
accepted
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
10
)))
@
pytest
.
mark
.
parametrize
(
"which_tokens_accepted"
,
[
"all_tokens_accepted"
,
"no_tokens_accepted"
,
"some_tokens_accepted"
])
@
torch
.
inference_mode
()
def
test_correct_output_format
(
which_tokens_accepted
:
str
,
seed
:
int
):
"""Verify the output has correct format given predetermined accepted matrix.
"""
set_random_seed
(
seed
)
batch_size
=
10
k
=
5
vocab_size
=
3000
if
which_tokens_accepted
==
"all_tokens_accepted"
:
accepted
=
mock_causal_accepted_tensor
(
k
,
-
1
+
k
*
torch
.
ones
((
batch_size
,
),
dtype
=
torch
.
long
))
elif
which_tokens_accepted
==
"no_tokens_accepted"
:
accepted
=
mock_causal_accepted_tensor
(
k
,
-
torch
.
ones
((
batch_size
,
),
dtype
=
torch
.
long
))
elif
which_tokens_accepted
==
"some_tokens_accepted"
:
last_accepted_indices
=
torch
.
randint
(
low
=-
1
,
high
=
k
,
size
=
(
batch_size
,
))
accepted
=
mock_causal_accepted_tensor
(
k
,
last_accepted_indices
)
else
:
raise
AssertionError
()
recovered_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
draft_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
rejection_sampler
=
RejectionSampler
()
rejection_sampler
.
init_gpu_tensors
(
rank
=
0
)
output_token_ids
=
rejection_sampler
.
_create_output
(
# pylint: disable=protected-access
accepted
,
recovered_token_ids
,
draft_token_ids
,
bonus_token_ids
,
)
if
which_tokens_accepted
==
"all_tokens_accepted"
:
# Expect all tokens to be equal to draft tokens.
assert
torch
.
equal
(
output_token_ids
[:,
:
-
1
],
draft_token_ids
)
# Expect all bonus tokens to be included.
assert
torch
.
equal
(
output_token_ids
[:,
-
1
:],
bonus_token_ids
)
elif
which_tokens_accepted
==
"no_tokens_accepted"
:
# Expect first token to be equal to recovered tokens.
assert
torch
.
equal
(
output_token_ids
[:,
0
],
recovered_token_ids
[:,
0
])
# Expect everything else to be -1.
assert
torch
.
equal
(
output_token_ids
[:,
1
:],
torch
.
ones_like
(
output_token_ids
[:,
1
:])
*
-
1
)
elif
which_tokens_accepted
==
"some_tokens_accepted"
:
recovered_plus_bonus
=
torch
.
cat
(
(
recovered_token_ids
,
bonus_token_ids
),
dim
=-
1
)
# Assert first rejected token is a recovered token or bonus token.
assert
torch
.
equal
(
recovered_plus_bonus
[
torch
.
arange
(
0
,
batch_size
),
last_accepted_indices
+
1
],
output_token_ids
[
torch
.
arange
(
0
,
batch_size
),
last_accepted_indices
+
1
])
# Assert every subsequent token is -1.
subsequent_mask
=
torch
.
arange
(
0
,
k
+
1
).
expand
(
batch_size
,
k
+
1
)
>=
(
last_accepted_indices
+
2
).
unsqueeze
(
-
1
)
assert
torch
.
all
(
output_token_ids
[
subsequent_mask
]
==
-
1
)
@
pytest
.
mark
.
parametrize
(
"k"
,
list
(
range
(
1
,
6
)))
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
30_000
,
50_000
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
list
(
range
(
1
,
32
)))
@
torch
.
inference_mode
()
def
test_no_crash_with_varying_dims
(
k
:
int
,
vocab_size
:
int
,
batch_size
:
int
):
rejection_sampler
=
RejectionSampler
()
rejection_sampler
.
init_gpu_tensors
(
rank
=
0
)
draft_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
target_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
draft_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
rejection_sampler
(
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
)
@
pytest
.
mark
.
parametrize
(
"above_or_below_vocab_range"
,
[
"above"
,
"below"
])
@
pytest
.
mark
.
parametrize
(
"which_token_ids"
,
[
"bonus_token_ids"
,
"draft_token_ids"
])
@
torch
.
inference_mode
()
def
test_raises_when_vocab_oob
(
above_or_below_vocab_range
:
str
,
which_token_ids
:
str
):
k
=
3
batch_size
=
5
vocab_size
=
30_000
rejection_sampler
=
RejectionSampler
(
strict_mode
=
True
)
rejection_sampler
.
init_gpu_tensors
(
rank
=
0
)
draft_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
target_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
draft_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
oob_token_ids
=
None
if
which_token_ids
==
"bonus_token_ids"
:
oob_token_ids
=
bonus_token_ids
elif
which_token_ids
==
"draft_token_ids"
:
oob_token_ids
=
draft_token_ids
else
:
raise
AssertionError
()
if
above_or_below_vocab_range
==
"above"
:
rogue_token_id
=
vocab_size
+
1
elif
above_or_below_vocab_range
==
"below"
:
rogue_token_id
=
-
1
else
:
raise
AssertionError
()
oob_token_ids
[
0
][
0
]
=
rogue_token_id
with
pytest
.
raises
(
AssertionError
):
rejection_sampler
(
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
)
@
pytest
.
mark
.
parametrize
(
"draft_and_target_probs_equal"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
5
)))
@
torch
.
inference_mode
()
def
test_rejection_sampling_approximates_target_distribution
(
seed
:
int
,
draft_and_target_probs_equal
:
bool
):
"""Verify rejection sampling approximates target distribution,
despite sampling from a potentially distinct draft distribution.
This is done by first creating a random target probability
distribution and a random draft probability distribution. We then
sample token ids from the rejection sampler using these draft
and target distributions. The samples are used to estimate
the output probability distribution, which we expect to approximate
the target distribution.
A basic distance metric is used to determine similarity between
distributions.
We expect that as we increase the number of samples,
the distance between the observed distribution and the target
distribution decreases. To measure this, we compare the distance
of the observed distribution against both the target distribution
and a uniform random distribution. We expect the distance between
the observed distribution and the target distribution to improve
much more than the distance improvement between the observed
distribution and the random distribution.
When draft_and_target_probs_equal=True, the draft and target
probabilities are exactly equal. Rejection sampling should
still work without any NaNs or exceptions.
"""
set_random_seed
(
seed
)
helper
=
_CorrectnessTestHelper
(
vocab_size
=
10
,
rejection_sampler
=
RejectionSampler
(),
)
draft_probs
,
target_probs
,
reference_probs
=
helper
.
generate_probs_for_test
(
draft_and_target_probs_equal
)
sample_sizes
=
[
10
,
100
,
1_000
,
10_000
,
100_000
]
distance_wrt_reference
=
[]
distance_wrt_target
=
[]
for
num_samples
in
sample_sizes
:
(
reference_vs_rejsample_dist
,
target_vs_rejsample_dist
)
=
helper
.
run_and_compare_distributions
(
draft_probs
,
target_probs
,
reference_probs
,
num_samples
,
)
distance_wrt_reference
.
append
(
reference_vs_rejsample_dist
)
distance_wrt_target
.
append
(
target_vs_rejsample_dist
)
relative_change_in_distance_wrt_target
=
get_ratio_first_to_last
(
distance_wrt_target
)
relative_change_in_distance_wrt_reference
=
get_ratio_first_to_last
(
distance_wrt_reference
)
print
(
f
"
{
num_samples
=
}
{
target_vs_rejsample_dist
=
:.
05
f
}
"
f
"
{
reference_vs_rejsample_dist
=
:.
05
f
}
"
)
print
(
f
"
{
num_samples
=
}
{
relative_change_in_distance_wrt_target
=
:.
02
f
}
"
f
"
{
relative_change_in_distance_wrt_reference
=
:.
02
f
}
"
)
relative_change_in_distance_wrt_target
=
get_ratio_first_to_last
(
distance_wrt_target
)
relative_change_in_distance_wrt_reference
=
get_ratio_first_to_last
(
distance_wrt_reference
)
expected_improvement_multiplier
=
20
assert
(
relative_change_in_distance_wrt_target
>
relative_change_in_distance_wrt_reference
*
expected_improvement_multiplier
)
def
get_ratio_first_to_last
(
elements
:
List
[
float
])
->
float
:
return
elements
[
0
]
/
elements
[
-
1
]
class
_CorrectnessTestHelper
:
"""Class that packages together logic required for the unit-level
rejection sampling correctness test.
"""
def
__init__
(
self
,
vocab_size
:
int
,
rejection_sampler
:
RejectionSampler
):
self
.
rejection_sampler
=
rejection_sampler
self
.
vocab_size
=
vocab_size
self
.
vocab_range
=
(
0
,
vocab_size
)
self
.
rejection_sampler
.
init_gpu_tensors
(
rank
=
0
)
# Keep test simple, use k=1
self
.
k
=
1
# Bonus tokens not used, but rejection sampler requires
# correct shape.
self
.
num_bonus_tokens
=
1
def
generate_probs_for_test
(
self
,
draft_and_target_probs_equal
:
bool
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
draft_probs
,
target_probs
=
[
F
.
softmax
(
torch
.
rand
(
self
.
vocab_size
,
dtype
=
torch
.
float32
),
dim
=-
1
,
)
for
_
in
range
(
2
)
]
num_reference_probs
=
100
reference_probs
=
F
.
softmax
(
torch
.
rand
(
num_reference_probs
,
self
.
vocab_size
,
dtype
=
torch
.
float32
),
dim
=-
1
,
)
if
draft_and_target_probs_equal
:
target_probs
=
draft_probs
.
clone
()
return
draft_probs
,
target_probs
,
reference_probs
def
run_and_compare_distributions
(
self
,
draft_probs
:
torch
.
Tensor
,
target_probs
:
torch
.
Tensor
,
reference_probs
:
torch
.
Tensor
,
num_samples
:
int
)
->
Tuple
[
float
,
float
]:
# Sample using rejection sampling.
rej_sample_probs
=
self
.
_estimate_rejection_sampling_pdf
(
draft_probs
,
target_probs
,
num_samples
)
# Average distance from reference probs.
reference_vs_rejsample_dist
=
torch
.
dist
(
reference_probs
,
rej_sample_probs
).
item
()
/
reference_probs
.
shape
[
0
]
target_vs_rejsample_dist
=
torch
.
dist
(
target_probs
,
rej_sample_probs
).
item
()
return
reference_vs_rejsample_dist
,
target_vs_rejsample_dist
def
_estimate_rejection_sampling_pdf
(
self
,
draft_probs
:
torch
.
Tensor
,
target_probs
:
torch
.
Tensor
,
num_samples
:
int
,
)
->
torch
.
Tensor
:
# Repeat draft probs num_samples times.
draft_probs
=
draft_probs
.
reshape
(
1
,
self
.
k
,
self
.
vocab_size
).
repeat
(
num_samples
,
1
,
1
)
# Repeat target probs num_samples * k times.
# Rejection sampler requires bonus token probs, but they aren't used.
target_probs
=
target_probs
.
reshape
(
1
,
1
,
self
.
vocab_size
).
repeat
(
num_samples
,
self
.
k
,
1
)
# Randomly sample draft token ids from draft probs.
draft_token_ids
=
torch
.
multinomial
(
draft_probs
[:,
0
,
:],
num_samples
=
1
,
replacement
=
True
).
reshape
(
num_samples
,
self
.
k
)
# Bonus tokens not used but required.
bonus_token_ids
=
torch
.
zeros
((
1
,
self
.
num_bonus_tokens
),
dtype
=
torch
.
int64
,
device
=
"cuda"
).
repeat
(
num_samples
,
1
)
# Get output tokens via rejection sampling.
output_token_ids
=
self
.
rejection_sampler
(
target_probs
.
to
(
"cuda"
),
bonus_token_ids
.
to
(
"cuda"
),
draft_probs
.
to
(
"cuda"
),
draft_token_ids
.
to
(
"cuda"
))
# Remove bonus tokens
output_token_ids
=
output_token_ids
[:,
:
-
1
].
flatten
()
# Estimate probability density function
hist
=
torch
.
histogram
(
output_token_ids
.
to
(
dtype
=
torch
.
float
,
device
=
"cpu"
),
bins
=
self
.
vocab_size
,
range
=
self
.
vocab_range
,
density
=
True
)
return
hist
.
hist
tests/samplers/test_sampler.py
View file @
ead94d93
...
@@ -4,6 +4,7 @@ from unittest.mock import patch
...
@@ -4,6 +4,7 @@ from unittest.mock import patch
import
pytest
import
pytest
import
torch
import
torch
from
transformers
import
GenerationConfig
,
GenerationMixin
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
...
@@ -74,6 +75,8 @@ def test_sampler_all_greedy(seed: int):
...
@@ -74,6 +75,8 @@ def test_sampler_all_greedy(seed: int):
for
nth_output
in
sequence_output
.
samples
:
for
nth_output
in
sequence_output
.
samples
:
assert
nth_output
.
output_token
==
expected
[
i
].
item
()
assert
nth_output
.
output_token
==
expected
[
i
].
item
()
del
model_runner
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
def
test_sampler_all_random
(
seed
:
int
):
def
test_sampler_all_random
(
seed
:
int
):
...
@@ -110,6 +113,8 @@ def test_sampler_all_random(seed: int):
...
@@ -110,6 +113,8 @@ def test_sampler_all_random(seed: int):
for
nth_output
in
sequence_output
.
samples
:
for
nth_output
in
sequence_output
.
samples
:
assert
nth_output
.
output_token
==
i
assert
nth_output
.
output_token
==
i
del
model_runner
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
def
test_sampler_all_beam
(
seed
:
int
):
def
test_sampler_all_beam
(
seed
:
int
):
...
@@ -143,6 +148,7 @@ def test_sampler_all_beam(seed: int):
...
@@ -143,6 +148,7 @@ def test_sampler_all_beam(seed: int):
# the outputs are expected - in other words, this just tests
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
# whether there are no exceptions in the sampler
# when handling an all-beam search case.
# when handling an all-beam search case.
del
model_runner
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
...
@@ -197,6 +203,8 @@ def test_sampler_mixed(seed: int):
...
@@ -197,6 +203,8 @@ def test_sampler_mixed(seed: int):
for
nth_output
in
sequence_output
.
samples
:
for
nth_output
in
sequence_output
.
samples
:
assert
nth_output
.
output_token
in
expected_tokens
assert
nth_output
.
output_token
in
expected_tokens
del
model_runner
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
def
test_sampler_logits_processors
(
seed
:
int
):
def
test_sampler_logits_processors
(
seed
:
int
):
...
@@ -233,3 +241,69 @@ def test_sampler_logits_processors(seed: int):
...
@@ -233,3 +241,69 @@ def test_sampler_logits_processors(seed: int):
for
_
,
sequence_output
in
enumerate
(
sampler_output
):
for
_
,
sequence_output
in
enumerate
(
sampler_output
):
for
idx
,
nth_output
in
enumerate
(
sequence_output
.
samples
):
for
idx
,
nth_output
in
enumerate
(
sequence_output
.
samples
):
assert
nth_output
.
output_token
==
idx
assert
nth_output
.
output_token
==
idx
del
model_runner
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
def
test_sampler_top_k_top_p
(
seed
:
int
):
set_random_seed
(
seed
)
batch_size
=
random
.
randint
(
1
,
256
)
top_k
=
random
.
randint
(
100
,
500
)
top_p
=
random
.
random
()
*
0.1
vocab_size
=
32000
input_tensor
=
torch
.
rand
((
batch_size
,
1024
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
fake_logits
=
torch
.
normal
(
0
,
5
,
size
=
(
batch_size
,
vocab_size
),
device
=
input_tensor
.
device
,
dtype
=
input_tensor
.
dtype
)
sampler
=
MockLogitsSampler
(
32000
,
fake_logits
)
model_runner
=
ModelRunner
(
None
,
None
,
None
)
generation_model
=
GenerationMixin
()
generation_config
=
GenerationConfig
(
top_k
=
top_k
,
top_p
=
top_p
,
do_sample
=
True
)
warpers
=
generation_model
.
_get_logits_warper
(
generation_config
)
assert
len
(
warpers
)
==
2
# top_p and top_k
seq_group_metadata_list
=
[]
prompt_lens
=
[]
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
([
1
,
2
,
3
])},
sampling_params
=
SamplingParams
(
temperature
=
1
,
top_k
=
top_k
,
top_p
=
top_p
,
),
block_tables
=
{
0
:
[
1
]},
))
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
)
sample_probs
=
None
def
mock_sample
(
probs
,
logprobs
,
sampling_metadata
):
nonlocal
sample_probs
sample_probs
=
probs
return
[[
prob
.
topk
(
1
,
dim
=-
1
).
indices
.
tolist
(),
[
0
]]
for
prob
in
probs
]
with
patch
(
"vllm.model_executor.layers.sampler._sample"
,
mock_sample
):
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
sampling_metadata
=
sampling_metadata
)
hf_probs
=
warpers
(
torch
.
zeros_like
(
fake_logits
),
fake_logits
.
clone
())
hf_probs
=
torch
.
softmax
(
hf_probs
,
dim
=-
1
,
dtype
=
torch
.
float
)
assert
torch
.
allclose
(
hf_probs
,
sample_probs
,
atol
=
1e-5
)
assert
torch
.
equal
(
hf_probs
.
eq
(
0
),
sample_probs
.
eq
(
0
))
del
model_runner
vllm/core/policy.py
View file @
ead94d93
from
typing
import
List
from
collections
import
deque
from
typing
import
Deque
from
vllm.sequence
import
SequenceGroup
from
vllm.sequence
import
SequenceGroup
...
@@ -15,13 +16,14 @@ class Policy:
...
@@ -15,13 +16,14 @@ class Policy:
def
sort_by_priority
(
def
sort_by_priority
(
self
,
self
,
now
:
float
,
now
:
float
,
seq_groups
:
List
[
SequenceGroup
],
seq_groups
:
Deque
[
SequenceGroup
],
)
->
List
[
SequenceGroup
]:
)
->
Deque
[
SequenceGroup
]:
return
sorted
(
return
deque
(
seq_groups
,
sorted
(
key
=
lambda
seq_group
:
self
.
get_priority
(
now
,
seq_group
),
seq_groups
,
reverse
=
True
,
key
=
lambda
seq_group
:
self
.
get_priority
(
now
,
seq_group
),
)
reverse
=
True
,
))
class
FCFS
(
Policy
):
class
FCFS
(
Policy
):
...
...
vllm/core/scheduler.py
View file @
ead94d93
from
collections
import
deque
import
enum
import
enum
import
time
import
time
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Deque
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
vllm.config
import
CacheConfig
,
SchedulerConfig
from
vllm.config
import
CacheConfig
,
SchedulerConfig
from
vllm.core.block_manager
import
AllocStatus
,
BlockSpaceManager
from
vllm.core.block_manager
import
AllocStatus
,
BlockSpaceManager
...
@@ -29,7 +30,7 @@ class SchedulerOutputs:
...
@@ -29,7 +30,7 @@ class SchedulerOutputs:
def
__init__
(
def
__init__
(
self
,
self
,
scheduled_seq_groups
:
List
[
SequenceGroup
],
scheduled_seq_groups
:
Iterable
[
SequenceGroup
],
prompt_run
:
bool
,
prompt_run
:
bool
,
num_batched_tokens
:
int
,
num_batched_tokens
:
int
,
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
...
@@ -75,38 +76,52 @@ class Scheduler:
...
@@ -75,38 +76,52 @@ class Scheduler:
num_cpu_blocks
=
self
.
cache_config
.
num_cpu_blocks
,
num_cpu_blocks
=
self
.
cache_config
.
num_cpu_blocks
,
sliding_window
=
self
.
cache_config
.
sliding_window
)
sliding_window
=
self
.
cache_config
.
sliding_window
)
# TODO(zhuohan): Use deque instead of list for better performance.
# Sequence groups in the WAITING state.
# Sequence groups in the WAITING state.
self
.
waiting
:
List
[
SequenceGroup
]
=
[]
self
.
waiting
:
Deque
[
SequenceGroup
]
=
deque
()
# Sequence groups in the RUNNING state.
# Sequence groups in the RUNNING state.
self
.
running
:
List
[
SequenceGroup
]
=
[]
self
.
running
:
Deque
[
SequenceGroup
]
=
deque
()
# Sequence groups in the SWAPPED state.
# Sequence groups in the SWAPPED state.
self
.
swapped
:
List
[
SequenceGroup
]
=
[]
self
.
swapped
:
Deque
[
SequenceGroup
]
=
deque
()
def
add_seq_group
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
def
add_seq_group
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
# Add sequence groups to the waiting queue.
# Add sequence groups to the waiting queue.
self
.
waiting
.
append
(
seq_group
)
self
.
waiting
.
append
(
seq_group
)
def
abort_seq_group
(
self
,
request_id
:
Union
[
str
,
Iterable
[
str
]])
->
None
:
def
abort_seq_group
(
self
,
request_id
:
Union
[
str
,
Iterable
[
str
]])
->
None
:
"""Aborts a sequence group with the given ID.
Check if the sequence group with the given ID
is present in any of the state queue.
If present, remove the sequence group from the state queue.
Also, if any of the sequences in the sequence group is not finished,
free the sequence with status `FINISHED_ABORTED`.
Otherwise, do nothing.
Args:
request_id: The ID(s) of the sequence group to abort.
"""
if
isinstance
(
request_id
,
str
):
if
isinstance
(
request_id
,
str
):
request_id
=
(
request_id
,
)
request_id
=
(
request_id
,
)
request_ids
=
set
(
request_id
)
request_ids
=
set
(
request_id
)
for
state_queue
in
[
self
.
waiting
,
self
.
running
,
self
.
swapped
]:
for
state_queue
in
[
self
.
waiting
,
self
.
running
,
self
.
swapped
]:
# We need to reverse the list as we are removing elements
aborted_groups
=
[]
# from it as we iterate over it. If we don't do it,
for
seq_group
in
state_queue
:
# indices will get messed up and we will skip over elements.
if
not
request_ids
:
for
seq_group
in
reversed
(
state_queue
):
# Using 'break' here may add two extra iterations,
# but is acceptable to reduce complexity .
break
if
seq_group
.
request_id
in
request_ids
:
if
seq_group
.
request_id
in
request_ids
:
# Remove the sequence group from the state queue.
# Appending aborted group into pending list.
state_queue
.
remove
(
seq_group
)
aborted_groups
.
append
(
seq_group
)
for
seq
in
seq_group
.
get_seqs
():
if
seq
.
is_finished
():
continue
seq
.
status
=
SequenceStatus
.
FINISHED_ABORTED
self
.
free_seq
(
seq
)
request_ids
.
remove
(
seq_group
.
request_id
)
request_ids
.
remove
(
seq_group
.
request_id
)
if
not
request_ids
:
for
aborted_group
in
aborted_groups
:
return
# Remove the sequence group from the state queue.
state_queue
.
remove
(
aborted_group
)
for
seq
in
seq_group
.
get_seqs
():
if
seq
.
is_finished
():
continue
seq
.
status
=
SequenceStatus
.
FINISHED_ABORTED
self
.
free_seq
(
seq
)
def
has_unfinished_seqs
(
self
)
->
bool
:
def
has_unfinished_seqs
(
self
)
->
bool
:
return
self
.
waiting
or
self
.
running
or
self
.
swapped
return
self
.
waiting
or
self
.
running
or
self
.
swapped
...
@@ -152,7 +167,7 @@ class Scheduler:
...
@@ -152,7 +167,7 @@ class Scheduler:
for
seq
in
waiting_seqs
:
for
seq
in
waiting_seqs
:
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
ignored_seq_groups
.
append
(
seq_group
)
ignored_seq_groups
.
append
(
seq_group
)
self
.
waiting
.
pop
(
0
)
self
.
waiting
.
pop
left
()
continue
continue
# If the sequence group cannot be allocated, stop.
# If the sequence group cannot be allocated, stop.
...
@@ -166,7 +181,7 @@ class Scheduler:
...
@@ -166,7 +181,7 @@ class Scheduler:
for
seq
in
waiting_seqs
:
for
seq
in
waiting_seqs
:
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
ignored_seq_groups
.
append
(
seq_group
)
ignored_seq_groups
.
append
(
seq_group
)
self
.
waiting
.
pop
(
0
)
self
.
waiting
.
pop
left
()
continue
continue
# If the number of batched tokens exceeds the limit, stop.
# If the number of batched tokens exceeds the limit, stop.
...
@@ -188,7 +203,7 @@ class Scheduler:
...
@@ -188,7 +203,7 @@ class Scheduler:
break
break
seq_lens
=
new_seq_lens
seq_lens
=
new_seq_lens
seq_group
=
self
.
waiting
.
pop
(
0
)
seq_group
=
self
.
waiting
.
pop
left
()
self
.
_allocate
(
seq_group
)
self
.
_allocate
(
seq_group
)
self
.
running
.
append
(
seq_group
)
self
.
running
.
append
(
seq_group
)
num_curr_seqs
+=
num_new_seqs
num_curr_seqs
+=
num_new_seqs
...
@@ -214,14 +229,14 @@ class Scheduler:
...
@@ -214,14 +229,14 @@ class Scheduler:
self
.
running
=
self
.
policy
.
sort_by_priority
(
now
,
self
.
running
)
self
.
running
=
self
.
policy
.
sort_by_priority
(
now
,
self
.
running
)
# Reserve new token slots for the running sequence groups.
# Reserve new token slots for the running sequence groups.
running
:
List
[
SequenceGroup
]
=
[]
running
:
Deque
[
SequenceGroup
]
=
deque
()
preempted
:
List
[
SequenceGroup
]
=
[]
preempted
:
List
[
SequenceGroup
]
=
[]
while
self
.
running
:
while
self
.
running
:
seq_group
=
self
.
running
.
pop
(
0
)
seq_group
=
self
.
running
.
pop
left
()
while
not
self
.
block_manager
.
can_append_slot
(
seq_group
):
while
not
self
.
block_manager
.
can_append_slot
(
seq_group
):
if
self
.
running
:
if
self
.
running
:
# Preempt the lowest-priority sequence groups.
# Preempt the lowest-priority sequence groups.
victim_seq_group
=
self
.
running
.
pop
(
-
1
)
victim_seq_group
=
self
.
running
.
pop
()
self
.
_preempt
(
victim_seq_group
,
blocks_to_swap_out
)
self
.
_preempt
(
victim_seq_group
,
blocks_to_swap_out
)
preempted
.
append
(
victim_seq_group
)
preempted
.
append
(
victim_seq_group
)
else
:
else
:
...
@@ -255,7 +270,7 @@ class Scheduler:
...
@@ -255,7 +270,7 @@ class Scheduler:
self
.
scheduler_config
.
max_num_seqs
):
self
.
scheduler_config
.
max_num_seqs
):
break
break
seq_group
=
self
.
swapped
.
pop
(
0
)
seq_group
=
self
.
swapped
.
pop
left
()
self
.
_swap_in
(
seq_group
,
blocks_to_swap_in
)
self
.
_swap_in
(
seq_group
,
blocks_to_swap_in
)
self
.
_append_slot
(
seq_group
,
blocks_to_copy
)
self
.
_append_slot
(
seq_group
,
blocks_to_copy
)
num_curr_seqs
+=
num_new_seqs
num_curr_seqs
+=
num_new_seqs
...
@@ -376,7 +391,7 @@ class Scheduler:
...
@@ -376,7 +391,7 @@ class Scheduler:
self
.
block_manager
.
free
(
seq
)
self
.
block_manager
.
free
(
seq
)
# NOTE: For FCFS, we insert the preempted sequence group to the front
# NOTE: For FCFS, we insert the preempted sequence group to the front
# of the waiting queue.
# of the waiting queue.
self
.
waiting
.
insert
(
0
,
seq_group
)
self
.
waiting
.
appendleft
(
seq_group
)
def
_preempt_by_swap
(
def
_preempt_by_swap
(
self
,
self
,
...
...
vllm/engine/async_llm_engine.py
View file @
ead94d93
...
@@ -253,7 +253,8 @@ class AsyncLLMEngine:
...
@@ -253,7 +253,8 @@ class AsyncLLMEngine:
log_requests: Whether to log the requests.
log_requests: Whether to log the requests.
start_engine_loop: If True, the background task to run the engine
start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call.
will be automatically started in the generate call.
*args, *kwargs: Arguments for LLMEngine.
*args: Arguments for LLMEngine.
*kwargs: Arguments for LLMEngine.
"""
"""
_engine_class
:
Type
[
_AsyncLLMEngine
]
=
_AsyncLLMEngine
_engine_class
:
Type
[
_AsyncLLMEngine
]
=
_AsyncLLMEngine
...
@@ -428,6 +429,49 @@ class AsyncLLMEngine:
...
@@ -428,6 +429,49 @@ class AsyncLLMEngine:
Yields:
Yields:
The output `RequestOutput` objects from the LLMEngine for the
The output `RequestOutput` objects from the LLMEngine for the
request.
request.
Details:
- If the engine is not running, start the background loop,
which iteratively invokes
:meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
to process the waiting requests.
- Add the request to the engine's `RequestTracker`.
On the next background loop, this request will be sent to
the underlying engine.
Also, a corresponding `AsyncStream` will be created.
- Wait for the request outputs from `AsyncStream` and yield them.
Example:
>>> # Please refer to entrypoints/api_server.py for
>>> # the complete example.
>>>
>>> # initialize the engine and the example input
>>> engine = AsyncLLMEngine.from_engine_args(engine_args)
>>> example_input = {
>>> "prompt": "What is LLM?",
>>> "stream": False, # assume the non-streaming case
>>> "temperature": 0.0,
>>> "request_id": 0,
>>> }
>>>
>>> # start the generation
>>> results_generator = engine.generate(
>>> example_input["prompt"],
>>> SamplingParams(temperature=example_input["temperature"]),
>>> example_input["request_id"])
>>>
>>> # get the results
>>> final_output = None
>>> async for request_output in results_generator:
>>> if await request.is_disconnected():
>>> # Abort the request if the client disconnects.
>>> await engine.abort(request_id)
>>> # Return or raise an error
>>> ...
>>> final_output = request_output
>>>
>>> # Process and return the final output
>>> ...
"""
"""
# Preprocess the request.
# Preprocess the request.
# This should not be used for logging, as it is monotonic time.
# This should not be used for logging, as it is monotonic time.
...
@@ -506,3 +550,9 @@ class AsyncLLMEngine:
...
@@ -506,3 +550,9 @@ class AsyncLLMEngine:
max_log_len
=
engine_args
.
max_log_len
,
max_log_len
=
engine_args
.
max_log_len
,
start_engine_loop
=
start_engine_loop
)
start_engine_loop
=
start_engine_loop
)
return
engine
return
engine
async
def
do_log_stats
(
self
)
->
None
:
if
self
.
engine_use_ray
:
await
self
.
engine
.
do_log_stats
.
remote
()
else
:
self
.
engine
.
do_log_stats
()
vllm/engine/llm_engine.py
View file @
ead94d93
...
@@ -257,7 +257,26 @@ class LLMEngine:
...
@@ -257,7 +257,26 @@ class LLMEngine:
self
.
cache_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
self
.
cache_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
def
_init_cache
(
self
)
->
None
:
def
_init_cache
(
self
)
->
None
:
"""Profiles the memory usage and initializes the KV cache."""
"""Profiles the memory usage and initializes the KV cache.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
More details can be found in the
:meth:`~vllm.worker.worker.Worker.profile_num_available_blocks` method
from class :class:`~vllm.worker.Worker`.
Afterwards, as there may be multiple workers,
we take the minimum number of blocks across all workers
to ensure this can be applied to all of them.
Finally, the engine will initialize the KV cache
with the calculated number of blocks.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameters.
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks
=
self
.
_run_workers
(
num_blocks
=
self
.
_run_workers
(
"profile_num_available_blocks"
,
"profile_num_available_blocks"
,
...
@@ -334,6 +353,30 @@ class LLMEngine:
...
@@ -334,6 +353,30 @@ class LLMEngine:
use the tokenizer to convert the prompts to token IDs.
use the tokenizer to convert the prompts to token IDs.
arrival_time: The arrival time of the request. If None, we use
arrival_time: The arrival time of the request. If None, we use
the current monotonic time.
the current monotonic time.
Details:
- Set arrival_time to the current time if it is None.
- Set prompt_token_ids to the encoded prompt if it is None.
- Create `best_of` number of :class:`~vllm.Sequence` objects.
- Create a :class:`~vllm.SequenceGroup` object
from the list of :class:`~vllm.Sequence`.
- Add the :class:`~vllm.SequenceGroup` object to the scheduler.
Example:
>>> # initialize engine
>>> engine = LLMEngine.from_engine_args(engine_args)
>>> # set request arguments
>>> example_prompt = "Who is the president of the United States?"
>>> sampling_params = SamplingParams(temperature=0.0)
>>> request_id = 0
>>>
>>> # add the request to the engine
>>> engine.add_request(
>>> str(request_id),
>>> example_prompt,
>>> SamplingParams(temperature=0.0))
>>> # continue the request processing
>>> ...
"""
"""
if
arrival_time
is
None
:
if
arrival_time
is
None
:
arrival_time
=
time
.
monotonic
()
arrival_time
=
time
.
monotonic
()
...
@@ -358,6 +401,17 @@ class LLMEngine:
...
@@ -358,6 +401,17 @@ class LLMEngine:
Args:
Args:
request_id: The ID(s) of the request to abort.
request_id: The ID(s) of the request to abort.
Details:
- Refer to the
:meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
from class :class:`~vllm.core.scheduler.Scheduler`.
Example:
>>> # initialize engine and add a request with request_id
>>> request_id = str(0)
>>> # abort the request
>>> engine.abort_request(request_id)
"""
"""
self
.
scheduler
.
abort_seq_group
(
request_id
)
self
.
scheduler
.
abort_seq_group
(
request_id
)
...
@@ -601,8 +655,10 @@ class LLMEngine:
...
@@ -601,8 +655,10 @@ class LLMEngine:
# Create the outputs.
# Create the outputs.
request_outputs
:
List
[
RequestOutput
]
=
[]
request_outputs
:
List
[
RequestOutput
]
=
[]
for
seq_group
in
(
scheduled_seq_groups
+
for
seq_group
in
scheduled_seq_groups
:
scheduler_outputs
.
ignored_seq_groups
):
request_output
=
RequestOutput
.
from_seq_group
(
seq_group
)
request_outputs
.
append
(
request_output
)
for
seq_group
in
scheduler_outputs
.
ignored_seq_groups
:
request_output
=
RequestOutput
.
from_seq_group
(
seq_group
)
request_output
=
RequestOutput
.
from_seq_group
(
seq_group
)
request_outputs
.
append
(
request_output
)
request_outputs
.
append
(
request_output
)
...
@@ -615,11 +671,53 @@ class LLMEngine:
...
@@ -615,11 +671,53 @@ class LLMEngine:
def
step
(
self
)
->
List
[
RequestOutput
]:
def
step
(
self
)
->
List
[
RequestOutput
]:
"""Performs one decoding iteration and returns newly generated results.
"""Performs one decoding iteration and returns newly generated results.
This function performs one decoding iteration of the engine. It first
.. figure:: https://i.imgur.com/sv2HssD.png
schedules the sequences to be executed in the next iteration and the
:alt: Overview of the step function
token blocks to be swapped in/out/copy. Then, it executes the model
:align: center
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
Overview of the step function.
Details:
- Step 1: Schedules the sequences to be executed in the next
iteration and the token blocks to be swapped in/out/copy.
- Depending on the scheduling policy,
sequences may be `preempted/reordered`.
- A Sequence Group (SG) refer to a group of sequences
that are generated from the same prompt.
- Step 2: Calls the workers to execute the model.
- Step 3: Processes the model output. This mainly includes:
- Decodes the relevant outputs.
- Updates the scheduled sequence groups with model outputs
based on its `sampling parameters` (`use_beam_search` or not).
- Frees the finished sequence groups.
- Finally, it creates and returns the newly generated results.
Example:
>>> # Please see the example/ folder for more detailed examples.
>>>
>>> # initialize engine and request arguments
>>> engine = LLMEngine.from_engine_args(engine_args)
>>> example_inputs = [(0, "What is LLM?",
>>> SamplingParams(temperature=0.0))]
>>>
>>> # Start the engine with an event loop
>>> while True:
>>> if example_inputs:
>>> req_id, prompt, sampling_params = example_inputs.pop(0)
>>> engine.add_request(str(req_id), prompt, sampling_params)
>>>
>>> # continue the request processing
>>> request_outputs = engine.step()
>>> for request_output in request_outputs:
>>> if request_output.finished:
>>> # return or show the request output
>>>
>>> if not (engine.has_unfinished_requests() or example_inputs):
>>> break
"""
"""
seq_group_metadata_list
,
scheduler_outputs
=
self
.
scheduler
.
schedule
()
seq_group_metadata_list
,
scheduler_outputs
=
self
.
scheduler
.
schedule
()
...
@@ -641,6 +739,9 @@ class LLMEngine:
...
@@ -641,6 +739,9 @@ class LLMEngine:
return
self
.
_process_model_outputs
(
output
,
scheduler_outputs
)
return
self
.
_process_model_outputs
(
output
,
scheduler_outputs
)
def
do_log_stats
(
self
)
->
None
:
self
.
_log_system_stats
(
False
,
0
)
def
_log_system_stats
(
def
_log_system_stats
(
self
,
self
,
prompt_run
:
bool
,
prompt_run
:
bool
,
...
...
vllm/engine/ray_utils.py
View file @
ead94d93
...
@@ -55,7 +55,7 @@ def initialize_cluster(
...
@@ -55,7 +55,7 @@ def initialize_cluster(
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
engine_use_ray
:
bool
=
False
,
engine_use_ray
:
bool
=
False
,
ray_address
:
Optional
[
str
]
=
None
,
ray_address
:
Optional
[
str
]
=
None
,
)
->
Tuple
[
str
,
Optional
[
"PlacementGroup"
]
]
:
)
->
Optional
[
"PlacementGroup"
]:
"""Initialize the distributed cluster probably with Ray.
"""Initialize the distributed cluster probably with Ray.
Args:
Args:
...
...
vllm/entrypoints/api_server.py
View file @
ead94d93
...
@@ -74,12 +74,18 @@ if __name__ == "__main__":
...
@@ -74,12 +74,18 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
)
parser
.
add_argument
(
"--ssl-keyfile"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--ssl-keyfile"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--ssl-certfile"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--ssl-certfile"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--root-path"
,
type
=
str
,
default
=
None
,
help
=
"FastAPI root_path when app is behind a path based routing proxy"
)
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
engine
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
)
engine
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
)
app
.
root_path
=
args
.
root_path
uvicorn
.
run
(
app
,
uvicorn
.
run
(
app
,
host
=
args
.
host
,
host
=
args
.
host
,
port
=
args
.
port
,
port
=
args
.
port
,
...
...
vllm/entrypoints/openai/api_server.py
View file @
ead94d93
...
@@ -6,6 +6,7 @@ import asyncio
...
@@ -6,6 +6,7 @@ import asyncio
import
codecs
import
codecs
import
json
import
json
import
time
import
time
from
contextlib
import
asynccontextmanager
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
typing
import
AsyncGenerator
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
AsyncGenerator
,
Dict
,
List
,
Optional
,
Tuple
,
Union
...
@@ -38,11 +39,28 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds
...
@@ -38,11 +39,28 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
served_model
=
None
served_model
=
None
app
=
fastapi
.
FastAPI
()
engine_args
=
None
engine
=
None
engine
=
None
response_role
=
None
response_role
=
None
@
asynccontextmanager
async
def
lifespan
(
app
:
fastapi
.
FastAPI
):
async
def
_force_log
():
while
True
:
await
asyncio
.
sleep
(
10
)
await
engine
.
do_log_stats
()
if
not
engine_args
.
disable_log_stats
:
asyncio
.
create_task
(
_force_log
())
yield
app
=
fastapi
.
FastAPI
(
lifespan
=
lifespan
)
def
parse_args
():
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
description
=
"vLLM OpenAI-Compatible RESTful API server."
)
description
=
"vLLM OpenAI-Compatible RESTful API server."
)
...
@@ -88,6 +106,11 @@ def parse_args():
...
@@ -88,6 +106,11 @@ def parse_args():
type
=
str
,
type
=
str
,
default
=
None
,
default
=
None
,
help
=
"The file path to the SSL cert file"
)
help
=
"The file path to the SSL cert file"
)
parser
.
add_argument
(
"--root-path"
,
type
=
str
,
default
=
None
,
help
=
"FastAPI root_path when app is behind a path based routing proxy"
)
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
...
@@ -748,6 +771,7 @@ if __name__ == "__main__":
...
@@ -748,6 +771,7 @@ if __name__ == "__main__":
# Register labels for metrics
# Register labels for metrics
add_global_metrics_labels
(
model_name
=
engine_args
.
model
)
add_global_metrics_labels
(
model_name
=
engine_args
.
model
)
app
.
root_path
=
args
.
root_path
uvicorn
.
run
(
app
,
uvicorn
.
run
(
app
,
host
=
args
.
host
,
host
=
args
.
host
,
port
=
args
.
port
,
port
=
args
.
port
,
...
...
vllm/model_executor/layers/attention.py
View file @
ead94d93
...
@@ -156,20 +156,15 @@ class PagedAttention(nn.Module):
...
@@ -156,20 +156,15 @@ class PagedAttention(nn.Module):
output
=
out
.
view_as
(
query
)
output
=
out
.
view_as
(
query
)
else
:
else
:
# Decoding run.
# Decoding run.
if
key_cache
is
not
None
and
value_cache
is
not
None
:
output
=
_paged_attention
(
output
=
_paged_attention
(
query
,
query
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
input_metadata
,
input_metadata
,
self
.
num_kv_heads
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
scale
,
self
.
alibi_slopes
,
self
.
alibi_slopes
,
)
)
else
:
# This happens during the initial memory profiling run for
# CUDA graphs.
output
=
torch
.
zeros_like
(
query
)
# Reshape the output tensor.
# Reshape the output tensor.
return
output
.
view
(
batch_size
,
seq_len
,
hidden_size
)
return
output
.
view
(
batch_size
,
seq_len
,
hidden_size
)
...
...
vllm/model_executor/layers/linear.py
View file @
ead94d93
...
@@ -423,7 +423,10 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -423,7 +423,10 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_offset
=
shard_offset
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
shard_size
)
shard_id
=
tp_rank
//
self
.
num_kv_head_replicas
if
loaded_shard_id
==
"q"
:
shard_id
=
tp_rank
else
:
shard_id
=
tp_rank
//
self
.
num_kv_head_replicas
start_idx
=
shard_id
*
shard_size
start_idx
=
shard_id
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
shard_size
)
...
...
vllm/model_executor/layers/rejection_sampler.py
0 → 100644
View file @
ead94d93
from
typing
import
Tuple
,
Optional
from
functools
import
cached_property
import
torch
import
torch.nn
as
nn
import
torch.jit
class
RejectionSampler
(
nn
.
Module
):
"""Apply modified rejection sampling as described in "Accelerating Large
Language Model Decoding with Speculative Sampling"
https://arxiv.org/pdf/2302.01318.pdf.
"""
def
__init__
(
self
,
strict_mode
:
bool
=
False
):
"""Create a rejection sampler.
Args:
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
"""
super
().
__init__
()
self
.
probs_dtype
=
torch
.
float32
self
.
token_id_dtype
=
torch
.
int64
self
.
_strict_mode
=
strict_mode
# NOTE: A "bonus token" is accepted iff all proposal tokens are
# accepted. There is always only one possible bonus token. We store this
# value in a variable for readability.
self
.
_num_bonus_tokens
=
1
self
.
num_accepted_tokens
:
Optional
[
torch
.
Tensor
]
=
None
self
.
num_emitted_tokens
:
Optional
[
torch
.
Tensor
]
=
None
self
.
num_draft_tokens
:
int
=
0
def
init_gpu_tensors
(
self
,
rank
:
int
)
->
None
:
assert
self
.
num_accepted_tokens
is
None
device
=
f
"cuda:
{
rank
}
"
self
.
num_accepted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
device
)
self
.
num_emitted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
device
)
def
forward
(
self
,
target_probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Sample token ids using rejection sampling. This accepts or rejects
tokens proposed by the draft model using the probability of each token
according to the draft and target models.
In the worst case where all draft tokens are rejected, it is guaranteed
one correct token will be emitted.
In the case where all draft tokens are accepted, a bonus token will be
accepted as its cheap to have the target model score this speculative
sequence.
Args:
target_probs: The probability distribution over token ids given
context according to the target model.
shape = [batch_size, num_speculative_tokens, vocab_size]
bonus_token_ids: The "bonus" token ids that are accepted iff all
speculative tokens in a sequence are accepted.
shape = [batch_size, num_bonus_tokens]
draft_probs: The probability distribution over token ids given
context according to the draft model.
shape = [batch_size, num_speculative_tokens, vocab_size]
draft_token_ids: The token ids that were sampled from the draft
probabilities.
shape = [batch_size, num_speculative_tokens]
Returns:
output_token_ids: The token ids sampled via rejection sampling,
or -1 if unable to sample a token because the previous token
was rejected.
shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
"""
# Only perform shape/dtype/device checking in strict mode, as it adds
# overhead.
if
self
.
_strict_mode
:
self
.
_raise_if_incorrect_shape
(
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
)
self
.
_raise_if_incorrect_dtype
(
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
)
self
.
_raise_if_inconsistent_device
(
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
)
self
.
_raise_if_out_of_bounds_vocab
(
target_probs
.
shape
[
-
1
],
bonus_token_ids
,
draft_token_ids
)
accepted
,
recovered_token_ids
=
self
.
_batch_modified_rejection_sampling
(
target_probs
,
draft_probs
,
draft_token_ids
,
)
output_token_ids
=
self
.
_create_output
(
accepted
,
recovered_token_ids
,
draft_token_ids
,
bonus_token_ids
,
)
return
output_token_ids
def
_batch_modified_rejection_sampling
(
self
,
target_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Perform modified rejection sampling on each sequence.
Returns:
A tuple of two tensors:
0: A bool tensor of which tokens in each sequence is accepted.
shape = [batch_size, k]
1: Token ids sampled from a recovered distribution, to be used
when a token is rejected.
shape = [batch_size, k]
"""
batch_size
,
k
,
vocab_size
=
draft_probs
.
shape
# shape [batch_size, k]
accepted
=
self
.
_get_accepted
(
target_probs
,
draft_probs
,
draft_token_ids
)
recovered_probs
=
self
.
_get_recovered_probs
(
target_probs
,
draft_probs
).
reshape
(
batch_size
*
k
,
vocab_size
)
recovered_token_ids
=
_multinomial
(
recovered_probs
,
num_samples
=
1
).
reshape
(
batch_size
,
k
)
return
accepted
,
recovered_token_ids
def
_get_accepted
(
self
,
target_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
)
->
torch
.
Tensor
:
r
"""Create bool matrix over the proposed draft tokens. If
True, then a token can be accepted, else it should be
rejected.
Given :math:`q(\hat{x}_{n+1}|x_1, \dots, x_n)`, the probability of
:math:`\hat{x}_{n+1}` given context :math:`x_1, \dots, x_n` according
to the target model, and :math:`p(\hat{x}_{n+1}|x_1, \dots, x_n)`, the
same conditional probability according to the draft model, the token
is accepted with probability:
.. math::
\min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)}
{p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right)
This implementation does not apply causality. When using the output,
if a token is rejected, subsequent tokens should not be used.
Returns a bool tensor of shape [batch_size, k] specifying which tokens
are accepted.
"""
batch_size
,
k
,
_
=
draft_probs
.
shape
batch_indices
=
torch
.
arange
(
batch_size
,
device
=
target_probs
.
device
)[:,
None
]
probs_indicies
=
torch
.
arange
(
k
,
device
=
target_probs
.
device
)
# shape [batch_size, k]
selected_draft_probs
=
draft_probs
[
batch_indices
,
probs_indicies
,
draft_token_ids
]
# shape [batch_size, k]
selected_target_probs
=
target_probs
[
batch_indices
,
probs_indicies
,
draft_token_ids
]
uniform_rand
=
torch
.
rand
(
batch_size
,
k
,
dtype
=
self
.
probs_dtype
,
device
=
target_probs
.
device
)
capped_ratio
=
torch
.
minimum
(
selected_target_probs
/
selected_draft_probs
,
torch
.
full
((
1
,
),
1
,
device
=
target_probs
.
device
))
accepted
=
uniform_rand
<
capped_ratio
return
accepted
def
_get_recovered_probs
(
self
,
target_probs
:
torch
.
Tensor
,
# [k, vocab_size]
draft_probs
:
torch
.
Tensor
,
# [k, vocab_size]
)
->
torch
.
Tensor
:
r
"""Create a probability distribution for each proposed token which can
be sampled if the proposed token is rejected.
When this routine is applied sequentially, the true distribution of the
target model is recovered (within hardware numerics).
The probability distribution used in this rejection case is constructed
as follows. Given :math:`q(x|x_1, \dots, x_n)`, the probability of
:math:`x` given context :math:`x_1, \dots, x_n` according to the target
model and :math:`p(x|x_1, \dots, x_n)`, the same conditional probability
according to the draft model:
.. math::
x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+
where :math:`(f(x))_+` is defined as:
.. math::
(f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))}
See https://github.com/vllm-project/vllm/pull/2336 for a visualization
of the draft, target, and recovered probability distributions.
Returns a tensor of shape [batch_size, k, vocab_size].
Note: This batches operations on GPU and thus constructs the recovered
distribution for all tokens, even if they are accepted. This causes
division-by-zero errors, so we use self._smallest_positive_value to
avoid that. This introduces some drift to the distribution.
"""
_
,
k
,
_
=
draft_probs
.
shape
# shape [batch_size, k, vocab_size]
difference
=
target_probs
-
draft_probs
# TODO(cade): Can we use logprobs instead of probs, and avoid the
# division-by-zero errors without introducing distribution drift?
# shape [batch_size, k, vocab_size]
f
=
torch
.
clamp
(
difference
,
min
=
self
.
_smallest_positive_value
)
# shape [batch_size, k, vocab_size]
recovered_probs
=
f
/
torch
.
sum
(
f
,
dim
=-
1
).
reshape
(
-
1
,
k
,
1
)
return
recovered_probs
@
cached_property
def
_smallest_positive_value
(
self
)
->
float
:
"""Return the smallest positive value representable by the probs dtype.
This value is used when constructing a distribution from which to sample
recovered tokens in the first rejection case.
See _get_recovered_probs for more details
Note that this isn't actually the smallest positive value representable
by float32, but the smallest positive normal value.
See https://en.wikipedia.org/wiki/Subnormal_number for more information.
"""
return
torch
.
finfo
(
self
.
probs_dtype
).
tiny
def
_create_output
(
self
,
accepted
:
torch
.
Tensor
,
# [batch_size, k]
recovered_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
bonus_token_ids
:
torch
.
Tensor
,
# [batch_size]
)
->
torch
.
Tensor
:
"""Format output. Returns a matrix of token ids. When
a token is rejected via rejection sampling, all subsequent
token ids are set to -1 for the sequence.
shape = [batch_size, k + num_bonus_tokens]
"""
bonus_token_ids
=
bonus_token_ids
.
squeeze
()
batch_size
,
k
=
recovered_token_ids
.
shape
# Determine the index of the first False value for each row.
limits
=
(
accepted
==
0
).
max
(
1
).
indices
limits
[
~
(
accepted
==
0
).
any
(
1
)]
=
k
# Create masks using the indices.
indices
=
torch
.
arange
(
k
,
device
=
accepted
.
device
).
unsqueeze
(
0
)
accepted_mask
=
indices
<
limits
.
unsqueeze
(
1
)
after_false_mask
=
indices
==
limits
.
unsqueeze
(
1
)
# Create an extended output tensor
output_with_bonus_tokens
=
-
torch
.
ones
(
(
batch_size
,
k
+
self
.
_num_bonus_tokens
),
dtype
=
self
.
token_id_dtype
,
device
=
accepted
.
device
)
output
=
output_with_bonus_tokens
[:,
:
k
]
# Fill in the first k columns of the output tensor using masks and data
# tensors.
output
[:,
:
k
]
=
torch
.
where
(
accepted_mask
,
draft_token_ids
,
-
torch
.
ones_like
(
draft_token_ids
))
# Fill the last column.
# We check output directly as accepted may have True values inconsistent
# with causal acceptance.
output_with_bonus_tokens
[:,
-
1
]
=
torch
.
where
(
output
[:,
-
1
]
!=
-
1
,
bonus_token_ids
,
-
1
)
# Fill the recovered token ids.
output
.
mul_
(
~
after_false_mask
).
add_
(
recovered_token_ids
.
mul
(
after_false_mask
))
self
.
num_accepted_tokens
+=
accepted
.
sum
()
self
.
num_emitted_tokens
+=
(
output_with_bonus_tokens
!=
-
1
).
sum
()
self
.
num_draft_tokens
+=
batch_size
*
k
return
output_with_bonus_tokens
def
_raise_if_incorrect_shape
(
self
,
target_probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
)
->
None
:
(
target_batch_size
,
num_target_probs
,
target_vocab_size
)
=
target_probs
.
shape
bonus_batch_size
,
num_bonus_tokens
=
bonus_token_ids
.
shape
draft_batch_size
,
num_draft_probs
,
draft_vocab_size
=
draft_probs
.
shape
draft_token_ids_batch_size
,
num_draft_token_ids
=
draft_token_ids
.
shape
assert
draft_batch_size
==
target_batch_size
assert
num_draft_probs
==
num_target_probs
assert
(
draft_vocab_size
==
target_vocab_size
),
f
"
{
draft_vocab_size
=
}
{
target_vocab_size
=
}
"
assert
draft_token_ids_batch_size
==
draft_batch_size
assert
num_draft_token_ids
==
num_draft_probs
assert
bonus_batch_size
==
target_batch_size
assert
num_bonus_tokens
==
self
.
_num_bonus_tokens
def
_raise_if_incorrect_dtype
(
self
,
target_probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
)
->
None
:
assert
all
(
probs
.
dtype
==
self
.
probs_dtype
for
probs
in
[
target_probs
,
draft_probs
])
assert
all
(
token_ids
.
dtype
==
self
.
token_id_dtype
for
token_ids
in
[
bonus_token_ids
,
draft_token_ids
])
def
_raise_if_inconsistent_device
(
self
,
target_probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
)
->
None
:
devices
=
[
t
.
device
for
t
in
[
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
]
]
assert
all
([
devices
[
0
]
==
device
for
device
in
devices
])
def
_raise_if_out_of_bounds_vocab
(
self
,
vocab_size
:
int
,
bonus_token_ids
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
)
->
None
:
assert
torch
.
all
(
bonus_token_ids
<
vocab_size
)
assert
torch
.
all
(
bonus_token_ids
>=
0
)
assert
torch
.
all
(
draft_token_ids
<
vocab_size
)
assert
torch
.
all
(
draft_token_ids
>=
0
)
# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead that skips the sync.
# Note that we always sample with replacement.
# probs will be modified in place, but this is fine, as we pass
# in a copy already.
@
torch
.
jit
.
script
def
_multinomial
(
probs
:
torch
.
Tensor
,
num_samples
:
int
,
)
->
torch
.
Tensor
:
if
num_samples
>
1
:
# This is equivalent to torch.repeat_interleaved (which also
# forces a GPU<->CPU sync).
probs
=
probs
[:,
None
,
:].
expand
(
probs
.
shape
[
0
],
num_samples
,
probs
.
shape
[
1
]).
contiguous
().
view
(
-
1
,
probs
.
shape
[
1
])
q
=
torch
.
empty_like
(
probs
).
exponential_
(
1.0
)
return
probs
.
div_
(
q
).
argmax
(
dim
=
1
).
view
(
-
1
,
num_samples
)
vllm/model_executor/layers/sampler.py
View file @
ead94d93
...
@@ -76,7 +76,7 @@ class Sampler(nn.Module):
...
@@ -76,7 +76,7 @@ class Sampler(nn.Module):
logits
.
div_
(
sampling_tensors
.
temperatures
.
unsqueeze_
(
dim
=
1
))
logits
.
div_
(
sampling_tensors
.
temperatures
.
unsqueeze_
(
dim
=
1
))
if
do_top_p_top_k
:
if
do_top_p_top_k
:
logits
=
_apply_top_
p
_top_
k
(
logits
,
sampling_tensors
.
top_ps
,
logits
=
_apply_top_
k
_top_
p
(
logits
,
sampling_tensors
.
top_ps
,
sampling_tensors
.
top_ks
)
sampling_tensors
.
top_ks
)
if
do_min_p
:
if
do_min_p
:
...
@@ -185,27 +185,27 @@ def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
...
@@ -185,27 +185,27 @@ def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
return
logits
return
logits
def
_apply_top_
p
_top_
k
(
def
_apply_top_
k
_top_
p
(
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
logits_sort
,
logits_idx
=
logits
.
sort
(
dim
=-
1
,
descending
=
True
)
logits_sort
,
logits_idx
=
logits
.
sort
(
dim
=-
1
,
descending
=
False
)
# Apply top-k.
top_k_mask
=
logits_sort
.
size
(
1
)
-
k
.
to
(
torch
.
long
)
# Get all the top_k values.
top_k_mask
=
logits_sort
.
gather
(
1
,
top_k_mask
.
unsqueeze
(
dim
=
1
))
top_k_mask
=
logits_sort
<
top_k_mask
logits_sort
.
masked_fill_
(
top_k_mask
,
-
float
(
"inf"
))
# Apply top-p.
# Apply top-p.
probs_sort
=
logits_sort
.
softmax
(
dim
=-
1
)
probs_sort
=
logits_sort
.
softmax
(
dim
=-
1
)
probs_sum
=
probs_sort
.
cumsum
(
dim
=-
1
).
sub_
(
probs_sort
)
probs_sum
=
probs_sort
.
cumsum
(
dim
=-
1
)
top_p_mask
=
probs_sum
>
p
.
unsqueeze_
(
dim
=
1
)
top_p_mask
=
probs_sum
<=
1
-
p
.
unsqueeze
(
dim
=
1
)
# at least one
# Apply top-k.
top_p_mask
[:,
-
1
]
=
False
# Create a mask for the top-k elements.
logits_sort
.
masked_fill_
(
top_p_mask
,
-
float
(
"inf"
))
top_k_mask
=
torch
.
arange
(
logits_idx
.
shape
[
-
1
],
device
=
logits_idx
.
device
)
top_k_mask
=
top_k_mask
.
expand
(
logits_idx
.
shape
[
0
],
-
1
)
top_k_mask
=
top_k_mask
>=
k
.
unsqueeze_
(
dim
=
1
)
# Final mask.
mask
=
(
top_p_mask
|
top_k_mask
)
logits_sort
.
masked_fill_
(
mask
,
-
float
(
"inf"
))
# Re-sort the probabilities.
# Re-sort the probabilities.
src
=
torch
.
arange
(
logits_idx
.
shape
[
-
1
],
src
=
torch
.
arange
(
logits_idx
.
shape
[
-
1
],
...
...
vllm/model_executor/models/__init__.py
View file @
ead94d93
...
@@ -33,7 +33,7 @@ _MODELS = {
...
@@ -33,7 +33,7 @@ _MODELS = {
"MptForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MptForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MPTForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MPTForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"OPTForCausalLM"
:
(
"opt"
,
"OPTForCausalLM"
),
"OPTForCausalLM"
:
(
"opt"
,
"OPTForCausalLM"
),
"PhiForCausalLM"
:
(
"phi
_1_5
"
,
"PhiForCausalLM"
),
"PhiForCausalLM"
:
(
"phi"
,
"PhiForCausalLM"
),
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
"RWForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"RWForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"YiForCausalLM"
:
(
"yi"
,
"YiForCausalLM"
),
"YiForCausalLM"
:
(
"yi"
,
"YiForCausalLM"
),
...
...
vllm/model_executor/models/phi
_1_5
.py
→
vllm/model_executor/models/phi.py
View file @
ead94d93
...
@@ -62,20 +62,6 @@ from vllm.sequence import SamplerOutput
...
@@ -62,20 +62,6 @@ from vllm.sequence import SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
PhiEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
):
super
().
__init__
()
self
.
wte
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
):
return
self
.
wte
(
input_ids
)
class
PhiAttention
(
nn
.
Module
):
class
PhiAttention
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -93,27 +79,22 @@ class PhiAttention(nn.Module):
...
@@ -93,27 +79,22 @@ class PhiAttention(nn.Module):
tensor_model_parallel_world_size
)
tensor_model_parallel_world_size
)
# pylint: disable=C0103
# pylint: disable=C0103
self
.
Wqkv
=
QKVParallelLinear
(
self
.
hidden_size
,
self
.
head_size
,
self
.
total_num_heads
,
linear_method
=
linear_method
,
)
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
qkv_proj
=
QKVParallelLinear
(
config
.
hidden_size
,
self
.
hidden_size
,
self
.
head_size
,
self
.
head_size
,
self
.
total_num_heads
,
self
.
total_num_heads
,
bias
=
Fals
e
,
bias
=
Tru
e
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
out_proj
=
RowParallelLinear
(
self
.
dense
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
self
.
hidden_size
,
self
.
hidden_size
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
scaling
=
self
.
head_size
**-
0.5
scaling
=
self
.
head_size
**-
0.5
rotary_dim
=
config
.
rotary_dim
rotary_dim
=
int
(
config
.
partial_rotary_factor
*
(
config
.
hidden_size
//
config
.
num_attention_heads
))
assert
rotary_dim
%
2
==
0
assert
rotary_dim
%
2
==
0
# pylint: disable=C0301
# pylint: disable=C0301
...
@@ -136,12 +117,12 @@ class PhiAttention(nn.Module):
...
@@ -136,12 +117,12 @@ class PhiAttention(nn.Module):
kv_cache
:
KVCache
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
W
qkv
(
hidden_states
)
qkv
,
_
=
self
.
qkv
_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
output
,
_
=
self
.
out_proj
(
attn_output
)
output
,
_
=
self
.
dense
(
attn_output
)
return
output
return
output
...
@@ -166,8 +147,7 @@ class PhiMLP(nn.Module):
...
@@ -166,8 +147,7 @@ class PhiMLP(nn.Module):
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
self
.
act
=
get_act_fn
(
config
.
hidden_act
,
quant_config
,
n_inner
)
n_inner
)
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
hidden_states
,
_
=
self
.
fc1
(
hidden_states
)
hidden_states
,
_
=
self
.
fc1
(
hidden_states
)
...
@@ -182,9 +162,9 @@ class PhiLayer(nn.Module):
...
@@ -182,9 +162,9 @@ class PhiLayer(nn.Module):
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
ln
=
nn
.
LayerNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
ilon
)
eps
=
config
.
layer_norm_eps
)
self
.
mixer
=
PhiAttention
(
config
,
linear_method
)
self
.
self_attn
=
PhiAttention
(
config
,
linear_method
)
self
.
mlp
=
PhiMLP
(
config
,
linear_method
)
self
.
mlp
=
PhiMLP
(
config
,
linear_method
)
def
forward
(
def
forward
(
...
@@ -195,8 +175,8 @@ class PhiLayer(nn.Module):
...
@@ -195,8 +175,8 @@ class PhiLayer(nn.Module):
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
ln
(
hidden_states
)
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
attn_outputs
=
self
.
mixer
(
attn_outputs
=
self
.
self_attn
(
position_ids
=
position_ids
,
position_ids
=
position_ids
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
kv_cache
=
kv_cache
,
...
@@ -215,11 +195,14 @@ class PhiModel(nn.Module):
...
@@ -215,11 +195,14 @@ class PhiModel(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
self
.
embd
=
PhiEmbedding
(
config
)
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
self
.
h
=
nn
.
ModuleList
([
config
.
hidden_size
)
self
.
layers
=
nn
.
ModuleList
([
PhiLayer
(
config
,
linear_method
)
PhiLayer
(
config
,
linear_method
)
for
_
in
range
(
config
.
num_hidden_layers
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
])
self
.
final_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -228,27 +211,19 @@ class PhiModel(nn.Module):
...
@@ -228,27 +211,19 @@ class PhiModel(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
emb
d
(
input_ids
)
hidden_states
=
self
.
emb
ed_tokens
(
input_ids
)
for
i
in
range
(
self
.
config
.
num_hidden_layers
):
for
i
in
range
(
self
.
config
.
num_hidden_layers
):
layer
=
self
.
h
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
hidden_states
=
layer
(
positions
,
positions
,
hidden_states
,
hidden_states
,
kv_caches
[
i
],
kv_caches
[
i
],
input_metadata
,
input_metadata
,
)
)
return
hidden_states
class
PhiCausalLMHead
(
nn
.
Module
):
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
def
__init__
(
self
,
config
:
PretrainedConfig
):
return
hidden_states
super
().
__init__
()
self
.
ln
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
linear
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
bias
=
True
)
class
PhiForCausalLM
(
nn
.
Module
):
class
PhiForCausalLM
(
nn
.
Module
):
...
@@ -260,8 +235,11 @@ class PhiForCausalLM(nn.Module):
...
@@ -260,8 +235,11 @@ class PhiForCausalLM(nn.Module):
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
self
.
transformer
=
PhiModel
(
config
,
linear_method
)
self
.
model
=
PhiModel
(
config
,
linear_method
)
self
.
lm_head
=
PhiCausalLMHead
(
config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
bias
=
True
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
def
forward
(
def
forward
(
...
@@ -271,9 +249,9 @@ class PhiForCausalLM(nn.Module):
...
@@ -271,9 +249,9 @@ class PhiForCausalLM(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
)
input_metadata
)
hidden_states
=
self
.
lm_head
.
ln
(
hidden_states
)
return
hidden_states
return
hidden_states
def
sample
(
def
sample
(
...
@@ -281,7 +259,7 @@ class PhiForCausalLM(nn.Module):
...
@@ -281,7 +259,7 @@ class PhiForCausalLM(nn.Module):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
head
=
self
.
lm_head
.
linear
head
=
self
.
lm_head
next_tokens
=
self
.
sampler
(
head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
head
.
weight
,
hidden_states
,
sampling_metadata
,
head
.
bias
)
sampling_metadata
,
head
.
bias
)
return
next_tokens
return
next_tokens
...
@@ -291,17 +269,37 @@ class PhiForCausalLM(nn.Module):
...
@@ -291,17 +269,37 @@ class PhiForCausalLM(nn.Module):
cache_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
):
revision
:
Optional
[
str
]
=
None
):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
)
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
# Skip loading extra bias for GPTQ models.
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
weight_name
not
in
name
:
continue
continue
# pylint: disable=E1136
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
# Skip loading extra bias for GPTQ models.
weight_loader
=
getattr
(
param
,
"weight_loader"
,
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
default_weight_loader
)
continue
weight_loader
(
param
,
loaded_weight
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# pylint: disable=E1136
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/utils.py
View file @
ead94d93
...
@@ -58,7 +58,9 @@ def in_wsl() -> bool:
...
@@ -58,7 +58,9 @@ def in_wsl() -> bool:
def
get_ip
()
->
str
:
def
get_ip
()
->
str
:
return
socket
.
gethostbyname
(
socket
.
gethostname
())
s
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_DGRAM
)
s
.
connect
((
"8.8.8.8"
,
80
))
# Doesn't need to be reachable
return
s
.
getsockname
()[
0
]
def
get_open_port
()
->
int
:
def
get_open_port
()
->
int
:
...
...
vllm/worker/model_runner.py
View file @
ead94d93
...
@@ -235,9 +235,11 @@ class ModelRunner:
...
@@ -235,9 +235,11 @@ class ModelRunner:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
block_tables
=
torch
.
tensor
(
input_block_tables
,
device
=
"cuda"
)
block_tables
=
torch
.
tensor
(
input_block_tables
,
device
=
"cuda"
)
else
:
else
:
max_block_table_len
=
(
max_context_len
+
self
.
block_size
-
1
)
//
self
.
block_size
block_tables
=
_make_tensor_with_pad
(
block_tables
=
_make_tensor_with_pad
(
block_tables
,
block_tables
,
max_len
=
max_
context
_len
,
max_len
=
max_
block_table
_len
,
pad
=
0
,
pad
=
0
,
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
device
=
"cuda"
,
device
=
"cuda"
,
...
@@ -504,7 +506,9 @@ class ModelRunner:
...
@@ -504,7 +506,9 @@ class ModelRunner:
"use '--enforce-eager' in the CLI."
)
"use '--enforce-eager' in the CLI."
)
logger
.
info
(
"CUDA graphs can take additional 1~3 GiB memory per GPU. "
logger
.
info
(
"CUDA graphs can take additional 1~3 GiB memory per GPU. "
"If you are running out of memory, consider decreasing "
"If you are running out of memory, consider decreasing "
"`gpu_memory_utilization` or enforcing eager mode."
)
"`gpu_memory_utilization` or enforcing eager mode. "
"You can also reduce the `max_num_seqs` as needed "
"to decrease memory usage."
)
start_time
=
time
.
perf_counter
()
start_time
=
time
.
perf_counter
()
# Prepare dummy inputs. These will be reused for all batch sizes.
# Prepare dummy inputs. These will be reused for all batch sizes.
...
@@ -517,9 +521,15 @@ class ModelRunner:
...
@@ -517,9 +521,15 @@ class ModelRunner:
context_lens
=
torch
.
ones
(
max_batch_size
,
dtype
=
torch
.
int32
).
cuda
()
context_lens
=
torch
.
ones
(
max_batch_size
,
dtype
=
torch
.
int32
).
cuda
()
block_tables
=
torch
.
from_numpy
(
self
.
graph_block_tables
).
cuda
()
block_tables
=
torch
.
from_numpy
(
self
.
graph_block_tables
).
cuda
()
graph_batch_size
=
_get_graph_batch_size
(
self
.
scheduler_config
.
max_num_seqs
)
batch_size_capture_list
=
[
bs
for
bs
in
_BATCH_SIZES_TO_CAPTURE
if
bs
<=
graph_batch_size
]
# NOTE: Capturing the largest batch size first may help reduce the
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
# memory usage of CUDA graph.
for
batch_size
in
reversed
(
_BATCH_SIZES_TO_CAPTURE
):
for
batch_size
in
reversed
(
batch_size_capture_list
):
# Create dummy input_metadata.
# Create dummy input_metadata.
input_metadata
=
InputMetadata
(
input_metadata
=
InputMetadata
(
is_prompt
=
False
,
is_prompt
=
False
,
...
...
vllm/worker/worker.py
View file @
ead94d93
...
@@ -87,6 +87,14 @@ class Worker:
...
@@ -87,6 +87,14 @@ class Worker:
gpu_memory_utilization
:
float
,
gpu_memory_utilization
:
float
,
cpu_swap_space
:
int
,
cpu_swap_space
:
int
,
)
->
Tuple
[
int
,
int
]:
)
->
Tuple
[
int
,
int
]:
"""Profiles the peak memory usage of the model and returns the maximum
number of GPU and CPU cache blocks that can be allocated.
Args:
block_size: The size of the cache block.
gpu_memory_utilization: The fraction of the total GPU memory to use.
cpu_swap_space: The size of the CPU swap space in bytes.
"""
# Profile the memory usage of the model and get the maximum number of
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
# cache blocks that can be allocated with the remaining free memory.
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
@@ -231,4 +239,6 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
...
@@ -231,4 +239,6 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
raise
ValueError
(
raise
ValueError
(
"Bfloat16 is only supported on GPUs with compute capability "
"Bfloat16 is only supported on GPUs with compute capability "
f
"of at least 8.0. Your
{
gpu_name
}
GPU has compute capability "
f
"of at least 8.0. Your
{
gpu_name
}
GPU has compute capability "
f
"
{
compute_capability
[
0
]
}
.
{
compute_capability
[
1
]
}
."
)
f
"
{
compute_capability
[
0
]
}
.
{
compute_capability
[
1
]
}
. "
"You can use float16 instead by explicitly setting the"
"`dtype` flag in CLI, for example: --dtype=half."
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment