Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f4f921b7
Unverified
Commit
f4f921b7
authored
Apr 29, 2024
by
youkaichao
Committed by
GitHub
Apr 29, 2024
Browse files
[Core][Distributed] use cpu group to broadcast metadata in cpu (#4444)
parent
ac5ccf01
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
63 additions
and
35 deletions
+63
-35
tests/tensorizer_loader/tensorize_vllm_model_for_testing.py
tests/tensorizer_loader/tensorize_vllm_model_for_testing.py
+3
-3
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+12
-11
vllm/distributed/communication_op.py
vllm/distributed/communication_op.py
+48
-21
No files found.
tests/tensorizer_loader/tensorize_vllm_model_for_testing.py
View file @
f4f921b7
...
@@ -6,14 +6,14 @@ import uuid
...
@@ -6,14 +6,14 @@ import uuid
from
functools
import
partial
from
functools
import
partial
from
typing
import
Type
from
typing
import
Type
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
tensorizer
import
(
DecryptionParams
,
EncryptionParams
,
TensorDeserializer
,
from
tensorizer
import
(
DecryptionParams
,
EncryptionParams
,
TensorDeserializer
,
TensorSerializer
,
stream_io
)
TensorSerializer
,
stream_io
)
from
tensorizer.utils
import
convert_bytes
,
get_mem_usage
,
no_init_or_tensor
from
tensorizer.utils
import
convert_bytes
,
get_mem_usage
,
no_init_or_tensor
from
transformers
import
AutoConfig
,
PretrainedConfig
from
transformers
import
AutoConfig
,
PretrainedConfig
from
vllm.distributed
import
initialize_model_parallel
from
vllm.distributed
import
(
init_distributed_environment
,
initialize_model_parallel
)
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerArgs
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerArgs
...
@@ -226,7 +226,7 @@ model_name = model_ref.split("/")[1]
...
@@ -226,7 +226,7 @@ model_name = model_ref.split("/")[1]
os
.
environ
[
"MASTER_ADDR"
]
=
"127.0.0.1"
os
.
environ
[
"MASTER_ADDR"
]
=
"127.0.0.1"
os
.
environ
[
"MASTER_PORT"
]
=
"8080"
os
.
environ
[
"MASTER_PORT"
]
=
"8080"
torch
.
distributed
.
init_process_group
(
world_size
=
1
,
rank
=
0
)
init_
distributed
_environment
(
world_size
=
1
,
rank
=
0
,
local_
rank
=
0
)
initialize_model_parallel
()
initialize_model_parallel
()
keyfile
=
args
.
keyfile
if
args
.
keyfile
else
None
keyfile
=
args
.
keyfile
if
args
.
keyfile
else
None
...
...
tests/worker/test_model_runner.py
View file @
f4f921b7
...
@@ -2,8 +2,10 @@ import pytest
...
@@ -2,8 +2,10 @@ import pytest
import
torch
import
torch
from
vllm.config
import
ModelConfig
,
SchedulerConfig
from
vllm.config
import
ModelConfig
,
SchedulerConfig
from
vllm.distributed.parallel_state
import
init_distributed_environment
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
get_open_port
from
vllm.worker.model_runner
import
ModelRunner
,
_get_graph_batch_size
from
vllm.worker.model_runner
import
ModelRunner
,
_get_graph_batch_size
...
@@ -249,19 +251,18 @@ def test_empty_seq_group():
...
@@ -249,19 +251,18 @@ def test_empty_seq_group():
assert
len
(
return_prompt_lens
)
==
0
assert
len
(
return_prompt_lens
)
==
0
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
list
(
range
(
2
,
128
)))
@
pytest
.
fixture
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
,
False
])
def
distributed_init
():
def
test_hybrid_batches
(
batch_size
,
enforce_eager
,
monkeypatch
):
init_distributed_environment
(
world_size
=
1
,
def
get_world_size
(
group
=
None
):
rank
=
0
,
return
1
distributed_init_method
=
f
"tcp://127.0.0.1:
{
get_open_port
()
}
"
,
local_rank
=
0
)
def
mock_get_process_group_ranks
(
group
=
None
):
return
[
0
]
monkeypatch
.
setattr
(
torch
.
distributed
,
"get_world_size"
,
get_world_size
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
list
(
range
(
2
,
128
))
)
monkeypatch
.
setattr
(
torch
.
distributed
,
"get_process_group_ranks"
,
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
,
False
])
mock_get_process_group_ranks
)
def
test_hybrid_batches
(
batch_size
,
enforce_eager
,
distributed_init
):
model_config
=
ModelConfig
(
model_config
=
ModelConfig
(
"facebook/opt-125m"
,
"facebook/opt-125m"
,
...
...
vllm/distributed/communication_op.py
View file @
f4f921b7
...
@@ -4,7 +4,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union
...
@@ -4,7 +4,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import
torch
import
torch
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
.parallel_state
import
(
get_tensor_model_parallel_group
,
from
.parallel_state
import
(
get_cpu_world_group
,
get_tensor_model_parallel_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
is_pynccl_enabled_for_all_reduce
)
is_pynccl_enabled_for_all_reduce
)
...
@@ -140,13 +141,46 @@ def broadcast_object_list(obj_list: List[Any],
...
@@ -140,13 +141,46 @@ def broadcast_object_list(obj_list: List[Any],
TensorMetadata
=
namedtuple
(
"TensorMetadata"
,
[
"dtype"
,
"size"
])
TensorMetadata
=
namedtuple
(
"TensorMetadata"
,
[
"dtype"
,
"size"
])
def
_split_tensor_dict
(
tensor_dict
:
Dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]]
)
->
Tuple
[
List
[
Tuple
[
str
,
Any
]],
List
[
torch
.
Tensor
]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata.
2. A list of tensors.
"""
metadata_list
=
[]
tensor_list
=
[]
for
key
,
value
in
tensor_dict
.
items
():
if
isinstance
(
value
,
torch
.
Tensor
):
# Note(youkaichao): currently this only supports broadcasting
# tensors on cuda. In the future, we can add device as a field in
# TensorMetadata to support broadcasting tensors on different
# devices.
assert
value
.
is_cuda
,
(
f
"Tensor
{
key
}
:
{
value
}
is not on cuda. Currently we only "
f
"support broadcasting tensors on cuda."
)
metadata_list
.
append
((
key
,
TensorMetadata
(
value
.
dtype
,
value
.
size
())))
tensor_list
.
append
(
value
)
else
:
metadata_list
.
append
((
key
,
value
))
return
metadata_list
,
tensor_list
def
broadcast_tensor_dict
(
def
broadcast_tensor_dict
(
tensor_dict
:
Optional
[
Dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]]]
=
None
,
tensor_dict
:
Optional
[
Dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]]]
=
None
,
src
:
int
=
0
,
src
:
int
=
0
,
group
:
Optional
[
ProcessGroup
]
=
None
,
group
:
Optional
[
ProcessGroup
]
=
None
,
metadata_group
:
Optional
[
ProcessGroup
]
=
None
)
->
Optional
[
Dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]]]:
)
->
Optional
[
Dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]]]:
"""Broadcast the input tensor dictionary."""
"""Broadcast the input tensor dictionary.
`group` is used to broadcast the tensors, while `metadata_group` is used
to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
dtypes).
"""
group
=
group
or
torch
.
distributed
.
group
.
WORLD
group
=
group
or
torch
.
distributed
.
group
.
WORLD
metadata_group
=
metadata_group
or
get_cpu_world_group
()
ranks
=
torch
.
distributed
.
get_process_group_ranks
(
group
)
ranks
=
torch
.
distributed
.
get_process_group_ranks
(
group
)
assert
src
in
ranks
,
f
"Invalid src rank (
{
src
}
)"
assert
src
in
ranks
,
f
"Invalid src rank (
{
src
}
)"
...
@@ -161,22 +195,15 @@ def broadcast_tensor_dict(
...
@@ -161,22 +195,15 @@ def broadcast_tensor_dict(
assert
isinstance
(
assert
isinstance
(
tensor_dict
,
tensor_dict
,
dict
),
(
f
"Expecting a dictionary, got
{
type
(
tensor_dict
)
}
"
)
dict
),
(
f
"Expecting a dictionary, got
{
type
(
tensor_dict
)
}
"
)
for
key
,
value
in
tensor_dict
.
items
():
metadata_list
,
tensor_list
=
_split_tensor_dict
(
tensor_dict
)
if
isinstance
(
value
,
torch
.
Tensor
):
# `metadata_list` lives in CPU memory.
assert
value
.
is_cuda
,
(
# `broadcast_object_list` involves serialization and deserialization,
f
"Tensor
{
key
}
:
{
value
}
is not on cuda. Currently we only "
# all happening on CPU. Therefore, we can use the CPU group.
f
"support broadcasting tensors on cuda."
)
metadata_list
.
append
(
(
key
,
TensorMetadata
(
value
.
dtype
,
value
.
size
())))
else
:
metadata_list
.
append
((
key
,
value
))
torch
.
distributed
.
broadcast_object_list
([
metadata_list
],
torch
.
distributed
.
broadcast_object_list
([
metadata_list
],
src
=
src
,
src
=
src
,
group
=
group
)
group
=
metadata_
group
)
async_handles
=
[]
async_handles
=
[]
for
key
,
value
in
metadata_list
:
for
tensor
in
tensor_list
:
if
isinstance
(
value
,
TensorMetadata
):
tensor
=
tensor_dict
[
key
]
async_handles
.
append
(
async_handles
.
append
(
torch
.
distributed
.
broadcast
(
tensor
,
torch
.
distributed
.
broadcast
(
tensor
,
src
=
src
,
src
=
src
,
...
@@ -189,7 +216,7 @@ def broadcast_tensor_dict(
...
@@ -189,7 +216,7 @@ def broadcast_tensor_dict(
recv_metadata_list
=
[
None
]
recv_metadata_list
=
[
None
]
torch
.
distributed
.
broadcast_object_list
(
recv_metadata_list
,
torch
.
distributed
.
broadcast_object_list
(
recv_metadata_list
,
src
=
src
,
src
=
src
,
group
=
group
)
group
=
metadata_
group
)
assert
recv_metadata_list
[
0
]
is
not
None
assert
recv_metadata_list
[
0
]
is
not
None
tensor_dict
=
{}
tensor_dict
=
{}
async_handles
=
[]
async_handles
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment