Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
8f02a88d
Unverified
Commit
8f02a88d
authored
Dec 20, 2021
by
ver217
Committed by
GitHub
Dec 20, 2021
Browse files
add interleaved pipeline, fix naive amp and update pipeline model initializer (#80)
parent
91c327cb
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
554 additions
and
180 deletions
+554
-180
colossalai/amp/naive_amp/__init__.py
colossalai/amp/naive_amp/__init__.py
+9
-3
colossalai/amp/naive_amp/_fp16_optimizer.py
colossalai/amp/naive_amp/_fp16_optimizer.py
+36
-20
colossalai/builder/__init__.py
colossalai/builder/__init__.py
+2
-2
colossalai/builder/pipeline.py
colossalai/builder/pipeline.py
+94
-83
colossalai/communication/p2p.py
colossalai/communication/p2p.py
+32
-23
colossalai/communication/utils.py
colossalai/communication/utils.py
+2
-8
colossalai/context/parallel_context.py
colossalai/context/parallel_context.py
+20
-0
colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py
...ngine/gradient_handler/_data_parallel_gradient_handler.py
+1
-1
colossalai/engine/schedule/__init__.py
colossalai/engine/schedule/__init__.py
+2
-2
colossalai/engine/schedule/_pipeline_schedule.py
colossalai/engine/schedule/_pipeline_schedule.py
+316
-14
colossalai/initialize.py
colossalai/initialize.py
+10
-10
colossalai/utils/__init__.py
colossalai/utils/__init__.py
+3
-2
colossalai/utils/common.py
colossalai/utils/common.py
+10
-0
docs/parallelization.md
docs/parallelization.md
+10
-4
tests/test_data_pipeline_tensor_parallel/run_cifar10_vit2d_with_pipeline.py
...peline_tensor_parallel/run_cifar10_vit2d_with_pipeline.py
+3
-4
tests/test_trainer/test_pipeline/test_partition.py
tests/test_trainer/test_pipeline/test_partition.py
+2
-2
tests/test_trainer/test_pipeline/test_pipeline_schedule.py
tests/test_trainer/test_pipeline/test_pipeline_schedule.py
+2
-2
No files found.
colossalai/amp/naive_amp/__init__.py
View file @
8f02a88d
...
...
@@ -20,10 +20,16 @@ def convert_to_naive_amp(model: nn.Module,
:return: (model, optimizer)
:rtype: Tuple
"""
if
is_no_pp_or_last_stage
():
model
=
NaiveAMPModel
(
model
,
output_to_fp32
=
True
)
if
isinstance
(
model
,
nn
.
ModuleList
):
# interleaved pipeline
module_list
=
[]
for
chunk
,
m
in
enumerate
(
model
):
output_to_fp32
=
is_no_pp_or_last_stage
()
and
chunk
==
len
(
model
)
-
1
module_list
.
append
(
NaiveAMPModel
(
m
,
output_to_fp32
=
output_to_fp32
))
model
=
nn
.
ModuleList
(
module_list
)
else
:
model
=
NaiveAMPModel
(
model
,
output_to_fp32
=
False
)
output_to_fp32
=
is_no_pp_or_last_stage
()
model
=
NaiveAMPModel
(
model
,
output_to_fp32
=
output_to_fp32
)
optimizer
=
NaiveAMPOptimizer
(
optimizer
,
**
amp_config
)
return
model
,
optimizer
...
...
colossalai/amp/naive_amp/_fp16_optimizer.py
View file @
8f02a88d
...
...
@@ -14,7 +14,7 @@ from colossalai.context.parallel_mode import ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
(
print_rank_0
,
copy_tensor_parallel_attributes
,
clip_grad_norm_fp32
,
count_zeros_fp32
,
multi_tensor_applier
)
clip_grad_norm_fp32
,
count_zeros_fp32
,
multi_tensor_applier
,
is_using_pp
)
def
_zero_grad_group_helper
(
group
,
set_to_none
):
...
...
@@ -58,7 +58,8 @@ class DynamicGradScaler:
backoff_factor
,
growth_interval
,
hysteresis
,
max_scale
:
int
=
None
):
max_scale
:
int
=
None
,
verbose
:
bool
=
False
):
""""Grad scaler with dynamic scale that gets adjusted
during training."""
assert
initial_scale
>
0.0
...
...
@@ -91,6 +92,7 @@ class DynamicGradScaler:
self
.
_hysteresis_tracker
=
self
.
hysteresis
self
.
_logger
=
get_dist_logger
()
self
.
verbose
=
verbose
@
property
def
scale
(
self
):
...
...
@@ -111,7 +113,8 @@ class DynamicGradScaler:
if
self
.
_hysteresis_tracker
<=
0
:
self
.
_scale
=
torch
.
max
(
self
.
_scale
*
self
.
backoff_factor
,
self
.
min_scale
)
self
.
_logger
.
info
(
f
'overflow occurs, loss scale is adjusted to
{
self
.
_scale
}
'
,
ranks
=
[
0
])
if
self
.
verbose
:
self
.
_logger
.
info
(
f
'overflow occurs, loss scale is adjusted to
{
self
.
_scale
}
'
,
ranks
=
[
0
])
else
:
# If there is no nan/inf, increment the growth tracker.
self
.
_growth_tracker
+=
1
...
...
@@ -122,11 +125,14 @@ class DynamicGradScaler:
self
.
_hysteresis_tracker
=
self
.
hysteresis
# and scale up the loss scale.
if
self
.
_max_scale
is
not
None
and
self
.
_scale
>=
self
.
_max_scale
:
self
.
_logger
.
info
(
f
'Current loss scale
{
self
.
_scale
}
has reached the max scale
{
self
.
_max_scale
}
allowed'
,
ranks
=
[
0
])
if
self
.
verbose
:
self
.
_logger
.
info
(
f
'Current loss scale
{
self
.
_scale
}
has reached the max scale
{
self
.
_max_scale
}
allowed'
,
ranks
=
[
0
])
else
:
self
.
_scale
=
self
.
_scale
*
self
.
growth_factor
self
.
_logger
.
info
(
f
'no consecutive overflow, loss scale is adjusted to
{
self
.
_scale
}
'
,
ranks
=
[
0
])
if
self
.
verbose
:
self
.
_logger
.
info
(
f
'no consecutive overflow, loss scale is adjusted to
{
self
.
_scale
}
'
,
ranks
=
[
0
])
def
state_dict
(
self
):
state_dict
=
{}
...
...
@@ -162,6 +168,8 @@ class FP16Optimizer(Optimizer):
:type hysterisis: int
:param max_scale: maximum loss scale allowed
:type max_scale: int
:param verbose: if set to `True`, will print debug info
:type verbose: bool
"""
def
__init__
(
self
,
...
...
@@ -174,27 +182,29 @@ class FP16Optimizer(Optimizer):
backoff_factor
=
0.5
,
growth_interval
=
1000
,
hysteresis
=
2
,
max_scale
:
int
=
2
**
32
):
max_scale
:
int
=
2
**
32
,
verbose
:
bool
=
False
):
# default args for compatibility
bf16
=
False
params_have_main_grad
=
Tru
e
params_have_main_grad
=
Fals
e
# have a defaults for compatibility with pytorch optim
self
.
defaults
=
optimizer
.
defaults
# log config
self
.
_logger
=
get_dist_logger
()
self
.
_logger
.
info
(
f
"
\n
========= FP16 Optimizer Config =========
\n
"
f
"Optimizer:
{
optimizer
.
__class__
.
__name__
}
\n
"
f
"clip_grad =
{
clip_grad
}
\n
"
f
"log_num_zeros_in_grad =
{
log_num_zeros_in_grad
}
\n
"
f
"initial_scale =
{
initial_scale
}
\n
"
f
"min_scale =
{
min_scale
}
\n
"
f
"growth_factor =
{
growth_factor
}
\n
"
f
"backoff_factor =
{
backoff_factor
}
\n
"
f
"growth_interval =
{
growth_interval
}
\n
"
f
"hysteresis =
{
hysteresis
}
\n
"
f
"=========================================="
,
ranks
=
[
0
])
if
verbose
:
self
.
_logger
.
info
(
f
"
\n
========= FP16 Optimizer Config =========
\n
"
f
"Optimizer:
{
optimizer
.
__class__
.
__name__
}
\n
"
f
"clip_grad =
{
clip_grad
}
\n
"
f
"log_num_zeros_in_grad =
{
log_num_zeros_in_grad
}
\n
"
f
"initial_scale =
{
initial_scale
}
\n
"
f
"min_scale =
{
min_scale
}
\n
"
f
"growth_factor =
{
growth_factor
}
\n
"
f
"backoff_factor =
{
backoff_factor
}
\n
"
f
"growth_interval =
{
growth_interval
}
\n
"
f
"hysteresis =
{
hysteresis
}
\n
"
f
"=========================================="
,
ranks
=
[
0
])
"""Input optimizer is the base optimizer for example Adam."""
self
.
optimizer
=
optimizer
...
...
@@ -212,7 +222,8 @@ class FP16Optimizer(Optimizer):
backoff_factor
=
backoff_factor
,
growth_interval
=
growth_interval
,
hysteresis
=
hysteresis
,
max_scale
=
max_scale
max_scale
=
max_scale
,
verbose
=
verbose
)
# None grad scaler is only supported for bf16.
...
...
@@ -350,6 +361,11 @@ class FP16Optimizer(Optimizer):
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
gpc
.
get_group
(
ParallelMode
.
TENSOR
))
if
is_using_pp
():
torch
.
distributed
.
all_reduce
(
self
.
found_inf
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
gpc
.
get_group
(
ParallelMode
.
PIPELINE
))
# Check for nan.
found_inf_flag
=
(
self
.
found_inf
.
item
()
>
0
)
return
found_inf_flag
...
...
colossalai/builder/__init__.py
View file @
8f02a88d
from
.builder
import
(
build_schedule
,
build_lr_scheduler
,
build_model
,
build_optimizer
,
build_layer
,
build_loss
,
build_hooks
,
build_dataset
,
build_transform
,
build_data_sampler
,
build_gradient_handler
)
from
.pipeline
import
P
ipeline
M
odel
Initializer
from
.pipeline
import
build_p
ipeline
_m
odel
,
build_pipeline_model_from_cfg
__all__
=
[
'build_schedule'
,
'build_lr_scheduler'
,
'build_model'
,
'build_optimizer'
,
'build_layer'
,
'build_loss'
,
'build_hooks'
,
'build_dataset'
,
'build_transform'
,
'build_data_sampler'
,
'build_gradient_handler'
,
'
P
ipeline
M
odel
Initializer
'
'build_gradient_handler'
,
'
build_p
ipeline
_m
odel
'
,
'build_pipeline_model_from_cfg
'
]
colossalai/builder/pipeline.py
View file @
8f02a88d
import
copy
import
heapq
from
colossalai.builder
import
build_model
,
build_layer
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
set_to_cuda
import
torch.nn
as
nn
def
_binary_partition
(
weights
,
st
,
ed
):
...
...
@@ -150,7 +151,19 @@ def _partition_balanced(weights, pipeline_parallel_size, num_chunks):
return
parts
class
PipelineModelInitializer
():
def
_count_layer_params
(
layers
):
"""Count the number of parameters in each layer
"""
param_counts
=
[
0
]
*
len
(
layers
)
for
idx
,
cfg
in
enumerate
(
layers
):
layer
=
build_layer
(
cfg
)
params
=
filter
(
lambda
p
:
p
.
requires_grad
,
layer
.
parameters
())
param_counts
[
idx
]
=
sum
(
p
.
numel
()
for
p
in
params
)
return
param_counts
def
build_pipeline_model_from_cfg
(
config
,
num_chunks
:
int
=
1
,
partition_method
:
str
=
'parameter'
,
verbose
:
bool
=
False
):
"""An intializer to split the model into different stages for pipeline parallelism.
An example for the model config is shown below. The class VisionTransformerFromConfig should
...
...
@@ -168,88 +181,86 @@ class PipelineModelInitializer():
:param num_chunks: the number of chunks you want to have on the current stage. This value should be 1
in most cases unless you are using virutal pipeline parallelism.
:type num_chunks: int
:param partition_method: this parameter determines how you want to split your model layers into stages,
you can set it as 'layer' or 'parameter'
:type partition_method: str
:param verbose: whether to print the logs
:type verbose: bool
"""
ori_model
=
build_model
(
config
)
layers
=
ori_model
.
layers_cfg
layer_length
=
len
(
layers
)
logger
=
get_dist_logger
()
if
verbose
:
logger
.
info
(
f
"The total length of layers is
{
layer_length
}
"
,
ranks
=
[
0
])
pipeline_parallel_size
=
gpc
.
get_world_size
(
ParallelMode
.
PIPELINE
)
pipeline_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
method
=
partition_method
.
lower
()
# Make a partition
if
method
==
'layer'
:
num_layers
=
len
(
layers
)
parts
=
_partition_uniform
(
num_layers
,
pipeline_parallel_size
,
num_chunks
)
elif
method
==
'parameter'
:
param_counts
=
_count_layer_params
(
layers
)
# print_rank_0(param_counts)
parts
=
_partition_balanced
(
param_counts
,
pipeline_parallel_size
,
num_chunks
)
else
:
raise
ValueError
(
"Method should be a pre-set string in [layer, parameter]"
)
# Display the partition
if
verbose
:
log_str
=
'Layer allocation after partitioning:
\n
'
for
stage
in
range
(
pipeline_parallel_size
):
num_layers
=
0
for
st
,
ed
in
parts
[
stage
]:
num_layers
+=
ed
-
st
log_str
+=
f
'
\n
===== stage=
{
stage
}
, layers=
{
num_layers
}
=====
\n
'
for
st
,
ed
in
parts
[
stage
]:
for
idx
,
layer
in
enumerate
(
layers
[
st
:
ed
]):
log_str
+=
f
'
\t
{
idx
+
st
:
2
d
}
:
{
layer
}
\n
'
logger
.
info
(
log_str
,
ranks
=
[
0
])
# Save the partition
interval
=
parts
[
pipeline_rank
]
models
=
[]
for
st
,
ed
in
interval
:
model
=
copy
.
deepcopy
(
ori_model
)
model
.
build_from_cfg
(
st
,
ed
)
models
.
append
(
model
)
return
nn
.
ModuleList
(
models
)
if
len
(
models
)
>
1
else
models
[
0
]
def
build_pipeline_model
(
layers
:
nn
.
Sequential
,
num_chunks
:
int
=
1
,
verbose
:
bool
=
False
):
"""An intializer to split the model into different stages for pipeline parallelism.
Note that `layer` must be `torch.nn.Sequential`.
def
__init__
(
self
,
config
,
num_chunks
,
verbose
=
False
):
self
.
num_chunks
=
num_chunks
self
.
ori_model
=
build_model
(
config
)
self
.
layers
=
self
.
ori_model
.
layers_cfg
layer_length
=
len
(
self
.
layers
)
self
.
verbose
=
verbose
self
.
_logger
=
get_dist_logger
()
self
.
_logger
.
info
(
f
"The total length of layers is
{
layer_length
}
"
,
ranks
=
[
0
])
def
initialize
(
self
,
partition_method
=
'parameter'
):
"""Initialize the model object from the config passed
:param partition_method: this parameter determines how you want to split your model layers into stages,
you can set it as 'layer' or 'parameter'
:type partition_method: str
"""
# Some space for initializing comunication groups
self
.
_interval
=
None
self
.
_partition_layers
(
method
=
partition_method
)
models
=
self
.
_build
()
model
=
set_to_cuda
(
models
)
return
model
def
_partition_layers
(
self
,
method
):
pipeline_parallel_size
=
gpc
.
get_world_size
(
ParallelMode
.
PIPELINE
)
pipeline_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
method
=
method
.
lower
()
# Make a partition
if
method
==
'layer'
:
num_layers
=
len
(
self
.
layers
)
self
.
parts
=
_partition_uniform
(
num_layers
,
pipeline_parallel_size
,
self
.
num_chunks
)
elif
method
==
'parameter'
:
param_counts
=
self
.
_count_layer_params
()
# print_rank_0(param_counts)
self
.
parts
=
_partition_balanced
(
param_counts
,
pipeline_parallel_size
,
self
.
num_chunks
)
else
:
raise
ValueError
(
"Method should be a pre-set string in [layer, parameter]"
)
# Display the partition
if
gpc
.
get_global_rank
()
==
0
and
self
.
verbose
:
log_str
=
'Layer allocation after partitioning:
\n
'
for
stage
in
range
(
pipeline_parallel_size
):
num_layers
=
0
for
st
,
ed
in
self
.
parts
[
stage
]:
num_layers
+=
ed
-
st
log_str
+=
f
'
\n
===== stage=
{
stage
}
, layers=
{
num_layers
}
=====
\n
'
for
st
,
ed
in
self
.
parts
[
stage
]:
for
idx
,
layer
in
enumerate
(
self
.
layers
[
st
:
ed
]):
log_str
+=
f
'
\t
{
idx
+
st
:
2
d
}
:
{
layer
}
\n
'
self
.
_logger
.
info
(
log_str
,
ranks
=
[
0
])
# Save the partition
self
.
_interval
=
self
.
parts
[
pipeline_rank
]
def
_build
(
self
):
"""Build model from the layer cfg according to the partition
"""
models
=
[]
for
st
,
ed
in
self
.
_interval
:
model
=
copy
.
copy
(
self
.
ori_model
)
model
.
build_from_cfg
(
st
,
ed
)
models
.
append
(
model
)
return
models
def
_count_layer_params
(
self
):
"""Count the number of parameters in each layer
"""
param_counts
=
[
0
]
*
len
(
self
.
layers
)
for
idx
,
cfg
in
enumerate
(
self
.
layers
):
layer
=
build_layer
(
cfg
)
params
=
filter
(
lambda
p
:
p
.
requires_grad
,
layer
.
parameters
())
param_counts
[
idx
]
=
sum
(
p
.
numel
()
for
p
in
params
)
return
param_counts
:param layers: layers of model
:type config: `torch.nn.Sequential`
:param num_chunks: the number of chunks you want to have on the current stage. This value should be 1
in most cases unless you are using virutal pipeline parallelism.
:type num_chunks: int
:param verbose: whether to print the logs
:type verbose: bool
"""
pipeline_parallel_size
=
gpc
.
get_world_size
(
ParallelMode
.
PIPELINE
)
pipeline_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
partitions
=
_partition_uniform
(
len
(
layers
),
pipeline_parallel_size
,
num_chunks
)
module_list
=
[]
for
start
,
end
in
partitions
[
pipeline_rank
]:
module_list
.
append
(
nn
.
Sequential
(
*
layers
[
start
:
end
]))
if
verbose
:
logger
=
get_dist_logger
()
logger
.
info
(
f
'Total
{
len
(
layers
)
}
layers'
,
ranks
=
[
0
])
for
rank
,
part
in
enumerate
(
partitions
):
log_str
=
f
'===== stage=
{
rank
}
=====
\n
'
for
chunk
,
(
start
,
end
)
in
enumerate
(
part
):
log_str
+=
f
'===== chunk=
{
chunk
}
, layer=[
{
start
}
-
{
end
}
] =====
\n
'
log_str
+=
'
\n
'
.
join
([
str
(
layer
)
for
layer
in
layers
[
start
:
end
]])
+
'
\n
'
logger
.
info
(
log_str
,
ranks
=
[
0
])
return
nn
.
ModuleList
(
module_list
)
if
len
(
module_list
)
>
1
else
module_list
[
0
]
colossalai/communication/p2p.py
View file @
8f02a88d
...
...
@@ -63,9 +63,6 @@ def _communicate(tensor_send_next=None,
next_rank
=
gpc
.
get_next_global_rank
(
ParallelMode
.
PIPELINE
)
# rank = dist.get_rank()
rank
=
gpc
.
get_global_rank
()
ops
=
[]
if
tensor_send_prev
is
not
None
:
send_prev_op
=
dist
.
P2POp
(
dist
.
isend
,
tensor_send_prev
,
prev_rank
)
...
...
@@ -88,7 +85,7 @@ def _communicate(tensor_send_next=None,
return
tensor_recv_prev
,
tensor_recv_next
def
recv_forward
(
input_tensor_shape
,
prev_rank
=
None
):
def
recv_forward
(
input_tensor_shape
,
prev_rank
=
None
,
dtype
=
torch
.
float
):
"""Receives the input tensor from the previous member in pipeline.
:param input_tensor_shape: The shape of the tensor to be recieved
...
...
@@ -98,16 +95,17 @@ def recv_forward(input_tensor_shape, prev_rank=None):
:return: The input tensor in forward step
:rtype: :class:`torch.Tensor`
"""
if
gpc
.
is_
first_rank
(
ParallelMode
.
PIPELINE
):
if
gpc
.
is_
pipeline_first_stage
(
):
input_tensor
=
None
else
:
input_tensor
,
_
=
_communicate
(
recv_prev
=
True
,
recv_prev_shape
=
input_tensor_shape
,
prev_rank
=
prev_rank
)
prev_rank
=
prev_rank
,
dtype
=
dtype
)
return
input_tensor
def
recv_backward
(
output_grad_shape
,
next_rank
=
None
):
def
recv_backward
(
output_grad_shape
,
next_rank
=
None
,
dtype
=
torch
.
float
):
"""Receives the grad tensor from the next member in pipeline.
:param output_grad_shape: The shape of the tensor to be recieved
...
...
@@ -117,12 +115,13 @@ def recv_backward(output_grad_shape, next_rank=None):
:return: The grad of output tensor in forward step
:rtype: :class:`torch.Tensor`
"""
if
gpc
.
is_
last_rank
(
ParallelMode
.
PIPELINE
):
if
gpc
.
is_
pipeline_last_stage
(
):
output_tensor_grad
=
None
else
:
_
,
output_tensor_grad
=
_communicate
(
recv_next
=
True
,
recv_next_shape
=
output_grad_shape
,
next_rank
=
next_rank
)
next_rank
=
next_rank
,
dtype
=
dtype
)
return
output_tensor_grad
...
...
@@ -134,7 +133,7 @@ def send_forward(output_tensor, next_rank=None):
:type output_tensor: :class:`torch.Tensor`
:type next_rank: int, optional
"""
if
not
gpc
.
is_
last_rank
(
ParallelMode
.
PIPELINE
):
if
not
gpc
.
is_
pipeline_last_stage
(
):
_communicate
(
tensor_send_next
=
output_tensor
,
next_rank
=
next_rank
)
...
...
@@ -147,7 +146,7 @@ def send_backward(input_tensor_grad, prev_rank=None):
:type input_tensor_grad: :class:`torch.Tensor`
:type prev_rank: int, optional
"""
if
not
gpc
.
is_
first_rank
(
ParallelMode
.
PIPELINE
):
if
not
gpc
.
is_
pipeline_first_stage
(
):
_communicate
(
tensor_send_prev
=
input_tensor_grad
,
prev_rank
=
prev_rank
)
...
...
@@ -155,7 +154,8 @@ def send_backward(input_tensor_grad, prev_rank=None):
def
send_forward_recv_backward
(
output_tensor
,
output_grad_shape
,
recv_next
=
True
,
next_rank
=
None
):
next_rank
=
None
,
dtype
=
torch
.
float
):
"""Batched communication operation. Sends the input tensor to the
next member in pipeline, while recieves the grad tensor from the
next member in pipeline.
...
...
@@ -167,20 +167,22 @@ def send_forward_recv_backward(output_tensor,
:return: The grad of output tensor in forward step
:rtype: :class:`torch.Tensor`
"""
if
gpc
.
is_
last_rank
(
ParallelMode
.
PIPELINE
):
if
gpc
.
is_
pipeline_last_stage
(
):
output_tensor_grad
=
None
else
:
_
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
output_tensor
,
recv_next
=
recv_next
,
recv_next_shape
=
output_grad_shape
,
next_rank
=
next_rank
)
next_rank
=
next_rank
,
dtype
=
dtype
)
return
output_tensor_grad
def
send_backward_recv_forward
(
input_tensor_grad
,
input_tensor_shape
,
recv_prev
=
True
,
prev_rank
=
None
):
prev_rank
=
None
,
dtype
=
torch
.
float
):
"""Batched communication operation. Sends the grad tensor to the
previous member in pipeline, while recieves the input tensor from the
previous member in pipeline.
...
...
@@ -192,13 +194,14 @@ def send_backward_recv_forward(input_tensor_grad,
:return: The input tensor in forward step
:rtype: :class:`torch.Tensor`
"""
if
gpc
.
is_
first_rank
(
ParallelMode
.
PIPELINE
):
if
gpc
.
is_
pipeline_first_stage
(
):
input_tensor
=
None
else
:
input_tensor
,
_
=
_communicate
(
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_prev_shape
=
input_tensor_shape
,
prev_rank
=
prev_rank
)
prev_rank
=
prev_rank
,
dtype
=
dtype
)
return
input_tensor
...
...
@@ -206,7 +209,8 @@ def send_forward_recv_forward(output_tensor,
input_tensor_shape
,
recv_prev
=
True
,
prev_rank
=
None
,
next_rank
=
None
):
next_rank
=
None
,
dtype
=
torch
.
float
):
"""Batched communication operation. Sends the input tensor to the
next member in pipeline, while recieves the input tensor from the
previous member in pipeline.
...
...
@@ -222,7 +226,8 @@ def send_forward_recv_forward(output_tensor,
recv_prev
=
recv_prev
,
recv_prev_shape
=
input_tensor_shape
,
prev_rank
=
prev_rank
,
next_rank
=
next_rank
)
next_rank
=
next_rank
,
dtype
=
dtype
)
return
input_tensor
...
...
@@ -230,7 +235,8 @@ def send_backward_recv_backward(input_tensor_grad,
output_grad_shape
,
recv_next
=
True
,
prev_rank
=
None
,
next_rank
=
None
):
next_rank
=
None
,
dtype
=
torch
.
float
):
"""Batched communication operation. Sends the grad tensor to the
previous member in pipeline, while recieves the grad tensor from the
next member in pipeline.
...
...
@@ -246,7 +252,8 @@ def send_backward_recv_backward(input_tensor_grad,
recv_next
=
recv_next
,
recv_next_shape
=
output_grad_shape
,
prev_rank
=
prev_rank
,
next_rank
=
next_rank
)
next_rank
=
next_rank
,
dtype
=
dtype
)
return
output_tensor_grad
...
...
@@ -257,7 +264,8 @@ def send_forward_backward_recv_forward_backward(output_tensor,
recv_prev
=
True
,
recv_next
=
True
,
prev_rank
=
None
,
next_rank
=
None
):
next_rank
=
None
,
dtype
=
torch
.
float
):
"""Batched communication operation. Sends the input tensor to the next and
the grad tensor to the previous, while recieves the grad tensor from the
next and the input tensor from the previous.
...
...
@@ -281,5 +289,6 @@ def send_forward_backward_recv_forward_backward(output_tensor,
recv_prev_shape
=
input_tensor_shape
,
recv_next_shape
=
output_grad_shape
,
prev_rank
=
prev_rank
,
next_rank
=
next_rank
)
next_rank
=
next_rank
,
dtype
=
dtype
)
return
input_tensor
,
output_tensor_grad
colossalai/communication/utils.py
View file @
8f02a88d
...
...
@@ -29,14 +29,8 @@ def send_tensor_meta(tensor, need_meta=True, next_rank=None):
send_shape
=
torch
.
tensor
(
tensor
.
size
(),
**
tensor_kwargs
)
send_ndims
=
torch
.
tensor
(
len
(
tensor
.
size
()),
**
tensor_kwargs
)
ops
=
[
dist
.
P2POp
(
dist
.
isend
,
send_ndims
,
next_rank
),
dist
.
P2POp
(
dist
.
isend
,
send_shape
,
next_rank
)
]
reqs
=
dist
.
batch_isend_irecv
(
ops
)
for
req
in
reqs
:
req
.
wait
()
torch
.
cuda
.
synchronize
()
dist
.
send
(
send_ndims
,
next_rank
)
dist
.
send
(
send_shape
,
next_rank
)
return
False
...
...
colossalai/context/parallel_context.py
View file @
8f02a88d
...
...
@@ -53,6 +53,8 @@ class ParallelContext:
self
.
data_parallel_size
=
1
self
.
pipeline_parallel_size
=
1
self
.
tensor_parallel_size
=
1
self
.
virtual_pipeline_parallel_size
=
None
self
.
virtual_pipeline_parallel_rank
=
None
# logging
self
.
_verbose
=
False
...
...
@@ -205,6 +207,18 @@ class ParallelContext:
world_size
=
self
.
get_world_size
(
parallel_mode
)
return
rank
==
world_size
-
1
def
is_pipeline_first_stage
(
self
,
ignore_virtual
=
False
):
if
not
ignore_virtual
:
if
self
.
virtual_pipeline_parallel_size
is
not
None
and
self
.
virtual_pipeline_parallel_rank
!=
0
:
return
False
return
self
.
is_first_rank
(
ParallelMode
.
PIPELINE
)
def
is_pipeline_last_stage
(
self
,
ignore_virtual
=
False
):
if
not
ignore_virtual
:
if
self
.
virtual_pipeline_parallel_size
is
not
None
and
self
.
virtual_pipeline_parallel_rank
!=
self
.
virtual_pipeline_parallel_size
-
1
:
return
False
return
self
.
is_last_rank
(
ParallelMode
.
PIPELINE
)
def
get_world_size
(
self
,
parallel_mode
:
ParallelMode
):
"""Returns the world size for `parallel_mode`.
...
...
@@ -494,3 +508,9 @@ class ParallelContext:
self
.
_logger
.
info
(
'WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states'
,
ranks
=
[
0
])
def
set_virtual_pipeline_parallel_size
(
self
,
size
):
self
.
virtual_pipeline_parallel_size
=
size
def
set_virtual_pipeline_parallel_rank
(
self
,
rank
):
self
.
virtual_pipeline_parallel_rank
=
rank
colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py
View file @
8f02a88d
...
...
@@ -32,7 +32,7 @@ class DataParallelGradientHandler(BaseGradientHandler):
if
tp
not
in
buckets
:
buckets
[
tp
]
=
[]
buckets
[
tp
].
append
(
param
)
param
.
main_grad
=
param
.
grad
#
param.main_grad = param.grad
# For each bucket, all-reduce and copy all-reduced grads.
for
tp
in
buckets
:
...
...
colossalai/engine/schedule/__init__.py
View file @
8f02a88d
from
._base_schedule
import
BaseSchedule
from
._pipeline_schedule
import
PipelineSchedule
from
._pipeline_schedule
import
PipelineSchedule
,
InterleavedPipelineSchedule
from
._non_pipeline_schedule
import
NonPipelineSchedule
__all__
=
[
'BaseSchedule'
,
'PipelineSchedule'
,
'NonPipelineSchedule'
]
__all__
=
[
'BaseSchedule'
,
'PipelineSchedule'
,
'NonPipelineSchedule'
,
'InterleavedPipelineSchedule'
]
colossalai/engine/schedule/_pipeline_schedule.py
View file @
8f02a88d
...
...
@@ -13,9 +13,8 @@ from colossalai.core import global_context as gpc
from
colossalai.amp.naive_amp
import
NaiveAMPModel
from
colossalai.zero
import
(
ZeroRedundancyOptimizer_Level_2
,
ZeroRedundancyOptimizer_Level_3
)
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
,
switch_virtual_pipeline_parallel_rank
from
._base_schedule
import
BaseSchedule
from
colossalai.amp
import
AMP_TYPE
def
squeeze
(
x
:
Union
[
Tensor
,
tuple
,
list
]):
...
...
@@ -47,6 +46,7 @@ class PipelineSchedule(BaseSchedule):
self
.
num_microbatches
=
num_microbatches
self
.
sync_data
=
sync_data
self
.
dtype
=
torch
.
float
def
_move_to_device
(
self
,
data
):
if
isinstance
(
data
,
(
...
...
@@ -122,12 +122,8 @@ class PipelineSchedule(BaseSchedule):
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
)
# LSG: set default dtype to fp16 for communication
if
isinstance
(
engine
.
model
,
NaiveAMPModel
):
torch
.
set_default_dtype
(
torch
.
half
)
self
.
logger
.
warning
(
'default tensor dtype is set to torch.half for fp16 training'
,
ranks
=
[
0
])
self
.
dtype
=
torch
.
half
def
forward_step
(
self
,
engine
,
input_tensor
,
return_tensors
,
return_loss
=
True
):
"""Forward step for passed-in model. If it is the first stage, the input tensor
...
...
@@ -140,7 +136,7 @@ class PipelineSchedule(BaseSchedule):
:type input_tensor: :class:`torch.Tensor`
:param return_tensors: a list of tensors to return
:type return_tensors: List[:class:`torch.Tensor`]
:return: output or the loss value of the current pipeline stage
:rtype: :class:`torch.Tensor`
"""
...
...
@@ -252,7 +248,7 @@ class PipelineSchedule(BaseSchedule):
for
i
in
range
(
num_warmup_microbatches
):
if
not
gpc
.
is_first_rank
(
ParallelMode
.
PIPELINE
):
ft_shape
=
recv_tensor_meta
(
ft_shape
)
input_tensor
=
recv_forward
(
ft_shape
)
input_tensor
=
recv_forward
(
ft_shape
,
dtype
=
self
.
dtype
)
output_tensor
=
self
.
forward_step
(
engine
,
input_tensor
,
return_tensors
,
return_loss
=
return_loss
...
...
@@ -272,7 +268,7 @@ class PipelineSchedule(BaseSchedule):
if
num_microbatches_remaining
>
0
:
if
not
gpc
.
is_first_rank
(
ParallelMode
.
PIPELINE
):
ft_shape
=
recv_tensor_meta
(
ft_shape
)
input_tensor
=
recv_forward
(
ft_shape
)
input_tensor
=
recv_forward
(
ft_shape
,
dtype
=
self
.
dtype
)
# Run 1F1B in steady state.
for
i
in
range
(
num_microbatches_remaining
):
...
...
@@ -286,11 +282,11 @@ class PipelineSchedule(BaseSchedule):
send_forward
(
output_tensor
)
if
not
last_iteration
:
input_tensor
=
recv_forward
(
ft_shape
)
input_tensor
=
recv_forward
(
ft_shape
,
dtype
=
self
.
dtype
)
else
:
output_tensor_grad
=
send_forward_recv_backward
(
output_tensor
,
bt_shape
)
output_tensor
,
bt_shape
,
dtype
=
self
.
dtype
)
# Add input_tensor and output_tensor to end of list.
input_tensors
.
append
(
input_tensor
)
...
...
@@ -312,7 +308,7 @@ class PipelineSchedule(BaseSchedule):
send_backward
(
input_tensor_grad
)
else
:
input_tensor
=
send_backward_recv_forward
(
input_tensor_grad
,
ft_shape
)
input_tensor_grad
,
ft_shape
,
dtype
=
self
.
dtype
)
# Run cooldown backward passes.
if
not
forward_only
:
...
...
@@ -320,7 +316,7 @@ class PipelineSchedule(BaseSchedule):
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
output_tensor_grad
=
recv_backward
(
bt_shape
)
output_tensor_grad
=
recv_backward
(
bt_shape
,
dtype
=
self
.
dtype
)
input_tensor_grad
=
self
.
backward_step
(
engine
,
...
...
@@ -340,3 +336,309 @@ class PipelineSchedule(BaseSchedule):
return
tuple
((
torch
.
cat
(
return_tensors
,
dim
=
0
),
None
,
None
))
else
:
return
tuple
((
None
,
None
,
None
))
class
InterleavedPipelineSchedule
(
PipelineSchedule
):
def
__init__
(
self
,
num_microbatches
,
num_model_chunks
,
sync_data
:
bool
=
True
):
assert
num_microbatches
%
gpc
.
get_world_size
(
ParallelMode
.
PIPELINE
)
==
0
,
\
'num_microbatches must be an integer multiple of pipeline parallel world size'
super
().
__init__
(
num_microbatches
,
sync_data
=
sync_data
)
gpc
.
set_virtual_pipeline_parallel_size
(
num_model_chunks
)
gpc
.
set_virtual_pipeline_parallel_rank
(
0
)
def
pre_processing
(
self
,
engine
):
if
isinstance
(
engine
.
optimizer
,
(
ZeroRedundancyOptimizer_Level_2
,
ZeroRedundancyOptimizer_Level_3
)):
raise
TypeError
(
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
)
if
isinstance
(
engine
.
model
[
0
],
NaiveAMPModel
):
self
.
dtype
=
torch
.
half
def
forward_step
(
self
,
engine
,
model
,
input_tensor
,
return_tensors
,
return_loss
=
True
):
"""Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
Returns output tensor. This is a helper function and can be ignored by users.
"""
if
input_tensor
is
None
:
input_tensor
,
label
=
self
.
load_micro_batch
()
input_tensor
=
squeeze
(
input_tensor
)
output_tensor
=
model
(
input_tensor
)
output_tensor
=
squeeze
(
output_tensor
)
if
gpc
.
is_pipeline_last_stage
():
if
return_loss
:
input_tensor
,
label
=
self
.
load_micro_batch
()
loss_reduced
=
engine
.
criterion
(
output_tensor
,
*
label
)
/
self
.
num_microbatches
return_tensors
.
append
(
tuple
((
output_tensor
,
label
[
0
],
loss_reduced
)))
return
loss_reduced
else
:
return_tensors
.
append
(
output_tensor
)
return
output_tensor
else
:
return
output_tensor
def
forward_backward_step
(
self
,
engine
,
data_iter
,
forward_only
=
False
,
return_loss
=
True
):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
assert
forward_only
or
return_loss
,
\
'The argument
\'
return_loss
\'
has to be True when
\'
forward_only
\'
is False, but got False.'
self
.
load_batch
(
data_iter
)
model
=
engine
.
model
input_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
output_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
return_tensors
=
[]
if
not
forward_only
:
output_tensor_grads
=
[[]
for
_
in
range
(
len
(
model
))]
# Used for tensor meta information communication
input_tensor_shapes
=
[
None
for
_
in
range
(
len
(
model
))]
output_tensor_shapes
=
[
None
for
_
in
range
(
len
(
model
))]
send_tensor_shape_flags
=
[
True
for
_
in
range
(
len
(
model
))]
pipeline_parallel_size
=
gpc
.
get_world_size
(
ParallelMode
.
PIPELINE
)
pipeline_parallel_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
# Compute number of warmup and remaining microbatches.
num_model_chunks
=
len
(
model
)
num_microbatches
=
self
.
num_microbatches
*
num_model_chunks
all_warmup_microbatches
=
False
if
forward_only
:
num_warmup_microbatches
=
num_microbatches
else
:
# Run all forward passes and then all backward passes if number of
# microbatches is just the number of pipeline stages.
# Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
# all workers, followed by more microbatches after depending on
# stage ID (more forward passes for earlier stages, later stages can
# immediately start with 1F1B).
if
self
.
num_microbatches
==
pipeline_parallel_size
:
num_warmup_microbatches
=
num_microbatches
all_warmup_microbatches
=
True
else
:
num_warmup_microbatches
=
\
(
pipeline_parallel_size
-
pipeline_parallel_rank
-
1
)
*
2
num_warmup_microbatches
+=
(
num_model_chunks
-
1
)
*
pipeline_parallel_size
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_microbatches
)
num_microbatches_remaining
=
\
num_microbatches
-
num_warmup_microbatches
def
get_model_chunk_id
(
microbatch_id
,
forward
):
"""Helper method to get the model chunk ID given the iteration number."""
microbatch_id_in_group
=
microbatch_id
%
(
pipeline_parallel_size
*
num_model_chunks
)
model_chunk_id
=
microbatch_id_in_group
//
pipeline_parallel_size
if
not
forward
:
model_chunk_id
=
(
num_model_chunks
-
model_chunk_id
-
1
)
return
model_chunk_id
def
forward_step_helper
(
microbatch_id
):
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
True
)
gpc
.
set_virtual_pipeline_parallel_rank
(
model_chunk_id
)
# forward step
if
gpc
.
is_pipeline_first_stage
():
if
len
(
input_tensors
[
model_chunk_id
])
==
\
len
(
output_tensors
[
model_chunk_id
]):
input_tensors
[
model_chunk_id
].
append
(
None
)
input_tensor
=
input_tensors
[
model_chunk_id
][
-
1
]
output_tensor
=
self
.
forward_step
(
engine
,
model
[
model_chunk_id
],
input_tensor
,
return_tensors
,
return_loss
=
return_loss
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
# if forward-only, no need to save tensors for a backward pass
if
forward_only
:
input_tensors
[
model_chunk_id
].
pop
()
output_tensors
[
model_chunk_id
].
pop
()
return
output_tensor
def
backward_step_helper
(
microbatch_id
):
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
False
)
gpc
.
set_virtual_pipeline_parallel_rank
(
model_chunk_id
)
if
gpc
.
is_pipeline_last_stage
():
if
len
(
output_tensor_grads
[
model_chunk_id
])
==
0
:
output_tensor_grads
[
model_chunk_id
].
append
(
None
)
input_tensor
=
input_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor
=
output_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor_grad
=
output_tensor_grads
[
model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
self
.
backward_step
(
engine
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
return
input_tensor_grad
# Run warmup forward passes.
gpc
.
set_virtual_pipeline_parallel_rank
(
0
)
if
not
gpc
.
is_pipeline_first_stage
():
input_tensor_shapes
[
0
]
=
recv_tensor_meta
(
input_tensor_shapes
[
0
])
input_tensors
[
0
].
append
(
recv_forward
(
input_tensor_shapes
[
0
],
dtype
=
self
.
dtype
))
for
k
in
range
(
num_warmup_microbatches
):
model_chunk_id
=
get_model_chunk_id
(
k
,
forward
=
True
)
output_tensor
=
forward_step_helper
(
k
)
if
not
gpc
.
is_pipeline_last_stage
():
output_tensor_shapes
[
model_chunk_id
]
=
output_tensor
.
shape
send_tensor_shape_flags
[
model_chunk_id
]
=
send_tensor_meta
(
output_tensor
,
send_tensor_shape_flags
[
model_chunk_id
])
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
True
)
recv_prev
=
True
if
gpc
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
if
next_forward_model_chunk_id
==
0
:
recv_prev
=
False
if
k
==
(
num_microbatches
-
1
):
recv_prev
=
False
# Don't send tensor downstream if on last stage.
if
gpc
.
is_pipeline_last_stage
():
output_tensor
=
None
with
switch_virtual_pipeline_parallel_rank
(
next_forward_model_chunk_id
):
if
not
gpc
.
is_pipeline_first_stage
():
input_tensor_shapes
[
next_forward_model_chunk_id
]
=
recv_tensor_meta
(
input_tensor_shapes
[
next_forward_model_chunk_id
])
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
input_shape
=
input_tensor_shapes
[
next_forward_model_chunk_id
]
if
recv_prev
else
None
if
k
==
(
num_warmup_microbatches
-
1
)
and
not
forward_only
and
\
not
all_warmup_microbatches
:
input_tensor_grad
=
None
recv_next
=
True
if
gpc
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
recv_next
=
False
output_shape
=
output_tensor_shapes
[
num_model_chunks
-
1
]
if
recv_next
else
None
input_tensor
,
output_tensor_grad
=
\
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
input_shape
,
output_shape
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
dtype
=
self
.
dtype
)
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
else
:
input_tensor
=
\
send_forward_recv_forward
(
output_tensor
,
input_shape
,
recv_prev
=
recv_prev
,
dtype
=
self
.
dtype
)
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
# Run 1F1B in steady state.
for
k
in
range
(
num_microbatches_remaining
):
# Forward pass.
forward_k
=
k
+
num_warmup_microbatches
output_tensor
=
forward_step_helper
(
forward_k
)
# Backward pass.
backward_k
=
k
input_tensor_grad
=
backward_step_helper
(
backward_k
)
# Send output_tensor and input_tensor_grad, receive input_tensor
# and output_tensor_grad.
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
,
forward
=
True
)
gpc
.
set_virtual_pipeline_parallel_rank
(
forward_model_chunk_id
)
if
gpc
.
is_pipeline_last_stage
():
output_tensor
=
None
backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
,
forward
=
False
)
gpc
.
set_virtual_pipeline_parallel_rank
(
backward_model_chunk_id
)
if
gpc
.
is_pipeline_first_stage
():
input_tensor_grad
=
None
# Determine if peers are sending, and where in data structure to put
# received tensors.
recv_prev
=
True
if
gpc
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
-
(
pipeline_parallel_size
-
1
),
forward
=
True
)
if
next_forward_model_chunk_id
==
(
num_model_chunks
-
1
):
recv_prev
=
False
next_forward_model_chunk_id
+=
1
else
:
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
+
1
,
forward
=
True
)
recv_next
=
True
if
gpc
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
-
(
pipeline_parallel_size
-
1
),
forward
=
False
)
if
next_backward_model_chunk_id
==
0
:
recv_next
=
False
next_backward_model_chunk_id
-=
1
else
:
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
+
1
,
forward
=
False
)
# If last iteration, don't receive; we already received one extra
# before the start of the for loop.
if
k
==
(
num_microbatches_remaining
-
1
):
recv_prev
=
False
input_shape
=
input_tensor_shapes
[
next_forward_model_chunk_id
]
if
recv_prev
else
None
output_shape
=
output_tensor_shapes
[
next_backward_model_chunk_id
]
if
recv_next
else
None
# Communicate tensors.
input_tensor
,
output_tensor_grad
=
\
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
input_shape
,
output_shape
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
dtype
=
self
.
dtype
)
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
if
recv_prev
:
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
if
recv_next
:
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
output_tensor_grad
)
# Run cooldown backward passes (flush out pipeline).
if
not
forward_only
:
if
all_warmup_microbatches
:
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
recv_backward
(
output_tensor_shapes
[
num_model_chunks
-
1
]))
for
k
in
range
(
num_microbatches_remaining
,
num_microbatches
):
input_tensor_grad
=
backward_step_helper
(
k
)
next_backward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
False
)
recv_next
=
True
if
gpc
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
next_backward_model_chunk_id
==
(
num_model_chunks
-
1
):
recv_next
=
False
if
k
==
(
num_microbatches
-
1
):
recv_next
=
False
output_shape
=
output_tensor_shapes
[
next_backward_model_chunk_id
]
if
recv_next
else
None
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
send_backward_recv_backward
(
input_tensor_grad
,
output_shape
,
recv_next
=
recv_next
,
dtype
=
self
.
dtype
))
if
len
(
return_tensors
)
>
0
:
if
return_loss
:
output
,
label
,
loss
=
tuple
(
map
(
list
,
zip
(
*
return_tensors
)))
return
(
torch
.
cat
(
output
,
dim
=
0
),
torch
.
cat
(
label
,
dim
=
0
),
sum
(
loss
))
else
:
return
tuple
((
torch
.
cat
(
return_tensors
,
dim
=
0
),
None
,
None
))
else
:
return
tuple
((
None
,
None
,
None
))
colossalai/initialize.py
View file @
8f02a88d
...
...
@@ -280,11 +280,21 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
raise
ConfigException
(
"It is not allowed to set fp16 and zero configuration in your config file at the same time"
)
# clip grad norm
clip_grad_norm
=
gpc
.
config
.
get
(
'clip_grad_norm'
,
0.0
)
if
clip_grad_norm
>
0
:
if
zero_cfg
is
not
None
:
raise
ConfigException
(
"clip_grad_norm should be specified with zero, you should specify clip_grad in zero configuration"
)
# initialize amp
amp_mode
=
None
if
fp16_cfg
is
not
None
and
fp16_cfg
.
mode
is
not
None
:
# TODO: pipeline only support NAIVE AMP
cfg_
=
fp16_cfg
.
copy
()
amp_mode
=
cfg_
.
pop
(
'mode'
)
if
amp_mode
==
AMP_TYPE
.
NAIVE
:
cfg_
[
'clip_grad'
]
=
clip_grad_norm
model
,
optimizer
,
criterion
=
convert_to_amp
(
model
=
model
,
optimizer
=
optimizer
,
criterion
=
criterion
,
...
...
@@ -357,16 +367,6 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
gradient_handlers
=
gradient_handlers
,
lr_scheduler
=
lr_scheduler
)
# clip grad norm
clip_grad_norm
=
gpc
.
config
.
get
(
'clip_grad_norm'
,
0.0
)
if
clip_grad_norm
>
0
:
if
zero_cfg
is
not
None
:
raise
ConfigException
(
"clip_grad_norm should be specified with zero, you should specify clip_grad in zero configuration"
)
elif
fp16_cfg
is
not
None
and
fp16_cfg
.
mode
==
AMP_TYPE
.
NAIVE
:
raise
ConfigException
(
"clip_grad_norm should be specified with AMP_TYPE.NAIVE, you should specify clip_grad in fp16 configuration"
)
engine
=
Engine
(
model
=
model
,
optimizer
=
optimizer
,
...
...
colossalai/utils/__init__.py
View file @
8f02a88d
...
...
@@ -3,7 +3,7 @@ from .common import (print_rank_0, sync_model_param_in_dp, is_dp_rank_0,
is_tp_rank_0
,
is_no_pp_or_last_stage
,
is_using_ddp
,
is_using_pp
,
conditional_context
,
is_model_parallel_parameter
,
clip_grad_norm_fp32
,
count_zeros_fp32
,
copy_tensor_parallel_attributes
,
param_is_not_tensor_parallel_duplicate
)
param_is_not_tensor_parallel_duplicate
,
switch_virtual_pipeline_parallel_rank
)
from
.cuda
import
get_current_device
,
synchronize
,
empty_cache
,
set_to_cuda
from
.memory
import
report_memory_usage
from
.timer
import
MultiTimer
,
Timer
...
...
@@ -22,5 +22,6 @@ __all__ = ['checkpoint',
'Timer'
,
'MultiTimer'
,
'multi_tensor_applier'
,
'accumulate_gradient'
,
'DataParallelSampler'
,
'get_dataloader'
'DataParallelSampler'
,
'get_dataloader'
,
'switch_virtual_pipeline_parallel_rank'
]
colossalai/utils/common.py
View file @
8f02a88d
...
...
@@ -249,3 +249,13 @@ def param_is_not_tensor_parallel_duplicate(param):
return
(
hasattr
(
param
,
IS_TENSOR_PARALLEL
)
and
getattr
(
param
,
IS_TENSOR_PARALLEL
))
or
(
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
)
@
contextmanager
def
switch_virtual_pipeline_parallel_rank
(
rank
):
prev_rank
=
gpc
.
virtual_pipeline_parallel_rank
try
:
gpc
.
set_virtual_pipeline_parallel_rank
(
rank
)
yield
finally
:
gpc
.
set_virtual_pipeline_parallel_rank
(
prev_rank
)
docs/parallelization.md
View file @
8f02a88d
...
...
@@ -172,10 +172,10 @@ elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1:
2.
Make sure your model inherit
`colossalai.nn.model.ModelFromConfig`
and registered into the
`MODELS`
registry. Define the
`self.layers_cfg`
attribute.
Pass in a dict/Config object which specifies the parameters of your model.
Use
`colossalai.builder.pipeline.
P
ipeline
M
odel
Initializer
`
to partition the layers.
Use
`colossalai.builder.pipeline.
build_p
ipeline
_m
odel
_from_cfg
`
to partition the layers.
```
python
from
colossalai.builder
import
P
ipeline
M
odel
Initializer
from
colossalai.builder
import
build_p
ipeline
_m
odel
_from_cfg
from
colossalai.nn.model
import
ModelFromConfig
from
colossalai.registry
import
MODELS
...
...
@@ -199,8 +199,11 @@ model_cfg = dict(
...
)
initializer
=
PipelineModelInitializer
(
model_cfg
,
num_chunks
=
1
)
model
=
initializer
.
initialize
()
# from config
model
=
build_pipeline_model_from_cfg
(
model_cfg
,
num_chunks
=
1
)
# from torch.nn.Sequential
# model = build_pipeline_model(sequential_model, num_model_chunks)
```
...
...
@@ -214,6 +217,9 @@ engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criteri
schedule
=
PipelineSchedule
(
num_microbatches
=
4
)
# interleaved pipeline
# schedule = InterleavedPipelineSchedule(num_microbatches=4, num_model_chunks=2)
# execute a training epoch
data_iter
=
iter
(
train_dataloader
)
...
...
tests/test_data_pipeline_tensor_parallel/run_cifar10_vit2d_with_pipeline.py
View file @
8f02a88d
...
...
@@ -6,7 +6,7 @@ from colossalai.logging import get_dist_logger
import
colossalai
import
torch
import
os
from
colossalai.builder
import
P
ipeline
M
odel
Initializer
from
colossalai.builder
import
build_p
ipeline
_m
odel
_from_cfg
from
colossalai.core
import
global_context
as
gpc
from
colossalai.utils
import
get_dataloader
,
MultiTimer
from
colossalai.nn.loss
import
CrossEntropyLoss2D
...
...
@@ -50,8 +50,7 @@ def test_hybrid_parallel():
# suffix='cifar10_2d_vit_ddp1_torch_amp_grad_accum_2_clip_grad_1', mode='w')
# build vit-t-32
initializer
=
PipelineModelInitializer
(
vit_t_2d
.
model_cfg
,
num_chunks
=
1
)
model
=
initializer
.
initialize
()
model
=
build_pipeline_model_from_cfg
(
vit_t_2d
.
model_cfg
,
num_chunks
=
1
)
# build dataloaders
train_dataset
=
CIFAR10
(
...
...
@@ -139,4 +138,4 @@ def test_hybrid_parallel():
if
__name__
==
'__main__'
:
main
()
test_hybrid_parallel
()
tests/test_trainer/test_pipeline/test_partition.py
View file @
8f02a88d
...
...
@@ -5,7 +5,7 @@ import torch
import
torch.multiprocessing
as
mp
from
torch.utils.data
import
DataLoader
from
colossalai.builder.pipeline
import
P
ipeline
M
odel
Initializer
from
colossalai.builder.pipeline
import
build_p
ipeline
_m
odel
_from_cfg
from
colossalai.core
import
global_context
from
colossalai.initialize
import
launch
from
colossalai.logging
import
get_dist_logger
...
...
@@ -28,7 +28,7 @@ def run_partition(rank, world_size):
logger
.
info
(
'finished initialization'
)
# build model
model
=
P
ipeline
M
odel
Initializer
(
global_context
.
config
.
model
,
1
,
verbose
=
True
)
.
initialize
()
model
=
build_p
ipeline
_m
odel
_from_cfg
(
global_context
.
config
.
model
,
1
,
verbose
=
True
)
assert
isinstance
(
model
,
torch
.
nn
.
Module
)
logger
.
info
(
'model is created'
)
...
...
tests/test_trainer/test_pipeline/test_pipeline_schedule.py
View file @
8f02a88d
...
...
@@ -8,7 +8,7 @@ import torch
import
torch.multiprocessing
as
mp
import
model
from
colossalai.builder
import
P
ipeline
M
odel
Initializer
from
colossalai.builder
import
build_p
ipeline
_m
odel
_from_cfg
from
colossalai.communication
import
p2p
as
p2p_communication
from
colossalai.communication.utils
import
send_tensor_meta
,
recv_tensor_meta
from
colossalai.context.parallel_mode
import
ParallelMode
...
...
@@ -39,7 +39,7 @@ def run_schedule(rank, world_size):
backend
=
'nccl'
)
# build model
model
=
P
ipeline
M
odel
Initializer
(
gpc
.
config
.
model
,
1
)
.
initialize
()
model
=
build_p
ipeline
_m
odel
_from_cfg
(
gpc
.
config
.
model
,
1
)
print_rank_0
(
'model is created'
)
train_dataset
=
CIFAR10
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment