Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
705f6a35
Commit
705f6a35
authored
Jul 16, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.2' into v0.5.2-dtk24.04.1
parents
af837396
4cf256ae
Changes
439
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1197 additions
and
252 deletions
+1197
-252
tests/spec_decode/test_metrics.py
tests/spec_decode/test_metrics.py
+47
-47
tests/spec_decode/test_multi_step_worker.py
tests/spec_decode/test_multi_step_worker.py
+222
-12
tests/spec_decode/test_ngram_worker.py
tests/spec_decode/test_ngram_worker.py
+6
-3
tests/spec_decode/test_spec_decode_worker.py
tests/spec_decode/test_spec_decode_worker.py
+244
-78
tests/spec_decode/test_utils.py
tests/spec_decode/test_utils.py
+24
-2
tests/spec_decode/utils.py
tests/spec_decode/utils.py
+15
-8
tests/tensorizer_loader/test_tensorizer.py
tests/tensorizer_loader/test_tensorizer.py
+104
-30
tests/test_cache_block_hashing.py
tests/test_cache_block_hashing.py
+1
-1
tests/test_embedded_commit.py
tests/test_embedded_commit.py
+7
-0
tests/test_logger.py
tests/test_logger.py
+1
-0
tests/test_logits_processor.py
tests/test_logits_processor.py
+1
-1
tests/test_sharded_state_loader.py
tests/test_sharded_state_loader.py
+1
-1
tests/test_utils.py
tests/test_utils.py
+60
-1
tests/tokenization/test_detokenize.py
tests/tokenization/test_detokenize.py
+89
-24
tests/tokenization/test_get_eos.py
tests/tokenization/test_get_eos.py
+31
-0
tests/tokenization/test_image_processor.py
tests/tokenization/test_image_processor.py
+0
-20
tests/tokenization/test_tokenizer_group.py
tests/tokenization/test_tokenizer_group.py
+99
-0
tests/tracing/__init__.py
tests/tracing/__init__.py
+0
-0
tests/tracing/test_tracing.py
tests/tracing/test_tracing.py
+116
-0
tests/utils.py
tests/utils.py
+129
-24
No files found.
Too many changes to show.
To preserve performance only
439 of 439+
files are displayed.
Plain diff
Email patch
tests/spec_decode/test_metrics.py
View file @
705f6a35
...
@@ -10,16 +10,16 @@ from vllm.spec_decode.metrics import AsyncMetricsCollector
...
@@ -10,16 +10,16 @@ from vllm.spec_decode.metrics import AsyncMetricsCollector
def
test_initial_call_returns_none
():
def
test_initial_call_returns_none
():
"""Expect first call to get metrics to return None.
"""Expect first call to get metrics to return None.
"""
"""
rej
_sampler
=
MagicMock
()
spec_decode
_sampler
=
MagicMock
()
rej
_sampler
.
num_accepted_tokens
=
torch
.
tensor
(
0
,
spec_decode
_sampler
.
num_accepted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
device
=
'cuda'
)
rej
_sampler
.
num_emitted_tokens
=
torch
.
tensor
(
0
,
spec_decode
_sampler
.
num_emitted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
device
=
'cuda'
)
rej
_sampler
.
num_draft_tokens
=
0
spec_decode
_sampler
.
num_draft_tokens
=
0
collector
=
AsyncMetricsCollector
(
rej
_sampler
)
collector
=
AsyncMetricsCollector
(
spec_decode
_sampler
)
collector
.
init_gpu_tensors
(
rank
=
0
)
collector
.
init_gpu_tensors
(
rank
=
0
)
maybe_metrics
=
collector
.
maybe_collect_rejsample_metrics
(
k
=
5
)
maybe_metrics
=
collector
.
maybe_collect_rejsample_metrics
(
k
=
5
)
assert
maybe_metrics
is
None
assert
maybe_metrics
is
None
...
@@ -28,14 +28,14 @@ def test_initial_call_returns_none():
...
@@ -28,14 +28,14 @@ def test_initial_call_returns_none():
def
test_second_call_returns_metrics
():
def
test_second_call_returns_metrics
():
"""Expect second call to not return None.
"""Expect second call to not return None.
"""
"""
rej
_sampler
=
MagicMock
()
spec_decode
_sampler
=
MagicMock
()
rej
_sampler
.
num_accepted_tokens
=
torch
.
tensor
(
0
,
spec_decode
_sampler
.
num_accepted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
device
=
'cuda'
)
rej
_sampler
.
num_emitted_tokens
=
torch
.
tensor
(
0
,
spec_decode
_sampler
.
num_emitted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
device
=
'cuda'
)
rej
_sampler
.
num_draft_tokens
=
0
spec_decode
_sampler
.
num_draft_tokens
=
0
collect_interval_s
=
5.0
collect_interval_s
=
5.0
timer
=
MagicMock
()
timer
=
MagicMock
()
...
@@ -43,7 +43,7 @@ def test_second_call_returns_metrics():
...
@@ -43,7 +43,7 @@ def test_second_call_returns_metrics():
0.0
,
collect_interval_s
+
0.1
,
collect_interval_s
+
0.2
0.0
,
collect_interval_s
+
0.1
,
collect_interval_s
+
0.2
]
]
collector
=
AsyncMetricsCollector
(
rejection_sampler
=
rej
_sampler
,
collector
=
AsyncMetricsCollector
(
spec_decode_sampler
=
spec_decode
_sampler
,
timer
=
timer
,
timer
=
timer
,
collect_interval_s
=
collect_interval_s
)
collect_interval_s
=
collect_interval_s
)
collector
.
init_gpu_tensors
(
rank
=
0
)
collector
.
init_gpu_tensors
(
rank
=
0
)
...
@@ -56,16 +56,16 @@ def test_second_call_returns_metrics():
...
@@ -56,16 +56,16 @@ def test_second_call_returns_metrics():
def
test_nonzero_rank_noop
(
rank
):
def
test_nonzero_rank_noop
(
rank
):
"""Verify nonzero ranks don't collect metrics.
"""Verify nonzero ranks don't collect metrics.
"""
"""
rej
_sampler
=
MagicMock
()
spec_decode
_sampler
=
MagicMock
()
rej
_sampler
.
num_accepted_tokens
=
torch
.
tensor
(
0
,
spec_decode
_sampler
.
num_accepted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
device
=
'cuda'
)
rej
_sampler
.
num_emitted_tokens
=
torch
.
tensor
(
0
,
spec_decode
_sampler
.
num_emitted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
device
=
'cuda'
)
rej
_sampler
.
num_draft_tokens
=
0
spec_decode
_sampler
.
num_draft_tokens
=
0
collector
=
AsyncMetricsCollector
(
rej
_sampler
)
collector
=
AsyncMetricsCollector
(
spec_decode
_sampler
)
collector
.
init_gpu_tensors
(
rank
=
rank
)
collector
.
init_gpu_tensors
(
rank
=
rank
)
_
=
collector
.
maybe_collect_rejsample_metrics
(
k
=
5
)
_
=
collector
.
maybe_collect_rejsample_metrics
(
k
=
5
)
metrics
=
collector
.
maybe_collect_rejsample_metrics
(
k
=
5
)
metrics
=
collector
.
maybe_collect_rejsample_metrics
(
k
=
5
)
...
@@ -75,14 +75,14 @@ def test_nonzero_rank_noop(rank):
...
@@ -75,14 +75,14 @@ def test_nonzero_rank_noop(rank):
def
test_noop_until_time
():
def
test_noop_until_time
():
"""Verify metrics aren't collected until enough time passes.
"""Verify metrics aren't collected until enough time passes.
"""
"""
rej
_sampler
=
MagicMock
()
spec_decode
_sampler
=
MagicMock
()
rej
_sampler
.
num_accepted_tokens
=
torch
.
tensor
(
0
,
spec_decode
_sampler
.
num_accepted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
device
=
'cuda'
)
rej
_sampler
.
num_emitted_tokens
=
torch
.
tensor
(
0
,
spec_decode
_sampler
.
num_emitted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
device
=
'cuda'
)
rej
_sampler
.
num_draft_tokens
=
0
spec_decode
_sampler
.
num_draft_tokens
=
0
collect_interval_s
=
5.0
collect_interval_s
=
5.0
timer
=
MagicMock
()
timer
=
MagicMock
()
...
@@ -91,7 +91,7 @@ def test_noop_until_time():
...
@@ -91,7 +91,7 @@ def test_noop_until_time():
collect_interval_s
+
0.1
,
collect_interval_s
+
0.1
collect_interval_s
+
0.1
,
collect_interval_s
+
0.1
]
]
collector
=
AsyncMetricsCollector
(
rejection_sampler
=
rej
_sampler
,
collector
=
AsyncMetricsCollector
(
spec_decode_sampler
=
spec_decode
_sampler
,
timer
=
timer
,
timer
=
timer
,
collect_interval_s
=
collect_interval_s
)
collect_interval_s
=
collect_interval_s
)
collector
.
init_gpu_tensors
(
rank
=
0
)
collector
.
init_gpu_tensors
(
rank
=
0
)
...
@@ -122,14 +122,14 @@ def test_initial_metrics_has_correct_values(has_data: bool):
...
@@ -122,14 +122,14 @@ def test_initial_metrics_has_correct_values(has_data: bool):
max_num_emitted_tokens
=
AsyncMetricsCollector
.
get_max_num_emitted_tokens
(
max_num_emitted_tokens
=
AsyncMetricsCollector
.
get_max_num_emitted_tokens
(
num_draft_tokens
,
k
)
num_draft_tokens
,
k
)
rej
_sampler
=
MagicMock
()
spec_decode
_sampler
=
MagicMock
()
rej
_sampler
.
num_accepted_tokens
=
torch
.
tensor
(
num_accepted_tokens
,
spec_decode
_sampler
.
num_accepted_tokens
=
torch
.
tensor
(
num_accepted_tokens
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
device
=
'cuda'
)
rej
_sampler
.
num_emitted_tokens
=
torch
.
tensor
(
num_emitted_tokens
,
spec_decode
_sampler
.
num_emitted_tokens
=
torch
.
tensor
(
num_emitted_tokens
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
device
=
'cuda'
)
rej
_sampler
.
num_draft_tokens
=
num_draft_tokens
spec_decode
_sampler
.
num_draft_tokens
=
num_draft_tokens
collect_interval_s
=
5.0
collect_interval_s
=
5.0
timer
=
MagicMock
()
timer
=
MagicMock
()
...
@@ -137,7 +137,7 @@ def test_initial_metrics_has_correct_values(has_data: bool):
...
@@ -137,7 +137,7 @@ def test_initial_metrics_has_correct_values(has_data: bool):
0.0
,
collect_interval_s
+
0.1
,
collect_interval_s
+
0.2
0.0
,
collect_interval_s
+
0.1
,
collect_interval_s
+
0.2
]
]
collector
=
AsyncMetricsCollector
(
rejection_sampler
=
rej
_sampler
,
collector
=
AsyncMetricsCollector
(
spec_decode_sampler
=
spec_decode
_sampler
,
timer
=
timer
,
timer
=
timer
,
collect_interval_s
=
collect_interval_s
)
collect_interval_s
=
collect_interval_s
)
collector
.
init_gpu_tensors
(
rank
=
0
)
collector
.
init_gpu_tensors
(
rank
=
0
)
...
...
tests/spec_decode/test_multi_step_worker.py
View file @
705f6a35
import
random
import
random
from
typing
import
Dict
,
List
from
unittest.mock
import
MagicMock
from
unittest.mock
import
MagicMock
import
pytest
import
pytest
import
torch
import
torch
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
Logprob
,
SamplerOutput
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.worker.worker
import
Worker
from
vllm.worker.worker
import
Worker
...
@@ -84,6 +86,7 @@ def test_same_output_for_single_step():
...
@@ -84,6 +86,7 @@ def test_same_output_for_single_step():
block_size
,
block_size
,
num_gpu_blocks
,
num_gpu_blocks
,
seed
,
seed
,
model_runner_cls
=
TP1DraftModelRunner
,
)
)
worker
=
create_worker
(
worker
=
create_worker
(
Worker
,
Worker
,
...
@@ -115,7 +118,8 @@ def test_same_output_for_single_step():
...
@@ -115,7 +118,8 @@ def test_same_output_for_single_step():
actual_output
,
_
=
multi_step_worker
.
sampler_output
(
actual_output
,
_
=
multi_step_worker
.
sampler_output
(
execute_model_req
=
ExecuteModelRequest
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
multi_step_seq_group
),
seq_group_metadata_list
=
multi_step_seq_group
),
sample_len
=
num_steps
)
sample_len
=
num_steps
,
seq_ids_with_bonus_token_in_last_step
=
set
())
assert
len
(
actual_output
)
==
num_steps
assert
len
(
actual_output
)
==
num_steps
actual_output
=
actual_output
[
0
]
actual_output
=
actual_output
[
0
]
...
@@ -167,6 +171,7 @@ def test_same_output_for_multi_step():
...
@@ -167,6 +171,7 @@ def test_same_output_for_multi_step():
block_size
,
block_size
,
num_gpu_blocks
,
num_gpu_blocks
,
seed
,
seed
,
model_runner_cls
=
TP1DraftModelRunner
,
)
)
worker
=
create_worker
(
worker
=
create_worker
(
...
@@ -206,11 +211,12 @@ def test_same_output_for_multi_step():
...
@@ -206,11 +211,12 @@ def test_same_output_for_multi_step():
multi_step_output
,
_
=
multi_step_worker
.
sampler_output
(
multi_step_output
,
_
=
multi_step_worker
.
sampler_output
(
execute_model_req
=
ExecuteModelRequest
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
),
seq_group_metadata_list
=
seq_group_metadata_list
),
sample_len
=
num_steps
)
sample_len
=
num_steps
,
seq_ids_with_bonus_token_in_last_step
=
set
())
# Run single-step repeatedly.
# Run single-step repeatedly.
zero_kv_cache
(
worker
.
cache_engine
)
zero_kv_cache
(
worker
.
cache_engine
)
single_step_output
=
[]
single_step_output
:
List
[
SamplerOutput
]
=
[]
continuations
=
[[
1
]
for
_
in
prompts
]
continuations
=
[[
1
]
for
_
in
prompts
]
set_random_seed
(
seed
)
set_random_seed
(
seed
)
...
@@ -232,11 +238,15 @@ def test_same_output_for_multi_step():
...
@@ -232,11 +238,15 @@ def test_same_output_for_multi_step():
continuations
[
i
].
append
(
seq_group_output
.
samples
[
0
].
output_token
)
continuations
[
i
].
append
(
seq_group_output
.
samples
[
0
].
output_token
)
# Get token ids and logprobs for comparison.
# Get token ids and logprobs for comparison.
multi_step_output_logprobs
=
[[]
for
_
in
prompts
]
multi_step_output_logprobs
:
List
[
List
[
Dict
[
int
,
single_step_output_logprobs
=
[[]
for
_
in
prompts
]
Logprob
]]]
=
[[]
for
_
in
prompts
]
multi_step_output_token_ids
=
[[]
for
_
in
prompts
]
single_step_output_logprobs
:
List
[
List
[
Dict
[
int
,
single_step_output_token_ids
=
[[]
for
_
in
prompts
]
Logprob
]]]
=
[[]
for
_
in
prompts
]
multi_step_output_token_ids
:
List
[
List
[
int
]]
=
[[]
for
_
in
prompts
]
single_step_output_token_ids
:
List
[
List
[
int
]]
=
[[]
for
_
in
prompts
]
for
i
,
_
in
enumerate
(
prompts
):
for
i
,
_
in
enumerate
(
prompts
):
for
multi_step
,
single_step
in
zip
(
multi_step_output
,
for
multi_step
,
single_step
in
zip
(
multi_step_output
,
single_step_output
):
single_step_output
):
...
@@ -269,6 +279,203 @@ def test_same_output_for_multi_step():
...
@@ -269,6 +279,203 @@ def test_same_output_for_multi_step():
single_step_logprobs
)
single_step_logprobs
)
@
torch
.
inference_mode
()
def
test_multi_step_with_batch_expansion_correct_output
():
"""
In this test we verify that the MultiStepWorker is able to handle bonus
tokens correctly. The test verifies that if a sequence has a
bonus token then the MultiStepWorker is able to expand the batch by adding
new sequences corresponding to the sequences with bonus tokens. The
expanded batch is then used for predicting the next tokens.
"""
seed
=
100
model_name
=
'JackFram/llama-68m'
block_size
=
16
num_gpu_blocks
=
2048
//
block_size
batch_size
=
128
multi_step_worker
=
create_worker
(
MultiStepWorker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
model_runner_cls
=
TP1DraftModelRunner
,
)
worker
=
create_worker
(
Worker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
)
random
.
seed
(
seed
)
prompts
=
[[
0
]
for
_
in
range
(
batch_size
)]
num_steps
=
2
final_prompt_lens
=
[(
num_steps
+
1
)
for
prompt
in
prompts
]
rand_seeds
=
list
(
random
.
randint
(
0
,
100
)
for
_
in
range
(
num_steps
))
multi_step_worker
.
execute_model
=
patch_execute_model_with_seeds
(
multi_step_worker
,
rand_seeds
)
worker
.
execute_model
=
patch_execute_model_with_seeds
(
worker
,
rand_seeds
)
# Create the test continuations
continuations
=
[[
random
.
randint
(
0
,
1000
)]
for
_
in
prompts
]
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
continuations
=
continuations
,
final_prompt_lens
=
final_prompt_lens
)
# Run single-step twice to generate 2 tokens. This
# will simulate the bonus token case with the second token
# being the bonus token.
zero_kv_cache
(
worker
.
cache_engine
)
single_step_output
:
List
[
SamplerOutput
]
=
[]
set_random_seed
(
seed
)
for
_
in
range
(
num_steps
):
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
continuations
=
continuations
,
final_prompt_lens
=
final_prompt_lens
)
single_step_output
.
extend
(
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
)))
# Append output tokens to new sequence data.
for
i
,
seq_group_output
in
enumerate
(
single_step_output
[
-
1
]):
continuations
[
i
].
append
(
seq_group_output
.
samples
[
0
].
output_token
)
# Create continuations for the MultiStepWorker. The continuations have
# 2 tokens in order to simulate the bonus token case.
multi_step_continuations
=
[]
for
continuation
in
continuations
:
multi_step_continuations
.
append
(
continuation
[:
2
])
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
continuations
=
multi_step_continuations
,
final_prompt_lens
=
final_prompt_lens
)
# Run multi-step and verify that the third token prediction is accurate
# for all sequences.
zero_kv_cache
(
multi_step_worker
.
cache_engine
)
all_seq_ids
=
{
i
for
i
in
range
(
batch_size
)}
multi_step_output
,
_
=
multi_step_worker
.
sampler_output
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
),
sample_len
=
1
,
seq_ids_with_bonus_token_in_last_step
=
all_seq_ids
)
for
index
,
output
in
enumerate
(
multi_step_output
[
-
1
].
outputs
):
assert
(
continuations
[
index
][
-
1
]
==
output
.
samples
[
0
].
output_token
)
@
torch
.
inference_mode
()
def
test_multi_step_with_batch_expansion_incorrect_output
():
"""
Tests the MultiStepWorker's ability to handle batch expansion with bonus
tokens in a negative case scenario. This test provides the MultiStepWorker
with a batch containing sequences with bonus tokens but specifies the
sequence IDs with bonus tokens incorrectly. The test verifies that the
MultiStepWorker generates correct tokens for the sequences where the
sequence ID is specified correctly and incorrect tokens for those where
the sequence ID is specified incorrectly.
"""
seed
=
100
model_name
=
'JackFram/llama-68m'
block_size
=
16
num_gpu_blocks
=
2048
//
block_size
batch_size
=
128
multi_step_worker
=
create_worker
(
MultiStepWorker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
model_runner_cls
=
TP1DraftModelRunner
,
)
worker
=
create_worker
(
Worker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
)
random
.
seed
(
seed
)
prompts
=
[[
0
]
for
_
in
range
(
batch_size
)]
num_steps
=
2
final_prompt_lens
=
[(
num_steps
+
1
)
for
prompt
in
prompts
]
rand_seeds
=
list
(
random
.
randint
(
0
,
100
)
for
_
in
range
(
num_steps
))
multi_step_worker
.
execute_model
=
patch_execute_model_with_seeds
(
multi_step_worker
,
rand_seeds
)
worker
.
execute_model
=
patch_execute_model_with_seeds
(
worker
,
rand_seeds
)
# Create the test continuations
continuations
=
[[
random
.
randint
(
0
,
1000
)]
for
_
in
prompts
]
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
continuations
=
continuations
,
final_prompt_lens
=
final_prompt_lens
)
# Run single-step twice to generate 2 tokens. This
# will simulate the bonus token case with the second token
# being the bonus token.
zero_kv_cache
(
worker
.
cache_engine
)
single_step_output
:
List
[
SamplerOutput
]
=
[]
set_random_seed
(
seed
)
for
_
in
range
(
num_steps
):
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
continuations
=
continuations
,
final_prompt_lens
=
final_prompt_lens
)
single_step_output
.
extend
(
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
)))
# Append output tokens to new sequence data.
for
i
,
seq_group_output
in
enumerate
(
single_step_output
[
-
1
]):
continuations
[
i
].
append
(
seq_group_output
.
samples
[
0
].
output_token
)
# Create continuations for the MultiStepWorker. The continuations have
# 2 tokens in order to simulate the bonus token case.
multi_step_continuations
=
[]
for
continuation
in
continuations
:
multi_step_continuations
.
append
(
continuation
[:
2
])
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
continuations
=
multi_step_continuations
,
final_prompt_lens
=
final_prompt_lens
)
# Run multi-step. In this run INCORRECTLY specify that only the odd number
# sequences have bonus tokens. Verify that with this setting the third token
# prediction is accurate only for the odd numbered sequences. Also verify
# that the prediction might be wrong for some of the even numbered
# sequences.
zero_kv_cache
(
multi_step_worker
.
cache_engine
)
set_random_seed
(
seed
)
odd_seq_ids
=
{
i
for
i
in
range
(
batch_size
)
if
i
%
2
!=
0
}
multi_step_output
,
_
=
multi_step_worker
.
sampler_output
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
),
sample_len
=
1
,
seq_ids_with_bonus_token_in_last_step
=
odd_seq_ids
)
num_mismatch
=
0
for
index
,
output
in
enumerate
(
multi_step_output
[
-
1
].
outputs
):
if
(
index
%
2
)
!=
0
:
assert
(
continuations
[
index
][
-
1
]
==
output
.
samples
[
0
].
output_token
)
elif
(
continuations
[
index
][
-
1
]
!=
output
.
samples
[
0
].
output_token
):
num_mismatch
+=
1
# The prediction is accurate for some of the sequences even without proper
# handling of the bonus tokens. Hence verify that the number of sequences
# for which there is a mismatch is > 0.
assert
(
num_mismatch
>
0
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_draft_proposals_full_speculation_len
():
def
test_draft_proposals_full_speculation_len
():
"""Verify Top1Proposer correctly handles case where all sequences
"""Verify Top1Proposer correctly handles case where all sequences
...
@@ -310,7 +517,8 @@ def test_draft_proposals_full_speculation_len():
...
@@ -310,7 +517,8 @@ def test_draft_proposals_full_speculation_len():
proposals
=
proposer
.
get_spec_proposals
(
proposals
=
proposer
.
get_spec_proposals
(
execute_model_req
=
ExecuteModelRequest
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
),
)
num_lookahead_slots
=
k
),
seq_ids_with_bonus_token_in_last_step
=
set
())
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
@@ -348,7 +556,8 @@ def test_draft_proposals_no_speculations():
...
@@ -348,7 +556,8 @@ def test_draft_proposals_no_speculations():
proposals
=
proposer
.
get_spec_proposals
(
proposals
=
proposer
.
get_spec_proposals
(
execute_model_req
=
ExecuteModelRequest
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
),
)
num_lookahead_slots
=
k
),
seq_ids_with_bonus_token_in_last_step
=
set
())
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
@@ -420,7 +629,8 @@ def test_draft_proposals_mixed_k():
...
@@ -420,7 +629,8 @@ def test_draft_proposals_mixed_k():
proposals
=
proposer
.
get_spec_proposals
(
proposals
=
proposer
.
get_spec_proposals
(
execute_model_req
=
ExecuteModelRequest
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
),
)
num_lookahead_slots
=
k
),
seq_ids_with_bonus_token_in_last_step
=
set
())
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
...
tests/spec_decode/test_ngram_worker.py
View file @
705f6a35
...
@@ -53,7 +53,8 @@ def test_ngram_algo_correctness_for_single_no_match():
...
@@ -53,7 +53,8 @@ def test_ngram_algo_correctness_for_single_no_match():
proposals
=
proposer
.
get_spec_proposals
(
proposals
=
proposer
.
get_spec_proposals
(
execute_model_req
=
ExecuteModelRequest
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
proposal_len
),
)
num_lookahead_slots
=
proposal_len
),
seq_ids_with_bonus_token_in_last_step
=
None
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
@@ -121,7 +122,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
...
@@ -121,7 +122,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
proposals
=
proposer
.
get_spec_proposals
(
proposals
=
proposer
.
get_spec_proposals
(
execute_model_req
=
ExecuteModelRequest
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
proposal_len
),
)
num_lookahead_slots
=
proposal_len
),
seq_ids_with_bonus_token_in_last_step
=
None
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
@@ -193,7 +195,8 @@ def test_ngram_algo_correctness_for_batches_match_all():
...
@@ -193,7 +195,8 @@ def test_ngram_algo_correctness_for_batches_match_all():
proposals
=
proposer
.
get_spec_proposals
(
proposals
=
proposer
.
get_spec_proposals
(
execute_model_req
=
ExecuteModelRequest
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
proposal_len
),
)
num_lookahead_slots
=
proposal_len
),
seq_ids_with_bonus_token_in_last_step
=
None
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
...
tests/spec_decode/test_spec_decode_worker.py
View file @
705f6a35
import
random
import
random
from
collections
import
defaultdict
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
from
typing
import
Dict
,
List
,
Set
from
unittest.mock
import
MagicMock
from
unittest.mock
import
MagicMock
import
pytest
import
pytest
import
torch
import
torch
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
,
SequenceOutput
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.metrics
import
(
AsyncMetricsCollector
,
from
vllm.spec_decode.metrics
import
(
AsyncMetricsCollector
,
SpecDecodeWorkerMetrics
)
SpecDecodeWorkerMetrics
)
...
@@ -15,23 +16,26 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker
...
@@ -15,23 +16,26 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker
from
vllm.spec_decode.spec_decode_worker
import
(
SpecDecodeWorker
,
from
vllm.spec_decode.spec_decode_worker
import
(
SpecDecodeWorker
,
split_num_cache_blocks_evenly
)
split_num_cache_blocks_evenly
)
from
.test_utils
import
mock_spec_decode_sampler
from
.utils
import
create_batch
,
create_sampler_output_list
,
mock_worker
from
.utils
import
create_batch
,
create_sampler_output_list
,
mock_worker
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
6
])
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
6
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_correctly_calls_draft_model
(
k
:
int
,
batch_size
:
int
):
def
test_correctly_calls_draft_model
(
k
:
int
,
batch_size
:
int
,
acceptance_sampler_method
:
str
):
"""Verify SpecDecodeWorker calls the draft worker with correct
"""Verify SpecDecodeWorker calls the draft worker with correct
inputs. Everything else is mocked out.
inputs. Everything else is mocked out.
"""
"""
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
target_worker
=
mock_worker
()
target_worker
=
mock_worker
()
rejection_sampler
=
MagicMock
(
spec
=
RejectionSampler
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
rejection_sampler
,
worker
=
SpecDecodeWorker
(
metrics_collector
)
draft_worker
,
target_worker
,
mock_spec_decode_sampler
(
acceptance_sampler_method
),
metrics_collector
)
exception_secret
=
'artificial stop'
exception_secret
=
'artificial stop'
draft_worker
.
get_spec_proposals
.
side_effect
=
ValueError
(
exception_secret
)
draft_worker
.
get_spec_proposals
.
side_effect
=
ValueError
(
exception_secret
)
...
@@ -52,15 +56,16 @@ def test_correctly_calls_draft_model(k: int, batch_size: int):
...
@@ -52,15 +56,16 @@ def test_correctly_calls_draft_model(k: int, batch_size: int):
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
6
])
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
6
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_correctly_calls_target_model
(
k
:
int
,
batch_size
:
int
):
def
test_correctly_calls_target_model
(
k
:
int
,
batch_size
:
int
,
acceptance_sampler_method
:
str
):
"""Verify SpecDecodeWorker calls the target model with correct
"""Verify SpecDecodeWorker calls the target model with correct
inputs. Everything else is mocked out.
inputs. Everything else is mocked out.
"""
"""
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
,
use_spec
=
False
)
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
,
use_spec
=
False
)
target_worker
=
mock_worker
(
use_spec
=
False
)
target_worker
=
mock_worker
(
use_spec
=
False
)
rejection_sampler
=
MagicMock
(
spec
=
RejectionSampler
)
rejection_sampler
.
token_id_dtype
=
torch
.
int64
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
draft_worker
.
device
=
'cuda'
draft_worker
.
device
=
'cuda'
...
@@ -68,8 +73,9 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
...
@@ -68,8 +73,9 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
set_random_seed
(
1
)
set_random_seed
(
1
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
rejection_sampler
,
worker
=
SpecDecodeWorker
(
metrics_collector
)
draft_worker
,
target_worker
,
mock_spec_decode_sampler
(
acceptance_sampler_method
),
metrics_collector
)
worker
.
init_device
()
worker
.
init_device
()
vocab_size
=
32_000
vocab_size
=
32_000
...
@@ -103,7 +109,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
...
@@ -103,7 +109,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
))
num_lookahead_slots
=
k
))
seen_contexts
=
[]
seen_contexts
:
List
[
List
[
int
]]
=
[]
call_args_list
=
target_worker
.
execute_model
.
call_args_list
call_args_list
=
target_worker
.
execute_model
.
call_args_list
assert
len
(
call_args_list
)
==
1
assert
len
(
call_args_list
)
==
1
...
@@ -116,7 +122,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
...
@@ -116,7 +122,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
for
seq_data
in
seq_group_metadata
.
seq_data
.
values
():
for
seq_data
in
seq_group_metadata
.
seq_data
.
values
():
seen_contexts
.
append
(
seq_data
.
get_token_ids
())
seen_contexts
.
append
(
seq_data
.
get_token_ids
())
expected_seen_contexts
=
[]
expected_seen_contexts
:
List
[
List
[
int
]]
=
[]
for
prompt
,
prev_generated
,
draft_tokens
in
zip
(
for
prompt
,
prev_generated
,
draft_tokens
in
zip
(
prompts
,
prev_output_tokens
,
proposal_token_ids
.
tolist
()):
prompts
,
prev_output_tokens
,
proposal_token_ids
.
tolist
()):
...
@@ -132,8 +138,11 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
...
@@ -132,8 +138,11 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
6
])
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
6
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_correctly_calls_rejection_sampler
(
k
:
int
,
batch_size
:
int
):
def
test_correctly_calls_spec_decode_sampler
(
k
:
int
,
batch_size
:
int
,
acceptance_sampler_method
:
str
):
"""Verify SpecDecodeWorker calls the rejection sampler with
"""Verify SpecDecodeWorker calls the rejection sampler with
correct inputs. Everything else is mocked out.
correct inputs. Everything else is mocked out.
"""
"""
...
@@ -143,15 +152,14 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
...
@@ -143,15 +152,14 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
vocab_size
=
vocab_size
,
vocab_size
=
vocab_size
,
use_spec
=
False
)
use_spec
=
False
)
target_worker
=
mock_worker
(
vocab_size
=
vocab_size
,
use_spec
=
False
)
target_worker
=
mock_worker
(
vocab_size
=
vocab_size
,
use_spec
=
False
)
rejection_sampler
=
MagicMock
(
spec
=
RejectionSampler
)
spec_decode_sampler
=
mock_spec_decode_sampler
(
acceptance_sampler_method
)
rejection_sampler
.
token_id_dtype
=
torch
.
int64
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
draft_worker
.
device
=
'cuda'
draft_worker
.
device
=
'cuda'
target_worker
.
device
=
'cuda'
target_worker
.
device
=
'cuda'
set_random_seed
(
1
)
set_random_seed
(
1
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
rejection
_sampler
,
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
spec_decode
_sampler
,
metrics_collector
)
metrics_collector
)
worker
.
init_device
()
worker
.
init_device
()
...
@@ -198,15 +206,16 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
...
@@ -198,15 +206,16 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]]
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]]
exception_secret
=
'artificial stop'
exception_secret
=
'artificial stop'
rejection_sampler
.
side_effect
=
ValueError
(
exception_secret
)
spec_decode_sampler
.
side_effect
=
ValueError
(
exception_secret
)
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
))
num_lookahead_slots
=
k
))
assert
len
(
rejection
_sampler
.
call_args_list
)
==
1
assert
len
(
spec_decode
_sampler
.
call_args_list
)
==
1
_
,
kwargs
=
rejection
_sampler
.
call_args_list
[
0
]
_
,
kwargs
=
spec_decode
_sampler
.
call_args_list
[
0
]
actual
=
SimpleNamespace
(
**
kwargs
)
actual
=
SimpleNamespace
(
**
kwargs
)
assert
torch
.
equal
(
actual
.
bonus_token_ids
,
assert
torch
.
equal
(
actual
.
bonus_token_ids
,
...
@@ -220,8 +229,11 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
...
@@ -220,8 +229,11 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
6
])
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
6
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_correctly_formats_output
(
k
:
int
,
batch_size
:
int
):
def
test_correctly_formats_output
(
k
:
int
,
batch_size
:
int
,
acceptance_sampler_method
:
str
):
"""Verify SpecDecodeWorker formats sampler output correctly.
"""Verify SpecDecodeWorker formats sampler output correctly.
Everything else is mocked out.
Everything else is mocked out.
"""
"""
...
@@ -231,15 +243,13 @@ def test_correctly_formats_output(k: int, batch_size: int):
...
@@ -231,15 +243,13 @@ def test_correctly_formats_output(k: int, batch_size: int):
vocab_size
=
vocab_size
,
vocab_size
=
vocab_size
,
use_spec
=
False
)
use_spec
=
False
)
target_worker
=
mock_worker
(
vocab_size
=
vocab_size
,
use_spec
=
False
)
target_worker
=
mock_worker
(
vocab_size
=
vocab_size
,
use_spec
=
False
)
rejection_sampler
=
MagicMock
(
spec
=
RejectionSampler
)
rejection_sampler
.
token_id_dtype
=
torch
.
int64
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
draft_worker
.
device
=
'cuda'
draft_worker
.
device
=
'cuda'
target_worker
.
device
=
'cuda'
target_worker
.
device
=
'cuda'
set_random_seed
(
1
)
set_random_seed
(
1
)
spec_decode_sampler
=
mock_spec_decode_sampler
(
acceptance_sampler_method
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
rejection
_sampler
,
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
spec_decode
_sampler
,
metrics_collector
)
metrics_collector
)
worker
.
init_device
()
worker
.
init_device
()
...
@@ -285,24 +295,23 @@ def test_correctly_formats_output(k: int, batch_size: int):
...
@@ -285,24 +295,23 @@ def test_correctly_formats_output(k: int, batch_size: int):
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]]
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]]
rejection
_sampler_output
=
torch
.
randint
(
low
=
0
,
spec_decode
_sampler_output
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
+
1
),
size
=
(
batch_size
,
k
+
1
),
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
device
=
'cuda'
)
device
=
'cuda'
)
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
minimum_accepted_tokens
=
1
minimum_accepted_tokens
=
1
rejection
_sampler_output
[
i
][
spec_decode
_sampler_output
[
i
][
-
random
.
randint
(
minimum_accepted_tokens
,
k
+
1
):]
=
-
1
-
random
.
randint
(
minimum_accepted_tokens
,
k
+
1
):]
=
-
1
rejection_sampler
.
return_value
=
rejection_sampler_output
spec_decode_sampler
.
return_value
=
spec_decode_sampler_output
output
=
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
output
=
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
))
num_lookahead_slots
=
k
))
expected_output
=
create_sampler_output_list
(
expected_output
=
create_sampler_output_list
(
token_ids
=
rejection
_sampler_output
.
transpose
(
0
,
1
),
token_ids
=
spec_decode
_sampler_output
.
transpose
(
0
,
1
),
probs
=
[
None
for
_
in
range
(
k
+
1
)],
probs
=
[
None
for
_
in
range
(
k
+
1
)],
logprobs
=
[
None
for
_
in
range
(
k
+
1
)])
logprobs
=
[
None
for
_
in
range
(
k
+
1
)])
...
@@ -310,8 +319,14 @@ def test_correctly_formats_output(k: int, batch_size: int):
...
@@ -310,8 +319,14 @@ def test_correctly_formats_output(k: int, batch_size: int):
next
(
iter
(
seq_group_metadata
.
seq_data
.
keys
()))
next
(
iter
(
seq_group_metadata
.
seq_data
.
keys
()))
for
seq_group_metadata
in
seq_group_metadata_list
for
seq_group_metadata
in
seq_group_metadata_list
]
]
actual_output_by_seq
=
{
seq_id
:
[]
for
seq_id
in
seq_ids
}
actual_output_by_seq
:
Dict
[
int
,
List
[
SequenceOutput
]]
=
{
expected_output_by_seq
=
{
seq_id
:
[]
for
seq_id
in
seq_ids
}
seq_id
:
[]
for
seq_id
in
seq_ids
}
expected_output_by_seq
:
Dict
[
int
,
List
[
SequenceOutput
]]
=
{
seq_id
:
[]
for
seq_id
in
seq_ids
}
for
step
in
output
:
for
step
in
output
:
for
seq_group
in
step
:
for
seq_group
in
step
:
...
@@ -343,8 +358,11 @@ def test_correctly_formats_output(k: int, batch_size: int):
...
@@ -343,8 +358,11 @@ def test_correctly_formats_output(k: int, batch_size: int):
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
])
@
pytest
.
mark
.
parametrize
(
'returns_metrics'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'returns_metrics'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_collects_metrics
(
k
:
int
,
batch_size
:
int
,
returns_metrics
:
bool
):
def
test_collects_metrics
(
k
:
int
,
batch_size
:
int
,
returns_metrics
:
bool
,
acceptance_sampler_method
:
str
):
"""Verify SpecDecodeWorker collects metrics.
"""Verify SpecDecodeWorker collects metrics.
"""
"""
vocab_size
=
32_000
vocab_size
=
32_000
...
@@ -353,16 +371,17 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
...
@@ -353,16 +371,17 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
vocab_size
=
vocab_size
,
vocab_size
=
vocab_size
,
use_spec
=
False
)
use_spec
=
False
)
target_worker
=
mock_worker
(
vocab_size
=
vocab_size
,
use_spec
=
False
)
target_worker
=
mock_worker
(
vocab_size
=
vocab_size
,
use_spec
=
False
)
rejection_sampler
=
MagicMock
(
spec
=
RejectionSampler
)
spec_decode_sampler
=
mock_spec_decode_sampler
(
acceptance_sampler_method
)
rejection_sampler
.
token_id_dtype
=
torch
.
int64
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
draft_worker
.
device
=
'cuda'
draft_worker
.
device
=
'cuda'
target_worker
.
device
=
'cuda'
target_worker
.
device
=
'cuda'
set_random_seed
(
1
)
set_random_seed
(
1
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
rejection_sampler
,
worker
=
SpecDecodeWorker
(
draft_worker
,
metrics_collector
)
target_worker
,
spec_decode_sampler
,
metrics_collector
=
metrics_collector
)
worker
.
init_device
()
worker
.
init_device
()
proposal_token_ids
=
torch
.
randint
(
low
=
0
,
proposal_token_ids
=
torch
.
randint
(
low
=
0
,
...
@@ -407,17 +426,16 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
...
@@ -407,17 +426,16 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]]
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]]
rejection
_sampler_output
=
torch
.
randint
(
low
=
0
,
spec_decode
_sampler_output
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
+
1
),
size
=
(
batch_size
,
k
+
1
),
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
device
=
'cuda'
)
device
=
'cuda'
)
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
minimum_accepted_tokens
=
1
minimum_accepted_tokens
=
1
rejection
_sampler_output
[
i
][
spec_decode
_sampler_output
[
i
][
-
random
.
randint
(
minimum_accepted_tokens
,
k
+
1
):]
=
-
1
-
random
.
randint
(
minimum_accepted_tokens
,
k
+
1
):]
=
-
1
spec_decode_sampler
.
return_value
=
spec_decode_sampler_output
rejection_sampler
.
return_value
=
rejection_sampler_output
mock_rejsample_metrics
=
MagicMock
(
mock_rejsample_metrics
=
MagicMock
(
spec
=
SpecDecodeWorkerMetrics
)
if
returns_metrics
else
None
spec
=
SpecDecodeWorkerMetrics
)
if
returns_metrics
else
None
...
@@ -438,26 +456,30 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
...
@@ -438,26 +456,30 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
@
pytest
.
mark
.
parametrize
(
'k'
,
[
0
])
@
pytest
.
mark
.
parametrize
(
'k'
,
[
0
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_k_equals_zero
(
k
:
int
,
batch_size
:
int
):
def
test_k_equals_zero
(
k
:
int
,
batch_size
:
int
,
acceptance_sampler_method
:
str
):
"""Verify that the SpecDecodeWorker calls the draft and target workers
"""Verify that the SpecDecodeWorker calls the draft and target workers
when k is zero. This happens during prefill.
when k is zero. This happens during prefill.
"""
"""
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
target_worker
=
mock_worker
()
target_worker
=
mock_worker
()
rejection_sampler
=
MagicMock
(
spec
=
RejectionSampler
)
rejection_sampler
.
token_id_dtype
=
torch
.
int64
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
target_worker
.
execute_model
.
return_value
=
[
MagicMock
(
spec
=
SamplerOutput
)]
sampler_output
=
MagicMock
(
spec
=
SamplerOutput
)
sampler_output
.
hidden_states
=
None
target_worker
.
execute_model
.
return_value
=
[
sampler_output
]
draft_worker
.
device
=
'cuda'
draft_worker
.
device
=
'cuda'
target_worker
.
device
=
'cuda'
target_worker
.
device
=
'cuda'
set_random_seed
(
1
)
set_random_seed
(
1
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
rejection_sampler
,
worker
=
SpecDecodeWorker
(
metrics_collector
)
draft_worker
,
target_worker
,
mock_spec_decode_sampler
(
acceptance_sampler_method
),
metrics_collector
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
,
k
,
...
@@ -478,27 +500,31 @@ def test_k_equals_zero(k: int, batch_size: int):
...
@@ -478,27 +500,31 @@ def test_k_equals_zero(k: int, batch_size: int):
@
pytest
.
mark
.
parametrize
(
'k'
,
[
0
,
5
])
@
pytest
.
mark
.
parametrize
(
'k'
,
[
0
,
5
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
0
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
0
])
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_empty_input_batch
(
k
:
int
,
batch_size
:
int
):
def
test_empty_input_batch
(
k
:
int
,
batch_size
:
int
,
acceptance_sampler_method
:
str
):
"""Verify that the SpecDecodeWorker calls the draft and target workers
"""Verify that the SpecDecodeWorker calls the draft and target workers
when the input batch is empty. This can happen if the engine communicates
when the input batch is empty. This can happen if the engine communicates
to the workers information without scheduling a batch.
to the workers information without scheduling a batch.
"""
"""
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
target_worker
=
mock_worker
()
target_worker
=
mock_worker
()
rejection_sampler
=
MagicMock
(
spec
=
RejectionSampler
)
rejection_sampler
.
token_id_dtype
=
torch
.
int64
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
target_worker
.
execute_model
.
return_value
=
[
MagicMock
(
spec
=
SamplerOutput
)]
sampler_output
=
MagicMock
(
spec
=
SamplerOutput
)
sampler_output
.
hidden_states
=
None
target_worker
.
execute_model
.
return_value
=
[
sampler_output
]
draft_worker
.
device
=
'cuda'
draft_worker
.
device
=
'cuda'
target_worker
.
device
=
'cuda'
target_worker
.
device
=
'cuda'
set_random_seed
(
1
)
set_random_seed
(
1
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
rejection_sampler
,
worker
=
SpecDecodeWorker
(
metrics_collector
)
draft_worker
,
target_worker
,
mock_spec_decode_sampler
(
acceptance_sampler_method
),
metrics_collector
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
,
k
,
...
@@ -517,20 +543,20 @@ def test_empty_input_batch(k: int, batch_size: int):
...
@@ -517,20 +543,20 @@ def test_empty_input_batch(k: int, batch_size: int):
target_worker
.
execute_model
.
assert_called_once_with
(
execute_model_req
)
target_worker
.
execute_model
.
assert_called_once_with
(
execute_model_req
)
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
def
test_init_device
():
def
test_init_device
(
acceptance_sampler_method
:
str
):
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
well as other GPU initialization.
well as other GPU initialization.
"""
"""
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
,
use_spec
=
False
)
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
,
use_spec
=
False
)
target_worker
=
mock_worker
(
use_spec
=
False
)
target_worker
=
mock_worker
(
use_spec
=
False
)
rejection_sampler
=
MagicMock
(
spec
=
RejectionSampler
)
spec_decode_sampler
=
mock_spec_decode_sampler
(
acceptance_sampler_method
)
rejection_sampler
.
token_id_dtype
=
torch
.
int64
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
rejection
_sampler
,
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
spec_decode
_sampler
,
metrics_collector
)
metrics_collector
)
worker
.
init_device
()
worker
.
init_device
()
draft_worker
.
init_device
.
assert_called_once
()
draft_worker
.
init_device
.
assert_called_once
()
...
@@ -538,22 +564,23 @@ def test_init_device():
...
@@ -538,22 +564,23 @@ def test_init_device():
target_worker
.
init_device
.
assert_called_once
()
target_worker
.
init_device
.
assert_called_once
()
metrics_collector
.
init_gpu_tensors
.
assert_called_once
()
metrics_collector
.
init_gpu_tensors
.
assert_called_once
()
rejection
_sampler
.
init_gpu_tensors
.
assert_called_once
()
spec_decode
_sampler
.
init_gpu_tensors
.
assert_called_once
()
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_initialize_cache
():
def
test_initialize_cache
(
acceptance_sampler_method
):
"""Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer
"""Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer
workers.
workers.
"""
"""
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
target_worker
=
mock_worker
()
target_worker
=
mock_worker
()
rejection_sampler
=
MagicMock
(
spec
=
RejectionSampler
)
rejection_sampler
.
token_id_dtype
=
torch
.
int64
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
rejection_sampler
,
worker
=
SpecDecodeWorker
(
metrics_collector
)
draft_worker
,
target_worker
,
mock_spec_decode_sampler
(
acceptance_sampler_method
),
metrics_collector
)
kwargs
=
{
"num_gpu_blocks"
:
1024
,
"num_cpu_blocks"
:
1023
}
kwargs
=
{
"num_gpu_blocks"
:
1024
,
"num_cpu_blocks"
:
1023
}
worker
.
initialize_cache
(
**
kwargs
)
worker
.
initialize_cache
(
**
kwargs
)
...
@@ -566,19 +593,20 @@ def test_initialize_cache():
...
@@ -566,19 +593,20 @@ def test_initialize_cache():
@
pytest
.
mark
.
parametrize
(
'available_cpu_blocks'
,
[
500
])
@
pytest
.
mark
.
parametrize
(
'available_cpu_blocks'
,
[
500
])
@
pytest
.
mark
.
parametrize
(
'target_cache_block_size_bytes'
,
[
2
*
2
*
4096
])
@
pytest
.
mark
.
parametrize
(
'target_cache_block_size_bytes'
,
[
2
*
2
*
4096
])
@
pytest
.
mark
.
parametrize
(
'draft_kv_size_bytes'
,
[
0
,
2
*
2
*
768
,
2
*
2
*
4096
])
@
pytest
.
mark
.
parametrize
(
'draft_kv_size_bytes'
,
[
0
,
2
*
2
*
768
,
2
*
2
*
4096
])
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
def
test_determine_num_available_blocks
(
available_gpu_blocks
:
int
,
def
test_determine_num_available_blocks
(
available_gpu_blocks
:
int
,
available_cpu_blocks
:
int
,
available_cpu_blocks
:
int
,
target_cache_block_size_bytes
:
int
,
target_cache_block_size_bytes
:
int
,
draft_kv_size_bytes
:
int
):
draft_kv_size_bytes
:
int
,
acceptance_sampler_method
:
str
):
"""Verify SpecDecodeWorker correctly profiles num available GPU blocks.
"""Verify SpecDecodeWorker correctly profiles num available GPU blocks.
Specifically, it should run profiling in the scorer worker, and then evenly
Specifically, it should run profiling in the scorer worker, and then evenly
split the blocks between proposer and scorer worker.
split the blocks between proposer and scorer worker.
"""
"""
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
target_worker
=
mock_worker
()
target_worker
=
mock_worker
()
rejection_sampler
=
MagicMock
(
spec
=
RejectionSampler
)
rejection_sampler
.
token_id_dtype
=
torch
.
int64
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
target_worker
.
determine_num_available_blocks
.
return_value
=
(
target_worker
.
determine_num_available_blocks
.
return_value
=
(
...
@@ -587,8 +615,9 @@ def test_determine_num_available_blocks(available_gpu_blocks: int,
...
@@ -587,8 +615,9 @@ def test_determine_num_available_blocks(available_gpu_blocks: int,
target_cache_block_size_bytes
)
target_cache_block_size_bytes
)
draft_worker
.
get_cache_block_size_bytes
.
return_value
=
draft_kv_size_bytes
draft_worker
.
get_cache_block_size_bytes
.
return_value
=
draft_kv_size_bytes
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
rejection_sampler
,
worker
=
SpecDecodeWorker
(
metrics_collector
)
draft_worker
,
target_worker
,
mock_spec_decode_sampler
(
acceptance_sampler_method
),
metrics_collector
)
num_gpu_blocks
,
num_cpu_blocks
=
worker
.
determine_num_available_blocks
()
num_gpu_blocks
,
num_cpu_blocks
=
worker
.
determine_num_available_blocks
()
...
@@ -618,3 +647,140 @@ def test_split_num_cache_blocks_evenly(available_gpu_blocks: int,
...
@@ -618,3 +647,140 @@ def test_split_num_cache_blocks_evenly(available_gpu_blocks: int,
assert
(
num_blocks
*
target_cache_block_size_bytes
)
+
(
assert
(
num_blocks
*
target_cache_block_size_bytes
)
+
(
num_blocks
*
draft_kv_size_bytes
)
<=
(
available_gpu_blocks
*
num_blocks
*
draft_kv_size_bytes
)
<=
(
available_gpu_blocks
*
target_cache_block_size_bytes
)
target_cache_block_size_bytes
)
@
torch
.
inference_mode
()
def
test_populate_seq_ids_with_bonus_tokens
():
"""
Verify that a call to _create_output_sampler_list correctly updates
seq_with_bonus_token_in_last_step.
seq_with_bonus_token_in_last_step is an internal data structure in
SpecDecodeWorker that tracks the sequence IDs which are assigned bonus
tokens by the target model in their last forward pass. This state is
maintained only for models relying on the KV cache, such as those using
the MultiStepWorker.
"""
batch_size
=
10
k
=
5
vocab_size
=
10000
num_sequences_with_bonus_tokens
=
5
target_worker
=
mock_worker
(
vocab_size
=
vocab_size
,
use_spec
=
False
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
target_worker
.
execute_model
.
return_value
=
[
MagicMock
(
spec
=
SamplerOutput
)]
target_worker
.
device
=
'cuda'
set_random_seed
(
1
)
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
draft_worker
.
device
=
'cuda'
# The sequence_ids attached to each sequence in the batch.
# The sequence at index i has seq_id assigned_seq_ids[i]
assigned_seq_ids
=
list
(
range
(
batch_size
))
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
,
seq_ids
=
assigned_seq_ids
,
prev_output_token_len
=
10
)
target_token_logprobs
=
torch
.
rand
(
batch_size
,
(
k
+
1
),
vocab_size
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
accepted_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
(
k
+
1
)),
dtype
=
torch
.
int64
,
device
=
'cuda'
)
expected_request_id_seq_ids_mapping
:
Dict
[
str
,
Set
[
int
]]
=
defaultdict
(
set
)
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_id
in
seq_group_metadata
.
seq_data
:
expected_request_id_seq_ids_mapping
[
seq_group_metadata
.
request_id
].
add
(
seq_id
)
# Generate a random sample of sequence indexes with bonus tokens
seq_indexes_with_bonus_tokens
=
random
.
sample
(
range
(
batch_size
),
num_sequences_with_bonus_tokens
)
# Create a mask that is True for indices in seq_indexes_with_bonus_tokens
mask
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
bool
,
device
=
'cuda'
)
mask
[
seq_indexes_with_bonus_tokens
]
=
False
# Set the last token ID to -1 for all indices not in
# seq_indexes_with_bonus_tokens to indicate the lack of bonus token in
# those indices.
accepted_token_ids
[
mask
,
-
1
:]
=
-
1
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
mock_spec_decode_sampler
(
"rejection_sampler"
),
metrics_collector
=
metrics_collector
)
# Initialize _seq_with_bonus_token_in_last_step with a set of sequence IDs.
# This set includes all sequence IDs in the batch as well as an additional
# `num_extra_sequence_ids` sequence IDs. Note that the sequence IDs are in
# the range [0, batch_size + num_extra_sequence_ids).
num_extra_sequence_ids
=
10
worker
.
_seq_with_bonus_token_in_last_step
=
set
(
range
(
batch_size
+
num_extra_sequence_ids
))
worker
.
_create_output_sampler_list
(
seq_group_metadata_list
=
seq_group_metadata_list
,
accepted_token_ids
=
accepted_token_ids
,
target_logprobs
=
target_token_logprobs
,
k
=
k
)
# Verify that _seq_with_bonus_token_in_last_step contains the following:
# 1. Sequence IDs that were already present in
# _seq_with_bonus_token_in_last_step but were not part of the current
# batch are retained.
# 2. Of the sequence IDs present in the current batch, only those with a
# bonus token are retained in _seq_with_bonus_token_in_last_step.
# Sequence IDs that are present in the current batch but do not have
# bonus tokens are removed from _seq_with_bonus_token_in_last_step.
expected_seq_ids_with_bonus_tokens
=
\
set
([
assigned_seq_ids
[
i
]
for
i
in
seq_indexes_with_bonus_tokens
])
additional_sequence_ids
=
\
set
(
range
(
batch_size
,
batch_size
+
num_extra_sequence_ids
))
assert
worker
.
_seq_with_bonus_token_in_last_step
==
\
expected_seq_ids_with_bonus_tokens
.
union
(
additional_sequence_ids
)
assert
worker
.
_request_id_seq_id_mapping
==
\
expected_request_id_seq_ids_mapping
@
torch
.
inference_mode
()
def
test_handle_finished_requests
():
"""
Test to verify that finished request IDs are appropriately processed to
update the internal state of the SpecDecodeWorker.
This test initializes the SpecDecodeWorker with mock data, marks certain
requests as finished, and ensures that the corresponding sequence IDs are
correctly removed from the internal mappings.
"""
batch_size
=
32
k
=
3
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
target_worker
=
mock_worker
()
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
mock_spec_decode_sampler
(
"rejection_sampler"
),
metrics_collector
)
# Initialize the request_id_seq_id_mapping mapping dict with a few fake
# request ids and corresponding sequence ids.
worker
.
_request_id_seq_id_mapping
=
\
{
'request-1'
:
{
1
,
2
,
3
},
'request-2'
:
{
4
,
5
,
6
,
7
},
'request-3'
:
{
8
,
9
},
'request-4'
:
{
10
,
11
}}
# Initialize seq_with_bonus_token_in_last_step with a few fake
# sequence ids.
worker
.
_seq_with_bonus_token_in_last_step
=
{
1
,
4
,
5
,
8
,
9
,
10
}
exception_secret
=
'artificial stop'
draft_worker
.
get_spec_proposals
.
side_effect
=
ValueError
(
exception_secret
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
# Mark requests with ids request-1 and request-3 as finished.
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
,
finished_requests_ids
=
[
'request-1'
,
'request-3'
])
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
# Verify that request-1 and request-3 are removed from
# request_id_seq_id_mapping
assert
worker
.
_request_id_seq_id_mapping
==
\
{
'request-2'
:
{
4
,
5
,
6
,
7
},
'request-4'
:
{
10
,
11
}}
# Verify that all sequence ids corresponding to 'request-1'
# and 'request-3' are removed from seq_with_bonus_token_in_last_step.
assert
worker
.
_seq_with_bonus_token_in_last_step
==
\
{
4
,
5
,
10
}
tests/spec_decode/test_utils.py
View file @
705f6a35
from
unittest.mock
import
MagicMock
from
unittest.mock
import
MagicMock
import
pytest
import
pytest
import
torch
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.spec_decode.util
import
get_all_seq_ids
,
split_batch_by_proposal_len
from
vllm.model_executor.layers.typical_acceptance_sampler
import
(
TypicalAcceptanceSampler
)
from
vllm.sequence
import
SequenceGroupMetadata
,
get_all_seq_ids
from
vllm.spec_decode.util
import
split_batch_by_proposal_len
def
test_get_all_seq_ids
():
def
test_get_all_seq_ids
():
...
@@ -109,3 +113,21 @@ def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
...
@@ -109,3 +113,21 @@ def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
assert
filtered_groups
==
[]
assert
filtered_groups
==
[]
assert
indices
==
[]
assert
indices
==
[]
def
mock_spec_decode_sampler
(
acceptance_sampler_method
):
"""
Returns either a RejectionSampler or TypicalAcceptanceSampler
object depending on whether acceptance_sampler_method is
'rejection_sampler' or 'typical_acceptance_sampler' respectively.
"""
if
acceptance_sampler_method
==
"rejection_sampler"
:
sampler
=
MagicMock
(
spec
=
RejectionSampler
)
sampler
.
token_id_dtype
=
torch
.
int64
return
sampler
elif
acceptance_sampler_method
==
"typical_acceptance_sampler"
:
sampler
=
MagicMock
(
spec
=
TypicalAcceptanceSampler
)
sampler
.
token_id_dtype
=
torch
.
int64
return
sampler
else
:
raise
ValueError
(
f
"Invalid sampler name
{
acceptance_sampler_method
}
"
)
tests/spec_decode/utils.py
View file @
705f6a35
from
itertools
import
count
from
itertools
import
count
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Union
from
typing
import
Callable
,
Dict
,
List
,
Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
TypeVar
,
Union
from
unittest.mock
import
MagicMock
from
unittest.mock
import
MagicMock
import
torch
import
torch
...
@@ -12,8 +14,11 @@ from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
...
@@ -12,8 +14,11 @@ from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SequenceOutput
)
SequenceOutput
)
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.model_runner
import
ModelRunner
from
vllm.worker.worker
import
Worker
from
vllm.worker.worker
import
Worker
T
=
TypeVar
(
"T"
,
bound
=
Worker
)
def
round_up_to_next_block
(
seq_len
:
int
,
block_size
:
int
)
->
int
:
def
round_up_to_next_block
(
seq_len
:
int
,
block_size
:
int
)
->
int
:
return
(
seq_len
+
block_size
-
1
)
//
block_size
return
(
seq_len
+
block_size
-
1
)
//
block_size
...
@@ -49,20 +54,21 @@ def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]):
...
@@ -49,20 +54,21 @@ def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]):
return
new_execute_model
return
new_execute_model
def
zero_kv_cache
(
cache_engine
:
CacheEngine
):
def
zero_kv_cache
(
cache_engine
:
List
[
CacheEngine
]
):
assert
cache_engine
.
gpu_cache
assert
cache_engine
[
0
]
.
gpu_cache
for
key_blocks
,
value_blocks
in
cache_engine
.
gpu_cache
:
for
key_blocks
,
value_blocks
in
cache_engine
[
0
]
.
gpu_cache
:
key_blocks
.
zero_
()
key_blocks
.
zero_
()
value_blocks
.
zero_
()
value_blocks
.
zero_
()
def
create_worker
(
cls
:
type
,
def
create_worker
(
cls
:
Callable
[...,
T
]
,
model_name
:
str
,
model_name
:
str
,
block_size
:
int
,
block_size
:
int
,
num_gpu_blocks
:
int
,
num_gpu_blocks
:
int
,
seed
:
int
,
seed
:
int
,
is_driver_worker
:
bool
=
True
,
is_driver_worker
:
bool
=
True
,
enforce_eager
:
bool
=
True
):
enforce_eager
:
bool
=
True
,
model_runner_cls
:
Optional
[
ModelRunner
]
=
None
)
->
T
:
engine_args
=
EngineArgs
(
engine_args
=
EngineArgs
(
model
=
model_name
,
model
=
model_name
,
seed
=
seed
,
seed
=
seed
,
...
@@ -85,6 +91,7 @@ def create_worker(cls: type,
...
@@ -85,6 +91,7 @@ def create_worker(cls: type,
rank
=
0
,
rank
=
0
,
distributed_init_method
=
distributed_init_method
,
distributed_init_method
=
distributed_init_method
,
is_driver_worker
=
is_driver_worker
,
is_driver_worker
=
is_driver_worker
,
model_runner_cls
=
model_runner_cls
,
)
)
worker
.
init_device
()
worker
.
init_device
()
...
@@ -159,8 +166,8 @@ def assert_logprobs_dict_allclose(
...
@@ -159,8 +166,8 @@ def assert_logprobs_dict_allclose(
def
create_sampler_output_list
(
def
create_sampler_output_list
(
token_ids
:
torch
.
Tensor
,
token_ids
:
torch
.
Tensor
,
probs
:
Iterabl
e
[
Optional
[
torch
.
Tensor
]],
probs
:
GenericSequenc
e
[
Optional
[
torch
.
Tensor
]],
logprobs
:
Iterabl
e
[
Optional
[
torch
.
Tensor
]],
logprobs
:
GenericSequenc
e
[
Optional
[
torch
.
Tensor
]],
seq_ids
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
SamplerOutput
]:
seq_ids
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
SamplerOutput
]:
num_steps
,
batch_size
=
token_ids
.
shape
num_steps
,
batch_size
=
token_ids
.
shape
token_ids_by_step
=
token_ids
.
tolist
()
token_ids_by_step
=
token_ids
.
tolist
()
...
...
tests/tensorizer_loader/test_tensorizer.py
View file @
705f6a35
import
json
import
json
import
os
import
os
import
pathlib
import
subprocess
import
subprocess
from
unittest.mock
import
MagicMock
,
patch
from
unittest.mock
import
MagicMock
,
patch
import
openai
import
openai
import
pytest
import
pytest
import
ray
import
torch
from
tensorizer
import
EncryptionParams
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
from
vllm.engine.arg_utils
import
EngineArgs
# yapf: disable
# yapf: disable
from
vllm.model_executor.model_loader.tensorizer
import
(
TensorizerConfig
,
from
vllm.model_executor.model_loader.tensorizer
import
(
TensorizerConfig
,
TensorSerializer
,
TensorSerializer
,
is_vllm_tensorized
,
is_vllm_tensorized
,
load_with_tensorizer
,
load_with_tensorizer
,
open_stream
,
open_stream
,
serialize_vllm_model
)
serialize_vllm_model
,
tensorize_vllm_model
)
from
..utils
import
ServerRunner
from
..conftest
import
VllmRunner
,
cleanup
from
..utils
import
RemoteOpenAIServer
# yapf conflicts with isort for this docstring
# yapf conflicts with isort for this docstring
prompts
=
[
prompts
=
[
"Hello, my name is"
,
"Hello, my name is"
,
"The president of the United States is"
,
"The president of the United States is"
,
...
@@ -42,6 +48,20 @@ def is_curl_installed():
...
@@ -42,6 +48,20 @@ def is_curl_installed():
except
(
subprocess
.
CalledProcessError
,
FileNotFoundError
):
except
(
subprocess
.
CalledProcessError
,
FileNotFoundError
):
return
False
return
False
def
get_torch_model
(
vllm_runner
:
VllmRunner
):
return
vllm_runner
\
.
model
\
.
llm_engine
\
.
model_executor
\
.
driver_worker
\
.
model_runner
\
.
model
def
write_keyfile
(
keyfile_path
:
str
):
encryption_params
=
EncryptionParams
.
random
()
pathlib
.
Path
(
keyfile_path
).
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
with
open
(
keyfile_path
,
'wb'
)
as
f
:
f
.
write
(
encryption_params
.
key
)
@
pytest
.
fixture
(
autouse
=
True
)
@
pytest
.
fixture
(
autouse
=
True
)
def
tensorizer_config
():
def
tensorizer_config
():
...
@@ -88,12 +108,17 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs(
...
@@ -88,12 +108,17 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs(
with
vllm_runner
(
model_ref
)
as
vllm_model
:
with
vllm_runner
(
model_ref
)
as
vllm_model
:
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
key_path
=
tmp_path
/
(
model_ref
+
".key"
)
key_path
=
tmp_path
/
(
model_ref
+
".key"
)
write_keyfile
(
key_path
)
outputs
=
vllm_model
.
generate
(
prompts
,
sampling_params
)
outputs
=
vllm_model
.
generate
(
prompts
,
sampling_params
)
config_for_serializing
=
TensorizerConfig
(
tensorizer_uri
=
model_path
)
config_for_serializing
=
TensorizerConfig
(
serialize_vllm_model
(
vllm_model
.
model
.
llm_engine
,
tensorizer_uri
=
model_path
,
config_for_serializing
,
encryption_keyfile
=
key_path
encryption_key_path
=
key_path
)
)
serialize_vllm_model
(
get_torch_model
(
vllm_model
),
config_for_serializing
)
config_for_deserializing
=
TensorizerConfig
(
tensorizer_uri
=
model_path
,
config_for_deserializing
=
TensorizerConfig
(
tensorizer_uri
=
model_path
,
encryption_keyfile
=
key_path
)
encryption_keyfile
=
key_path
)
...
@@ -145,7 +170,7 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
...
@@ -145,7 +170,7 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
with
vllm_runner
(
model_ref
,
)
as
vllm_model
:
with
vllm_runner
(
model_ref
,
)
as
vllm_model
:
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
serialize_vllm_model
(
vllm_model
.
model
.
llm_
engine
,
serialize_vllm_model
(
get_torch_
model
(
v
llm_
model
)
,
TensorizerConfig
(
tensorizer_uri
=
model_path
))
TensorizerConfig
(
tensorizer_uri
=
model_path
))
with
vllm_runner
(
with
vllm_runner
(
...
@@ -180,7 +205,7 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
...
@@ -180,7 +205,7 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
with
vllm_runner
(
model_ref
,
)
as
vllm_model
:
with
vllm_runner
(
model_ref
,
)
as
vllm_model
:
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
serialize_vllm_model
(
vllm_model
.
model
.
llm_
engine
,
serialize_vllm_model
(
get_torch_
model
(
v
llm_
model
)
,
TensorizerConfig
(
tensorizer_uri
=
model_path
))
TensorizerConfig
(
tensorizer_uri
=
model_path
))
model_loader_extra_config
=
{
model_loader_extra_config
=
{
...
@@ -191,29 +216,24 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
...
@@ -191,29 +216,24 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
openai_args
=
[
openai_args
=
[
"--model"
,
model_ref
,
"--dtype"
,
"float16"
,
"--load-format"
,
"--model"
,
model_ref
,
"--dtype"
,
"float16"
,
"--load-format"
,
"tensorizer"
,
"--model-loader-extra-config"
,
"tensorizer"
,
"--model-loader-extra-config"
,
json
.
dumps
(
model_loader_extra_config
),
"--port"
,
"8000"
json
.
dumps
(
model_loader_extra_config
),
]
]
server
=
ServerRunner
.
remote
(
openai_args
)
with
RemoteOpenAIServer
(
openai_args
)
as
server
:
print
(
"Server ready."
)
assert
ray
.
get
(
server
.
ready
.
remote
())
print
(
"Server ready."
)
client
=
openai
.
OpenAI
(
client
=
server
.
get_client
()
base_url
=
"http://localhost:8000/v1"
,
completion
=
client
.
completions
.
create
(
model
=
model_ref
,
api_key
=
"token-abc123"
,
prompt
=
"Hello, my name is"
,
)
max_tokens
=
5
,
completion
=
client
.
completions
.
create
(
model
=
model_ref
,
temperature
=
0.0
)
prompt
=
"Hello, my name is"
,
max_tokens
=
5
,
temperature
=
0.0
)
assert
completion
.
id
is
not
None
assert
completion
.
id
is
not
None
assert
len
(
completion
.
choices
)
==
1
assert
len
(
completion
.
choices
)
==
1
assert
len
(
completion
.
choices
[
0
].
text
)
>=
5
assert
len
(
completion
.
choices
[
0
].
text
)
>=
5
assert
completion
.
choices
[
0
].
finish_reason
==
"length"
assert
completion
.
choices
[
0
].
finish_reason
==
"length"
assert
completion
.
usage
==
openai
.
types
.
CompletionUsage
(
assert
completion
.
usage
==
openai
.
types
.
CompletionUsage
(
completion_tokens
=
5
,
prompt_tokens
=
6
,
total_tokens
=
11
)
completion_tokens
=
5
,
prompt_tokens
=
6
,
total_tokens
=
11
)
def
test_raise_value_error_on_invalid_load_format
(
vllm_runner
):
def
test_raise_value_error_on_invalid_load_format
(
vllm_runner
):
...
@@ -224,7 +244,9 @@ def test_raise_value_error_on_invalid_load_format(vllm_runner):
...
@@ -224,7 +244,9 @@ def test_raise_value_error_on_invalid_load_format(vllm_runner):
model_loader_extra_config
=
TensorizerConfig
(
tensorizer_uri
=
"test"
))
model_loader_extra_config
=
TensorizerConfig
(
tensorizer_uri
=
"test"
))
def
test_tensorizer_with_tp
(
vllm_runner
):
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
reason
=
"Requires 2 GPUs"
)
def
test_tensorizer_with_tp_path_without_template
(
vllm_runner
):
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
model_ref
=
"EleutherAI/pythia-1.4b"
model_ref
=
"EleutherAI/pythia-1.4b"
tensorized_path
=
f
"s3://tensorized/
{
model_ref
}
/fp16/model.tensors"
tensorized_path
=
f
"s3://tensorized/
{
model_ref
}
/fp16/model.tensors"
...
@@ -238,8 +260,60 @@ def test_tensorizer_with_tp(vllm_runner):
...
@@ -238,8 +260,60 @@ def test_tensorizer_with_tp(vllm_runner):
s3_endpoint
=
"object.ord1.coreweave.com"
,
s3_endpoint
=
"object.ord1.coreweave.com"
,
),
),
tensor_parallel_size
=
2
,
tensor_parallel_size
=
2
,
disable_custom_all_reduce
=
True
,
)
)
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
reason
=
"Requires 2 GPUs"
)
def
test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs
(
vllm_runner
,
tmp_path
):
model_ref
=
"EleutherAI/pythia-1.4b"
# record outputs from un-sharded un-tensorized model
base_model
=
vllm_runner
(
model_ref
,
disable_custom_all_reduce
=
True
,
enforce_eager
=
True
,
)
outputs
=
base_model
.
generate
(
prompts
,
sampling_params
)
base_model
.
model
.
llm_engine
.
model_executor
.
shutdown
()
del
base_model
cleanup
()
# load model with two shards and serialize with encryption
model_path
=
str
(
tmp_path
/
(
model_ref
+
"-%02d.tensors"
))
key_path
=
tmp_path
/
(
model_ref
+
".key"
)
tensorizer_config
=
TensorizerConfig
(
tensorizer_uri
=
model_path
,
encryption_keyfile
=
key_path
,
)
tensorize_vllm_model
(
engine_args
=
EngineArgs
(
model
=
model_ref
,
tensor_parallel_size
=
2
,
disable_custom_all_reduce
=
True
,
enforce_eager
=
True
,
),
tensorizer_config
=
tensorizer_config
,
)
assert
os
.
path
.
isfile
(
model_path
%
0
),
"Serialization subprocess failed"
assert
os
.
path
.
isfile
(
model_path
%
1
),
"Serialization subprocess failed"
cleanup
()
loaded_vllm_model
=
vllm_runner
(
model_ref
,
tensor_parallel_size
=
2
,
load_format
=
"tensorizer"
,
disable_custom_all_reduce
=
True
,
enforce_eager
=
True
,
model_loader_extra_config
=
tensorizer_config
)
deserialized_outputs
=
loaded_vllm_model
.
generate
(
prompts
,
sampling_params
)
assert
outputs
==
deserialized_outputs
def
test_vllm_tensorized_model_has_same_outputs
(
vllm_runner
,
tmp_path
):
def
test_vllm_tensorized_model_has_same_outputs
(
vllm_runner
,
tmp_path
):
model_ref
=
"facebook/opt-125m"
model_ref
=
"facebook/opt-125m"
...
@@ -248,7 +322,7 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
...
@@ -248,7 +322,7 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
with
vllm_runner
(
model_ref
)
as
vllm_model
:
with
vllm_runner
(
model_ref
)
as
vllm_model
:
outputs
=
vllm_model
.
generate
(
prompts
,
sampling_params
)
outputs
=
vllm_model
.
generate
(
prompts
,
sampling_params
)
serialize_vllm_model
(
vllm_model
.
model
.
llm_
engine
,
config
)
serialize_vllm_model
(
get_torch_
model
(
v
llm_
model
)
,
config
)
assert
is_vllm_tensorized
(
config
)
assert
is_vllm_tensorized
(
config
)
...
...
tests/test_cache_block_hashing.py
View file @
705f6a35
...
@@ -51,7 +51,7 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
...
@@ -51,7 +51,7 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
max_input_length
=
None
,
max_input_length
=
None
,
)
)
hashes
=
[]
hashes
:
List
[
List
[
List
[
int
]]]
=
[]
for
prefix
in
prefixes
:
for
prefix
in
prefixes
:
for
lora_int_id
in
concurrent_lora_int_ids
:
for
lora_int_id
in
concurrent_lora_int_ids
:
...
...
tests/test_embedded_commit.py
0 → 100644
View file @
705f6a35
import
vllm
def
test_embedded_commit_defined
():
assert
vllm
.
__commit__
!=
"COMMIT_HASH_PLACEHOLDER"
# 7 characters is the length of a short commit hash
assert
len
(
vllm
.
__commit__
)
>=
7
tests/test_logger.py
View file @
705f6a35
...
@@ -47,6 +47,7 @@ def test_default_vllm_root_logger_configuration():
...
@@ -47,6 +47,7 @@ def test_default_vllm_root_logger_configuration():
assert
not
logger
.
propagate
assert
not
logger
.
propagate
handler
=
logger
.
handlers
[
0
]
handler
=
logger
.
handlers
[
0
]
assert
isinstance
(
handler
,
logging
.
StreamHandler
)
assert
handler
.
stream
==
sys
.
stdout
assert
handler
.
stream
==
sys
.
stdout
assert
handler
.
level
==
logging
.
INFO
assert
handler
.
level
==
logging
.
INFO
...
...
tests/test_logits_processor.py
View file @
705f6a35
...
@@ -83,7 +83,7 @@ def test_logits_processors(seed: int, device: str):
...
@@ -83,7 +83,7 @@ def test_logits_processors(seed: int, device: str):
device
=
device
,
device
=
device
,
pin_memory
=
is_pin_memory_available
())
pin_memory
=
is_pin_memory_available
())
logits_processor_output
=
logits_processor
(
logits_processor_output
=
logits_processor
(
embedding
=
None
,
lm_head
=
None
,
hidden_states
=
input_tensor
,
hidden_states
=
input_tensor
,
sampling_metadata
=
sampling_metadata
)
sampling_metadata
=
sampling_metadata
)
...
...
tests/test_sharded_state_loader.py
View file @
705f6a35
...
@@ -39,7 +39,7 @@ def test_filter_subtensors():
...
@@ -39,7 +39,7 @@ def test_filter_subtensors():
filtered_state_dict
=
ShardedStateLoader
.
_filter_subtensors
(
state_dict
)
filtered_state_dict
=
ShardedStateLoader
.
_filter_subtensors
(
state_dict
)
assert
tuple
(
filtered_state_dict
.
keys
())
==
(
"a"
,
"b"
,
"c"
)
assert
tuple
(
filtered_state_dict
.
keys
())
==
(
"a"
,
"b"
,
"c"
)
for
key
,
tensor
in
filtered_state_dict
.
items
():
for
key
,
tensor
in
filtered_state_dict
.
items
():
# NOTE: don't use `e
u
qal` here, as the tensor might contain NaNs
# NOTE: don't use `eq
u
al` here, as the tensor might contain NaNs
assert
tensor
is
state_dict
[
key
]
assert
tensor
is
state_dict
[
key
]
...
...
tests/test_utils.py
View file @
705f6a35
...
@@ -7,7 +7,8 @@ from typing import (TYPE_CHECKING, Any, AsyncIterator, Awaitable, Protocol,
...
@@ -7,7 +7,8 @@ from typing import (TYPE_CHECKING, Any, AsyncIterator, Awaitable, Protocol,
import
pytest
import
pytest
from
vllm.utils
import
deprecate_kwargs
,
get_open_port
,
merge_async_iterators
from
vllm.utils
import
(
FlexibleArgumentParser
,
deprecate_kwargs
,
get_open_port
,
merge_async_iterators
)
from
.utils
import
error_on_warning
from
.utils
import
error_on_warning
...
@@ -130,3 +131,61 @@ def test_get_open_port():
...
@@ -130,3 +131,61 @@ def test_get_open_port():
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s3
:
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s3
:
s3
.
bind
((
"localhost"
,
get_open_port
()))
s3
.
bind
((
"localhost"
,
get_open_port
()))
os
.
environ
.
pop
(
"VLLM_PORT"
)
os
.
environ
.
pop
(
"VLLM_PORT"
)
# Tests for FlexibleArgumentParser
@
pytest
.
fixture
def
parser
():
parser
=
FlexibleArgumentParser
()
parser
.
add_argument
(
'--image-input-type'
,
choices
=
[
'pixel_values'
,
'image_features'
])
parser
.
add_argument
(
'--model-name'
)
parser
.
add_argument
(
'--batch-size'
,
type
=
int
)
parser
.
add_argument
(
'--enable-feature'
,
action
=
'store_true'
)
return
parser
def
test_underscore_to_dash
(
parser
):
args
=
parser
.
parse_args
([
'--image_input_type'
,
'pixel_values'
])
assert
args
.
image_input_type
==
'pixel_values'
def
test_mixed_usage
(
parser
):
args
=
parser
.
parse_args
([
'--image_input_type'
,
'image_features'
,
'--model-name'
,
'facebook/opt-125m'
])
assert
args
.
image_input_type
==
'image_features'
assert
args
.
model_name
==
'facebook/opt-125m'
def
test_with_equals_sign
(
parser
):
args
=
parser
.
parse_args
(
[
'--image_input_type=pixel_values'
,
'--model-name=facebook/opt-125m'
])
assert
args
.
image_input_type
==
'pixel_values'
assert
args
.
model_name
==
'facebook/opt-125m'
def
test_with_int_value
(
parser
):
args
=
parser
.
parse_args
([
'--batch_size'
,
'32'
])
assert
args
.
batch_size
==
32
args
=
parser
.
parse_args
([
'--batch-size'
,
'32'
])
assert
args
.
batch_size
==
32
def
test_with_bool_flag
(
parser
):
args
=
parser
.
parse_args
([
'--enable_feature'
])
assert
args
.
enable_feature
is
True
args
=
parser
.
parse_args
([
'--enable-feature'
])
assert
args
.
enable_feature
is
True
def
test_invalid_choice
(
parser
):
with
pytest
.
raises
(
SystemExit
):
parser
.
parse_args
([
'--image_input_type'
,
'invalid_choice'
])
def
test_missing_required_argument
(
parser
):
parser
.
add_argument
(
'--required-arg'
,
required
=
True
)
with
pytest
.
raises
(
SystemExit
):
parser
.
parse_args
([])
tests/tokenization/test_detokenize.py
View file @
705f6a35
from
typing
import
Dict
,
List
from
typing
import
Any
,
Dict
,
List
,
Optional
import
pytest
import
pytest
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
...
@@ -139,6 +139,15 @@ def create_dummy_logprobs(
...
@@ -139,6 +139,15 @@ def create_dummy_logprobs(
}
for
token_id
in
complete_sequence_token_ids
]
}
for
token_id
in
complete_sequence_token_ids
]
def
create_dummy_prompt_logprobs
(
complete_sequence_token_ids
:
List
[
int
]
)
->
List
[
Optional
[
Dict
[
int
,
Any
]]]:
# logprob for the first prompt token is None.
logprobs
:
List
[
Optional
[
Dict
[
int
,
Any
]]]
=
[
None
]
logprobs
.
extend
(
create_dummy_logprobs
(
complete_sequence_token_ids
)[
1
:])
return
logprobs
@
pytest
.
mark
.
parametrize
(
"complete_sequence"
,
TRUTH
)
@
pytest
.
mark
.
parametrize
(
"complete_sequence"
,
TRUTH
)
@
pytest
.
mark
.
parametrize
(
"tokenizer_name"
,
TOKENIZERS
)
@
pytest
.
mark
.
parametrize
(
"tokenizer_name"
,
TOKENIZERS
)
@
pytest
.
mark
.
parametrize
(
"skip_special_tokens"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"skip_special_tokens"
,
[
True
,
False
])
...
@@ -153,8 +162,8 @@ def test_decode_sequence_logprobs(complete_sequence: str,
...
@@ -153,8 +162,8 @@ def test_decode_sequence_logprobs(complete_sequence: str,
# Run sequentially.
# Run sequentially.
seq
=
create_sequence
()
seq
=
create_sequence
()
dummy_logprobs
=
create_dummy_logprobs
(
complete_sequence_token_ids
)
dummy_logprobs
=
create_dummy_logprobs
(
complete_sequence_token_ids
)
sequential_logprobs_text_chosen_token
=
[]
sequential_logprobs_text_chosen_token
:
List
[
str
]
=
[]
sequential_logprobs_text_other_token
=
[]
sequential_logprobs_text_other_token
:
List
[
str
]
=
[]
for
new_token
,
logprobs
in
zip
(
complete_sequence_token_ids
,
for
new_token
,
logprobs
in
zip
(
complete_sequence_token_ids
,
dummy_logprobs
):
dummy_logprobs
):
seq
.
append_token_id
(
new_token
,
logprobs
)
seq
.
append_token_id
(
new_token
,
logprobs
)
...
@@ -177,13 +186,10 @@ def test_decode_sequence_logprobs(complete_sequence: str,
...
@@ -177,13 +186,10 @@ def test_decode_sequence_logprobs(complete_sequence: str,
@
pytest
.
mark
.
parametrize
(
"complete_sequence"
,
TRUTH
)
@
pytest
.
mark
.
parametrize
(
"complete_sequence"
,
TRUTH
)
@
pytest
.
mark
.
parametrize
(
"tokenizer_name"
,
TOKENIZERS
)
@
pytest
.
mark
.
parametrize
(
"tokenizer_name"
,
TOKENIZERS
)
@
pytest
.
mark
.
parametrize
(
"skip_special_tokens"
,
[
True
])
def
test_decode_prompt_logprobs
(
complete_sequence_token_ids
:
List
[
int
],
def
test_decode_prompt_logprobs
(
complete_sequence
:
str
,
detokenizer
:
Detokenizer
):
complete_sequence_token_ids
:
List
[
int
],
detokenizer
:
Detokenizer
,
skip_special_tokens
:
bool
):
"""Verify Detokenizer decodes prompt logprobs correctly."""
"""Verify Detokenizer decodes prompt logprobs correctly."""
sampling_params
=
SamplingParams
(
skip_special_tokens
=
skip_special_tokens
,
sampling_params
=
SamplingParams
(
skip_special_tokens
=
True
,
prompt_logprobs
=
1
)
prompt_logprobs
=
1
)
# Run sequentially.
# Run sequentially.
...
@@ -192,19 +198,78 @@ def test_decode_prompt_logprobs(complete_sequence: str,
...
@@ -192,19 +198,78 @@ def test_decode_prompt_logprobs(complete_sequence: str,
seqs
=
[
seq
],
seqs
=
[
seq
],
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
arrival_time
=
0.0
)
arrival_time
=
0.0
)
dummy_logprobs
=
create_dummy_logprobs
(
complete_sequence_token_ids
)
dummy_logprobs
=
create_dummy_prompt_logprobs
(
complete_sequence_token_ids
)
detokenizer
.
decode_prompt_logprobs_inplace
(
seq_group
,
dummy_logprobs
)
detokenizer
.
decode_prompt_logprobs_inplace
(
seq_group
,
decoded_prompt_logprobs
=
dummy_logprobs
dummy_logprobs
,
position_offset
=
0
)
# First logprob is None.
decoded_prompt_logprobs
:
List
[
Dict
[
int
,
Any
]]
=
dummy_logprobs
[
1
:]
# type: ignore
if
skip_special_tokens
:
# decoded_prompt_logprobs doesn't contain the first token.
# Text for logprobs for the chosen token should be the same as the
token_ids
=
complete_sequence_token_ids
# prompt text. Note that this will only be true if we skip
tokenzier
=
detokenizer
.
get_tokenizer_for_seq
(
seq
)
# special tokens.
text_full
=
tokenzier
.
decode
(
token_ids
,
skip_special_tokens
=
True
)
assert
complete_sequence
==
""
.
join
([
text_first
=
tokenzier
.
decode
(
token_ids
[
0
],
skip_special_tokens
=
True
)
logprobs
[
token_id
].
decoded_token
for
token_id
,
logprobs
in
zip
(
text
=
text_full
[
len
(
text_first
):]
complete_sequence_token_ids
,
decoded_prompt_logprobs
)
])
# Text for logprobs for the chosen token should be the same as the
assert
complete_sequence
!=
""
.
join
([
# prompt text. Note that the first logprob is None.
logprobs
[
token_id
+
1
].
decoded_token
for
token_id
,
logprobs
in
zip
(
assert
text
==
""
.
join
([
complete_sequence_token_ids
,
decoded_prompt_logprobs
)
logprobs
[
token_id
].
decoded_token
])
for
token_id
,
logprobs
in
zip
(
token_ids
[
1
:],
decoded_prompt_logprobs
)
])
assert
text
!=
""
.
join
([
logprobs
[
token_id
+
1
].
decoded_token
for
token_id
,
logprobs
in
zip
(
token_ids
[
1
:],
decoded_prompt_logprobs
)
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"facebook/opt-125m"
])
@
pytest
.
mark
.
parametrize
(
"chunked_prefill_token_size"
,
[
1
,
4
,
7
,
16
,
-
1
])
def
test_decode_prompt_logprobs_chunked_prefill
(
vllm_runner
,
model
,
chunked_prefill_token_size
:
int
,
example_prompts
,
):
max_num_seqs
=
256
enable_chunked_prefill
=
False
max_num_batched_tokens
=
None
if
chunked_prefill_token_size
!=
-
1
:
enable_chunked_prefill
=
True
max_num_seqs
=
min
(
chunked_prefill_token_size
,
max_num_seqs
)
max_num_batched_tokens
=
chunked_prefill_token_size
with
vllm_runner
(
model
,
dtype
=
"half"
,
max_logprobs
=
5
,
gpu_memory_utilization
=
0.5
,
enable_chunked_prefill
=
enable_chunked_prefill
,
max_num_batched_tokens
=
max_num_batched_tokens
,
max_num_seqs
=
max_num_seqs
)
as
vllm_model
:
vllm_sampling_params
=
SamplingParams
(
max_tokens
=
10
,
logprobs
=
5
,
prompt_logprobs
=
5
,
temperature
=
0.0
)
vllm_results
=
vllm_model
.
model
.
generate
(
example_prompts
,
sampling_params
=
vllm_sampling_params
)
for
idx
,
result
in
enumerate
(
vllm_results
):
assert
result
.
prompt_logprobs
is
not
None
assert
result
.
prompt_logprobs
[
0
]
is
None
# Compared detokenized prompts ids to original prompt.
generated_string
=
""
for
(
prompt_token
,
prompt_logprobs
)
in
zip
(
result
.
prompt_token_ids
[
1
:],
result
.
prompt_logprobs
[
1
:]):
# prompt_logprobs is a dict of the token_id: logprob
# We select the token_id corresponding to the actual prompt
# Decoded token in the detokenized string corresponding to this
# prompt token.
generated_string
+=
prompt_logprobs
[
prompt_token
].
decoded_token
assert
generated_string
==
example_prompts
[
idx
],
(
"Detokenized prompt logprobs do not match original prompt"
)
tests/tokenization/test_get_eos.py
0 → 100644
View file @
705f6a35
"""
This test file includes some cases where it is inappropriate to
only get the `eos_token_id` from the tokenizer as defined by
:meth:`vllm.LLMEngine._get_eos_token_id`.
"""
from
vllm.transformers_utils.config
import
try_get_generation_config
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
def
test_get_llama3_eos_token
():
model_name
=
"meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer
=
get_tokenizer
(
model_name
)
assert
tokenizer
.
eos_token_id
==
128009
generation_config
=
try_get_generation_config
(
model_name
,
trust_remote_code
=
False
)
assert
generation_config
is
not
None
assert
generation_config
.
eos_token_id
==
[
128001
,
128009
]
def
test_get_blip2_eos_token
():
model_name
=
"Salesforce/blip2-opt-2.7b"
tokenizer
=
get_tokenizer
(
model_name
)
assert
tokenizer
.
eos_token_id
==
2
generation_config
=
try_get_generation_config
(
model_name
,
trust_remote_code
=
False
)
assert
generation_config
is
not
None
assert
generation_config
.
eos_token_id
==
50118
tests/tokenization/test_image_processor.py
deleted
100644 → 0
View file @
af837396
import
pytest
from
transformers.image_processing_utils
import
BaseImageProcessor
from
vllm.transformers_utils.image_processor
import
get_image_processor
IMAGE_PROCESSOR_NAMES
=
[
"llava-hf/llava-1.5-7b-hf"
,
"llava-hf/llava-v1.6-34b-hf"
,
]
@
pytest
.
mark
.
parametrize
(
"processor_name"
,
IMAGE_PROCESSOR_NAMES
)
def
test_image_processor_revision
(
processor_name
:
str
):
# Assume that "main" branch always exists
image_processor
=
get_image_processor
(
processor_name
,
revision
=
"main"
)
assert
isinstance
(
image_processor
,
BaseImageProcessor
)
# Assume that "never" branch always does not exist
with
pytest
.
raises
(
OSError
,
match
=
'not a valid git identifier'
):
get_image_processor
(
processor_name
,
revision
=
"never"
)
tests/tokenization/test_tokenizer_group.py
View file @
705f6a35
import
asyncio
import
asyncio
import
os
import
os
import
sys
from
typing
import
List
,
Optional
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
pytest
import
pytest
...
@@ -100,3 +102,100 @@ async def test_tokenizer_group_ray_pool_env_var_propagation(
...
@@ -100,3 +102,100 @@ async def test_tokenizer_group_ray_pool_env_var_propagation(
max_num_seqs
=
1
,
max_num_seqs
=
1
,
max_input_length
=
None
)
max_input_length
=
None
)
tokenizer_pool
.
ping
()
tokenizer_pool
.
ping
()
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"tokenizer_group_type"
,
[
"ray"
])
async
def
test_tokenizer_group_ray_pool_fault_tolerance
(
tokenizer_group_type
):
"""Test that Ray tokenizer pool group can recover from failures and
if that's not possible, mark itself as unhealthy."""
class
FailingTokenizerGroup
(
TokenizerGroup
):
def
__init__
(
self
,
*
args
,
fail_at
:
Optional
[
List
[
int
]]
=
None
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
i
=
0
self
.
fail_at
=
fail_at
or
[]
def
encode
(
self
,
*
args
,
**
kwargs
):
self
.
i
+=
1
if
self
.
i
in
self
.
fail_at
:
sys
.
exit
(
1
)
return
super
().
encode
(
*
args
,
**
kwargs
)
class
FailingRayTokenizerGroupPool
(
RayTokenizerGroupPool
):
_worker_cls
=
FailingTokenizerGroup
# Fail at first iteration
fail_at
=
[
1
]
tokenizer_pool_config
=
get_tokenizer_pool_config
(
tokenizer_group_type
)
tokenizer_group_pool
=
FailingRayTokenizerGroupPool
.
from_config
(
tokenizer_pool_config
,
tokenizer_id
=
"gpt2"
,
enable_lora
=
False
,
max_num_seqs
=
1
,
max_input_length
=
None
,
fail_at
=
fail_at
)
tokenizer_actors
=
tokenizer_group_pool
.
tokenizer_actors
.
copy
()
# Modify fail at to not fail at all (will be re-read when actor is
# re-initialized).
fail_at
[
0
]
=
1000
# We should recover successfully.
await
tokenizer_group_pool
.
encode_async
(
request_id
=
"1"
,
prompt
=
"prompt"
,
lora_request
=
None
)
await
tokenizer_group_pool
.
encode_async
(
request_id
=
"1"
,
prompt
=
"prompt"
,
lora_request
=
None
)
# Check that we have a new actor
assert
len
(
tokenizer_group_pool
.
tokenizer_actors
)
==
len
(
tokenizer_actors
)
assert
tokenizer_group_pool
.
tokenizer_actors
!=
tokenizer_actors
# Fail at first iteration
fail_at
=
[
1
]
tokenizer_group_pool
=
FailingRayTokenizerGroupPool
.
from_config
(
tokenizer_pool_config
,
tokenizer_id
=
"gpt2"
,
enable_lora
=
False
,
max_num_seqs
=
1
,
max_input_length
=
None
,
fail_at
=
fail_at
)
# We should fail after re-initialization.
with
pytest
.
raises
(
RuntimeError
):
await
tokenizer_group_pool
.
encode_async
(
request_id
=
"1"
,
prompt
=
"prompt"
,
lora_request
=
None
)
# check_health should raise the same thing
with
pytest
.
raises
(
RuntimeError
):
tokenizer_group_pool
.
check_health
()
# Ensure that non-ActorDiedErrors are still propagated correctly and do not
# cause a re-initialization.
fail_at
=
[]
tokenizer_group_pool
=
FailingRayTokenizerGroupPool
.
from_config
(
tokenizer_pool_config
,
tokenizer_id
=
"gpt2"
,
enable_lora
=
False
,
max_num_seqs
=
1
,
max_input_length
=
2
,
fail_at
=
fail_at
)
tokenizer_actors
=
tokenizer_group_pool
.
tokenizer_actors
.
copy
()
# Prompt too long error
with
pytest
.
raises
(
ValueError
):
await
tokenizer_group_pool
.
encode_async
(
request_id
=
"1"
,
prompt
=
"prompt"
*
100
,
lora_request
=
None
)
await
tokenizer_group_pool
.
encode_async
(
request_id
=
"1"
,
prompt
=
"prompt"
,
lora_request
=
None
)
# Actors should stay the same.
assert
tokenizer_group_pool
.
tokenizer_actors
==
tokenizer_actors
tests/tracing/__init__.py
0 → 100644
View file @
705f6a35
tests/tracing/test_tracing.py
0 → 100644
View file @
705f6a35
import
os
import
threading
from
concurrent
import
futures
from
typing
import
Callable
,
Dict
,
Iterable
,
Literal
import
grpc
import
pytest
from
opentelemetry.proto.collector.trace.v1.trace_service_pb2
import
(
ExportTraceServiceResponse
)
from
opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc
import
(
TraceServiceServicer
,
add_TraceServiceServicer_to_server
)
from
opentelemetry.proto.common.v1.common_pb2
import
AnyValue
,
KeyValue
from
opentelemetry.sdk.environment_variables
import
(
OTEL_EXPORTER_OTLP_TRACES_INSECURE
)
from
vllm
import
LLM
,
SamplingParams
from
vllm.tracing
import
SpanAttributes
FAKE_TRACE_SERVER_ADDRESS
=
"localhost:4317"
FieldName
=
Literal
[
'bool_value'
,
'string_value'
,
'int_value'
,
'double_value'
,
'array_value'
]
def
decode_value
(
value
:
AnyValue
):
field_decoders
:
Dict
[
FieldName
,
Callable
]
=
{
"bool_value"
:
(
lambda
v
:
v
.
bool_value
),
"string_value"
:
(
lambda
v
:
v
.
string_value
),
"int_value"
:
(
lambda
v
:
v
.
int_value
),
"double_value"
:
(
lambda
v
:
v
.
double_value
),
"array_value"
:
(
lambda
v
:
[
decode_value
(
item
)
for
item
in
v
.
array_value
.
values
]),
}
for
field
,
decoder
in
field_decoders
.
items
():
if
value
.
HasField
(
field
):
return
decoder
(
value
)
raise
ValueError
(
f
"Couldn't decode value:
{
value
}
"
)
def
decode_attributes
(
attributes
:
Iterable
[
KeyValue
]):
return
{
kv
.
key
:
decode_value
(
kv
.
value
)
for
kv
in
attributes
}
class
FakeTraceService
(
TraceServiceServicer
):
def
__init__
(
self
):
self
.
request
=
None
self
.
evt
=
threading
.
Event
()
def
Export
(
self
,
request
,
context
):
self
.
request
=
request
self
.
evt
.
set
()
return
ExportTraceServiceResponse
()
@
pytest
.
fixture
def
trace_service
():
"""Fixture to set up a fake gRPC trace service"""
server
=
grpc
.
server
(
futures
.
ThreadPoolExecutor
(
max_workers
=
1
))
service
=
FakeTraceService
()
add_TraceServiceServicer_to_server
(
service
,
server
)
server
.
add_insecure_port
(
FAKE_TRACE_SERVER_ADDRESS
)
server
.
start
()
yield
service
server
.
stop
(
None
)
def
test_traces
(
trace_service
):
os
.
environ
[
OTEL_EXPORTER_OTLP_TRACES_INSECURE
]
=
"true"
sampling_params
=
SamplingParams
(
temperature
=
0.01
,
top_p
=
0.1
,
max_tokens
=
256
)
model
=
"facebook/opt-125m"
llm
=
LLM
(
model
=
model
,
otlp_traces_endpoint
=
FAKE_TRACE_SERVER_ADDRESS
,
)
prompts
=
[
"This is a short prompt"
]
outputs
=
llm
.
generate
(
prompts
,
sampling_params
=
sampling_params
)
timeout
=
5
if
not
trace_service
.
evt
.
wait
(
timeout
):
raise
TimeoutError
(
f
"The fake trace service didn't receive a trace within "
f
"the
{
timeout
}
seconds timeout"
)
attributes
=
decode_attributes
(
trace_service
.
request
.
resource_spans
[
0
].
scope_spans
[
0
].
spans
[
0
].
attributes
)
assert
attributes
.
get
(
SpanAttributes
.
LLM_RESPONSE_MODEL
)
==
model
assert
attributes
.
get
(
SpanAttributes
.
LLM_REQUEST_ID
)
==
outputs
[
0
].
request_id
assert
attributes
.
get
(
SpanAttributes
.
LLM_REQUEST_TEMPERATURE
)
==
sampling_params
.
temperature
assert
attributes
.
get
(
SpanAttributes
.
LLM_REQUEST_TOP_P
)
==
sampling_params
.
top_p
assert
attributes
.
get
(
SpanAttributes
.
LLM_REQUEST_MAX_TOKENS
)
==
sampling_params
.
max_tokens
assert
attributes
.
get
(
SpanAttributes
.
LLM_REQUEST_BEST_OF
)
==
sampling_params
.
best_of
assert
attributes
.
get
(
SpanAttributes
.
LLM_REQUEST_N
)
==
sampling_params
.
n
assert
attributes
.
get
(
SpanAttributes
.
LLM_USAGE_PROMPT_TOKENS
)
==
len
(
outputs
[
0
].
prompt_token_ids
)
completion_tokens
=
sum
(
len
(
o
.
token_ids
)
for
o
in
outputs
[
0
].
outputs
)
assert
attributes
.
get
(
SpanAttributes
.
LLM_USAGE_COMPLETION_TOKENS
)
==
completion_tokens
metrics
=
outputs
[
0
].
metrics
assert
attributes
.
get
(
SpanAttributes
.
LLM_LATENCY_TIME_IN_QUEUE
)
==
metrics
.
time_in_queue
ttft
=
metrics
.
first_token_time
-
metrics
.
arrival_time
assert
attributes
.
get
(
SpanAttributes
.
LLM_LATENCY_TIME_TO_FIRST_TOKEN
)
==
ttft
e2e_time
=
metrics
.
finished_time
-
metrics
.
arrival_time
assert
attributes
.
get
(
SpanAttributes
.
LLM_LATENCY_E2E
)
==
e2e_time
tests/utils.py
View file @
705f6a35
...
@@ -4,57 +4,120 @@ import sys
...
@@ -4,57 +4,120 @@ import sys
import
time
import
time
import
warnings
import
warnings
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
List
import
openai
import
ray
import
ray
import
requests
import
requests
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
)
init_distributed_environment
)
from
vllm.utils
import
get_open_port
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
from
vllm.utils
import
FlexibleArgumentParser
,
get_open_port
,
is_hip
if
is_hip
():
from
amdsmi
import
(
amdsmi_get_gpu_vram_usage
,
amdsmi_get_processor_handles
,
amdsmi_init
,
amdsmi_shut_down
)
@
contextmanager
def
_nvml
():
try
:
amdsmi_init
()
yield
finally
:
amdsmi_shut_down
()
else
:
from
pynvml
import
(
nvmlDeviceGetHandleByIndex
,
nvmlDeviceGetMemoryInfo
,
nvmlInit
,
nvmlShutdown
)
@
contextmanager
def
_nvml
():
try
:
nvmlInit
()
yield
finally
:
nvmlShutdown
()
VLLM_PATH
=
Path
(
__file__
).
parent
.
parent
"""Path to root of the vLLM repository."""
class
RemoteOpenAIServer
:
DUMMY_API_KEY
=
"token-abc123"
# vLLM's OpenAI server does not need API key
MAX_SERVER_START_WAIT_S
=
600
# wait for server to start for 60 seconds
# Path to root of repository so that utilities can be imported by ray workers
def
__init__
(
self
,
cli_args
:
List
[
str
],
*
,
auto_port
:
bool
=
True
)
->
None
:
VLLM_PATH
=
os
.
path
.
abspath
(
os
.
path
.
join
(
__file__
,
os
.
pardir
,
os
.
pardir
))
if
auto_port
:
if
"-p"
in
cli_args
or
"--port"
in
cli_args
:
raise
ValueError
(
"You have manually specified the port"
"when `auto_port=True`."
)
cli_args
=
cli_args
+
[
"--port"
,
str
(
get_open_port
())]
@
ray
.
remote
(
num_gpus
=
1
)
parser
=
FlexibleArgumentParser
(
class
ServerRunner
:
description
=
"vLLM's remote OpenAI server."
)
MAX_SERVER_START_WAIT_S
=
600
# wait for server to start for 60 seconds
parser
=
make_arg_parser
(
parser
)
args
=
parser
.
parse_args
(
cli_args
)
self
.
host
=
str
(
args
.
host
or
'localhost'
)
self
.
port
=
int
(
args
.
port
)
def
__init__
(
self
,
args
):
env
=
os
.
environ
.
copy
()
env
=
os
.
environ
.
copy
()
env
[
"PYTHONUNBUFFERED"
]
=
"1"
# the current process might initialize cuda,
# to be safe, we should use spawn method
env
[
'VLLM_WORKER_MULTIPROC_METHOD'
]
=
'spawn'
self
.
proc
=
subprocess
.
Popen
(
self
.
proc
=
subprocess
.
Popen
(
[
sys
.
executable
,
"-m"
,
"vllm.entrypoints.openai.api_server"
]
+
[
sys
.
executable
,
"-m"
,
"vllm.entrypoints.openai.api_server"
]
+
args
,
cli_
args
,
env
=
env
,
env
=
env
,
stdout
=
sys
.
stdout
,
stdout
=
sys
.
stdout
,
stderr
=
sys
.
stderr
,
stderr
=
sys
.
stderr
)
)
self
.
_wait_for_server
(
url
=
self
.
url_for
(
"health"
),
self
.
_wait_for_server
(
)
timeout
=
self
.
MAX_SERVER_START_WAIT_S
)
def
ready
(
self
):
def
__enter__
(
self
):
return
True
return
self
def
_wait_for_server
(
self
):
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
self
.
proc
.
terminate
()
def
_wait_for_server
(
self
,
*
,
url
:
str
,
timeout
:
float
):
# run health check
# run health check
start
=
time
.
time
()
start
=
time
.
time
()
while
True
:
while
True
:
try
:
try
:
if
requests
.
get
(
if
requests
.
get
(
url
).
status_code
==
200
:
"http://localhost:8000/health"
).
status_code
==
200
:
break
break
except
Exception
as
err
:
except
Exception
as
err
:
if
self
.
proc
.
poll
()
is
not
None
:
result
=
self
.
proc
.
poll
()
if
result
is
not
None
and
result
!=
0
:
raise
RuntimeError
(
"Server exited unexpectedly."
)
from
err
raise
RuntimeError
(
"Server exited unexpectedly."
)
from
err
time
.
sleep
(
0.5
)
time
.
sleep
(
0.5
)
if
time
.
time
()
-
start
>
self
.
MAX_SERVER_START_WAIT_S
:
if
time
.
time
()
-
start
>
timeout
:
raise
RuntimeError
(
raise
RuntimeError
(
"Server failed to start in time."
)
from
err
"Server failed to start in time."
)
from
err
def
__del__
(
self
):
@
property
if
hasattr
(
self
,
"proc"
):
def
url_root
(
self
)
->
str
:
self
.
proc
.
terminate
()
return
f
"http://
{
self
.
host
}
:
{
self
.
port
}
"
def
url_for
(
self
,
*
parts
:
str
)
->
str
:
return
self
.
url_root
+
"/"
+
"/"
.
join
(
parts
)
def
get_client
(
self
):
return
openai
.
OpenAI
(
base_url
=
self
.
url_for
(
"v1"
),
api_key
=
self
.
DUMMY_API_KEY
,
)
def
get_async_client
(
self
):
return
openai
.
AsyncOpenAI
(
base_url
=
self
.
url_for
(
"v1"
),
api_key
=
self
.
DUMMY_API_KEY
,
)
def
init_test_distributed_environment
(
def
init_test_distributed_environment
(
...
@@ -73,13 +136,15 @@ def init_test_distributed_environment(
...
@@ -73,13 +136,15 @@ def init_test_distributed_environment(
ensure_model_parallel_initialized
(
tp_size
,
pp_size
)
ensure_model_parallel_initialized
(
tp_size
,
pp_size
)
def
multi_process_
tensor_
parallel
(
def
multi_process_parallel
(
tp_size
:
int
,
tp_size
:
int
,
pp_size
:
int
,
pp_size
:
int
,
test_target
,
test_target
:
Any
,
)
->
None
:
)
->
None
:
# Using ray helps debugging the error when it failed
# Using ray helps debugging the error when it failed
# as compared to multiprocessing.
# as compared to multiprocessing.
# NOTE: We need to set working_dir for distributed tests,
# otherwise we may get import errors on ray workers
ray
.
init
(
runtime_env
=
{
"working_dir"
:
VLLM_PATH
})
ray
.
init
(
runtime_env
=
{
"working_dir"
:
VLLM_PATH
})
distributed_init_port
=
get_open_port
()
distributed_init_port
=
get_open_port
()
...
@@ -102,3 +167,43 @@ def error_on_warning():
...
@@ -102,3 +167,43 @@ def error_on_warning():
warnings
.
simplefilter
(
"error"
)
warnings
.
simplefilter
(
"error"
)
yield
yield
@
_nvml
()
def
wait_for_gpu_memory_to_clear
(
devices
:
List
[
int
],
threshold_bytes
:
int
,
timeout_s
:
float
=
120
)
->
None
:
# Use nvml instead of pytorch to reduce measurement error from torch cuda
# context.
start_time
=
time
.
time
()
while
True
:
output
:
Dict
[
int
,
str
]
=
{}
output_raw
:
Dict
[
int
,
float
]
=
{}
for
device
in
devices
:
if
is_hip
():
dev_handle
=
amdsmi_get_processor_handles
()[
device
]
mem_info
=
amdsmi_get_gpu_vram_usage
(
dev_handle
)
gb_used
=
mem_info
[
"vram_used"
]
/
2
**
10
else
:
dev_handle
=
nvmlDeviceGetHandleByIndex
(
device
)
mem_info
=
nvmlDeviceGetMemoryInfo
(
dev_handle
)
gb_used
=
mem_info
.
used
/
2
**
30
output_raw
[
device
]
=
gb_used
output
[
device
]
=
f
'
{
gb_used
:.
02
f
}
'
print
(
'gpu memory used (GB): '
,
end
=
''
)
for
k
,
v
in
output
.
items
():
print
(
f
'
{
k
}
=
{
v
}
; '
,
end
=
''
)
print
(
''
)
dur_s
=
time
.
time
()
-
start_time
if
all
(
v
<=
(
threshold_bytes
/
2
**
30
)
for
v
in
output_raw
.
values
()):
print
(
f
'Done waiting for free GPU memory on devices
{
devices
=
}
'
f
'(
{
threshold_bytes
/
2
**
30
=
}
)
{
dur_s
=
:.
02
f
}
'
)
break
if
dur_s
>=
timeout_s
:
raise
ValueError
(
f
'Memory of devices
{
devices
=
}
not free after '
f
'
{
dur_s
=
:.
02
f
}
(
{
threshold_bytes
/
2
**
30
=
}
)'
)
time
.
sleep
(
5
)
Prev
1
…
11
12
13
14
15
16
17
18
19
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment