Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
99aa4edd
Unverified
Commit
99aa4edd
authored
Sep 16, 2024
by
youkaichao
Committed by
GitHub
Sep 16, 2024
Browse files
[torch.compile] register allreduce operations as custom ops (#8526)
parent
ee2bceaa
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
137 additions
and
50 deletions
+137
-50
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+3
-7
csrc/custom_all_reduce.cu
csrc/custom_all_reduce.cu
+0
-12
csrc/ops.h
csrc/ops.h
+0
-2
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+0
-5
tests/compile/__init__.py
tests/compile/__init__.py
+0
-0
tests/compile/test_full_graph.py
tests/compile/test_full_graph.py
+13
-2
vllm/_custom_ops.py
vllm/_custom_ops.py
+0
-6
vllm/distributed/device_communicators/custom_all_reduce.py
vllm/distributed/device_communicators/custom_all_reduce.py
+19
-2
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+102
-14
No files found.
.buildkite/test-pipeline.yaml
View file @
99aa4edd
...
@@ -163,13 +163,6 @@ steps:
...
@@ -163,13 +163,6 @@ steps:
-
python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
-
python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
-
python3 offline_inference_encoder_decoder.py
-
python3 offline_inference_encoder_decoder.py
-
label
:
torch compile integration test
source_file_dependencies
:
-
vllm/
commands
:
-
pytest -v -s ./compile/test_full_graph.py
-
pytest -v -s ./compile/test_wrapper.py
-
label
:
Prefix Caching Test
# 7min
-
label
:
Prefix Caching Test
# 7min
#mirror_hardwares: [amd]
#mirror_hardwares: [amd]
source_file_dependencies
:
source_file_dependencies
:
...
@@ -348,7 +341,10 @@ steps:
...
@@ -348,7 +341,10 @@ steps:
-
vllm/executor/
-
vllm/executor/
-
vllm/model_executor/models/
-
vllm/model_executor/models/
-
tests/distributed/
-
tests/distributed/
-
vllm/compilation
commands
:
commands
:
-
pytest -v -s ./compile/test_full_graph.py
-
pytest -v -s ./compile/test_wrapper.py
-
VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
-
VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
-
TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus
-
TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus
# Avoid importing model tests that cause CUDA reinitialization error
# Avoid importing model tests that cause CUDA reinitialization error
...
...
csrc/custom_all_reduce.cu
View file @
99aa4edd
...
@@ -55,18 +55,6 @@ bool _is_weak_contiguous(torch::Tensor& t) {
...
@@ -55,18 +55,6 @@ bool _is_weak_contiguous(torch::Tensor& t) {
t
.
numel
()
*
t
.
element_size
());
t
.
numel
()
*
t
.
element_size
());
}
}
bool
should_custom_ar
(
torch
::
Tensor
&
inp
,
int64_t
max_size
,
int64_t
world_size
,
bool
full_nvlink
)
{
auto
inp_size
=
inp
.
numel
()
*
inp
.
element_size
();
// custom allreduce requires input byte size to be multiples of 16
if
(
inp_size
%
16
!=
0
)
return
false
;
if
(
!
_is_weak_contiguous
(
inp
))
return
false
;
if
(
world_size
==
2
||
full_nvlink
)
return
inp_size
<=
max_size
;
// for 4 or more non NVLink-capable GPUs, custom allreduce provides little
// performance improvement over NCCL.
return
false
;
}
void
_all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
void
_all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
...
...
csrc/ops.h
View file @
99aa4edd
...
@@ -241,8 +241,6 @@ fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
...
@@ -241,8 +241,6 @@ fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
int64_t
rank
,
const
std
::
vector
<
int64_t
>&
offsets
,
int64_t
rank
,
bool
full_nvlink
);
bool
full_nvlink
);
bool
should_custom_ar
(
torch
::
Tensor
&
inp
,
int64_t
max_size
,
int64_t
world_size
,
bool
full_nvlink
);
void
all_reduce_reg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
);
void
all_reduce_reg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
);
void
all_reduce_unreg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
reg_buffer
,
void
all_reduce_unreg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
reg_buffer
,
torch
::
Tensor
&
out
);
torch
::
Tensor
&
out
);
...
...
csrc/torch_bindings.cpp
View file @
99aa4edd
...
@@ -411,11 +411,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
...
@@ -411,11 +411,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
"bool full_nvlink) -> int"
);
"bool full_nvlink) -> int"
);
custom_ar
.
impl
(
"init_custom_ar"
,
torch
::
kCUDA
,
&
init_custom_ar
);
custom_ar
.
impl
(
"init_custom_ar"
,
torch
::
kCUDA
,
&
init_custom_ar
);
custom_ar
.
def
(
"should_custom_ar(Tensor inp, int max_size, int world_size, "
"bool full_nvlink) -> bool"
);
custom_ar
.
impl
(
"should_custom_ar"
,
torch
::
kCUDA
,
&
should_custom_ar
);
custom_ar
.
def
(
"all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()"
);
custom_ar
.
def
(
"all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()"
);
custom_ar
.
impl
(
"all_reduce_reg"
,
torch
::
kCUDA
,
&
all_reduce_reg
);
custom_ar
.
impl
(
"all_reduce_reg"
,
torch
::
kCUDA
,
&
all_reduce_reg
);
...
...
tests/compile/__init__.py
0 → 100644
View file @
99aa4edd
tests/compile/test_full_graph.py
View file @
99aa4edd
...
@@ -2,9 +2,20 @@ import os
...
@@ -2,9 +2,20 @@ import os
import
pytest
import
pytest
from
vllm.utils
import
cuda_device_count_stateless
from
..utils
import
fork_new_process_for_each_test
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"meta-llama/Meta-Llama-3-8B"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"meta-llama/Meta-Llama-3-8B"
])
def
test_full_graph
(
model
):
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
,
2
])
@
fork_new_process_for_each_test
def
test_full_graph
(
model
,
tp_size
):
# Skip the test if there are not enough CUDA devices.
if
cuda_device_count_stateless
()
<
tp_size
:
pytest
.
skip
(
"Not enough CUDA devices for the test."
)
# make sure these models can be captured in full graph mode
# make sure these models can be captured in full graph mode
if
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE"
not
in
os
.
environ
:
if
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE"
not
in
os
.
environ
:
os
.
environ
[
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE"
]
=
"1"
os
.
environ
[
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE"
]
=
"1"
...
@@ -17,7 +28,7 @@ def test_full_graph(model):
...
@@ -17,7 +28,7 @@ def test_full_graph(model):
"The future of AI is"
,
"The future of AI is"
,
]
]
sampling_params
=
SamplingParams
(
temperature
=
0
)
sampling_params
=
SamplingParams
(
temperature
=
0
)
llm
=
LLM
(
model
=
model
,
enforce_eager
=
True
)
llm
=
LLM
(
model
=
model
,
enforce_eager
=
True
,
tensor_parallel_size
=
tp_size
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
...
...
vllm/_custom_ops.py
View file @
99aa4edd
...
@@ -870,12 +870,6 @@ def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
...
@@ -870,12 +870,6 @@ def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
offsets
,
rank
,
full_nvlink
)
offsets
,
rank
,
full_nvlink
)
def
should_custom_ar
(
inp
:
torch
.
Tensor
,
max_size
:
int
,
world_size
:
int
,
full_nvlink
:
bool
)
->
bool
:
return
torch
.
ops
.
_C_custom_ar
.
should_custom_ar
(
inp
,
max_size
,
world_size
,
full_nvlink
)
def
all_reduce_reg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
def
all_reduce_reg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
all_reduce_reg
(
fa
,
inp
,
out
)
torch
.
ops
.
_C_custom_ar
.
all_reduce_reg
(
fa
,
inp
,
out
)
...
...
vllm/distributed/device_communicators/custom_all_reduce.py
View file @
99aa4edd
...
@@ -33,6 +33,12 @@ def _can_p2p(rank: int, world_size: int) -> bool:
...
@@ -33,6 +33,12 @@ def _can_p2p(rank: int, world_size: int) -> bool:
return
True
return
True
def
is_weak_contiguous
(
inp
:
torch
.
Tensor
):
return
inp
.
is_contiguous
()
or
(
inp
.
storage
().
nbytes
()
-
inp
.
storage_offset
()
*
inp
.
element_size
()
==
inp
.
numel
()
*
inp
.
element_size
())
class
CustomAllreduce
:
class
CustomAllreduce
:
_SUPPORTED_WORLD_SIZES
=
[
2
,
4
,
6
,
8
]
_SUPPORTED_WORLD_SIZES
=
[
2
,
4
,
6
,
8
]
...
@@ -224,8 +230,19 @@ class CustomAllreduce:
...
@@ -224,8 +230,19 @@ class CustomAllreduce:
ops
.
register_graph_buffers
(
self
.
_ptr
,
handles
,
offsets
)
ops
.
register_graph_buffers
(
self
.
_ptr
,
handles
,
offsets
)
def
should_custom_ar
(
self
,
inp
:
torch
.
Tensor
):
def
should_custom_ar
(
self
,
inp
:
torch
.
Tensor
):
return
ops
.
should_custom_ar
(
inp
,
self
.
max_size
,
self
.
world_size
,
if
self
.
disabled
:
self
.
full_nvlink
)
return
False
inp_size
=
inp
.
numel
()
*
inp
.
element_size
()
# custom allreduce requires input byte size to be multiples of 16
if
inp_size
%
16
!=
0
:
return
False
if
not
is_weak_contiguous
(
inp
):
return
False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
if
self
.
world_size
==
2
or
self
.
full_nvlink
:
return
inp_size
<
self
.
max_size
return
False
# all reduce, assuming inp tensor is IPC registered with register_buffer,
# all reduce, assuming inp tensor is IPC registered with register_buffer,
# or, in the context of cuda graphs, register_graph_buffers
# or, in the context of cuda graphs, register_graph_buffers
...
...
vllm/distributed/parallel_state.py
View file @
99aa4edd
...
@@ -21,11 +21,12 @@ If you only need to use the distributed environment without model/pipeline
...
@@ -21,11 +21,12 @@ If you only need to use the distributed environment without model/pipeline
"""
"""
import
contextlib
import
contextlib
import
pickle
import
pickle
import
weakref
from
collections
import
namedtuple
from
collections
import
namedtuple
from
contextlib
import
contextmanager
,
nullcontext
from
contextlib
import
contextmanager
,
nullcontext
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
multiprocessing
import
shared_memory
from
multiprocessing
import
shared_memory
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
torch
import
torch
...
@@ -69,6 +70,58 @@ def _split_tensor_dict(
...
@@ -69,6 +70,58 @@ def _split_tensor_dict(
return
metadata_list
,
tensor_list
return
metadata_list
,
tensor_list
_group_name_counter
:
Dict
[
str
,
int
]
=
{}
def
_get_unique_name
(
name
:
str
)
->
str
:
"""Get a unique name for the group.
Example:
_get_unique_name("tp") -> "tp:0"
_get_unique_name("tp") -> "tp:1"
"""
if
name
not
in
_group_name_counter
:
_group_name_counter
[
name
]
=
0
newname
=
f
"
{
name
}
:
{
_group_name_counter
[
name
]
}
"
_group_name_counter
[
name
]
+=
1
return
newname
_groups
:
Dict
[
str
,
Callable
[[],
"GroupCoordinator"
]]
=
{}
def
_register_group
(
group
:
"GroupCoordinator"
)
->
None
:
# looks like Python 3.8 does not understand `ReferenceType`
_groups
[
group
.
unique_name
]
=
weakref
.
ref
(
group
)
# type: ignore
@
torch
.
library
.
custom_op
(
"vllm::inplace_all_reduce"
,
mutates_args
=
[
"tensor"
])
def
inplace_all_reduce
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
None
:
assert
group_name
in
_groups
,
f
"Group
{
group_name
}
is not found."
group
=
_groups
[
group_name
]()
if
group
is
None
:
raise
ValueError
(
f
"Group
{
group_name
}
is destroyed."
)
group
.
_all_reduce
(
tensor
)
@
inplace_all_reduce
.
register_fake
def
_
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
None
:
return
@
torch
.
library
.
custom_op
(
"vllm::outplace_all_reduce"
,
mutates_args
=
[])
def
outplace_all_reduce
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
torch
.
Tensor
:
assert
group_name
in
_groups
,
f
"Group
{
group_name
}
is not found."
group
=
_groups
[
group_name
]()
if
group
is
None
:
raise
ValueError
(
f
"Group
{
group_name
}
is destroyed."
)
return
group
.
_all_reduce
(
tensor
)
@
outplace_all_reduce
.
register_fake
def
_
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
tensor
)
class
GroupCoordinator
:
class
GroupCoordinator
:
"""
"""
PyTorch ProcessGroup wrapper for a group of processes.
PyTorch ProcessGroup wrapper for a group of processes.
...
@@ -111,7 +164,11 @@ class GroupCoordinator:
...
@@ -111,7 +164,11 @@ class GroupCoordinator:
use_custom_allreduce
:
bool
,
use_custom_allreduce
:
bool
,
use_tpu_communicator
:
bool
,
use_tpu_communicator
:
bool
,
use_message_queue_broadcaster
:
bool
=
False
,
use_message_queue_broadcaster
:
bool
=
False
,
group_name
:
Optional
[
str
]
=
None
,
):
):
group_name
=
group_name
or
"anonymous"
self
.
unique_name
=
_get_unique_name
(
group_name
)
_register_group
(
self
)
self
.
rank
=
torch
.
distributed
.
get_rank
()
self
.
rank
=
torch
.
distributed
.
get_rank
()
self
.
local_rank
=
local_rank
self
.
local_rank
=
local_rank
...
@@ -149,28 +206,24 @@ class GroupCoordinator:
...
@@ -149,28 +206,24 @@ class GroupCoordinator:
from
vllm.distributed.device_communicators.pynccl
import
(
from
vllm.distributed.device_communicators.pynccl
import
(
PyNcclCommunicator
)
PyNcclCommunicator
)
self
.
pynccl_comm
:
Optional
[
PyNcclCommunicator
]
self
.
pynccl_comm
:
Optional
[
PyNcclCommunicator
]
=
None
if
use_pynccl
and
self
.
world_size
>
1
:
if
use_pynccl
and
self
.
world_size
>
1
:
self
.
pynccl_comm
=
PyNcclCommunicator
(
self
.
pynccl_comm
=
PyNcclCommunicator
(
group
=
self
.
cpu_group
,
group
=
self
.
cpu_group
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
else
:
self
.
pynccl_comm
=
None
self
.
ca_comm
:
Optional
[
CustomAllreduce
]
self
.
ca_comm
:
Optional
[
CustomAllreduce
]
=
None
if
use_custom_allreduce
and
self
.
world_size
>
1
:
if
use_custom_allreduce
and
self
.
world_size
>
1
:
# Initialize a custom fast all-reduce implementation.
# Initialize a custom fast all-reduce implementation.
self
.
ca_comm
=
CustomAllreduce
(
self
.
ca_comm
=
CustomAllreduce
(
group
=
self
.
cpu_group
,
group
=
self
.
cpu_group
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
else
:
self
.
ca_comm
=
None
from
vllm.distributed.device_communicators.tpu_communicator
import
(
from
vllm.distributed.device_communicators.tpu_communicator
import
(
TpuCommunicator
)
TpuCommunicator
)
self
.
tpu_communicator
:
Optional
[
TpuCommunicator
]
self
.
tpu_communicator
:
Optional
[
TpuCommunicator
]
=
None
if
use_tpu_communicator
and
self
.
world_size
>
1
:
if
use_tpu_communicator
and
self
.
world_size
>
1
:
self
.
tpu_communicator
=
TpuCommunicator
(
group
=
self
.
cpu_group
)
self
.
tpu_communicator
=
TpuCommunicator
(
group
=
self
.
cpu_group
)
...
@@ -264,16 +317,46 @@ class GroupCoordinator:
...
@@ -264,16 +317,46 @@ class GroupCoordinator:
def
all_reduce
(
self
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
all_reduce
(
self
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
"""
User-facing all-reduce function before we actually call the
all-reduce operation.
We need this because Dynamo does not support passing an arbitrary
object (`self` in this case) to a custom op. We need to pass the
group name as a string, and then look up the group coordinator from
the group name, dispatch the all-reduce operation to the group
coordinator.
In addition, PyTorch custom ops do not support mutation or returning
a new tensor in the same op. So we need to figure out if the op is
in-place or out-of-place ahead of time.
"""
# Bypass the function if we are using only 1 GPU.
if
self
.
world_size
==
1
:
return
input_
if
self
.
tpu_communicator
is
not
None
and
\
not
self
.
tpu_communicator
.
disabled
:
# TPU handles Dynamo with its own logic.
return
self
.
_all_reduce
(
input_
)
if
self
.
ca_comm
is
not
None
and
self
.
ca_comm
.
should_custom_ar
(
input_
):
return
torch
.
ops
.
vllm
.
outplace_all_reduce
(
input_
,
group_name
=
self
.
unique_name
)
else
:
torch
.
ops
.
vllm
.
inplace_all_reduce
(
input_
,
group_name
=
self
.
unique_name
)
return
input_
def
_all_reduce
(
self
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
The actual all-reduce implementation.
NOTE: This operation will be applied in-place or out-of-place.
NOTE: This operation will be applied in-place or out-of-place.
Always assume this function modifies its input, but use the return
Always assume this function modifies its input, but use the return
value as the output.
value as the output.
"""
"""
ca_comm
=
self
.
ca_comm
ca_comm
=
self
.
ca_comm
# Bypass the function if we are using only 1 GPU.
if
self
.
world_size
==
1
:
return
input_
# For TPUs, use TPU communicator.
# For TPUs, use TPU communicator.
tpu_comm
=
self
.
tpu_communicator
tpu_comm
=
self
.
tpu_communicator
if
tpu_comm
is
not
None
and
not
tpu_comm
.
disabled
:
if
tpu_comm
is
not
None
and
not
tpu_comm
.
disabled
:
...
@@ -758,6 +841,7 @@ def init_world_group(ranks: List[int], local_rank: int,
...
@@ -758,6 +841,7 @@ def init_world_group(ranks: List[int], local_rank: int,
use_pynccl
=
False
,
use_pynccl
=
False
,
use_custom_allreduce
=
False
,
use_custom_allreduce
=
False
,
use_tpu_communicator
=
False
,
use_tpu_communicator
=
False
,
group_name
=
"world"
,
)
)
...
@@ -767,6 +851,7 @@ def init_model_parallel_group(
...
@@ -767,6 +851,7 @@ def init_model_parallel_group(
backend
:
str
,
backend
:
str
,
use_custom_allreduce
:
Optional
[
bool
]
=
None
,
use_custom_allreduce
:
Optional
[
bool
]
=
None
,
use_message_queue_broadcaster
:
bool
=
False
,
use_message_queue_broadcaster
:
bool
=
False
,
group_name
:
Optional
[
str
]
=
None
,
)
->
GroupCoordinator
:
)
->
GroupCoordinator
:
if
use_custom_allreduce
is
None
:
if
use_custom_allreduce
is
None
:
use_custom_allreduce
=
_ENABLE_CUSTOM_ALL_REDUCE
use_custom_allreduce
=
_ENABLE_CUSTOM_ALL_REDUCE
...
@@ -778,6 +863,7 @@ def init_model_parallel_group(
...
@@ -778,6 +863,7 @@ def init_model_parallel_group(
use_custom_allreduce
=
use_custom_allreduce
,
use_custom_allreduce
=
use_custom_allreduce
,
use_tpu_communicator
=
True
,
use_tpu_communicator
=
True
,
use_message_queue_broadcaster
=
use_message_queue_broadcaster
,
use_message_queue_broadcaster
=
use_message_queue_broadcaster
,
group_name
=
group_name
,
)
)
...
@@ -931,7 +1017,8 @@ def initialize_model_parallel(
...
@@ -931,7 +1017,8 @@ def initialize_model_parallel(
_TP
=
init_model_parallel_group
(
group_ranks
,
_TP
=
init_model_parallel_group
(
group_ranks
,
get_world_group
().
local_rank
,
get_world_group
().
local_rank
,
backend
,
backend
,
use_message_queue_broadcaster
=
True
)
use_message_queue_broadcaster
=
True
,
group_name
=
"tp"
)
# Build the pipeline model-parallel groups.
# Build the pipeline model-parallel groups.
num_pipeline_model_parallel_groups
:
int
=
(
world_size
//
num_pipeline_model_parallel_groups
:
int
=
(
world_size
//
...
@@ -947,7 +1034,8 @@ def initialize_model_parallel(
...
@@ -947,7 +1034,8 @@ def initialize_model_parallel(
_PP
=
init_model_parallel_group
(
group_ranks
,
_PP
=
init_model_parallel_group
(
group_ranks
,
get_world_group
().
local_rank
,
get_world_group
().
local_rank
,
backend
,
backend
,
use_custom_allreduce
=
False
)
use_custom_allreduce
=
False
,
group_name
=
"pp"
)
def
ensure_model_parallel_initialized
(
def
ensure_model_parallel_initialized
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment