Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
5b23c3f2
Unverified
Commit
5b23c3f2
authored
Jan 20, 2024
by
Junda Chen
Committed by
GitHub
Jan 20, 2024
Browse files
Add `group` as an argument in broadcast ops (#2522)
parent
00efdc84
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
17 deletions
+36
-17
vllm/model_executor/parallel_utils/communication_op.py
vllm/model_executor/parallel_utils/communication_op.py
+36
-17
No files found.
vllm/model_executor/parallel_utils/communication_op.py
View file @
5b23c3f2
from
collections
import
namedtuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
torch.distributed
import
ProcessGroup
import
torch
from
vllm.model_executor.parallel_utils.parallel_state
import
(
...
...
@@ -86,47 +88,59 @@ def tensor_model_parallel_gather(input_: torch.Tensor,
return
output_tensor
def
broadcast
(
input_
:
torch
.
Tensor
,
src
:
int
=
0
):
def
broadcast
(
input_
:
torch
.
Tensor
,
src
:
int
=
0
,
group
:
Optional
[
ProcessGroup
]
=
None
):
"""Broadcast the input tensor."""
world_size
=
torch
.
distributed
.
get_world_size
()
assert
0
<=
src
<
world_size
,
f
"Invalid src rank (
{
src
}
)"
group
=
group
or
torch
.
distributed
.
group
.
WORLD
ranks
=
torch
.
distributed
.
get_process_group_ranks
(
group
)
assert
src
in
ranks
,
f
"Invalid src rank (
{
src
}
)"
# Bypass the function if we are using only 1 GPU.
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
group
)
if
world_size
==
1
:
return
input_
# Broadcast.
torch
.
distributed
.
broadcast
(
input_
,
src
=
src
)
torch
.
distributed
.
broadcast
(
input_
,
src
=
src
,
group
=
group
)
return
input_
def
broadcast_object_list
(
obj_list
:
List
[
Any
],
src
:
int
=
0
):
def
broadcast_object_list
(
obj_list
:
List
[
Any
],
src
:
int
=
0
,
group
:
Optional
[
ProcessGroup
]
=
None
):
"""Broadcast the input object list."""
world_size
=
torch
.
distributed
.
get_world_size
()
assert
0
<=
src
<
world_size
,
f
"Invalid src rank (
{
src
}
)"
group
=
group
or
torch
.
distributed
.
group
.
WORLD
ranks
=
torch
.
distributed
.
get_process_group_ranks
(
group
)
assert
src
in
ranks
,
f
"Invalid src rank (
{
src
}
)"
# Bypass the function if we are using only 1 GPU.
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
group
)
if
world_size
==
1
:
return
obj_list
# Broadcast.
torch
.
distributed
.
broadcast_object_list
(
obj_list
,
src
=
src
)
torch
.
distributed
.
broadcast_object_list
(
obj_list
,
src
=
src
,
group
=
group
)
return
obj_list
TensorMetadata
=
namedtuple
(
"TensorMetadata"
,
[
"dtype"
,
"size"
])
def
broadcast_tensor_dict
(
tensor_dict
:
Optional
[
Dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]]]
=
None
,
src
:
int
=
0
)
->
Dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]]:
def
broadcast_tensor_dict
(
tensor_dict
:
Optional
[
Dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]]]
=
None
,
src
:
int
=
0
,
group
:
Optional
[
ProcessGroup
]
=
None
,
)
->
Dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]]:
"""Broadcast the input tensor dictionary."""
rank
=
torch
.
distributed
.
g
et_rank
()
world_size
=
torch
.
distributed
.
get_
world_size
(
)
assert
0
<=
src
<
world_size
,
f
"Invalid src rank (
{
src
}
)"
group
=
group
or
torch
.
distributed
.
g
roup
.
WORLD
ranks
=
torch
.
distributed
.
get_
process_group_ranks
(
group
)
assert
src
in
ranks
,
f
"Invalid src rank (
{
src
}
)"
# Bypass the function if we are using only 1 GPU.
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
group
)
if
world_size
==
1
:
return
tensor_dict
rank
=
torch
.
distributed
.
get_rank
()
if
rank
==
src
:
assert
isinstance
(
tensor_dict
,
...
...
@@ -141,14 +155,18 @@ def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor,
(
key
,
TensorMetadata
(
value
.
dtype
,
value
.
size
())))
else
:
metadata_list
.
append
((
key
,
value
))
torch
.
distributed
.
broadcast_object_list
([
metadata_list
],
src
=
src
)
torch
.
distributed
.
broadcast_object_list
([
metadata_list
],
src
=
src
,
group
=
group
)
for
key
,
value
in
metadata_list
:
if
isinstance
(
value
,
TensorMetadata
):
tensor
=
tensor_dict
[
key
]
torch
.
distributed
.
broadcast
(
tensor
,
src
=
src
)
else
:
recv_metadata_list
=
[
None
]
torch
.
distributed
.
broadcast_object_list
(
recv_metadata_list
,
src
=
src
)
torch
.
distributed
.
broadcast_object_list
(
recv_metadata_list
,
src
=
src
,
group
=
group
)
metadata_list
=
recv_metadata_list
[
0
]
tensor_dict
=
{}
async_handles
=
[]
...
...
@@ -159,7 +177,8 @@ def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor,
device
=
"cuda"
)
async_handle
=
torch
.
distributed
.
broadcast
(
tensor
,
src
=
src
,
async_op
=
True
)
async_op
=
True
,
group
=
group
)
async_handles
.
append
(
async_handle
)
tensor_dict
[
key
]
=
tensor
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment