Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
18bfcdd0
Unverified
Commit
18bfcdd0
authored
Jan 21, 2024
by
Cade Daniel
Committed by
GitHub
Jan 21, 2024
Browse files
[Speculative decoding 2/9] Multi-step worker for draft model (#2424)
parent
71d63ed7
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
658 additions
and
12 deletions
+658
-12
tests/worker/__init__.py
tests/worker/__init__.py
+0
-0
tests/worker/spec_decode/__init__.py
tests/worker/spec_decode/__init__.py
+0
-0
tests/worker/spec_decode/test_multi_step_worker.py
tests/worker/spec_decode/test_multi_step_worker.py
+261
-0
tests/worker/spec_decode/utils.py
tests/worker/spec_decode/utils.py
+177
-0
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+5
-3
vllm/engine/ray_utils.py
vllm/engine/ray_utils.py
+3
-4
vllm/model_executor/parallel_utils/parallel_state.py
vllm/model_executor/parallel_utils/parallel_state.py
+25
-0
vllm/utils.py
vllm/utils.py
+4
-0
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+2
-2
vllm/worker/spec_decode/multi_step_worker.py
vllm/worker/spec_decode/multi_step_worker.py
+178
-0
vllm/worker/worker.py
vllm/worker/worker.py
+3
-3
No files found.
tests/worker/__init__.py
0 → 100644
View file @
18bfcdd0
tests/worker/spec_decode/__init__.py
0 → 100644
View file @
18bfcdd0
tests/worker/spec_decode/test_multi_step_worker.py
0 → 100644
View file @
18bfcdd0
import
torch
import
random
import
pytest
from
unittest.mock
import
MagicMock
from
vllm.worker.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.worker.worker
import
Worker
from
vllm.model_executor.utils
import
set_random_seed
from
.utils
import
(
create_execute_model_data
,
create_worker
,
create_seq_group_metadata_from_prompts
,
zero_kv_cache
,
patch_execute_model_with_seeds
,
assert_logprobs_dict_allclose
)
@
pytest
.
mark
.
parametrize
(
'num_steps'
,
list
(
range
(
1
,
17
)))
def
test_assert_enough_kv_space
(
num_steps
:
int
):
"""Test that the multi step worker checks for sufficient space in the KV
cache. It should throw if it cannot run all the steps.
"""
block_size
=
16
num_gpu_blocks
=
2048
//
block_size
prompts
=
[
list
(
range
(
block_size
*
3
)),
list
(
range
(
block_size
*
2
)),
]
prev_output_tokens
=
[
list
(
range
(
block_size
*
1
)),
list
(
range
(
block_size
*
2
)),
]
final_seq_lens
=
[
len
(
prompt
+
output
)
+
num_steps
for
prompt
,
output
in
zip
(
prompts
,
prev_output_tokens
)
]
inputs
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_seq_lens
,
continuations
=
prev_output_tokens
)
assert_enough_kv_space
=
MultiStepWorker
.
_assert_enough_kv_space
# pylint: disable=protected-access
worker
=
MagicMock
()
worker
.
model_runner
.
block_size
=
block_size
for
seq_group_metadata
in
inputs
:
original_block_tables
=
seq_group_metadata
.
block_tables
# No exception.
assert_enough_kv_space
(
worker
,
inputs
,
num_steps
)
seq_group_metadata
.
block_tables
=
{
seq_id
:
[]
for
seq_id
,
physical_blocks
in
original_block_tables
.
items
()
}
# Expect exception.
with
pytest
.
raises
(
ValueError
,
match
=
'times but found insufficient KV space for'
):
assert_enough_kv_space
(
worker
,
inputs
,
num_steps
)
seq_group_metadata
.
block_tables
=
original_block_tables
@
torch
.
inference_mode
()
def
test_same_output_for_single_step
():
"""Verify the multi step worker produces the same output as the normal
worker for num_steps=1.
"""
seed
=
100
model_name
=
'JackFram/llama-68m'
block_size
=
32
num_gpu_blocks
=
2048
//
block_size
multi_step_worker
=
create_worker
(
MultiStepWorker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
)
worker
=
create_worker
(
Worker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
)
multi_step_worker
.
model_runner
=
worker
.
model_runner
multi_step_worker
.
cache_engine
=
worker
.
cache_engine
num_steps
=
1
prompts
=
[
[
1
,
2
,
3
,
4
,
5
],
[
6
,
7
,
8
,
9
,
10
],
]
final_seq_lens
=
[
len
(
prompt
)
+
num_steps
for
prompt
in
prompts
]
multi_step_execute_model_data
=
create_execute_model_data
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_seq_lens
=
final_seq_lens
))
single_step_execute_model_data
=
create_execute_model_data
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_seq_lens
=
final_seq_lens
))
zero_kv_cache
(
multi_step_worker
.
cache_engine
)
set_random_seed
(
seed
)
actual_output
=
multi_step_worker
.
execute_model_multi_step
(
**
multi_step_execute_model_data
.
to_dict
(),
num_steps
=
num_steps
)
assert
len
(
actual_output
)
==
num_steps
actual_output
=
actual_output
[
0
]
zero_kv_cache
(
worker
.
cache_engine
)
set_random_seed
(
seed
)
expected_output
=
worker
.
execute_model
(
**
single_step_execute_model_data
.
to_dict
(),
)
actual_token_ids
=
[
output
.
samples
[
0
].
output_token
for
output
in
actual_output
]
actual_logprobs
=
[
output
.
samples
[
0
].
logprobs
for
output
in
actual_output
]
expected_token_ids
=
[
output
.
samples
[
0
].
output_token
for
output
in
expected_output
]
expected_logprobs
=
[
output
.
samples
[
0
].
logprobs
for
output
in
expected_output
]
assert
actual_token_ids
==
expected_token_ids
print
(
f
'
{
actual_logprobs
=
}
'
)
print
(
f
'
{
expected_logprobs
=
}
'
)
assert_logprobs_dict_allclose
(
actual_logprobs
,
expected_logprobs
)
@
torch
.
inference_mode
()
def
test_same_output_for_multi_step
():
"""Verify the multi-step worker produces the same output as the normal
worker when num_steps > 1. This test runs the multi-step worker once, and
then runs the worker num_steps times, and compares the output.
"""
seed
=
100
model_name
=
'JackFram/llama-68m'
block_size
=
16
num_gpu_blocks
=
2048
//
block_size
multi_step_worker
=
create_worker
(
MultiStepWorker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
)
worker
=
create_worker
(
Worker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
)
# Make sure we go over the block boundary.
num_steps
=
block_size
+
1
random
.
seed
(
seed
)
prompts
=
[[
random
.
randint
(
0
,
1000
)
for
_
in
range
(
random
.
randint
(
10
,
20
))
]
for
_
in
range
(
10
)]
final_seq_lens
=
[
len
(
prompt
)
+
num_steps
for
prompt
in
prompts
]
rand_seeds
=
list
(
random
.
randint
(
0
,
100
)
for
_
in
range
(
num_steps
))
multi_step_worker
.
execute_model
=
patch_execute_model_with_seeds
(
multi_step_worker
,
rand_seeds
)
worker
.
execute_model
=
patch_execute_model_with_seeds
(
worker
,
rand_seeds
)
continuations
=
[[
1
]
for
_
in
prompts
]
execute_model_data
=
create_execute_model_data
(
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
continuations
=
continuations
,
final_seq_lens
=
final_seq_lens
),
)
# Run multi-step.
zero_kv_cache
(
multi_step_worker
.
cache_engine
)
set_random_seed
(
seed
)
multi_step_output
=
multi_step_worker
.
execute_model_multi_step
(
**
execute_model_data
.
to_dict
(),
num_steps
=
num_steps
)
# Run single-step repeatedly.
zero_kv_cache
(
worker
.
cache_engine
)
single_step_output
=
[]
continuations
=
[[
1
]
for
_
in
prompts
]
set_random_seed
(
seed
)
for
_
in
multi_step_output
:
execute_model_data
=
create_execute_model_data
(
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
continuations
=
continuations
,
final_seq_lens
=
final_seq_lens
))
single_step_output
.
append
(
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
))
# Append output tokens to new sequence data.
for
i
,
seq_group_output
in
enumerate
(
single_step_output
[
-
1
]):
continuations
[
i
].
append
(
seq_group_output
.
samples
[
0
].
output_token
)
# Get token ids and logprobs for comparison.
multi_step_output_logprobs
=
[[]
for
_
in
prompts
]
single_step_output_logprobs
=
[[]
for
_
in
prompts
]
multi_step_output_token_ids
=
[[]
for
_
in
prompts
]
single_step_output_token_ids
=
[[]
for
_
in
prompts
]
for
i
,
_
in
enumerate
(
prompts
):
for
multi_step
,
single_step
in
zip
(
multi_step_output
,
single_step_output
):
multi_step_output_token_ids
[
i
].
append
(
multi_step
[
i
].
samples
[
0
].
output_token
)
single_step_output_token_ids
[
i
].
append
(
single_step
[
i
].
samples
[
0
].
output_token
)
multi_step_output_logprobs
[
i
].
append
(
multi_step
[
i
].
samples
[
0
].
logprobs
)
single_step_output_logprobs
[
i
].
append
(
single_step
[
i
].
samples
[
0
].
logprobs
)
# Print per-sequence token ids
for
i
,
(
multi_step_tokens
,
single_step_tokens
)
in
enumerate
(
zip
(
multi_step_output_token_ids
,
single_step_output_token_ids
)):
print
(
f
'
{
i
=
}
{
multi_step_tokens
=
}
'
)
print
(
f
'
{
i
=
}
{
single_step_tokens
=
}
'
)
print
(
f
'
{
i
=
}
equal
{
multi_step_tokens
==
single_step_tokens
}
'
)
# Assert token ids are equal.
for
multi_step_tokens
,
single_step_tokens
in
zip
(
multi_step_output_token_ids
,
single_step_output_token_ids
):
assert
multi_step_tokens
==
single_step_tokens
# Assert logprobs are equal.
for
multi_step_logprobs
,
single_step_logprobs
in
zip
(
multi_step_output_logprobs
,
single_step_output_logprobs
):
assert_logprobs_dict_allclose
(
multi_step_logprobs
,
single_step_logprobs
)
tests/worker/spec_decode/utils.py
0 → 100644
View file @
18bfcdd0
import
torch
from
typing
import
List
,
Optional
,
Dict
from
vllm.worker.worker
import
Worker
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.sequence
import
SequenceGroupMetadata
,
SequenceData
from
vllm.sampling_params
import
SamplingParams
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.model_executor.utils
import
set_random_seed
from
dataclasses
import
dataclass
,
fields
@
dataclass
class
ExecuteModelData
:
"""Helper data structure which facilitates cleaner tests.
"""
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
blocks_to_swap_in
:
Dict
[
int
,
int
]
blocks_to_swap_out
:
Dict
[
int
,
int
]
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
def
to_dict
(
self
):
return
dict
(
(
field
.
name
,
getattr
(
self
,
field
.
name
))
for
field
in
fields
(
self
))
def
round_up_to_next_block
(
seq_len
:
int
,
block_size
:
int
)
->
int
:
return
(
seq_len
+
block_size
-
1
)
//
block_size
def
create_execute_model_data
(
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_copy
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
)
->
ExecuteModelData
:
if
blocks_to_swap_in
is
None
:
blocks_to_swap_in
=
{}
if
blocks_to_swap_out
is
None
:
blocks_to_swap_out
=
{}
if
blocks_to_copy
is
None
:
blocks_to_copy
=
{}
return
ExecuteModelData
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
)
def
patch_execute_model_with_seeds
(
worker
:
Worker
,
rand_seeds
:
List
[
int
]):
seed_iter
=
iter
(
rand_seeds
)
original_execute_model
=
worker
.
execute_model
def
new_execute_model
(
*
args
,
**
kwargs
):
result
=
original_execute_model
(
*
args
,
**
kwargs
)
set_random_seed
(
next
(
seed_iter
))
return
result
return
new_execute_model
def
zero_kv_cache
(
cache_engine
:
CacheEngine
):
assert
cache_engine
.
gpu_cache
for
key_blocks
,
value_blocks
in
cache_engine
.
gpu_cache
:
key_blocks
.
zero_
()
value_blocks
.
zero_
()
def
create_worker
(
cls
:
type
,
model_name
:
str
,
block_size
:
int
,
num_gpu_blocks
:
int
,
seed
:
int
,
is_driver_worker
:
bool
=
True
,
enforce_eager
:
bool
=
True
):
engine_args
=
EngineArgs
(
model
=
model_name
,
seed
=
seed
,
block_size
=
block_size
,
enforce_eager
=
enforce_eager
,
)
(
model_config
,
cache_config
,
parallel_config
,
scheduler_config
)
=
engine_args
.
create_engine_configs
()
distributed_init_method
=
get_distributed_init_method
(
get_ip
(),
get_open_port
())
worker
=
cls
(
model_config
=
model_config
,
parallel_config
=
parallel_config
,
scheduler_config
=
scheduler_config
,
local_rank
=
0
,
rank
=
0
,
distributed_init_method
=
distributed_init_method
,
is_driver_worker
=
is_driver_worker
,
)
worker
.
init_model
()
worker
.
load_model
()
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
cache_config
.
num_cpu_blocks
=
0
worker
.
init_cache_engine
(
cache_config
)
worker
.
warm_up_model
()
return
worker
def
create_seq_group_metadata_from_prompts
(
prompts
:
List
[
List
[
int
]],
num_gpu_blocks
:
int
,
block_size
:
int
,
final_seq_lens
:
List
[
int
],
continuations
:
Optional
[
List
[
List
[
int
]]]
=
None
,
num_tokens_processed
:
Optional
[
List
[
int
]]
=
None
,
seq_ids
:
Optional
[
List
[
int
]]
=
None
,
)
->
List
[
SequenceGroupMetadata
]:
if
continuations
is
None
:
continuations
=
[[]
for
_
in
prompts
]
if
num_tokens_processed
is
None
:
# Default to 1 token missing from kv cache for generation sequences.
num_tokens_processed
=
[]
for
continuation
,
prompt
in
zip
(
continuations
,
prompts
):
# If prefill, then default to zero tokens processed.
if
not
continuation
:
num_tokens_processed
.
append
(
0
)
else
:
# If generation, then default to all but one tokens processed.
num_tokens_processed
.
append
(
len
(
continuation
)
+
len
(
prompt
)
-
1
)
if
seq_ids
is
None
:
seq_ids
=
list
(
i
for
i
,
_
in
enumerate
(
prompts
))
free_gpu_blocks
=
list
(
range
(
num_gpu_blocks
))
block_allocations
=
{
i
:
[
free_gpu_blocks
.
pop
()
for
_
in
range
(
round_up_to_next_block
(
final_len
,
block_size
))
]
for
i
,
final_len
in
enumerate
(
final_seq_lens
)
}
return
[
SequenceGroupMetadata
(
request_id
=
str
(
i
),
is_prompt
=
len
(
cont_token_ids
)
==
0
,
seq_data
=
{
i
:
SequenceData
(
prompt_token_ids
=
prompt_token_ids
[:]
+
cont_token_ids
[:])
},
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
),
block_tables
=
{
i
:
block_allocations
[
i
][:]},
)
for
i
,
(
prompt_token_ids
,
cont_token_ids
,
num_tokens_saved
)
in
enumerate
(
zip
(
prompts
,
continuations
,
num_tokens_processed
))
]
def
assert_logprobs_dict_allclose
(
actual_logprobs
:
List
[
Dict
[
int
,
float
]],
expected_logprobs
:
List
[
Dict
[
int
,
float
]])
->
None
:
for
single_step_actual_logprobs
,
single_step_expected_logprobs
in
zip
(
actual_logprobs
,
expected_logprobs
):
assert
set
(
single_step_actual_logprobs
.
keys
())
==
set
(
single_step_expected_logprobs
.
keys
())
for
token_id
in
single_step_actual_logprobs
:
actual
=
torch
.
tensor
(
single_step_actual_logprobs
[
token_id
])
expected
=
torch
.
tensor
(
single_step_expected_logprobs
[
token_id
])
assert
torch
.
allclose
(
actual
,
expected
)
vllm/engine/llm_engine.py
View file @
18bfcdd0
...
...
@@ -18,7 +18,7 @@ from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
from
vllm.transformers_utils.tokenizer
import
(
detokenize_incrementally
,
get_tokenizer
)
from
vllm.utils
import
Counter
,
set_cuda_visible_devices
,
get_ip
,
get_open_port
from
vllm.utils
import
Counter
,
set_cuda_visible_devices
,
get_ip
,
get_open_port
,
get_distributed_init_method
if
ray
:
from
ray.util.scheduling_strategies
import
PlacementGroupSchedulingStrategy
...
...
@@ -132,7 +132,8 @@ class LLMEngine:
"Ray is required if parallel_config.world_size > 1."
)
self
.
workers
:
List
[
Worker
]
=
[]
distributed_init_method
=
f
"tcp://
{
get_ip
()
}
:
{
get_open_port
()
}
"
distributed_init_method
=
get_distributed_init_method
(
get_ip
(),
get_open_port
())
self
.
driver_worker
=
Worker
(
self
.
model_config
,
self
.
parallel_config
,
...
...
@@ -207,7 +208,8 @@ class LLMEngine:
for
worker
,
(
node_id
,
_
)
in
zip
(
self
.
workers
,
worker_node_and_gpu_ids
):
worker
.
set_cuda_visible_devices
.
remote
(
node_gpus
[
node_id
])
distributed_init_method
=
f
"tcp://
{
driver_ip
}
:
{
get_open_port
()
}
"
distributed_init_method
=
get_distributed_init_method
(
driver_ip
,
get_open_port
)
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
...
...
vllm/engine/ray_utils.py
View file @
18bfcdd0
...
...
@@ -65,10 +65,9 @@ def initialize_cluster(
the default Ray cluster address.
Returns:
A tuple of (`distributed_init_method`, `placement_group`). The
`distributed_init_method` is the address for initializing the
distributed backend. `placement_group` includes the specification
of the resources for each distributed worker.
An optional `PlacementGroup`. It includes the specification
of the resources for each distributed worker. None if Ray is
not used.
"""
if
parallel_config
.
worker_use_ray
or
engine_use_ray
:
if
ray
is
None
:
...
...
vllm/model_executor/parallel_utils/parallel_state.py
View file @
18bfcdd0
...
...
@@ -83,6 +83,31 @@ def initialize_model_parallel(
_PIPELINE_GLOBAL_RANKS
=
ranks
def
ensure_model_parallel_initialized
(
tensor_model_parallel_size
:
int
,
pipeline_model_parallel_size
:
int
,
)
->
None
:
"""Helper to initialize model parallel groups if they are not initialized,
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
values if the model parallel groups are initialized.
"""
if
not
model_parallel_is_initialized
():
initialize_model_parallel
(
tensor_model_parallel_size
,
pipeline_model_parallel_size
)
return
assert
(
get_tensor_model_parallel_world_size
()
==
tensor_model_parallel_size
),
(
"tensor parallel group already initialized, but of unexpected size: "
f
"
{
get_tensor_model_parallel_world_size
()
=
}
vs. "
f
"
{
tensor_model_parallel_size
=
}
"
)
assert
(
get_pipeline_model_parallel_world_size
(
)
==
pipeline_model_parallel_size
),
(
"pipeline parallel group already initialized, but of unexpected size: "
f
"
{
get_pipeline_model_parallel_world_size
()
=
}
vs. "
f
"
{
pipeline_model_parallel_size
=
}
"
)
def
model_parallel_is_initialized
():
"""Check if tensor and pipeline parallel groups are initialized."""
return
(
_TENSOR_MODEL_PARALLEL_GROUP
is
not
None
...
...
vllm/utils.py
View file @
18bfcdd0
...
...
@@ -65,6 +65,10 @@ def get_ip() -> str:
return
s
.
getsockname
()[
0
]
def
get_distributed_init_method
(
ip
:
str
,
port
:
int
)
->
str
:
return
f
"tcp://
{
ip
}
:
{
port
}
"
def
get_open_port
()
->
int
:
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
""
,
0
))
...
...
vllm/worker/model_runner.py
View file @
18bfcdd0
...
...
@@ -277,8 +277,8 @@ class ModelRunner:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
block_tables
=
torch
.
tensor
(
input_block_tables
,
device
=
"cuda"
)
else
:
max_block_table_len
=
(
max
_context_len
+
self
.
block_size
-
1
)
//
self
.
block_size
max_block_table_len
=
max
(
len
(
block_table
)
for
block_table
in
block_tables
)
block_tables
=
_make_tensor_with_pad
(
block_tables
,
max_len
=
max_block_table_len
,
...
...
vllm/worker/spec_decode/multi_step_worker.py
0 → 100644
View file @
18bfcdd0
from
typing
import
List
,
Dict
import
copy
import
torch
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.worker.worker
import
Worker
class
MultiStepWorker
(
Worker
):
"""The MultiStepWorker is equivalent to a Worker except that it allows
multiple forward passes in a single call, assuming the scheduler has
allocated enough space to store the additional KV. This reduces overhead
by invoking the scheduler less.
The MultiStepWorker does not support cache swap operations, or beam search.
Cache swap operations do not require large modifications. On the other hand,
beam search requires memory allocations during sequence forks and thus
requires more thought for MultiStepWorker support.
"""
@
torch
.
inference_mode
()
def
execute_model_multi_step
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
num_steps
:
int
,
)
->
List
[
SamplerOutput
]:
"""Run the model forward pass num_steps times. Returns the list of
sampler output, one per model forward pass.
"""
self
.
_raise_if_unsupported
(
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
)
# Shallow copy input data so modifications (such as appending tokens)
# do not cause side-effects.
copied_seq_group_metadata_list
=
self
.
_shallow_copy_inputs
(
seq_group_metadata_list
)
# Assert enough KV space for num_steps tokens per sequence.
self
.
_assert_enough_kv_space
(
seq_group_metadata_list
,
num_steps
)
# Run model num_steps times.
model_outputs
=
[]
for
_
in
range
(
num_steps
):
model_output
=
super
().
execute_model
(
seq_group_metadata_list
=
copied_seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
)
self
.
_append_new_tokens
(
model_output
,
copied_seq_group_metadata_list
)
model_outputs
.
append
(
model_output
)
return
model_outputs
def
_append_new_tokens
(
self
,
model_output
:
SamplerOutput
,
seq_group_metadata_list
:
SequenceGroupMetadata
)
->
None
:
"""Given model output from a single run, append the tokens to the
sequences. This is normally done outside of the worker, but it is
required if the worker is to perform multiple forward passes.
"""
for
seq_group_metadata
,
sequence_group_outputs
in
zip
(
seq_group_metadata_list
,
model_output
):
seq_group_metadata
.
is_prompt
=
False
for
seq_output
in
sequence_group_outputs
.
samples
:
# NOTE: Beam search is not supported, so we can assume that
# parent_seq_id == seq_id.
seq
=
seq_group_metadata
.
seq_data
[
seq_output
.
parent_seq_id
]
token_id
=
seq_output
.
output_token
token_logprob
=
seq_output
.
logprobs
[
token_id
]
seq
.
append_token_id
(
token_id
,
token_logprob
)
def
_shallow_copy_inputs
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
)
->
List
[
SequenceGroupMetadata
]:
"""Copy input data structures to remove side-effects when input data
structures are shared with other modules.
The multi-step worker must be able to append tokens to sequences after
a forward pass. This necessitates modification of the data structures
used by the worker. Since these data structures are shared with other
parts of vLLM, like the scheduler, we must take care not to introduce
unexpected side-effects.
When Ray is used to orchestrate worker processes (such as when the
tensor-parallel degree is >1), this is not a problem because the input
datastructures will be serialized and created anew in the worker
process.
However, when Ray is not used to orchestrate the worker processes (such
as when the tensor-parallel degree is 1), this is a problem. We avoid
the problem by shallow-copying the input datastructures (specifically,
the parts that will change in multiple steps).
"""
# Shallow-copy the list of SequenceGroupMetadata. This allows us to
# append tokens and change is_prompt without external side-effects.
new_seq_group_metadata_list
=
[]
for
old_seq_group_metadata
in
seq_group_metadata_list
:
# We must shallow-copy seq_group_metadata as is_prompt could change.
seq_group_metadata
=
copy
.
copy
(
old_seq_group_metadata
)
new_seq_group_metadata_list
.
append
(
seq_group_metadata
)
# We must shallow-copy seq_data as we will append token ids
new_seq_data
=
{}
for
seq_id
,
old_seq_data
in
seq_group_metadata
.
seq_data
.
items
():
new_seq_data
[
seq_id
]
=
copy
.
copy
(
old_seq_data
)
new_seq_data
[
seq_id
].
output_token_ids
=
old_seq_data
.
output_token_ids
[:]
seq_group_metadata
.
seq_data
=
new_seq_data
return
new_seq_group_metadata_list
def
_assert_enough_kv_space
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
num_steps
:
int
)
->
None
:
"""Assert there are enough physical blocks per sequence to store the
current KV plus additional KV from num_steps tokens.
"""
assert
self
.
model_runner
.
block_size
is
not
None
for
seq_group_metadata
in
seq_group_metadata_list
:
# Only one seq_id is guaranteed because there is no beam search.
seq_id
=
list
(
seq_group_metadata
.
seq_data
.
keys
())[
0
]
seq
=
seq_group_metadata
.
seq_data
[
seq_id
]
# After num_steps, the seq len will be the current seq len
# plus one token per step.
final_seq_len
=
seq
.
get_len
()
+
num_steps
# We will have final_seq_len - 1 KV because vLLM saves KV for a
# token in the iteration after the token was generated.
required_num_kv_slots
=
final_seq_len
-
1
# The allocated number of kv slots is the number of allocated blocks
# times the number of slots of block.
number_physical_blocks
=
len
(
seq_group_metadata
.
block_tables
[
seq_id
])
allocated_kv_slots
=
(
number_physical_blocks
*
self
.
model_runner
.
block_size
)
if
required_num_kv_slots
>
allocated_kv_slots
:
request_id
=
seq_group_metadata
.
request_id
raise
ValueError
(
"The worker attempted to run "
f
"
{
num_steps
}
times but found insufficient KV space for "
f
"
{
request_id
=
}
{
seq_id
=
}
. (
{
allocated_kv_slots
=
}
"
f
"
{
required_num_kv_slots
=
}
)."
)
def
_raise_if_unsupported
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
)
->
None
:
"""MultiStepWorker does not yet implement support for cache swap
operations or beam search.
"""
if
any
([
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
]):
raise
NotImplementedError
(
"MultiStepWorker does not support cache operations"
)
if
any
(
len
(
seq_group_metadata
.
seq_data
.
keys
())
!=
1
for
seq_group_metadata
in
seq_group_metadata_list
):
raise
NotImplementedError
(
"MultiStepWorker does not support beam search."
)
vllm/worker/worker.py
View file @
18bfcdd0
...
...
@@ -11,7 +11,7 @@ from vllm.model_executor import set_random_seed
from
vllm.model_executor.parallel_utils.communication_op
import
(
broadcast_tensor_dict
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
initializ
e_model_parallel
)
ensur
e_model_parallel
_initialized
)
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.model_runner
import
ModelRunner
...
...
@@ -227,8 +227,8 @@ def _init_distributed_environment(
# A small all_reduce for warmup.
torch
.
distributed
.
all_reduce
(
torch
.
zeros
(
1
).
cuda
())
initializ
e_model_parallel
(
parallel_config
.
tensor_parallel_size
,
parallel_config
.
pipeline_parallel_size
)
ensur
e_model_parallel
_initialized
(
parallel_config
.
tensor_parallel_size
,
parallel_config
.
pipeline_parallel_size
)
def
_check_if_gpu_supports_dtype
(
torch_dtype
:
torch
.
dtype
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment