Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
ef9b636e
Unverified
Commit
ef9b636e
authored
Jan 19, 2024
by
Zhuohan Li
Committed by
GitHub
Jan 19, 2024
Browse files
Simplify broadcast logic for control messages (#2501)
parent
2709c000
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
146 additions
and
129 deletions
+146
-129
tests/distributed/test_comm_ops.py
tests/distributed/test_comm_ops.py
+33
-2
vllm/model_executor/parallel_utils/communication_op.py
vllm/model_executor/parallel_utils/communication_op.py
+68
-5
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+30
-108
vllm/worker/worker.py
vllm/worker/worker.py
+15
-14
No files found.
tests/distributed/test_comm_ops.py
View file @
ef9b636e
...
@@ -11,6 +11,7 @@ from vllm.utils import get_open_port
...
@@ -11,6 +11,7 @@ from vllm.utils import get_open_port
from
vllm.model_executor.parallel_utils.communication_op
import
(
from
vllm.model_executor.parallel_utils.communication_op
import
(
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_gather
,
broadcast_tensor_dict
,
)
)
from
vllm.worker.worker
import
_init_distributed_environment
from
vllm.worker.worker
import
_init_distributed_environment
...
@@ -64,11 +65,41 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
...
@@ -64,11 +65,41 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
assert
torch
.
allclose
(
t
,
expected
)
assert
torch
.
allclose
(
t
,
expected
)
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
def
broadcast_tensor_dict_test_worker
(
tensor_parallel_size
:
int
,
rank
:
int
,
distributed_init_port
:
str
):
init_test_distributed_environment
(
1
,
tensor_parallel_size
,
rank
,
distributed_init_port
)
test_dict
=
{
"a"
:
torch
.
arange
(
8
,
dtype
=
torch
.
float32
,
device
=
"cuda"
),
"b"
:
torch
.
arange
(
16
,
dtype
=
torch
.
int8
,
device
=
"cuda"
),
"c"
:
"test"
,
"d"
:
[
1
,
2
,
3
],
"e"
:
{
"a"
:
1
,
"b"
:
2
},
}
if
rank
==
0
:
broadcast_tensor_dict
(
test_dict
,
src
=
0
)
else
:
recv_dict
=
broadcast_tensor_dict
(
src
=
0
)
assert
len
(
recv_dict
)
==
len
(
test_dict
)
assert
torch
.
allclose
(
recv_dict
[
"a"
],
test_dict
[
"a"
])
assert
torch
.
allclose
(
recv_dict
[
"b"
],
test_dict
[
"b"
])
assert
recv_dict
[
"c"
]
==
test_dict
[
"c"
]
assert
recv_dict
[
"d"
]
==
test_dict
[
"d"
]
assert
recv_dict
[
"e"
]
==
test_dict
[
"e"
]
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
reason
=
"Need at least 2 GPUs to run the test."
)
reason
=
"Need at least 2 GPUs to run the test."
)
@
pytest
.
mark
.
parametrize
(
"tensor_parallel_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"tensor_parallel_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"test_target"
,
@
pytest
.
mark
.
parametrize
(
"test_target"
,
[
[
all_reduce_test_worker
,
all_gather_test_worker
])
all_reduce_test_worker
,
all_gather_test_worker
,
broadcast_tensor_dict_test_worker
])
def
test_multi_process_tensor_parallel
(
tensor_parallel_size
,
test_target
):
def
test_multi_process_tensor_parallel
(
tensor_parallel_size
,
test_target
):
# Using ray helps debugging the error when it failed
# Using ray helps debugging the error when it failed
# as compared to multiprocessing.
# as compared to multiprocessing.
...
...
vllm/model_executor/parallel_utils/communication_op.py
View file @
ef9b636e
from
collections
import
namedtuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
import
torch
import
torch
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
...
@@ -7,7 +10,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
...
@@ -7,7 +10,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
)
)
def
tensor_model_parallel_all_reduce
(
input_
)
:
def
tensor_model_parallel_all_reduce
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""All-reduce the input tensor across model parallel group.
"""All-reduce the input tensor across model parallel group.
NOTE: This operation is applied in-place on the input tensor.
NOTE: This operation is applied in-place on the input tensor.
...
@@ -21,7 +24,8 @@ def tensor_model_parallel_all_reduce(input_):
...
@@ -21,7 +24,8 @@ def tensor_model_parallel_all_reduce(input_):
return
input_
return
input_
def
tensor_model_parallel_all_gather
(
input_
,
dim
=-
1
):
def
tensor_model_parallel_all_gather
(
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
"""All-gather the input tensor across model parallel group."""
"""All-gather the input tensor across model parallel group."""
world_size
=
get_tensor_model_parallel_world_size
()
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
# Bypass the function if we are using only 1 GPU.
...
@@ -48,7 +52,9 @@ def tensor_model_parallel_all_gather(input_, dim=-1):
...
@@ -48,7 +52,9 @@ def tensor_model_parallel_all_gather(input_, dim=-1):
return
output_tensor
return
output_tensor
def
tensor_model_parallel_gather
(
input_
,
dst
=
0
,
dim
=-
1
):
def
tensor_model_parallel_gather
(
input_
:
torch
.
Tensor
,
dst
:
int
=
0
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
"""Gather the input tensor across model parallel group.
"""Gather the input tensor across model parallel group.
NOTE: We assume that the input tensor is on the same device across
NOTE: We assume that the input tensor is on the same device across
...
@@ -80,7 +86,7 @@ def tensor_model_parallel_gather(input_, dst=0, dim=-1):
...
@@ -80,7 +86,7 @@ def tensor_model_parallel_gather(input_, dst=0, dim=-1):
return
output_tensor
return
output_tensor
def
broadcast
(
input_
,
src
=
0
):
def
broadcast
(
input_
:
torch
.
Tensor
,
src
:
int
=
0
):
"""Broadcast the input tensor."""
"""Broadcast the input tensor."""
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
assert
0
<=
src
<
world_size
,
f
"Invalid src rank (
{
src
}
)"
assert
0
<=
src
<
world_size
,
f
"Invalid src rank (
{
src
}
)"
...
@@ -93,7 +99,7 @@ def broadcast(input_, src=0):
...
@@ -93,7 +99,7 @@ def broadcast(input_, src=0):
return
input_
return
input_
def
broadcast_object_list
(
obj_list
,
src
=
0
):
def
broadcast_object_list
(
obj_list
:
List
[
Any
],
src
:
int
=
0
):
"""Broadcast the input object list."""
"""Broadcast the input object list."""
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
assert
0
<=
src
<
world_size
,
f
"Invalid src rank (
{
src
}
)"
assert
0
<=
src
<
world_size
,
f
"Invalid src rank (
{
src
}
)"
...
@@ -104,3 +110,60 @@ def broadcast_object_list(obj_list, src=0):
...
@@ -104,3 +110,60 @@ def broadcast_object_list(obj_list, src=0):
# Broadcast.
# Broadcast.
torch
.
distributed
.
broadcast_object_list
(
obj_list
,
src
=
src
)
torch
.
distributed
.
broadcast_object_list
(
obj_list
,
src
=
src
)
return
obj_list
return
obj_list
TensorMetadata
=
namedtuple
(
"TensorMetadata"
,
[
"dtype"
,
"size"
])
def
broadcast_tensor_dict
(
tensor_dict
:
Optional
[
Dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]]]
=
None
,
src
:
int
=
0
)
->
Dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]]:
"""Broadcast the input tensor dictionary."""
rank
=
torch
.
distributed
.
get_rank
()
world_size
=
torch
.
distributed
.
get_world_size
()
assert
0
<=
src
<
world_size
,
f
"Invalid src rank (
{
src
}
)"
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
tensor_dict
if
rank
==
src
:
assert
isinstance
(
tensor_dict
,
dict
),
(
f
"Expecting a dictionary, got
{
type
(
tensor_dict
)
}
"
)
metadata_list
=
[]
for
key
,
value
in
tensor_dict
.
items
():
if
isinstance
(
value
,
torch
.
Tensor
):
assert
value
.
is_cuda
,
(
f
"Tensor
{
key
}
:
{
value
}
is not on cuda. Currently we only "
f
"support broadcasting tensors on cuda."
)
metadata_list
.
append
(
(
key
,
TensorMetadata
(
value
.
dtype
,
value
.
size
())))
else
:
metadata_list
.
append
((
key
,
value
))
torch
.
distributed
.
broadcast_object_list
([
metadata_list
],
src
=
src
)
for
key
,
value
in
metadata_list
:
if
isinstance
(
value
,
TensorMetadata
):
tensor
=
tensor_dict
[
key
]
torch
.
distributed
.
broadcast
(
tensor
,
src
=
src
)
else
:
recv_metadata_list
=
[
None
]
torch
.
distributed
.
broadcast_object_list
(
recv_metadata_list
,
src
=
src
)
metadata_list
=
recv_metadata_list
[
0
]
tensor_dict
=
{}
async_handles
=
[]
for
key
,
value
in
metadata_list
:
if
isinstance
(
value
,
TensorMetadata
):
tensor
=
torch
.
empty
(
value
.
size
,
dtype
=
value
.
dtype
,
device
=
"cuda"
)
async_handle
=
torch
.
distributed
.
broadcast
(
tensor
,
src
=
src
,
async_op
=
True
)
async_handles
.
append
(
async_handle
)
tensor_dict
[
key
]
=
tensor
else
:
tensor_dict
[
key
]
=
value
for
async_handle
in
async_handles
:
async_handle
.
wait
()
return
tensor_dict
vllm/worker/model_runner.py
View file @
ef9b636e
...
@@ -9,7 +9,7 @@ from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
...
@@ -9,7 +9,7 @@ from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
get_model
,
InputMetadata
,
SamplingMetadata
from
vllm.model_executor
import
get_model
,
InputMetadata
,
SamplingMetadata
from
vllm.model_executor.parallel_utils.communication_op
import
(
from
vllm.model_executor.parallel_utils.communication_op
import
(
broadcast
,
broadcast_object_lis
t
)
broadcast
_tensor_dic
t
)
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
in_wsl
from
vllm.utils
import
in_wsl
...
@@ -393,121 +393,43 @@ class ModelRunner:
...
@@ -393,121 +393,43 @@ class ModelRunner:
prompt_lens
,
prompt_lens
,
subquery_lens
)
subquery_lens
)
def
get_size_or_none
(
x
:
Optional
[
torch
.
Tensor
]):
# Broadcast the metadata.
return
x
.
size
()
if
x
is
not
None
else
None
metadata_dict
=
{
"input_tokens"
:
input_tokens
,
# Broadcast the input data. For input tensors, we first broadcast
"input_positions"
:
input_positions
,
# its shape and then broadcast the tensor to avoid high
"is_prompt"
:
input_metadata
.
is_prompt
,
# serialization cost.
"slot_mapping"
:
input_metadata
.
slot_mapping
,
py_data
=
{
"prompt_lens"
:
input_metadata
.
prompt_lens
,
"input_tokens_size"
:
"max_seq_len"
:
input_metadata
.
max_seq_len
,
input_tokens
.
size
(),
"start_loc"
:
input_metadata
.
start_loc
,
"input_positions_size"
:
"max_context_len"
:
input_metadata
.
max_context_len
,
input_positions
.
size
(),
"context_lens"
:
input_metadata
.
context_lens
,
"is_prompt"
:
"block_tables"
:
input_metadata
.
block_tables
,
input_metadata
.
is_prompt
,
"use_cuda_graph"
:
input_metadata
.
use_cuda_graph
,
"slot_mapping_size"
:
"selected_token_indices"
:
get_size_or_none
(
input_metadata
.
slot_mapping
),
sampling_metadata
.
selected_token_indices
,
"prompt_lens_size"
:
get_size_or_none
(
input_metadata
.
prompt_lens
),
"max_seq_len"
:
input_metadata
.
max_seq_len
,
"start_loc_size"
:
get_size_or_none
(
input_metadata
.
start_loc
),
"max_context_len"
:
input_metadata
.
max_context_len
,
"context_lens_size"
:
get_size_or_none
(
input_metadata
.
context_lens
),
"block_tables_size"
:
get_size_or_none
(
input_metadata
.
block_tables
),
"use_cuda_graph"
:
input_metadata
.
use_cuda_graph
,
"selected_token_indices_size"
:
sampling_metadata
.
selected_token_indices
.
size
(),
}
}
broadcast_object_list
([
py_data
],
src
=
0
)
broadcast_tensor_dict
(
metadata_dict
,
src
=
0
)
# TODO(zhuohan): Combine the broadcasts or set async_op=True.
broadcast
(
input_tokens
,
src
=
0
)
broadcast
(
input_positions
,
src
=
0
)
if
input_metadata
.
slot_mapping
is
not
None
:
broadcast
(
input_metadata
.
slot_mapping
,
src
=
0
)
if
input_metadata
.
prompt_lens
is
not
None
:
broadcast
(
input_metadata
.
prompt_lens
,
src
=
0
)
if
input_metadata
.
start_loc
is
not
None
:
broadcast
(
input_metadata
.
start_loc
,
src
=
0
)
if
input_metadata
.
context_lens
is
not
None
:
broadcast
(
input_metadata
.
context_lens
,
src
=
0
)
if
input_metadata
.
block_tables
is
not
None
:
broadcast
(
input_metadata
.
block_tables
,
src
=
0
)
broadcast
(
sampling_metadata
.
selected_token_indices
,
src
=
0
)
else
:
else
:
receving_list
=
[
None
]
metadata_dict
=
broadcast_tensor_dict
(
src
=
0
)
broadcast_object_list
(
receving_list
,
src
=
0
)
input_tokens
=
metadata_dict
[
"input_tokens"
]
py_data
=
receving_list
[
0
]
input_positions
=
metadata_dict
[
"input_positions"
]
input_tokens
=
torch
.
empty
(
*
py_data
[
"input_tokens_size"
],
dtype
=
torch
.
long
,
device
=
"cuda"
)
broadcast
(
input_tokens
,
src
=
0
)
input_positions
=
torch
.
empty
(
*
py_data
[
"input_positions_size"
],
dtype
=
torch
.
long
,
device
=
"cuda"
)
broadcast
(
input_positions
,
src
=
0
)
if
py_data
[
"slot_mapping_size"
]
is
not
None
:
slot_mapping
=
torch
.
empty
(
*
py_data
[
"slot_mapping_size"
],
dtype
=
torch
.
long
,
device
=
"cuda"
)
broadcast
(
slot_mapping
,
src
=
0
)
else
:
slot_mapping
=
None
if
py_data
[
"prompt_lens_size"
]
is
not
None
:
prompt_lens
=
torch
.
empty
(
*
py_data
[
"prompt_lens_size"
],
dtype
=
torch
.
long
,
device
=
"cuda"
)
broadcast
(
prompt_lens
,
src
=
0
)
else
:
prompt_lens
=
None
if
py_data
[
"start_loc_size"
]
is
not
None
:
start_loc
=
torch
.
empty
(
*
py_data
[
"start_loc_size"
],
dtype
=
torch
.
long
,
device
=
"cuda"
)
broadcast
(
start_loc
,
src
=
0
)
else
:
start_loc
=
None
if
py_data
[
"context_lens_size"
]
is
not
None
:
context_lens
=
torch
.
empty
(
*
py_data
[
"context_lens_size"
],
dtype
=
torch
.
int
,
device
=
"cuda"
)
broadcast
(
context_lens
,
src
=
0
)
else
:
context_lens
=
None
if
py_data
[
"block_tables_size"
]
is
not
None
:
block_tables
=
torch
.
empty
(
*
py_data
[
"block_tables_size"
],
dtype
=
torch
.
int
,
device
=
"cuda"
)
broadcast
(
block_tables
,
src
=
0
)
else
:
block_tables
=
None
selected_token_indices
=
torch
.
empty
(
*
py_data
[
"selected_token_indices_size"
],
dtype
=
torch
.
long
,
device
=
"cuda"
)
broadcast
(
selected_token_indices
,
src
=
0
)
input_metadata
=
InputMetadata
(
input_metadata
=
InputMetadata
(
is_prompt
=
py_data
[
"is_prompt"
],
is_prompt
=
metadata_dict
[
"is_prompt"
],
slot_mapping
=
slot_mapping
,
slot_mapping
=
metadata_dict
[
"
slot_mapping
"
]
,
prompt_lens
=
prompt_lens
,
prompt_lens
=
metadata_dict
[
"
prompt_lens
"
]
,
max_seq_len
=
py_data
[
"max_seq_len"
],
max_seq_len
=
metadata_dict
[
"max_seq_len"
],
start_loc
=
start_loc
,
start_loc
=
metadata_dict
[
"
start_loc
"
]
,
max_context_len
=
py_data
[
"max_context_len"
],
max_context_len
=
metadata_dict
[
"max_context_len"
],
context_lens
=
context_lens
,
context_lens
=
metadata_dict
[
"
context_lens
"
]
,
block_tables
=
block_tables
,
block_tables
=
metadata_dict
[
"
block_tables
"
]
,
use_cuda_graph
=
py_data
[
"use_cuda_graph"
],
use_cuda_graph
=
metadata_dict
[
"use_cuda_graph"
],
)
)
sampling_metadata
=
SamplingMetadata
(
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
None
,
seq_groups
=
None
,
seq_data
=
None
,
seq_data
=
None
,
prompt_lens
=
None
,
prompt_lens
=
None
,
selected_token_indices
=
selected_token_indices
,
selected_token_indices
=
metadata_dict
[
"
selected_token_indices
"
]
,
categorized_sample_indices
=
None
,
categorized_sample_indices
=
None
,
perform_sampling
=
False
,
perform_sampling
=
False
,
)
)
...
...
vllm/worker/worker.py
View file @
ef9b636e
...
@@ -9,7 +9,7 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
...
@@ -9,7 +9,7 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig
)
SchedulerConfig
)
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor.parallel_utils.communication_op
import
(
from
vllm.model_executor.parallel_utils.communication_op
import
(
broadcast_
object_lis
t
)
broadcast_
tensor_dic
t
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
initialize_model_parallel
)
initialize_model_parallel
)
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
...
@@ -175,20 +175,21 @@ class Worker:
...
@@ -175,20 +175,21 @@ class Worker:
assert
blocks_to_swap_in
is
not
None
assert
blocks_to_swap_in
is
not
None
assert
blocks_to_swap_out
is
not
None
assert
blocks_to_swap_out
is
not
None
assert
blocks_to_copy
is
not
None
assert
blocks_to_copy
is
not
None
block_swapping_info
=
[
data
=
{
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
"num_seq_groups"
:
num_seq_groups
,
]
"blocks_to_swap_in"
:
blocks_to_swap_in
,
broadcast_object_list
([
num_seq_groups
]
+
block_swapping_info
,
"blocks_to_swap_out"
:
blocks_to_swap_out
,
src
=
0
)
"blocks_to_copy"
:
blocks_to_copy
,
}
broadcast_tensor_dict
(
data
,
src
=
0
)
else
:
else
:
# num_seq_groups, blocks_to_swap_in, blocks_to_swap_out,
data
=
broadcast_tensor_dict
(
src
=
0
)
# blocks_to_copy (4 elements)
num_seq_groups
=
data
[
"num_seq_groups"
]
recv_data
=
[
None
]
*
4
blocks_to_swap_in
=
data
[
"blocks_to_swap_in"
]
broadcast_object_list
(
recv_data
,
src
=
0
)
blocks_to_swap_out
=
data
[
"blocks_to_swap_out"
]
num_seq_groups
=
recv_data
[
0
]
blocks_to_copy
=
data
[
"blocks_to_copy"
]
block_swapping_info
=
recv_data
[
1
:]
self
.
cache_swap
(
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
)
self
.
cache_swap
(
*
block_swapping_info
)
# If there is no input, we don't need to execute the model.
# If there is no input, we don't need to execute the model.
if
num_seq_groups
==
0
:
if
num_seq_groups
==
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment