Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f4f921b7
Unverified
Commit
f4f921b7
authored
Apr 29, 2024
by
youkaichao
Committed by
GitHub
Apr 29, 2024
Browse files
[Core][Distributed] use cpu group to broadcast metadata in cpu (#4444)
parent
ac5ccf01
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
63 additions
and
35 deletions
+63
-35
tests/tensorizer_loader/tensorize_vllm_model_for_testing.py
tests/tensorizer_loader/tensorize_vllm_model_for_testing.py
+3
-3
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+12
-11
vllm/distributed/communication_op.py
vllm/distributed/communication_op.py
+48
-21
No files found.
tests/tensorizer_loader/tensorize_vllm_model_for_testing.py
View file @
f4f921b7
...
...
@@ -6,14 +6,14 @@ import uuid
from
functools
import
partial
from
typing
import
Type
import
torch
import
torch.nn
as
nn
from
tensorizer
import
(
DecryptionParams
,
EncryptionParams
,
TensorDeserializer
,
TensorSerializer
,
stream_io
)
from
tensorizer.utils
import
convert_bytes
,
get_mem_usage
,
no_init_or_tensor
from
transformers
import
AutoConfig
,
PretrainedConfig
from
vllm.distributed
import
initialize_model_parallel
from
vllm.distributed
import
(
init_distributed_environment
,
initialize_model_parallel
)
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerArgs
...
...
@@ -226,7 +226,7 @@ model_name = model_ref.split("/")[1]
os
.
environ
[
"MASTER_ADDR"
]
=
"127.0.0.1"
os
.
environ
[
"MASTER_PORT"
]
=
"8080"
torch
.
distributed
.
init_process_group
(
world_size
=
1
,
rank
=
0
)
init_
distributed
_environment
(
world_size
=
1
,
rank
=
0
,
local_
rank
=
0
)
initialize_model_parallel
()
keyfile
=
args
.
keyfile
if
args
.
keyfile
else
None
...
...
tests/worker/test_model_runner.py
View file @
f4f921b7
...
...
@@ -2,8 +2,10 @@ import pytest
import
torch
from
vllm.config
import
ModelConfig
,
SchedulerConfig
from
vllm.distributed.parallel_state
import
init_distributed_environment
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
get_open_port
from
vllm.worker.model_runner
import
ModelRunner
,
_get_graph_batch_size
...
...
@@ -249,19 +251,18 @@ def test_empty_seq_group():
assert
len
(
return_prompt_lens
)
==
0
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
list
(
range
(
2
,
128
)))
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
,
False
])
def
test_hybrid_batches
(
batch_size
,
enforce_eager
,
monkeypatch
):
def
get_world_size
(
group
=
None
):
return
1
@
pytest
.
fixture
def
distributed_init
():
init_distributed_environment
(
world_size
=
1
,
rank
=
0
,
distributed_init_method
=
f
"tcp://127.0.0.1:
{
get_open_port
()
}
"
,
local_rank
=
0
)
def
mock_get_process_group_ranks
(
group
=
None
):
return
[
0
]
monkeypatch
.
setattr
(
torch
.
distributed
,
"get_world_size"
,
get_world_size
)
monkeypatch
.
setattr
(
torch
.
distributed
,
"get_process_group_ranks"
,
mock_get_process_group_ranks
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
list
(
range
(
2
,
128
))
)
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
,
False
])
def
test_hybrid_batches
(
batch_size
,
enforce_eager
,
distributed_init
):
model_config
=
ModelConfig
(
"facebook/opt-125m"
,
...
...
vllm/distributed/communication_op.py
View file @
f4f921b7
...
...
@@ -4,7 +4,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import
torch
from
torch.distributed
import
ProcessGroup
from
.parallel_state
import
(
get_tensor_model_parallel_group
,
from
.parallel_state
import
(
get_cpu_world_group
,
get_tensor_model_parallel_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
is_pynccl_enabled_for_all_reduce
)
...
...
@@ -140,13 +141,46 @@ def broadcast_object_list(obj_list: List[Any],
TensorMetadata
=
namedtuple
(
"TensorMetadata"
,
[
"dtype"
,
"size"
])
def
_split_tensor_dict
(
tensor_dict
:
Dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]]
)
->
Tuple
[
List
[
Tuple
[
str
,
Any
]],
List
[
torch
.
Tensor
]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata.
2. A list of tensors.
"""
metadata_list
=
[]
tensor_list
=
[]
for
key
,
value
in
tensor_dict
.
items
():
if
isinstance
(
value
,
torch
.
Tensor
):
# Note(youkaichao): currently this only supports broadcasting
# tensors on cuda. In the future, we can add device as a field in
# TensorMetadata to support broadcasting tensors on different
# devices.
assert
value
.
is_cuda
,
(
f
"Tensor
{
key
}
:
{
value
}
is not on cuda. Currently we only "
f
"support broadcasting tensors on cuda."
)
metadata_list
.
append
((
key
,
TensorMetadata
(
value
.
dtype
,
value
.
size
())))
tensor_list
.
append
(
value
)
else
:
metadata_list
.
append
((
key
,
value
))
return
metadata_list
,
tensor_list
def
broadcast_tensor_dict
(
tensor_dict
:
Optional
[
Dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]]]
=
None
,
src
:
int
=
0
,
group
:
Optional
[
ProcessGroup
]
=
None
,
metadata_group
:
Optional
[
ProcessGroup
]
=
None
)
->
Optional
[
Dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]]]:
"""Broadcast the input tensor dictionary."""
"""Broadcast the input tensor dictionary.
`group` is used to broadcast the tensors, while `metadata_group` is used
to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
dtypes).
"""
group
=
group
or
torch
.
distributed
.
group
.
WORLD
metadata_group
=
metadata_group
or
get_cpu_world_group
()
ranks
=
torch
.
distributed
.
get_process_group_ranks
(
group
)
assert
src
in
ranks
,
f
"Invalid src rank (
{
src
}
)"
...
...
@@ -161,22 +195,15 @@ def broadcast_tensor_dict(
assert
isinstance
(
tensor_dict
,
dict
),
(
f
"Expecting a dictionary, got
{
type
(
tensor_dict
)
}
"
)
for
key
,
value
in
tensor_dict
.
items
():
if
isinstance
(
value
,
torch
.
Tensor
):
assert
value
.
is_cuda
,
(
f
"Tensor
{
key
}
:
{
value
}
is not on cuda. Currently we only "
f
"support broadcasting tensors on cuda."
)
metadata_list
.
append
(
(
key
,
TensorMetadata
(
value
.
dtype
,
value
.
size
())))
else
:
metadata_list
.
append
((
key
,
value
))
metadata_list
,
tensor_list
=
_split_tensor_dict
(
tensor_dict
)
# `metadata_list` lives in CPU memory.
# `broadcast_object_list` involves serialization and deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
torch
.
distributed
.
broadcast_object_list
([
metadata_list
],
src
=
src
,
group
=
group
)
group
=
metadata_
group
)
async_handles
=
[]
for
key
,
value
in
metadata_list
:
if
isinstance
(
value
,
TensorMetadata
):
tensor
=
tensor_dict
[
key
]
for
tensor
in
tensor_list
:
async_handles
.
append
(
torch
.
distributed
.
broadcast
(
tensor
,
src
=
src
,
...
...
@@ -189,7 +216,7 @@ def broadcast_tensor_dict(
recv_metadata_list
=
[
None
]
torch
.
distributed
.
broadcast_object_list
(
recv_metadata_list
,
src
=
src
,
group
=
group
)
group
=
metadata_
group
)
assert
recv_metadata_list
[
0
]
is
not
None
tensor_dict
=
{}
async_handles
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment