Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
293fb40c
"git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "abf4c27f6adc4b65914744a23ba23c4e60b2a722"
Unverified
Commit
293fb40c
authored
Jan 07, 2022
by
ver217
Committed by
GitHub
Jan 07, 2022
Browse files
add scatter/gather optim for pipeline (#123)
parent
404e6f88
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
166 additions
and
56 deletions
+166
-56
colossalai/communication/__init__.py
colossalai/communication/__init__.py
+2
-2
colossalai/communication/p2p.py
colossalai/communication/p2p.py
+83
-21
colossalai/communication/utils.py
colossalai/communication/utils.py
+28
-0
colossalai/engine/schedule/_pipeline_schedule.py
colossalai/engine/schedule/_pipeline_schedule.py
+51
-32
colossalai/initialize.py
colossalai/initialize.py
+2
-1
No files found.
colossalai/communication/__init__.py
View file @
293fb40c
...
@@ -13,5 +13,5 @@ __all__ = [
...
@@ -13,5 +13,5 @@ __all__ = [
'send_forward_backward_recv_forward_backward'
,
'send_backward'
,
'send_forward_backward_recv_forward_backward'
,
'send_backward'
,
'send_backward_recv_backward'
,
'send_backward_recv_forward'
,
'send_backward_recv_backward'
,
'send_backward_recv_forward'
,
'send_forward_recv_backward'
,
'recv_backward'
,
'recv_forward'
,
'send_forward_recv_backward'
,
'recv_backward'
,
'recv_forward'
,
'ring_forward'
,
'send_tensor_meta'
,
'recv_tensor_meta'
'ring_forward'
,
'send_tensor_meta'
,
'recv_tensor_meta'
,
]
]
\ No newline at end of file
colossalai/communication/p2p.py
View file @
293fb40c
#!/usr/bin/env python
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# -*- encoding: utf-8 -*-
from
typing
import
List
,
Tuple
,
Union
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
functools
import
reduce
import
operator
from
.utils
import
split_tensor_into_1d_equal_chunks
,
gather_split_1d_tensor
TensorShape
=
Union
[
torch
.
Size
,
List
[
int
],
Tuple
[
int
]]
def
_get_tensor_shape
(
tensor_shape
:
TensorShape
,
chunk_tensor
:
bool
=
False
)
->
Tuple
[
TensorShape
,
bool
]:
"""get the exact tensor shape when communicating and return whether the tensor is a chunk
:param tensor_shape: shape of tensor
:type tensor_shape: TensorShape
:param chunk_tensor: whether to chunk tensor, defaults to False
:type chunk_tensor: bool, optional
:return: exact tensor shape, whether to chunk tensor
:rtype: Tuple[Union[torch.Size, List[int], Tuple[int]], bool]
"""
if
chunk_tensor
:
tensor_chunk_shape
=
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
tensor_parallel_world_size
=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
if
tensor_chunk_shape
%
tensor_parallel_world_size
==
0
:
tensor_chunk_shape
=
tensor_chunk_shape
//
tensor_parallel_world_size
else
:
tensor_chunk_shape
=
tensor_shape
chunk_tensor
=
False
else
:
tensor_chunk_shape
=
tensor_shape
return
tensor_chunk_shape
,
chunk_tensor
def
_communicate
(
tensor_send_next
=
None
,
def
_communicate
(
tensor_send_next
=
None
,
...
@@ -17,7 +47,8 @@ def _communicate(tensor_send_next=None,
...
@@ -17,7 +47,8 @@ def _communicate(tensor_send_next=None,
recv_next_shape
=
None
,
recv_next_shape
=
None
,
prev_rank
=
None
,
prev_rank
=
None
,
next_rank
=
None
,
next_rank
=
None
,
dtype
=
None
):
dtype
=
None
,
scatter_gather_tensors
=
False
):
"""
"""
Adapted from megatron.p2p_communication.
Adapted from megatron.p2p_communication.
Communicate tensors between stages. Used as helper method in other
Communicate tensors between stages. Used as helper method in other
...
@@ -42,13 +73,15 @@ def _communicate(tensor_send_next=None,
...
@@ -42,13 +73,15 @@ def _communicate(tensor_send_next=None,
if
recv_prev
:
if
recv_prev
:
assert
recv_prev_shape
is
not
None
assert
recv_prev_shape
is
not
None
tensor_recv_prev
=
torch
.
empty
(
recv_prev_shape
,
recv_prev_chunk_shape
,
recv_prev_split
=
_get_tensor_shape
(
recv_prev_shape
,
scatter_gather_tensors
)
tensor_recv_prev
=
torch
.
empty
(
recv_prev_chunk_shape
,
requires_grad
=
True
,
requires_grad
=
True
,
device
=
get_current_device
(),
device
=
get_current_device
(),
dtype
=
dtype
)
dtype
=
dtype
)
if
recv_next
:
if
recv_next
:
assert
recv_next_shape
is
not
None
assert
recv_next_shape
is
not
None
tensor_recv_next
=
torch
.
empty
(
recv_next_shape
,
recv_next_chunk_shape
,
recv_next_split
=
_get_tensor_shape
(
recv_next_shape
,
scatter_gather_tensors
)
tensor_recv_next
=
torch
.
empty
(
recv_next_chunk_shape
,
requires_grad
=
True
,
requires_grad
=
True
,
device
=
get_current_device
(),
device
=
get_current_device
(),
dtype
=
dtype
)
dtype
=
dtype
)
...
@@ -63,6 +96,16 @@ def _communicate(tensor_send_next=None,
...
@@ -63,6 +96,16 @@ def _communicate(tensor_send_next=None,
next_rank
=
gpc
.
get_next_global_rank
(
next_rank
=
gpc
.
get_next_global_rank
(
ParallelMode
.
PIPELINE
)
ParallelMode
.
PIPELINE
)
if
tensor_send_prev
is
not
None
:
send_prev_split
=
_get_tensor_shape
(
tensor_send_prev
.
shape
,
scatter_gather_tensors
)[
1
]
if
send_prev_split
:
tensor_send_prev
=
split_tensor_into_1d_equal_chunks
(
tensor_send_prev
)
if
tensor_send_next
is
not
None
:
send_next_split
=
_get_tensor_shape
(
tensor_send_next
.
shape
,
scatter_gather_tensors
)[
1
]
if
send_next_split
:
tensor_send_next
=
split_tensor_into_1d_equal_chunks
(
tensor_send_next
)
ops
=
[]
ops
=
[]
if
tensor_send_prev
is
not
None
:
if
tensor_send_prev
is
not
None
:
send_prev_op
=
dist
.
P2POp
(
dist
.
isend
,
tensor_send_prev
,
prev_rank
)
send_prev_op
=
dist
.
P2POp
(
dist
.
isend
,
tensor_send_prev
,
prev_rank
)
...
@@ -82,10 +125,15 @@ def _communicate(tensor_send_next=None,
...
@@ -82,10 +125,15 @@ def _communicate(tensor_send_next=None,
req
.
wait
()
req
.
wait
()
# To protect against race condition when using batch_isend_irecv().
# To protect against race condition when using batch_isend_irecv().
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
if
recv_prev
and
recv_prev_split
:
tensor_recv_prev
=
gather_split_1d_tensor
(
tensor_recv_prev
).
view
(
recv_prev_shape
).
requires_grad_
()
if
recv_next
and
recv_next_split
:
tensor_recv_next
=
gather_split_1d_tensor
(
tensor_recv_next
).
view
(
recv_next_shape
).
requires_grad_
()
return
tensor_recv_prev
,
tensor_recv_next
return
tensor_recv_prev
,
tensor_recv_next
def
recv_forward
(
input_tensor_shape
,
prev_rank
=
None
,
dtype
=
torch
.
float
):
def
recv_forward
(
input_tensor_shape
,
prev_rank
=
None
,
dtype
=
torch
.
float
,
scatter_gather_tensors
=
False
):
"""Receives the input tensor from the previous member in pipeline.
"""Receives the input tensor from the previous member in pipeline.
:param input_tensor_shape: The shape of the tensor to be recieved
:param input_tensor_shape: The shape of the tensor to be recieved
...
@@ -101,11 +149,12 @@ def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float):
...
@@ -101,11 +149,12 @@ def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float):
input_tensor
,
_
=
_communicate
(
recv_prev
=
True
,
input_tensor
,
_
=
_communicate
(
recv_prev
=
True
,
recv_prev_shape
=
input_tensor_shape
,
recv_prev_shape
=
input_tensor_shape
,
prev_rank
=
prev_rank
,
prev_rank
=
prev_rank
,
dtype
=
dtype
)
dtype
=
dtype
,
scatter_gather_tensors
=
scatter_gather_tensors
)
return
input_tensor
return
input_tensor
def
recv_backward
(
output_grad_shape
,
next_rank
=
None
,
dtype
=
torch
.
float
):
def
recv_backward
(
output_grad_shape
,
next_rank
=
None
,
dtype
=
torch
.
float
,
scatter_gather_tensors
=
False
):
"""Receives the grad tensor from the next member in pipeline.
"""Receives the grad tensor from the next member in pipeline.
:param output_grad_shape: The shape of the tensor to be recieved
:param output_grad_shape: The shape of the tensor to be recieved
...
@@ -121,11 +170,12 @@ def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float):
...
@@ -121,11 +170,12 @@ def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float):
_
,
output_tensor_grad
=
_communicate
(
recv_next
=
True
,
_
,
output_tensor_grad
=
_communicate
(
recv_next
=
True
,
recv_next_shape
=
output_grad_shape
,
recv_next_shape
=
output_grad_shape
,
next_rank
=
next_rank
,
next_rank
=
next_rank
,
dtype
=
dtype
)
dtype
=
dtype
,
scatter_gather_tensors
=
scatter_gather_tensors
)
return
output_tensor_grad
return
output_tensor_grad
def
send_forward
(
output_tensor
,
next_rank
=
None
):
def
send_forward
(
output_tensor
,
next_rank
=
None
,
scatter_gather_tensors
=
False
):
"""Sends the input tensor to the next member in pipeline.
"""Sends the input tensor to the next member in pipeline.
:param output_tensor: Tensor to be sent
:param output_tensor: Tensor to be sent
...
@@ -135,10 +185,11 @@ def send_forward(output_tensor, next_rank=None):
...
@@ -135,10 +185,11 @@ def send_forward(output_tensor, next_rank=None):
"""
"""
if
not
gpc
.
is_pipeline_last_stage
():
if
not
gpc
.
is_pipeline_last_stage
():
_communicate
(
tensor_send_next
=
output_tensor
,
_communicate
(
tensor_send_next
=
output_tensor
,
next_rank
=
next_rank
)
next_rank
=
next_rank
,
scatter_gather_tensors
=
scatter_gather_tensors
)
def
send_backward
(
input_tensor_grad
,
prev_rank
=
None
):
def
send_backward
(
input_tensor_grad
,
prev_rank
=
None
,
scatter_gather_tensors
=
False
):
"""Sends the grad tensor to the previous member in pipeline.
"""Sends the grad tensor to the previous member in pipeline.
:param input_tensor_grad: Tensor to be sent
:param input_tensor_grad: Tensor to be sent
...
@@ -148,14 +199,16 @@ def send_backward(input_tensor_grad, prev_rank=None):
...
@@ -148,14 +199,16 @@ def send_backward(input_tensor_grad, prev_rank=None):
"""
"""
if
not
gpc
.
is_pipeline_first_stage
():
if
not
gpc
.
is_pipeline_first_stage
():
_communicate
(
tensor_send_prev
=
input_tensor_grad
,
_communicate
(
tensor_send_prev
=
input_tensor_grad
,
prev_rank
=
prev_rank
)
prev_rank
=
prev_rank
,
scatter_gather_tensors
=
scatter_gather_tensors
)
def
send_forward_recv_backward
(
output_tensor
,
def
send_forward_recv_backward
(
output_tensor
,
output_grad_shape
,
output_grad_shape
,
recv_next
=
True
,
recv_next
=
True
,
next_rank
=
None
,
next_rank
=
None
,
dtype
=
torch
.
float
):
dtype
=
torch
.
float
,
scatter_gather_tensors
=
False
):
"""Batched communication operation. Sends the input tensor to the
"""Batched communication operation. Sends the input tensor to the
next member in pipeline, while recieves the grad tensor from the
next member in pipeline, while recieves the grad tensor from the
next member in pipeline.
next member in pipeline.
...
@@ -174,7 +227,8 @@ def send_forward_recv_backward(output_tensor,
...
@@ -174,7 +227,8 @@ def send_forward_recv_backward(output_tensor,
recv_next
=
recv_next
,
recv_next
=
recv_next
,
recv_next_shape
=
output_grad_shape
,
recv_next_shape
=
output_grad_shape
,
next_rank
=
next_rank
,
next_rank
=
next_rank
,
dtype
=
dtype
)
dtype
=
dtype
,
scatter_gather_tensors
=
scatter_gather_tensors
)
return
output_tensor_grad
return
output_tensor_grad
...
@@ -182,7 +236,8 @@ def send_backward_recv_forward(input_tensor_grad,
...
@@ -182,7 +236,8 @@ def send_backward_recv_forward(input_tensor_grad,
input_tensor_shape
,
input_tensor_shape
,
recv_prev
=
True
,
recv_prev
=
True
,
prev_rank
=
None
,
prev_rank
=
None
,
dtype
=
torch
.
float
):
dtype
=
torch
.
float
,
scatter_gather_tensors
=
False
):
"""Batched communication operation. Sends the grad tensor to the
"""Batched communication operation. Sends the grad tensor to the
previous member in pipeline, while recieves the input tensor from the
previous member in pipeline, while recieves the input tensor from the
previous member in pipeline.
previous member in pipeline.
...
@@ -201,7 +256,8 @@ def send_backward_recv_forward(input_tensor_grad,
...
@@ -201,7 +256,8 @@ def send_backward_recv_forward(input_tensor_grad,
recv_prev
=
recv_prev
,
recv_prev
=
recv_prev
,
recv_prev_shape
=
input_tensor_shape
,
recv_prev_shape
=
input_tensor_shape
,
prev_rank
=
prev_rank
,
prev_rank
=
prev_rank
,
dtype
=
dtype
)
dtype
=
dtype
,
scatter_gather_tensors
=
scatter_gather_tensors
)
return
input_tensor
return
input_tensor
...
@@ -210,7 +266,8 @@ def send_forward_recv_forward(output_tensor,
...
@@ -210,7 +266,8 @@ def send_forward_recv_forward(output_tensor,
recv_prev
=
True
,
recv_prev
=
True
,
prev_rank
=
None
,
prev_rank
=
None
,
next_rank
=
None
,
next_rank
=
None
,
dtype
=
torch
.
float
):
dtype
=
torch
.
float
,
scatter_gather_tensors
=
False
):
"""Batched communication operation. Sends the input tensor to the
"""Batched communication operation. Sends the input tensor to the
next member in pipeline, while recieves the input tensor from the
next member in pipeline, while recieves the input tensor from the
previous member in pipeline.
previous member in pipeline.
...
@@ -227,7 +284,8 @@ def send_forward_recv_forward(output_tensor,
...
@@ -227,7 +284,8 @@ def send_forward_recv_forward(output_tensor,
recv_prev_shape
=
input_tensor_shape
,
recv_prev_shape
=
input_tensor_shape
,
prev_rank
=
prev_rank
,
prev_rank
=
prev_rank
,
next_rank
=
next_rank
,
next_rank
=
next_rank
,
dtype
=
dtype
)
dtype
=
dtype
,
scatter_gather_tensors
=
scatter_gather_tensors
)
return
input_tensor
return
input_tensor
...
@@ -236,7 +294,8 @@ def send_backward_recv_backward(input_tensor_grad,
...
@@ -236,7 +294,8 @@ def send_backward_recv_backward(input_tensor_grad,
recv_next
=
True
,
recv_next
=
True
,
prev_rank
=
None
,
prev_rank
=
None
,
next_rank
=
None
,
next_rank
=
None
,
dtype
=
torch
.
float
):
dtype
=
torch
.
float
,
scatter_gather_tensors
=
False
):
"""Batched communication operation. Sends the grad tensor to the
"""Batched communication operation. Sends the grad tensor to the
previous member in pipeline, while recieves the grad tensor from the
previous member in pipeline, while recieves the grad tensor from the
next member in pipeline.
next member in pipeline.
...
@@ -253,7 +312,8 @@ def send_backward_recv_backward(input_tensor_grad,
...
@@ -253,7 +312,8 @@ def send_backward_recv_backward(input_tensor_grad,
recv_next_shape
=
output_grad_shape
,
recv_next_shape
=
output_grad_shape
,
prev_rank
=
prev_rank
,
prev_rank
=
prev_rank
,
next_rank
=
next_rank
,
next_rank
=
next_rank
,
dtype
=
dtype
)
dtype
=
dtype
,
scatter_gather_tensors
=
scatter_gather_tensors
)
return
output_tensor_grad
return
output_tensor_grad
...
@@ -265,7 +325,8 @@ def send_forward_backward_recv_forward_backward(output_tensor,
...
@@ -265,7 +325,8 @@ def send_forward_backward_recv_forward_backward(output_tensor,
recv_next
=
True
,
recv_next
=
True
,
prev_rank
=
None
,
prev_rank
=
None
,
next_rank
=
None
,
next_rank
=
None
,
dtype
=
torch
.
float
):
dtype
=
torch
.
float
,
scatter_gather_tensors
=
False
):
"""Batched communication operation. Sends the input tensor to the next and
"""Batched communication operation. Sends the input tensor to the next and
the grad tensor to the previous, while recieves the grad tensor from the
the grad tensor to the previous, while recieves the grad tensor from the
next and the input tensor from the previous.
next and the input tensor from the previous.
...
@@ -290,5 +351,6 @@ def send_forward_backward_recv_forward_backward(output_tensor,
...
@@ -290,5 +351,6 @@ def send_forward_backward_recv_forward_backward(output_tensor,
recv_next_shape
=
output_grad_shape
,
recv_next_shape
=
output_grad_shape
,
prev_rank
=
prev_rank
,
prev_rank
=
prev_rank
,
next_rank
=
next_rank
,
next_rank
=
next_rank
,
dtype
=
dtype
)
dtype
=
dtype
,
scatter_gather_tensors
=
scatter_gather_tensors
)
return
input_tensor
,
output_tensor_grad
return
input_tensor
,
output_tensor_grad
colossalai/communication/utils.py
View file @
293fb40c
...
@@ -62,3 +62,31 @@ def recv_tensor_meta(tensor_shape, prev_rank=None):
...
@@ -62,3 +62,31 @@ def recv_tensor_meta(tensor_shape, prev_rank=None):
tensor_shape
=
torch
.
Size
(
recv_shape
)
tensor_shape
=
torch
.
Size
(
recv_shape
)
return
tensor_shape
return
tensor_shape
def
split_tensor_into_1d_equal_chunks
(
tensor
,
new_buffer
=
False
):
"""Break a tensor into equal 1D chunks."""
partition_size
=
torch
.
numel
(
tensor
)
//
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)
start_index
=
partition_size
*
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
end_index
=
start_index
+
partition_size
if
new_buffer
:
data
=
torch
.
empty
(
partition_size
,
dtype
=
tensor
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
data
.
copy_
(
tensor
.
view
(
-
1
)[
start_index
:
end_index
])
else
:
data
=
tensor
.
view
(
-
1
)[
start_index
:
end_index
]
return
data
def
gather_split_1d_tensor
(
tensor
):
"""Opposite of above function, gather values from model parallel ranks."""
world_size
=
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)
numel
=
torch
.
numel
(
tensor
)
numel_gathered
=
world_size
*
numel
gathered
=
torch
.
empty
(
numel_gathered
,
dtype
=
tensor
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
chunks
=
[
gathered
[
i
*
numel
:(
i
+
1
)
*
numel
]
for
i
in
range
(
world_size
)]
dist
.
all_gather
(
chunks
,
tensor
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
))
return
gathered
colossalai/engine/schedule/_pipeline_schedule.py
View file @
293fb40c
...
@@ -6,7 +6,7 @@ import inspect
...
@@ -6,7 +6,7 @@ import inspect
import
torch.cuda
import
torch.cuda
from
torch
import
Tensor
from
torch
import
Tensor
from
colossalai.communication
import
*
import
colossalai.communication
as
comm
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.amp.naive_amp
import
NaiveAMPModel
from
colossalai.amp.naive_amp
import
NaiveAMPModel
...
@@ -33,16 +33,22 @@ class PipelineSchedule(BaseSchedule):
...
@@ -33,16 +33,22 @@ class PipelineSchedule(BaseSchedule):
:type num_microbatches: int
:type num_microbatches: int
:param batch_data_process_func: The preprocessing function which receives a batch of data, and it will be executed in `load_batch`
:param batch_data_process_func: The preprocessing function which receives a batch of data, and it will be executed in `load_batch`
:type batch_data_process_func: Callable
:type batch_data_process_func: Callable
:param scatter_gather_tensors: If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization
:type scatter_gather_tensors: bool
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
num_microbatches
,
num_microbatches
,
batch_data_process_func
:
Callable
=
None
,
batch_data_process_func
:
Callable
=
None
,
tensor_shape
:
Union
[
torch
.
Size
,
List
[
int
],
Tuple
[
int
]]
=
None
):
tensor_shape
:
Union
[
torch
.
Size
,
List
[
int
],
Tuple
[
int
]]
=
None
,
scatter_gather_tensors
:
bool
=
False
):
super
().
__init__
(
batch_data_process_func
=
batch_data_process_func
)
super
().
__init__
(
batch_data_process_func
=
batch_data_process_func
)
self
.
num_microbatches
=
num_microbatches
self
.
num_microbatches
=
num_microbatches
self
.
dtype
=
torch
.
float
self
.
dtype
=
torch
.
float
self
.
tensor_shape
=
tensor_shape
self
.
tensor_shape
=
tensor_shape
self
.
scatter_gather_tensors
=
False
if
gpc
.
is_initialized
(
ParallelMode
.
PARALLEL_1D
)
and
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)
>
1
:
self
.
scatter_gather_tensors
=
scatter_gather_tensors
def
load_batch
(
self
,
data_iter
):
def
load_batch
(
self
,
data_iter
):
# Pipeline schedule just puts data in memory
# Pipeline schedule just puts data in memory
...
@@ -227,8 +233,9 @@ class PipelineSchedule(BaseSchedule):
...
@@ -227,8 +233,9 @@ class PipelineSchedule(BaseSchedule):
# Run warmup forward passes.
# Run warmup forward passes.
for
i
in
range
(
num_warmup_microbatches
):
for
i
in
range
(
num_warmup_microbatches
):
if
not
gpc
.
is_first_rank
(
ParallelMode
.
PIPELINE
):
if
not
gpc
.
is_first_rank
(
ParallelMode
.
PIPELINE
):
ft_shape
=
recv_tensor_meta
(
ft_shape
)
ft_shape
=
comm
.
recv_tensor_meta
(
ft_shape
)
input_tensor
=
recv_forward
(
ft_shape
,
dtype
=
self
.
dtype
)
input_tensor
=
comm
.
recv_forward
(
ft_shape
,
dtype
=
self
.
dtype
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
output_tensor
=
self
.
forward_step
(
output_tensor
=
self
.
forward_step
(
engine
,
input_tensor
,
return_tensors
,
engine
,
input_tensor
,
return_tensors
,
return_output_label
=
return_output_label
,
return_output_label
=
return_output_label
,
...
@@ -236,8 +243,8 @@ class PipelineSchedule(BaseSchedule):
...
@@ -236,8 +243,8 @@ class PipelineSchedule(BaseSchedule):
)
)
if
not
gpc
.
is_last_rank
(
ParallelMode
.
PIPELINE
):
if
not
gpc
.
is_last_rank
(
ParallelMode
.
PIPELINE
):
bt_shape
=
output_tensor
.
shape
bt_shape
=
output_tensor
.
shape
fs_checker
=
send_tensor_meta
(
output_tensor
,
fs_checker
)
fs_checker
=
comm
.
send_tensor_meta
(
output_tensor
,
fs_checker
)
send_forward
(
output_tensor
)
comm
.
send_forward
(
output_tensor
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
if
not
forward_only
:
if
not
forward_only
:
input_tensors
.
append
(
input_tensor
)
input_tensors
.
append
(
input_tensor
)
...
@@ -248,8 +255,9 @@ class PipelineSchedule(BaseSchedule):
...
@@ -248,8 +255,9 @@ class PipelineSchedule(BaseSchedule):
# receive this tensor here.
# receive this tensor here.
if
num_microbatches_remaining
>
0
:
if
num_microbatches_remaining
>
0
:
if
not
gpc
.
is_first_rank
(
ParallelMode
.
PIPELINE
):
if
not
gpc
.
is_first_rank
(
ParallelMode
.
PIPELINE
):
ft_shape
=
recv_tensor_meta
(
ft_shape
)
ft_shape
=
comm
.
recv_tensor_meta
(
ft_shape
)
input_tensor
=
recv_forward
(
ft_shape
,
dtype
=
self
.
dtype
)
input_tensor
=
comm
.
recv_forward
(
ft_shape
,
dtype
=
self
.
dtype
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
# Run 1F1B in steady state.
# Run 1F1B in steady state.
for
i
in
range
(
num_microbatches_remaining
):
for
i
in
range
(
num_microbatches_remaining
):
...
@@ -261,14 +269,15 @@ class PipelineSchedule(BaseSchedule):
...
@@ -261,14 +269,15 @@ class PipelineSchedule(BaseSchedule):
accum_loss
=
accum_loss
accum_loss
=
accum_loss
)
)
if
forward_only
:
if
forward_only
:
send_forward
(
output_tensor
)
comm
.
send_forward
(
output_tensor
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
if
not
last_iteration
:
if
not
last_iteration
:
input_tensor
=
recv_forward
(
ft_shape
,
dtype
=
self
.
dtype
)
input_tensor
=
comm
.
recv_forward
(
ft_shape
,
dtype
=
self
.
dtype
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
else
:
else
:
output_tensor_grad
=
send_forward_recv_backward
(
output_tensor_grad
=
comm
.
send_forward_recv_backward
(
output_tensor
,
bt_shape
,
dtype
=
self
.
dtype
)
output_tensor
,
bt_shape
,
dtype
=
self
.
dtype
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
# Add input_tensor and output_tensor to end of list.
# Add input_tensor and output_tensor to end of list.
input_tensors
.
append
(
input_tensor
)
input_tensors
.
append
(
input_tensor
)
...
@@ -287,10 +296,10 @@ class PipelineSchedule(BaseSchedule):
...
@@ -287,10 +296,10 @@ class PipelineSchedule(BaseSchedule):
if
last_iteration
:
if
last_iteration
:
input_tensor
=
None
input_tensor
=
None
send_backward
(
input_tensor_grad
)
comm
.
send_backward
(
input_tensor_grad
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
else
:
else
:
input_tensor
=
send_backward_recv_forward
(
input_tensor
=
comm
.
send_backward_recv_forward
(
input_tensor_grad
,
ft_shape
,
dtype
=
self
.
dtype
)
input_tensor_grad
,
ft_shape
,
dtype
=
self
.
dtype
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
# Run cooldown backward passes.
# Run cooldown backward passes.
if
not
forward_only
:
if
not
forward_only
:
...
@@ -298,7 +307,8 @@ class PipelineSchedule(BaseSchedule):
...
@@ -298,7 +307,8 @@ class PipelineSchedule(BaseSchedule):
input_tensor
=
input_tensors
.
pop
(
0
)
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
output_tensor_grad
=
recv_backward
(
bt_shape
,
dtype
=
self
.
dtype
)
output_tensor_grad
=
comm
.
recv_backward
(
bt_shape
,
dtype
=
self
.
dtype
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
input_tensor_grad
=
self
.
backward_step
(
input_tensor_grad
=
self
.
backward_step
(
engine
,
engine
,
...
@@ -306,7 +316,7 @@ class PipelineSchedule(BaseSchedule):
...
@@ -306,7 +316,7 @@ class PipelineSchedule(BaseSchedule):
output_tensor_grad
output_tensor_grad
)
)
send_backward
(
input_tensor_grad
)
comm
.
send_backward
(
input_tensor_grad
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
if
len
(
return_tensors
)
>
0
:
if
len
(
return_tensors
)
>
0
:
output
,
label
=
tuple
(
map
(
list
,
zip
(
*
return_tensors
)))
output
,
label
=
tuple
(
map
(
list
,
zip
(
*
return_tensors
)))
...
@@ -322,7 +332,8 @@ class InterleavedPipelineSchedule(PipelineSchedule):
...
@@ -322,7 +332,8 @@ class InterleavedPipelineSchedule(PipelineSchedule):
num_microbatches
,
num_microbatches
,
num_model_chunks
,
num_model_chunks
,
batch_data_process_func
:
Callable
=
None
,
batch_data_process_func
:
Callable
=
None
,
tensor_shape
:
Union
[
torch
.
Size
,
List
[
int
],
Tuple
[
int
]]
=
None
):
tensor_shape
:
Union
[
torch
.
Size
,
List
[
int
],
Tuple
[
int
]]
=
None
,
scatter_gather_tensors
:
bool
=
False
):
"""A helper schedule class for pipeline parallelism running environment.
"""A helper schedule class for pipeline parallelism running environment.
It uses interleaved 1F1B strategy. Other properties are similar as
It uses interleaved 1F1B strategy. Other properties are similar as
:class:`NonPipelineSchedule`.
:class:`NonPipelineSchedule`.
...
@@ -333,10 +344,13 @@ class InterleavedPipelineSchedule(PipelineSchedule):
...
@@ -333,10 +344,13 @@ class InterleavedPipelineSchedule(PipelineSchedule):
:type num_model_chunks: int
:type num_model_chunks: int
:param batch_data_process_func: The preprocessing function which receives a batch of data, and it will be executed in `load_batch`
:param batch_data_process_func: The preprocessing function which receives a batch of data, and it will be executed in `load_batch`
:type batch_data_process_func: Callable
:type batch_data_process_func: Callable
:param scatter_gather_tensors: If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization
:type scatter_gather_tensors: bool
"""
"""
assert
num_microbatches
%
gpc
.
get_world_size
(
ParallelMode
.
PIPELINE
)
==
0
,
\
assert
num_microbatches
%
gpc
.
get_world_size
(
ParallelMode
.
PIPELINE
)
==
0
,
\
'num_microbatches must be an integer multiple of pipeline parallel world size'
'num_microbatches must be an integer multiple of pipeline parallel world size'
super
().
__init__
(
num_microbatches
,
batch_data_process_func
=
batch_data_process_func
,
tensor_shape
=
tensor_shape
)
super
().
__init__
(
num_microbatches
,
batch_data_process_func
=
batch_data_process_func
,
tensor_shape
=
tensor_shape
,
scatter_gather_tensors
=
scatter_gather_tensors
)
gpc
.
set_virtual_pipeline_parallel_size
(
num_model_chunks
)
gpc
.
set_virtual_pipeline_parallel_size
(
num_model_chunks
)
gpc
.
set_virtual_pipeline_parallel_rank
(
0
)
gpc
.
set_virtual_pipeline_parallel_rank
(
0
)
self
.
num_model_chunks
=
num_model_chunks
self
.
num_model_chunks
=
num_model_chunks
...
@@ -494,15 +508,16 @@ class InterleavedPipelineSchedule(PipelineSchedule):
...
@@ -494,15 +508,16 @@ class InterleavedPipelineSchedule(PipelineSchedule):
# Run warmup forward passes.
# Run warmup forward passes.
gpc
.
set_virtual_pipeline_parallel_rank
(
0
)
gpc
.
set_virtual_pipeline_parallel_rank
(
0
)
if
not
gpc
.
is_pipeline_first_stage
():
if
not
gpc
.
is_pipeline_first_stage
():
input_tensor_shapes
[
0
]
=
recv_tensor_meta
(
input_tensor_shapes
[
0
])
input_tensor_shapes
[
0
]
=
comm
.
recv_tensor_meta
(
input_tensor_shapes
[
0
])
input_tensors
[
0
].
append
(
recv_forward
(
input_tensor_shapes
[
0
],
dtype
=
self
.
dtype
))
input_tensors
[
0
].
append
(
comm
.
recv_forward
(
input_tensor_shapes
[
0
],
dtype
=
self
.
dtype
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
))
for
k
in
range
(
num_warmup_microbatches
):
for
k
in
range
(
num_warmup_microbatches
):
model_chunk_id
=
get_model_chunk_id
(
k
,
forward
=
True
)
model_chunk_id
=
get_model_chunk_id
(
k
,
forward
=
True
)
output_tensor
=
forward_step_helper
(
k
)
output_tensor
=
forward_step_helper
(
k
)
if
not
gpc
.
is_pipeline_last_stage
():
if
not
gpc
.
is_pipeline_last_stage
():
output_tensor_shapes
[
model_chunk_id
]
=
output_tensor
.
shape
output_tensor_shapes
[
model_chunk_id
]
=
output_tensor
.
shape
send_tensor_shape_flags
[
model_chunk_id
]
=
send_tensor_meta
(
send_tensor_shape_flags
[
model_chunk_id
]
=
comm
.
send_tensor_meta
(
output_tensor
,
send_tensor_shape_flags
[
model_chunk_id
])
output_tensor
,
send_tensor_shape_flags
[
model_chunk_id
])
# Determine if tensor should be received from previous stage.
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
True
)
next_forward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
True
)
...
@@ -519,7 +534,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
...
@@ -519,7 +534,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
with
switch_virtual_pipeline_parallel_rank
(
next_forward_model_chunk_id
):
with
switch_virtual_pipeline_parallel_rank
(
next_forward_model_chunk_id
):
if
not
gpc
.
is_pipeline_first_stage
():
if
not
gpc
.
is_pipeline_first_stage
():
input_tensor_shapes
[
next_forward_model_chunk_id
]
=
recv_tensor_meta
(
input_tensor_shapes
[
next_forward_model_chunk_id
]
=
comm
.
recv_tensor_meta
(
input_tensor_shapes
[
next_forward_model_chunk_id
])
input_tensor_shapes
[
next_forward_model_chunk_id
])
# Send and receive tensors as appropriate (send tensors computed
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
# in this iteration; receive tensors for next iteration).
...
@@ -532,20 +547,22 @@ class InterleavedPipelineSchedule(PipelineSchedule):
...
@@ -532,20 +547,22 @@ class InterleavedPipelineSchedule(PipelineSchedule):
recv_next
=
False
recv_next
=
False
output_shape
=
output_tensor_shapes
[
num_model_chunks
-
1
]
if
recv_next
else
None
output_shape
=
output_tensor_shapes
[
num_model_chunks
-
1
]
if
recv_next
else
None
input_tensor
,
output_tensor_grad
=
\
input_tensor
,
output_tensor_grad
=
\
send_forward_backward_recv_forward_backward
(
comm
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
output_tensor
,
input_tensor_grad
,
input_shape
,
input_shape
,
output_shape
,
output_shape
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
dtype
=
self
.
dtype
)
dtype
=
self
.
dtype
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
else
:
else
:
input_tensor
=
\
input_tensor
=
\
send_forward_recv_forward
(
comm
.
send_forward_recv_forward
(
output_tensor
,
output_tensor
,
input_shape
,
input_shape
,
recv_prev
=
recv_prev
,
recv_prev
=
recv_prev
,
dtype
=
self
.
dtype
)
dtype
=
self
.
dtype
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
# Run 1F1B in steady state.
# Run 1F1B in steady state.
...
@@ -608,12 +625,13 @@ class InterleavedPipelineSchedule(PipelineSchedule):
...
@@ -608,12 +625,13 @@ class InterleavedPipelineSchedule(PipelineSchedule):
output_shape
=
output_tensor_shapes
[
next_backward_model_chunk_id
]
if
recv_next
else
None
output_shape
=
output_tensor_shapes
[
next_backward_model_chunk_id
]
if
recv_next
else
None
# Communicate tensors.
# Communicate tensors.
input_tensor
,
output_tensor_grad
=
\
input_tensor
,
output_tensor_grad
=
\
send_forward_backward_recv_forward_backward
(
comm
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
output_tensor
,
input_tensor_grad
,
input_shape
,
input_shape
,
output_shape
,
output_shape
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
dtype
=
self
.
dtype
)
dtype
=
self
.
dtype
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
# Put input_tensor and output_tensor_grad in data structures in the
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
# right location.
...
@@ -627,7 +645,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
...
@@ -627,7 +645,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
if
not
forward_only
:
if
not
forward_only
:
if
all_warmup_microbatches
:
if
all_warmup_microbatches
:
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
recv_backward
(
output_tensor_shapes
[
num_model_chunks
-
1
]))
comm
.
recv_backward
(
output_tensor_shapes
[
num_model_chunks
-
1
]
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
))
for
k
in
range
(
num_microbatches_remaining
,
num_microbatches
):
for
k
in
range
(
num_microbatches_remaining
,
num_microbatches
):
input_tensor_grad
=
backward_step_helper
(
k
)
input_tensor_grad
=
backward_step_helper
(
k
)
next_backward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
False
)
next_backward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
False
)
...
@@ -639,11 +657,12 @@ class InterleavedPipelineSchedule(PipelineSchedule):
...
@@ -639,11 +657,12 @@ class InterleavedPipelineSchedule(PipelineSchedule):
recv_next
=
False
recv_next
=
False
output_shape
=
output_tensor_shapes
[
next_backward_model_chunk_id
]
if
recv_next
else
None
output_shape
=
output_tensor_shapes
[
next_backward_model_chunk_id
]
if
recv_next
else
None
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
send_backward_recv_backward
(
comm
.
send_backward_recv_backward
(
input_tensor_grad
,
input_tensor_grad
,
output_shape
,
output_shape
,
recv_next
=
recv_next
,
recv_next
=
recv_next
,
dtype
=
self
.
dtype
))
dtype
=
self
.
dtype
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
))
if
len
(
return_tensors
)
>
0
:
if
len
(
return_tensors
)
>
0
:
output
,
label
=
tuple
(
map
(
list
,
zip
(
*
return_tensors
)))
output
,
label
=
tuple
(
map
(
list
,
zip
(
*
return_tensors
)))
...
...
colossalai/initialize.py
View file @
293fb40c
...
@@ -290,9 +290,10 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
...
@@ -290,9 +290,10 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
# initialize amp
# initialize amp
amp_mode
=
None
amp_mode
=
None
if
fp16_cfg
is
not
None
and
fp16_cfg
.
mode
is
not
None
:
if
fp16_cfg
is
not
None
and
fp16_cfg
.
mode
is
not
None
:
# TODO: pipeline only support NAIVE AMP
cfg_
=
fp16_cfg
.
copy
()
cfg_
=
fp16_cfg
.
copy
()
amp_mode
=
cfg_
.
pop
(
'mode'
)
amp_mode
=
cfg_
.
pop
(
'mode'
)
if
is_using_pp
():
assert
amp_mode
==
AMP_TYPE
.
NAIVE
,
'Pipeline only support NaiveAMP currently'
if
amp_mode
==
AMP_TYPE
.
NAIVE
:
if
amp_mode
==
AMP_TYPE
.
NAIVE
:
cfg_
[
'clip_grad'
]
=
clip_grad_norm
cfg_
[
'clip_grad'
]
=
clip_grad_norm
model
,
optimizer
,
criterion
=
convert_to_amp
(
model
=
model
,
model
,
optimizer
,
criterion
=
convert_to_amp
(
model
=
model
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment