Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
SIYIXNI
vllm
Commits
ef9b636e
"git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "53932182a8b3a1480f101034d10d974502580792"
Unverified
Commit
ef9b636e
authored
Jan 19, 2024
by
Zhuohan Li
Committed by
GitHub
Jan 19, 2024
Browse files
Simplify broadcast logic for control messages (#2501)
parent
2709c000
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
146 additions
and
129 deletions
+146
-129
tests/distributed/test_comm_ops.py
tests/distributed/test_comm_ops.py
+33
-2
vllm/model_executor/parallel_utils/communication_op.py
vllm/model_executor/parallel_utils/communication_op.py
+68
-5
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+30
-108
vllm/worker/worker.py
vllm/worker/worker.py
+15
-14
No files found.
tests/distributed/test_comm_ops.py
View file @
ef9b636e
...
@@ -11,6 +11,7 @@ from vllm.utils import get_open_port
...
@@ -11,6 +11,7 @@ from vllm.utils import get_open_port
from
vllm.model_executor.parallel_utils.communication_op
import
(
from
vllm.model_executor.parallel_utils.communication_op
import
(
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_gather
,
broadcast_tensor_dict
,
)
)
from
vllm.worker.worker
import
_init_distributed_environment
from
vllm.worker.worker
import
_init_distributed_environment
...
@@ -64,11 +65,41 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
...
@@ -64,11 +65,41 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
assert
torch
.
allclose
(
t
,
expected
)
assert
torch
.
allclose
(
t
,
expected
)
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
def
broadcast_tensor_dict_test_worker
(
tensor_parallel_size
:
int
,
rank
:
int
,
distributed_init_port
:
str
):
init_test_distributed_environment
(
1
,
tensor_parallel_size
,
rank
,
distributed_init_port
)
test_dict
=
{
"a"
:
torch
.
arange
(
8
,
dtype
=
torch
.
float32
,
device
=
"cuda"
),
"b"
:
torch
.
arange
(
16
,
dtype
=
torch
.
int8
,
device
=
"cuda"
),
"c"
:
"test"
,
"d"
:
[
1
,
2
,
3
],
"e"
:
{
"a"
:
1
,
"b"
:
2
},
}
if
rank
==
0
:
broadcast_tensor_dict
(
test_dict
,
src
=
0
)
else
:
recv_dict
=
broadcast_tensor_dict
(
src
=
0
)
assert
len
(
recv_dict
)
==
len
(
test_dict
)
assert
torch
.
allclose
(
recv_dict
[
"a"
],
test_dict
[
"a"
])
assert
torch
.
allclose
(
recv_dict
[
"b"
],
test_dict
[
"b"
])
assert
recv_dict
[
"c"
]
==
test_dict
[
"c"
]
assert
recv_dict
[
"d"
]
==
test_dict
[
"d"
]
assert
recv_dict
[
"e"
]
==
test_dict
[
"e"
]
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
reason
=
"Need at least 2 GPUs to run the test."
)
reason
=
"Need at least 2 GPUs to run the test."
)
@
pytest
.
mark
.
parametrize
(
"tensor_parallel_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"tensor_parallel_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"test_target"
,
@
pytest
.
mark
.
parametrize
(
"test_target"
,
[
[
all_reduce_test_worker
,
all_gather_test_worker
])
all_reduce_test_worker
,
all_gather_test_worker
,
broadcast_tensor_dict_test_worker
])
def
test_multi_process_tensor_parallel
(
tensor_parallel_size
,
test_target
):
def
test_multi_process_tensor_parallel
(
tensor_parallel_size
,
test_target
):
# Using ray helps debugging the error when it failed
# Using ray helps debugging the error when it failed
# as compared to multiprocessing.
# as compared to multiprocessing.
...
...
vllm/model_executor/parallel_utils/communication_op.py
View file @
ef9b636e
from
collections
import
namedtuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
import
torch
import
torch
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
...
@@ -7,7 +10,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
...
@@ -7,7 +10,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
)
)
def
tensor_model_parallel_all_reduce
(
input_
)
:
def
tensor_model_parallel_all_reduce
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""All-reduce the input tensor across model parallel group.
"""All-reduce the input tensor across model parallel group.
NOTE: This operation is applied in-place on the input tensor.
NOTE: This operation is applied in-place on the input tensor.
...
@@ -21,7 +24,8 @@ def tensor_model_parallel_all_reduce(input_):
...
@@ -21,7 +24,8 @@ def tensor_model_parallel_all_reduce(input_):
return
input_
return
input_
def
tensor_model_parallel_all_gather
(
input_
,
dim
=-
1
):
def
tensor_model_parallel_all_gather
(
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
"""All-gather the input tensor across model parallel group."""
"""All-gather the input tensor across model parallel group."""
world_size
=
get_tensor_model_parallel_world_size
()
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
# Bypass the function if we are using only 1 GPU.
...
@@ -48,7 +52,9 @@ def tensor_model_parallel_all_gather(input_, dim=-1):
...
@@ -48,7 +52,9 @@ def tensor_model_parallel_all_gather(input_, dim=-1):
return
output_tensor
return
output_tensor
def
tensor_model_parallel_gather
(
input_
,
dst
=
0
,
dim
=-
1
):
def
tensor_model_parallel_gather
(
input_
:
torch
.
Tensor
,
dst
:
int
=
0
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
"""Gather the input tensor across model parallel group.
"""Gather the input tensor across model parallel group.
NOTE: We assume that the input tensor is on the same device across
NOTE: We assume that the input tensor is on the same device across
...
@@ -80,7 +86,7 @@ def tensor_model_parallel_gather(input_, dst=0, dim=-1):
...
@@ -80,7 +86,7 @@ def tensor_model_parallel_gather(input_, dst=0, dim=-1):
return
output_tensor
return
output_tensor
def
broadcast
(
input_
,
src
=
0
):
def
broadcast
(
input_
:
torch
.
Tensor
,
src
:
int
=
0
):
"""Broadcast the input tensor."""
"""Broadcast the input tensor."""
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
assert
0
<=
src
<
world_size
,
f
"Invalid src rank (
{
src
}
)"
assert
0
<=
src
<
world_size
,
f
"Invalid src rank (
{
src
}
)"
...
@@ -93,7 +99,7 @@ def broadcast(input_, src=0):
...
@@ -93,7 +99,7 @@ def broadcast(input_, src=0):
return
input_
return
input_
def
broadcast_object_list
(
obj_list
,
src
=
0
):
def
broadcast_object_list
(
obj_list
:
List
[
Any
],
src
:
int
=
0
):
"""Broadcast the input object list."""
"""Broadcast the input object list."""
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
assert
0
<=
src
<
world_size
,
f
"Invalid src rank (
{
src
}
)"
assert
0
<=
src
<
world_size
,
f
"Invalid src rank (
{
src
}
)"
...
@@ -104,3 +110,60 @@ def broadcast_object_list(obj_list, src=0):
...
@@ -104,3 +110,60 @@ def broadcast_object_list(obj_list, src=0):
# Broadcast.
# Broadcast.
torch
.
distributed
.
broadcast_object_list
(
obj_list
,
src
=
src
)
torch
.
distributed
.
broadcast_object_list
(
obj_list
,
src
=
src
)
return
obj_list
return
obj_list
TensorMetadata
=
namedtuple
(
"TensorMetadata"
,
[
"dtype"
,
"size"
])
def
broadcast_tensor_dict
(
tensor_dict
:
Optional
[
Dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]]]
=
None
,
src
:
int
=
0
)
->
Dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]]:
"""Broadcast the input tensor dictionary."""
rank
=
torch
.
distributed
.
get_rank
()
world_size
=
torch
.
distributed
.
get_world_size
()
assert
0
<=
src
<
world_size
,
f
"Invalid src rank (
{
src
}
)"
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
tensor_dict
if
rank
==
src
:
assert
isinstance
(
tensor_dict
,
dict
),
(
f
"Expecting a dictionary, got
{
type
(
tensor_dict
)
}
"
)
metadata_list
=
[]
for
key
,
value
in
tensor_dict
.
items
():
if
isinstance
(
value
,
torch
.
Tensor
):
assert
value
.
is_cuda
,
(
f
"Tensor
{
key
}
:
{
value
}
is not on cuda. Currently we only "
f
"support broadcasting tensors on cuda."
)
metadata_list
.
append
(
(
key
,
TensorMetadata
(
value
.
dtype
,
value
.
size
())))
else
:
metadata_list
.
append
((
key
,
value
))
torch
.
distributed
.
broadcast_object_list
([
metadata_list
],
src
=
src
)
for
key
,
value
in
metadata_list
:
if
isinstance
(
value
,
TensorMetadata
):
tensor
=
tensor_dict
[
key
]
torch
.
distributed
.
broadcast
(
tensor
,
src
=
src
)
else
:
recv_metadata_list
=
[
None
]
torch
.
distributed
.
broadcast_object_list
(
recv_metadata_list
,
src
=
src
)
metadata_list
=
recv_metadata_list
[
0
]
tensor_dict
=
{}
async_handles
=
[]
for
key
,
value
in
metadata_list
:
if
isinstance
(
value
,
TensorMetadata
):
tensor
=
torch
.
empty
(
value
.
size
,
dtype
=
value
.
dtype
,
device
=
"cuda"
)
async_handle
=
torch
.
distributed
.
broadcast
(
tensor
,
src
=
src
,
async_op
=
True
)
async_handles
.
append
(
async_handle
)
tensor_dict
[
key
]
=
tensor
else
:
tensor_dict
[
key
]
=
value
for
async_handle
in
async_handles
:
async_handle
.
wait
()
return
tensor_dict
vllm/worker/model_runner.py
View file @
ef9b636e
...
@@ -9,7 +9,7 @@ from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
...
@@ -9,7 +9,7 @@ from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
get_model
,
InputMetadata
,
SamplingMetadata
from
vllm.model_executor
import
get_model
,
InputMetadata
,
SamplingMetadata
from
vllm.model_executor.parallel_utils.communication_op
import
(
from
vllm.model_executor.parallel_utils.communication_op
import
(
broadcast
,
broadcast_object_lis
t
)
broadcast
_tensor_dic
t
)
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
in_wsl
from
vllm.utils
import
in_wsl
...
@@ -393,121 +393,43 @@ class ModelRunner:
...
@@ -393,121 +393,43 @@ class ModelRunner:
prompt_lens
,
prompt_lens
,
subquery_lens
)
subquery_lens
)
def
get_size_or_none
(
x
:
Optional
[
torch
.
Tensor
]):
# Broadcast the metadata.
return
x
.
size
()
if
x
is
not
None
else
None
metadata_dict
=
{
"input_tokens"
:
input_tokens
,
# Broadcast the input data. For input tensors, we first broadcast
"input_positions"
:
input_positions
,
# its shape and then broadcast the tensor to avoid high
"is_prompt"
:
input_metadata
.
is_prompt
,
# serialization cost.
"slot_mapping"
:
input_metadata
.
slot_mapping
,
py_data
=
{
"prompt_lens"
:
input_metadata
.
prompt_lens
,
"input_tokens_size"
:
"max_seq_len"
:
input_metadata
.
max_seq_len
,
input_tokens
.
size
(),
"start_loc"
:
input_metadata
.
start_loc
,
"input_positions_size"
:
"max_context_len"
:
input_metadata
.
max_context_len
,
input_positions
.
size
(),
"context_lens"
:
input_metadata
.
context_lens
,
"is_prompt"
:
"block_tables"
:
input_metadata
.
block_tables
,
input_metadata
.
is_prompt
,
"use_cuda_graph"
:
input_metadata
.
use_cuda_graph
,
"slot_mapping_size"
:
"selected_token_indices"
:
get_size_or_none
(
input_metadata
.
slot_mapping
),
sampling_metadata
.
selected_token_indices
,
"prompt_lens_size"
:
get_size_or_none
(
input_metadata
.
prompt_lens
),
"max_seq_len"
:
input_metadata
.
max_seq_len
,
"start_loc_size"
:
get_size_or_none
(
input_metadata
.
start_loc
),
"max_context_len"
:
input_metadata
.
max_context_len
,
"context_lens_size"
:
get_size_or_none
(
input_metadata
.
context_lens
),
"block_tables_size"
:
get_size_or_none
(
input_metadata
.
block_tables
),
"use_cuda_graph"
:
input_metadata
.
use_cuda_graph
,
"selected_token_indices_size"
:
sampling_metadata
.
selected_token_indices
.
size
(),
}
}
broadcast_object_list
([
py_data
],
src
=
0
)
broadcast_tensor_dict
(
metadata_dict
,
src
=
0
)
# TODO(zhuohan): Combine the broadcasts or set async_op=True.
broadcast
(
input_tokens
,
src
=
0
)
broadcast
(
input_positions
,
src
=
0
)
if
input_metadata
.
slot_mapping
is
not
None
:
broadcast
(
input_metadata
.
slot_mapping
,
src
=
0
)
if
input_metadata
.
prompt_lens
is
not
None
:
broadcast
(
input_metadata
.
prompt_lens
,
src
=
0
)
if
input_metadata
.
start_loc
is
not
None
:
broadcast
(
input_metadata
.
start_loc
,
src
=
0
)
if
input_metadata
.
context_lens
is
not
None
:
broadcast
(
input_metadata
.
context_lens
,
src
=
0
)
if
input_metadata
.
block_tables
is
not
None
:
broadcast
(
input_metadata
.
block_tables
,
src
=
0
)
broadcast
(
sampling_metadata
.
selected_token_indices
,
src
=
0
)
else
:
else
:
receving_list
=
[
None
]
metadata_dict
=
broadcast_tensor_dict
(
src
=
0
)
broadcast_object_list
(
receving_list
,
src
=
0
)
input_tokens
=
metadata_dict
[
"input_tokens"
]
py_data
=
receving_list
[
0
]
input_positions
=
metadata_dict
[
"input_positions"
]
input_tokens
=
torch
.
empty
(
*
py_data
[
"input_tokens_size"
],
dtype
=
torch
.
long
,
device
=
"cuda"
)
broadcast
(
input_tokens
,
src
=
0
)
input_positions
=
torch
.
empty
(
*
py_data
[
"input_positions_size"
],
dtype
=
torch
.
long
,
device
=
"cuda"
)
broadcast
(
input_positions
,
src
=
0
)
if
py_data
[
"slot_mapping_size"
]
is
not
None
:
slot_mapping
=
torch
.
empty
(
*
py_data
[
"slot_mapping_size"
],
dtype
=
torch
.
long
,
device
=
"cuda"
)
broadcast
(
slot_mapping
,
src
=
0
)
else
:
slot_mapping
=
None
if
py_data
[
"prompt_lens_size"
]
is
not
None
:
prompt_lens
=
torch
.
empty
(
*
py_data
[
"prompt_lens_size"
],
dtype
=
torch
.
long
,
device
=
"cuda"
)
broadcast
(
prompt_lens
,
src
=
0
)
else
:
prompt_lens
=
None
if
py_data
[
"start_loc_size"
]
is
not
None
:
start_loc
=
torch
.
empty
(
*
py_data
[
"start_loc_size"
],
dtype
=
torch
.
long
,
device
=
"cuda"
)
broadcast
(
start_loc
,
src
=
0
)
else
:
start_loc
=
None
if
py_data
[
"context_lens_size"
]
is
not
None
:
context_lens
=
torch
.
empty
(
*
py_data
[
"context_lens_size"
],
dtype
=
torch
.
int
,
device
=
"cuda"
)
broadcast
(
context_lens
,
src
=
0
)
else
:
context_lens
=
None
if
py_data
[
"block_tables_size"
]
is
not
None
:
block_tables
=
torch
.
empty
(
*
py_data
[
"block_tables_size"
],
dtype
=
torch
.
int
,
device
=
"cuda"
)
broadcast
(
block_tables
,
src
=
0
)
else
:
block_tables
=
None
selected_token_indices
=
torch
.
empty
(
*
py_data
[
"selected_token_indices_size"
],
dtype
=
torch
.
long
,
device
=
"cuda"
)
broadcast
(
selected_token_indices
,
src
=
0
)
input_metadata
=
InputMetadata
(
input_metadata
=
InputMetadata
(
is_prompt
=
py_data
[
"is_prompt"
],
is_prompt
=
metadata_dict
[
"is_prompt"
],
slot_mapping
=
slot_mapping
,
slot_mapping
=
metadata_dict
[
"
slot_mapping
"
]
,
prompt_lens
=
prompt_lens
,
prompt_lens
=
metadata_dict
[
"
prompt_lens
"
]
,
max_seq_len
=
py_data
[
"max_seq_len"
],
max_seq_len
=
metadata_dict
[
"max_seq_len"
],
start_loc
=
start_loc
,
start_loc
=
metadata_dict
[
"
start_loc
"
]
,
max_context_len
=
py_data
[
"max_context_len"
],
max_context_len
=
metadata_dict
[
"max_context_len"
],
context_lens
=
context_lens
,
context_lens
=
metadata_dict
[
"
context_lens
"
]
,
block_tables
=
block_tables
,
block_tables
=
metadata_dict
[
"
block_tables
"
]
,
use_cuda_graph
=
py_data
[
"use_cuda_graph"
],
use_cuda_graph
=
metadata_dict
[
"use_cuda_graph"
],
)
)
sampling_metadata
=
SamplingMetadata
(
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
None
,
seq_groups
=
None
,
seq_data
=
None
,
seq_data
=
None
,
prompt_lens
=
None
,
prompt_lens
=
None
,
selected_token_indices
=
selected_token_indices
,
selected_token_indices
=
metadata_dict
[
"
selected_token_indices
"
]
,
categorized_sample_indices
=
None
,
categorized_sample_indices
=
None
,
perform_sampling
=
False
,
perform_sampling
=
False
,
)
)
...
...
vllm/worker/worker.py
View file @
ef9b636e
...
@@ -9,7 +9,7 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
...
@@ -9,7 +9,7 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig
)
SchedulerConfig
)
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor.parallel_utils.communication_op
import
(
from
vllm.model_executor.parallel_utils.communication_op
import
(
broadcast_
object_lis
t
)
broadcast_
tensor_dic
t
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
initialize_model_parallel
)
initialize_model_parallel
)
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
...
@@ -175,20 +175,21 @@ class Worker:
...
@@ -175,20 +175,21 @@ class Worker:
assert
blocks_to_swap_in
is
not
None
assert
blocks_to_swap_in
is
not
None
assert
blocks_to_swap_out
is
not
None
assert
blocks_to_swap_out
is
not
None
assert
blocks_to_copy
is
not
None
assert
blocks_to_copy
is
not
None
block_swapping_info
=
[
data
=
{
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
"num_seq_groups"
:
num_seq_groups
,
]
"blocks_to_swap_in"
:
blocks_to_swap_in
,
broadcast_object_list
([
num_seq_groups
]
+
block_swapping_info
,
"blocks_to_swap_out"
:
blocks_to_swap_out
,
src
=
0
)
"blocks_to_copy"
:
blocks_to_copy
,
}
broadcast_tensor_dict
(
data
,
src
=
0
)
else
:
else
:
# num_seq_groups, blocks_to_swap_in, blocks_to_swap_out,
data
=
broadcast_tensor_dict
(
src
=
0
)
# blocks_to_copy (4 elements)
num_seq_groups
=
data
[
"num_seq_groups"
]
recv_data
=
[
None
]
*
4
blocks_to_swap_in
=
data
[
"blocks_to_swap_in"
]
broadcast_object_list
(
recv_data
,
src
=
0
)
blocks_to_swap_out
=
data
[
"blocks_to_swap_out"
]
num_seq_groups
=
recv_data
[
0
]
blocks_to_copy
=
data
[
"blocks_to_copy"
]
block_swapping_info
=
recv_data
[
1
:]
self
.
cache_swap
(
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
)
self
.
cache_swap
(
*
block_swapping_info
)
# If there is no input, we don't need to execute the model.
# If there is no input, we don't need to execute the model.
if
num_seq_groups
==
0
:
if
num_seq_groups
==
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment