Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
e532679c
Commit
e532679c
authored
Jan 10, 2023
by
oahzxl
Browse files
Merge branch 'main' of
https://github.com/oahzxl/ColossalAI
into chunk
parents
c1492e50
7d5640b9
Changes
441
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1706 additions
and
257 deletions
+1706
-257
colossalai/nn/parallel/gemini_parallel.py
colossalai/nn/parallel/gemini_parallel.py
+57
-0
colossalai/nn/parallel/utils.py
colossalai/nn/parallel/utils.py
+112
-20
colossalai/pipeline/__init__.py
colossalai/pipeline/__init__.py
+1
-1
colossalai/pipeline/layer_spec.py
colossalai/pipeline/layer_spec.py
+0
-0
colossalai/pipeline/middleware/__init__.py
colossalai/pipeline/middleware/__init__.py
+3
-0
colossalai/pipeline/middleware/adaptor/__init__.py
colossalai/pipeline/middleware/adaptor/__init__.py
+3
-0
colossalai/pipeline/middleware/adaptor/fx.py
colossalai/pipeline/middleware/adaptor/fx.py
+145
-0
colossalai/pipeline/middleware/topo.py
colossalai/pipeline/middleware/topo.py
+206
-0
colossalai/pipeline/pipelinable.py
colossalai/pipeline/pipelinable.py
+1
-1
colossalai/pipeline/rpc/_pipeline_base.py
colossalai/pipeline/rpc/_pipeline_base.py
+565
-105
colossalai/pipeline/rpc/_pipeline_schedule.py
colossalai/pipeline/rpc/_pipeline_schedule.py
+4
-6
colossalai/pipeline/rpc/utils.py
colossalai/pipeline/rpc/utils.py
+19
-5
colossalai/pipeline/utils.py
colossalai/pipeline/utils.py
+5
-2
colossalai/tensor/__init__.py
colossalai/tensor/__init__.py
+12
-13
colossalai/tensor/colo_parameter.py
colossalai/tensor/colo_parameter.py
+27
-10
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+38
-25
colossalai/tensor/comm_spec.py
colossalai/tensor/comm_spec.py
+175
-10
colossalai/tensor/dist_spec_mgr.py
colossalai/tensor/dist_spec_mgr.py
+6
-4
colossalai/tensor/param_op_hook.py
colossalai/tensor/param_op_hook.py
+87
-36
colossalai/tensor/shape_consistency.py
colossalai/tensor/shape_consistency.py
+240
-19
No files found.
Too many changes to show.
To preserve performance only
441 of 441+
files are displayed.
Plain diff
Email patch
colossalai/nn/parallel/gemini_parallel.py
0 → 100644
View file @
e532679c
from
typing
import
Optional
import
torch
from
colossalai.gemini.chunk
import
init_chunk_manager
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.gemini.memory_tracer
import
MemStats
from
.data_parallel
import
ZeroDDP
class
GeminiDDP
(
ZeroDDP
):
def
__init__
(
self
,
module
:
torch
.
nn
.
Module
,
device
:
torch
.
device
,
placement_policy
:
str
=
"cpu"
,
pin_memory
:
bool
=
False
,
force_outputs_fp32
:
bool
=
False
,
search_range_mb
:
int
=
32
,
hidden_dim
:
Optional
[
int
]
=
None
,
min_chunk_size_mb
:
Optional
[
float
]
=
None
,
memstats
:
Optional
[
MemStats
]
=
None
)
->
None
:
"""
A torch.Module warpper using ZeRO-DP and Genimi.
ZeRO is for parallel. Gemini is for memory management.
WARNING: The class will modify the module inline!
Example:
model is initialized under the context of ColoInitContext
>>> model = GeminiDDP(model, torch.cuda.current_device(), "cuda")
>>> logits = model(x)
>>> loss = criterion(logits, labels)
>>> model.backward(loss)
Args:
module (torch.nn.Module): the model to be wrapped.
device (torch.device): device to place the model.
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32.
hidden_dim (int, optional): the hidden dimension of DNN.
Users can provide this argument to speed up searching.
If users do not know this argument before training, it is ok. We will use a default value 1024.
min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte.
If the aggregate size of parameters is still samller than the minimum chunk size,
all parameters will be compacted into one small chunk.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
"""
chunk_manager
=
init_chunk_manager
(
model
=
module
,
init_device
=
device
,
hidden_dim
=
hidden_dim
,
search_range_mb
=
search_range_mb
,
min_chunk_size_mb
=
min_chunk_size_mb
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
,
memstats
)
super
().
__init__
(
module
,
gemini_manager
,
pin_memory
,
force_outputs_fp32
)
colossalai/nn/parallel/utils.py
View file @
e532679c
import
torch
import
torch.distributed
as
dist
from
colossalai.gemini.chunk
import
Chunk
from
colossalai.utils
import
get_current_device
def
get_temp_total_chunk_on_cuda
(
chunk
:
Chunk
):
if
chunk
.
is_gathered
:
return
chunk
.
chunk_total
if
chunk
.
cuda_shard
is
not
None
:
shard_temp
=
chunk
.
cuda_shard
else
:
shard_temp
=
chunk
.
cpu_shard
.
to
(
get_current_device
())
total_temp
=
torch
.
zeros
(
chunk
.
chunk_size
,
dtype
=
chunk
.
dtype
,
device
=
get_current_device
())
gather_list
=
list
(
torch
.
chunk
(
input
=
total_temp
,
chunks
=
chunk
.
pg_size
,
dim
=
0
))
dist
.
all_gather
(
tensor_list
=
gather_list
,
tensor
=
shard_temp
,
group
=
chunk
.
torch_pg
)
return
total_temp
from
collections
import
OrderedDict
from
copy
import
copy
from
typing
import
Optional
,
Set
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
colossalai.gemini.chunk
import
Chunk
from
colossalai.utils
import
get_current_device
def
get_temp_total_chunk_on_cuda
(
chunk
:
Chunk
):
if
chunk
.
is_gathered
:
return
chunk
.
cuda_global_chunk
if
chunk
.
cuda_shard
is
not
None
:
shard_temp
=
chunk
.
cuda_shard
else
:
shard_temp
=
chunk
.
cpu_shard
.
to
(
get_current_device
())
total_temp
=
torch
.
zeros
(
chunk
.
chunk_size
,
dtype
=
chunk
.
dtype
,
device
=
get_current_device
())
gather_list
=
list
(
torch
.
chunk
(
input
=
total_temp
,
chunks
=
chunk
.
pg_size
,
dim
=
0
))
dist
.
all_gather
(
tensor_list
=
gather_list
,
tensor
=
shard_temp
,
group
=
chunk
.
torch_pg
)
return
total_temp
def
_get_dfs_module_list
(
module
:
nn
.
Module
,
memo
:
Optional
[
Set
[
nn
.
Module
]]
=
None
,
prefix
:
str
=
''
):
"""Get a dfs module list of the given module. Its order is same as the order of creations of modules.
"""
if
memo
is
None
:
memo
=
set
()
if
module
not
in
memo
:
for
name
,
submodule
in
module
.
_modules
.
items
():
if
submodule
is
None
:
continue
submodule_prefix
=
prefix
+
(
'.'
if
prefix
else
''
)
+
name
for
m
in
_get_dfs_module_list
(
submodule
,
memo
,
submodule_prefix
):
yield
m
memo
.
add
(
module
)
yield
prefix
,
module
def
_get_shallow_copy_model
(
model
:
nn
.
Module
):
"""Get a shallow copy of the given model. Each submodule is different from the original submodule.
But the new submodule and the old submodule share all attributes.
"""
old_to_new
=
dict
()
for
name
,
module
in
_get_dfs_module_list
(
model
):
new_module
=
copy
(
module
)
new_module
.
_modules
=
OrderedDict
()
for
subname
,
submodule
in
module
.
_modules
.
items
():
if
submodule
is
None
:
continue
setattr
(
new_module
,
subname
,
old_to_new
[
submodule
])
old_to_new
[
module
]
=
new_module
return
old_to_new
[
model
]
def
get_static_torch_model
(
zero_ddp_model
,
device
=
torch
.
device
(
"cpu"
),
dtype
=
torch
.
float32
,
only_rank_0
=
True
)
->
torch
.
nn
.
Module
:
"""Get a static torch.nn.Module model from the given ZeroDDP module.
You should notice that the original ZeroDDP model is not modified.
Thus, you can use the original model in further training.
But you should not use the returned torch model to train, this can cause unexpected errors.
Args:
zero_ddp_model (ZeroDDP): a zero ddp model
device (torch.device): the device of the final torch model
dtype (torch.dtype): the dtype of the final torch model
only_rank_0 (bool): if True, only rank0 has the coverted torch model
Returns:
torch.nn.Module: a static torch model used for saving checkpoints or numeric checks
"""
from
colossalai.nn.parallel
import
ZeroDDP
assert
isinstance
(
zero_ddp_model
,
ZeroDDP
)
state_dict
=
zero_ddp_model
.
state_dict
(
only_rank_0
=
only_rank_0
,
strict
=
False
)
colo_model
=
zero_ddp_model
.
module
torch_model
=
_get_shallow_copy_model
(
colo_model
)
if
not
only_rank_0
or
dist
.
get_rank
()
==
0
:
# record the mapping relationship between colo parameters and torch parameters
colo_to_torch
=
dict
()
for
(
name
,
colo_module
),
(
_
,
torch_module
)
in
\
zip
(
_get_dfs_module_list
(
colo_model
),
_get_dfs_module_list
(
torch_model
)):
# clean the parameter list of the new torch module
torch_module
.
_parameters
=
OrderedDict
()
for
sufix_param_name
,
param
in
colo_module
.
named_parameters
(
recurse
=
False
):
# get the full name of the parameter
full_param_name
=
name
+
(
'.'
if
name
else
''
)
+
sufix_param_name
if
full_param_name
not
in
state_dict
:
# this means the parameter is shared by multiple modules
# we should use colo_to_torch to get the torch parameter created before
assert
param
in
colo_to_torch
,
f
"can not find parameter `
{
full_param_name
}
` in the GeminiDDP module"
torch_param
=
colo_to_torch
[
param
]
else
:
# we meet the parameter the first time, just use the state dict to get the data
state_param
=
state_dict
[
full_param_name
]
torch_param
=
torch
.
nn
.
Parameter
(
state_param
.
data
.
to
(
device
=
device
,
dtype
=
dtype
))
colo_to_torch
[
param
]
=
torch_param
setattr
(
torch_module
,
sufix_param_name
,
torch_param
)
dist
.
barrier
()
return
torch_model
colossalai/pipeline/__init__.py
View file @
e532679c
from
.pipelinable
import
PipelinableContext
,
PipelinableModel
from
.layer_s
e
pc
import
LayerSpec
from
.layer_sp
e
c
import
LayerSpec
__all__
=
[
'PipelinableModel'
,
'PipelinableContext'
,
'LayerSpec'
]
\ No newline at end of file
colossalai/pipeline/layer_s
e
pc.py
→
colossalai/pipeline/layer_sp
e
c.py
View file @
e532679c
File moved
colossalai/pipeline/middleware/__init__.py
0 → 100644
View file @
e532679c
from
.topo
import
Topo
,
Partition
,
PartitionOutputVal
,
PartitionInputVal
__all__
=
[
'Topo'
,
'Partition'
,
'PartitionOutputVal'
,
'PartitionInputVal'
]
\ No newline at end of file
colossalai/pipeline/middleware/adaptor/__init__.py
0 → 100644
View file @
e532679c
from
.fx
import
get_topology
as
get_fx_topology
__all__
=
[
'get_fx_topology'
]
\ No newline at end of file
colossalai/pipeline/middleware/adaptor/fx.py
0 → 100644
View file @
e532679c
from
torch.fx.graph_module
import
GraphModule
from
colossalai.pipeline.middleware.topo
import
Partition
,
PartitionInputVal
,
PartitionOutputVal
,
Topo
import
torch
def
partition_name_to_id
(
partition_name
,
is_input
=
False
,
is_output
=
False
):
if
is_input
:
partition_id
=
0
elif
is_output
:
partition_id
=
1
else
:
prefix
=
'submod_'
partition_id
=
int
(
partition_name
.
split
(
prefix
)[
-
1
])
+
2
return
partition_id
# There are two kinds of def in fx.graph
# 1. non direct_use & non direct_def, which means the output is used by next partition with a temporary mid value.
# e.g. submod1 = call_module(...)
# temporary_val = submod1[0]
# submod2 = call_module(temporary_val, ...)
# 2. direct_use & direct_def, which means the output is used by next partition directly.
# e.g. submod1 = call_module(...)
# submod2 = call_module(submod1, ...)
def
find_input_in_partition
(
node
,
partitions
,
input_partitions
=
None
):
p_input_val
=
None
direct_def
=
not
node
.
name
.
startswith
(
'getitem'
)
# search in input
if
direct_def
and
input_partitions
is
not
None
:
partition_id
=
partition_name_to_id
(
''
,
is_input
=
True
)
for
i
,
input_node
in
enumerate
(
input_partitions
):
if
input_node
==
node
:
p_input_val
=
PartitionInputVal
(
partition_id
=
partition_id
,
offset
=
i
)
return
p_input_val
# search submod in mid part
if
direct_def
:
for
partition
in
partitions
:
if
partition
==
node
:
partition_id
=
partition_name_to_id
(
partition
.
name
)
p_input_val
=
PartitionInputVal
(
partition_id
=
partition_id
,
offset
=
0
)
return
p_input_val
# search temporary value in graph
else
:
for
partition
in
partitions
:
for
offset
,
mid_val
in
enumerate
(
partition
.
users
):
if
mid_val
==
node
:
partition_id
=
partition_name_to_id
(
partition
.
name
)
p_input_val
=
PartitionInputVal
(
partition_id
=
partition_id
,
offset
=
offset
)
return
p_input_val
return
p_input_val
def
find_output_in_partition
(
node
,
partitions
,
output_partitions
=
None
):
p_output_val
=
PartitionOutputVal
()
for
user
in
node
.
users
:
direct_use
=
not
user
.
name
.
startswith
(
'getitem'
)
# user is mid partition
for
partition
in
partitions
:
# direct call
if
direct_use
:
if
user
==
partition
:
partition_id
=
partition_name_to_id
(
partition
.
name
)
for
i
,
arg
in
enumerate
(
partition
.
args
):
if
arg
==
node
:
p_output_val
.
add
(
partition_id
=
partition_id
,
offset
=
i
)
break
# getitem call
else
:
if
user
in
partition
.
args
:
partition_id
=
partition_name_to_id
(
partition
.
name
)
for
i
,
arg
in
enumerate
(
partition
.
args
):
if
arg
==
user
:
p_output_val
.
add
(
partition_id
=
partition_id
,
offset
=
i
)
break
# user is output
if
output_partitions
is
not
None
:
output_node
=
output_partitions
[
0
]
if
user
.
op
==
output_node
.
op
:
output_keys
=
{}
partition_id
=
partition_name_to_id
(
''
,
is_output
=
True
)
torch
.
fx
.
graph
.
map_arg
(
output_node
.
args
[
0
],
lambda
n
:
output_keys
.
setdefault
(
n
))
for
i
,
arg
in
enumerate
(
output_keys
):
if
arg
==
node
:
p_output_val
.
add
(
partition_id
=
partition_id
,
offset
=
i
)
break
return
p_output_val
def
get_topology
(
gm
:
GraphModule
):
topo
=
Topo
()
topo_output_partition
=
Partition
()
input_partitions
=
[]
partitions
=
[]
output_partitions
=
[]
for
node
in
gm
.
graph
.
nodes
:
if
node
.
op
==
'placeholder'
:
input_partitions
.
append
(
node
)
elif
node
.
name
.
startswith
(
'submod_'
):
partitions
.
append
(
node
)
elif
node
.
op
==
'output'
:
output_partitions
.
append
(
node
)
else
:
continue
# set output for input_partition
topo_input_partition
=
Partition
()
for
partition
in
input_partitions
:
cur_node
=
partition
p_output_val
=
find_output_in_partition
(
cur_node
,
partitions
,
output_partitions
)
topo_input_partition
.
add_output_val
(
p_output_val
)
topo
.
set_partitions
(
partition_id
=
0
,
partition
=
topo_input_partition
)
topo
.
set_input_partition_id
(
partition_id
=
0
)
for
i
,
partition
in
enumerate
(
partitions
):
topo_mid_partition
=
Partition
()
# set input for submodule
for
arg
in
partition
.
args
:
cur_node
=
arg
p_input_val
=
find_input_in_partition
(
cur_node
,
partitions
,
input_partitions
)
topo_mid_partition
.
add_input_val
(
p_input_val
)
# set output for submodule
direct_use
=
True
for
user
in
partition
.
users
:
if
user
.
name
.
startswith
(
'getitem'
):
direct_use
=
False
break
if
direct_use
:
cur_node
=
partition
p_output_val
=
find_output_in_partition
(
cur_node
,
partitions
,
output_partitions
)
topo_mid_partition
.
add_output_val
(
p_output_val
)
else
:
for
user
in
partition
.
users
:
cur_node
=
user
p_output_val
=
find_output_in_partition
(
cur_node
,
partitions
,
output_partitions
)
topo_mid_partition
.
add_output_val
(
p_output_val
)
topo
.
set_partitions
(
partition_id
=
i
+
2
,
partition
=
topo_mid_partition
)
# set input for output_partition
for
partition
in
output_partitions
:
topo_output_partition
=
Partition
()
torch
.
fx
.
graph
.
map_arg
(
partition
.
args
[
0
],
lambda
n
:
topo_output_partition
.
add_input_val
(
find_input_in_partition
(
n
,
partitions
,
input_partitions
)))
topo
.
set_partitions
(
partition_id
=
1
,
partition
=
topo_output_partition
)
topo
.
set_output_partition_id
(
partition_id
=
1
)
return
topo
\ No newline at end of file
colossalai/pipeline/middleware/topo.py
0 → 100644
View file @
e532679c
from
typing
import
Dict
,
List
from
dataclasses
import
dataclass
# This file includes data structure used by Pipeline Middleware.
@
dataclass
class
ValPosition
:
partition_id
:
int
offset
:
int
def
__str__
(
self
)
->
str
:
res
=
f
'[partition_id:
{
self
.
partition_id
}
,offset:
{
self
.
offset
}
]'
return
res
def
__repr__
(
self
)
->
str
:
return
self
.
__str__
()
class
PartitionInputVal
(
object
):
def
__init__
(
self
,
partition_id
,
offset
)
->
None
:
# every input from which partition_id and which offset
val_pos
=
ValPosition
(
partition_id
,
offset
)
self
.
_from_partition_and_offset
:
ValPosition
=
val_pos
def
get
(
self
):
return
self
.
_from_partition_and_offset
def
__str__
(
self
)
->
str
:
res
=
''
res
+=
f
'<-(
{
self
.
_from_partition_and_offset
}
)'
return
res
def
__repr__
(
self
)
->
str
:
return
self
.
__str__
()
class
PartitionOutputVal
(
object
):
def
__init__
(
self
)
->
None
:
# every output to which partition_id and which offset
self
.
_to_partition_and_offset
:
List
[
ValPosition
]
=
[]
def
add
(
self
,
partition_id
,
offset
):
val_pos
=
ValPosition
(
partition_id
,
offset
)
self
.
_to_partition_and_offset
.
append
(
val_pos
)
def
get
(
self
):
return
self
.
_to_partition_and_offset
def
__str__
(
self
)
->
str
:
res
=
''
res
+=
'->('
for
val_pos
in
self
.
_to_partition_and_offset
:
res
+=
f
'
{
val_pos
}
,'
res
+=
')'
return
res
def
__repr__
(
self
)
->
str
:
return
self
.
__str__
()
class
Partition
(
object
):
def
__init__
(
self
)
->
None
:
self
.
_input_vals
:
List
[
PartitionInputVal
]
=
[]
self
.
_output_vals
:
List
[
PartitionOutputVal
]
=
[]
def
add_input_val
(
self
,
input_val
:
PartitionInputVal
):
self
.
_input_vals
.
append
(
input_val
)
def
add_output_val
(
self
,
output_val
:
PartitionOutputVal
):
self
.
_output_vals
.
append
(
output_val
)
def
get_input_vals
(
self
):
return
self
.
_input_vals
def
get_output_vals
(
self
):
return
self
.
_output_vals
# get the output offsets sent to dst_partition_id
def
get_output_offsets
(
self
,
dst_partition_id
):
res
=
[]
for
offset
,
output_val
in
enumerate
(
self
.
_output_vals
):
outputs
=
output_val
.
get
()
for
val_pos
in
outputs
:
if
val_pos
.
partition_id
==
dst_partition_id
:
res
.
append
(
offset
)
return
res
# get all input dst partition_ids
def
get_input_partition_ids
(
self
):
res
=
[]
for
input_val
in
self
.
_input_vals
:
val_pos
=
input_val
.
get
()
if
val_pos
.
partition_id
not
in
res
:
res
.
append
(
val_pos
.
partition_id
)
return
res
# get all output dst partition_ids
def
get_output_partition_ids
(
self
):
res
=
[]
for
output_val
in
self
.
_output_vals
:
outputs
=
output_val
.
get
()
for
val_pos
in
outputs
:
if
val_pos
.
partition_id
not
in
res
:
res
.
append
(
val_pos
.
partition_id
)
return
res
def
__str__
(
self
)
->
str
:
res
=
''
res
+=
f
' input:
\n
'
res
+=
f
' length:
{
len
(
self
.
_input_vals
)
}
\n
'
for
i
,
input_val
in
enumerate
(
self
.
_input_vals
):
res
+=
f
' offset=
{
i
}
:
{
input_val
}
\n
'
res
+=
f
' output:
\n
'
res
+=
f
' length:
{
len
(
self
.
_output_vals
)
}
\n
'
for
i
,
output_val
in
enumerate
(
self
.
_output_vals
):
res
+=
f
' offset=
{
i
}
:
{
output_val
}
\n
'
return
res
def
__repr__
(
self
)
->
str
:
return
self
.
__str__
()
# This class is a middleware between partition splitter
# and Pipeline Scheduler. It records the graph info about
# partition input/output and provides it to scheduler.
# There are three kinds of partition in Pipeline Middleware Design
# which represents the whole process of a model execution: input-fwd-output
# 1. input_partition: records the input of a model.
# 2. mid_partition: record the splitted forwards execution of a model.
# 3. output_partition: records the output of a model.
# attributes:
# _partitions: include all partitions
# _input_partition_id: the key represents input_partition
# _output_partition_id: the key represents output_partition
class
Topo
(
object
):
def
__init__
(
self
,
input_partition_id
=
None
,
output_partition_id
=
None
)
->
None
:
self
.
_partitions
:
Dict
[
int
,
Partition
]
=
{}
self
.
_input_partition_id
=
input_partition_id
self
.
_output_partition_id
=
output_partition_id
def
set_input_partition_id
(
self
,
partition_id
:
int
):
self
.
_input_partition_id
=
partition_id
def
set_output_partition_id
(
self
,
partition_id
:
int
):
self
.
_output_partition_id
=
partition_id
def
get_input_partition_id
(
self
):
return
self
.
_input_partition_id
def
get_output_partition_id
(
self
):
return
self
.
_output_partition_id
def
set_partitions
(
self
,
partition_id
:
int
,
partition
:
Partition
):
self
.
_partitions
[
partition_id
]
=
partition
def
get_mid_partitions
(
self
):
res
=
{}
#{partition_id: Partition}
for
partition_id
,
partition
in
self
.
_partitions
.
items
():
if
self
.
_input_partition_id
==
partition_id
or
self
.
_output_partition_id
==
partition_id
:
continue
res
[
partition_id
]
=
partition
return
res
def
get_mid_partition_ids
(
self
):
return
list
(
self
.
get_mid_partitions
().
keys
())
def
get_input_partition
(
self
):
if
self
.
_input_partition_id
is
not
None
:
return
self
.
_partitions
[
self
.
_input_partition_id
]
return
None
def
get_output_partition
(
self
):
if
self
.
_output_partition_id
is
not
None
:
return
self
.
_partitions
[
self
.
_output_partition_id
]
return
None
def
get_partition_by_id
(
self
,
partition_id
):
return
self
.
_partitions
[
partition_id
]
def
__str__
(
self
)
->
str
:
res
=
''
if
len
(
self
.
_partitions
)
==
0
:
return
'Empty Topo Graph.'
input_part
=
self
.
get_input_partition
()
if
input_part
is
not
None
:
res
+=
'{
\n
'
res
+=
f
'InputPartition:
\n
partition_id=
{
self
.
_input_partition_id
}
\n
{
input_part
}
'
res
+=
'}
\n
'
mid_parts
=
self
.
get_mid_partitions
()
for
i
,
(
partition_id
,
part
)
in
enumerate
(
mid_parts
.
items
()):
res
+=
'{
\n
'
res
+=
f
'SubPartition_
{
i
}
:
\n
partition_id=
{
partition_id
}
\n
{
part
}
'
res
+=
'}
\n
'
output_part
=
self
.
get_output_partition
()
if
output_part
is
not
None
:
res
+=
'{
\n
'
res
+=
f
'OutputPartition:
\n
partition_id=
{
self
.
_output_partition_id
}
\n
{
output_part
}
'
res
+=
'}
\n
'
return
res
def
__repr__
(
self
)
->
str
:
return
self
.
__str__
()
\ No newline at end of file
colossalai/pipeline/pipelinable.py
View file @
e532679c
...
...
@@ -9,7 +9,7 @@ from colossalai.nn.layer.utils import CheckpointModule
from
colossalai.tensor
import
ColoParameter
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context
import
ParallelMode
from
.layer_s
e
pc
import
LayerSpec
from
.layer_sp
e
c
import
LayerSpec
class
PipelinableContext
(
InsertPostInitMethodToModuleSubClasses
):
...
...
colossalai/pipeline/rpc/_pipeline_base.py
View file @
e532679c
...
...
@@ -8,18 +8,28 @@ from typing import Any, Callable, Dict, List, Tuple
import
torch
import
torch.distributed.rpc
as
rpc
from
colossalai.pipeline.pipeline_process_group
import
ppg
from
colossalai.pipeline.rpc.utils
import
(
get_batch_lengths
,
get_real_args_kwargs
,
pytree_filter
,
pytree_map
,
split_batch
,
tensor_shape_list
,
type_detail
)
from
torch
import
autograd
,
nn
,
optim
from
torch._C._distributed_rpc
import
PyRRef
from
torch.futures
import
Future
from
colossalai.pipeline.middleware
import
Partition
,
PartitionInputVal
,
PartitionOutputVal
,
Topo
from
colossalai.pipeline.pipeline_process_group
import
ppg
from
colossalai.pipeline.rpc.utils
import
(
get_batch_lengths
,
pyobj_map
,
pytree_filter
,
pytree_map
,
split_batch
,
tensor_shape_list
,
type_detail
,
)
class
Phase
(
Enum
):
FORWARD
=
0
BACKWARD
=
1
UPDATE
=
2
INPUT
=
3
class
UniqueKey
:
...
...
@@ -134,6 +144,7 @@ class WorkerBase(ABC):
self
.
partition_args
=
partition_args
self
.
criterion
=
criterion
self
.
metric
=
metric
self
.
reset
=
False
# context to maintain loop
self
.
_initialize_context_container
()
...
...
@@ -164,6 +175,7 @@ class WorkerBase(ABC):
self
.
work_list_condition_lock
=
threading
.
Condition
(
threading
.
Lock
())
self
.
output_list_condition_lock
=
threading
.
Condition
(
threading
.
Lock
())
self
.
label_lock
=
threading
.
Condition
(
threading
.
Lock
())
self
.
reset_condition
=
threading
.
Condition
(
threading
.
Lock
())
def
_initialize_partition
(
self
):
partition_fn
=
self
.
partition_fn
...
...
@@ -173,6 +185,41 @@ class WorkerBase(ABC):
self
.
module_partition
:
nn
.
Module
=
partition_fn
(
*
partition_args
).
to
(
device
)
self
.
partition_condition_lock
.
notify_all
()
def
_get_output_all
(
self
,
key
:
UniqueKey
,
ref_use
=
False
,
rank
=
None
):
with
self
.
output_list_condition_lock
:
self
.
output_list_condition_lock
.
wait_for
(
lambda
:
key
in
self
.
output_list
)
output_work_item
=
self
.
output_list
[
key
]
output
=
output_work_item
.
output
if
not
ref_use
and
output_work_item
.
phase
!=
Phase
.
INPUT
:
self
.
output_list
.
pop
(
key
)
if
not
ref_use
and
output_work_item
.
phase
!=
Phase
.
INPUT
:
output_work_item
.
refcount
+=
1
refcount
=
output_work_item
.
refcount
# lifecycle management for DAG scheduler
if
output_work_item
.
phase
==
Phase
.
FORWARD
:
lifecycle
=
len
(
self
.
get_consumer_stage_ids
())
if
self
.
is_model_output
():
# an extra reference for scheduler collecting results
lifecycle
+=
1
elif
output_work_item
.
phase
==
Phase
.
BACKWARD
:
lifecycle
=
len
(
self
.
get_producer_stage_ids
())
if
self
.
is_model_input
()
and
self
.
_is_last_step
(
output_work_item
):
# an extra reference for ensure_backward
lifecycle
+=
1
else
:
lifecycle
=
0
refcount
=
0
with
self
.
output_list_condition_lock
:
if
refcount
<
lifecycle
:
self
.
output_list
[
key
]
=
output_work_item
self
.
output_list_condition_lock
.
notify_all
()
if
isinstance
(
output
,
Future
):
output
=
output
.
wait
()
return
output
def
sync_global_worker_rrefs
(
self
,
pp_rank_to_worker_rref
:
Dict
[
int
,
PyRRef
])
->
None
:
assert
self
.
pp_rank_to_worker_rref
is
None
,
f
"in rank
{
self
.
pp_rank
}
, worker has sync global workers rrefs"
assert
pp_rank_to_worker_rref
is
not
None
,
"stage_to_workers must be a dict instead of None"
...
...
@@ -182,23 +229,21 @@ class WorkerBase(ABC):
# construction of partition is executed after the registion of pp_rank_to_worker_rref
self
.
_initialize_partition
()
def
get_output_by_key
(
self
,
key
:
UniqueKey
)
->
Any
:
with
self
.
output_list_condition_lock
:
self
.
output_list_condition_lock
.
wait_for
(
lambda
:
key
in
self
.
output_list
)
output_work_item
=
self
.
output_list
[
key
]
output
=
output_work_item
.
output
if
isinstance
(
output
,
Future
):
output
=
output
.
wait
()
output_work_item
.
refcount
+=
1
# all consumers have been satisfied, the work_item can be released
with
self
.
output_list_condition_lock
:
if
output_work_item
.
refcount
>=
len
(
self
.
consumer_stage_ids
):
self
.
output_list
.
pop
(
key
)
# res_use works for lifecycle counter,
# if ref_use is True, lifecycle won't add.
# offset supports get partial output to reduce comm costs.
def
get_output_by_key
(
self
,
key
:
UniqueKey
,
ref_use
=
False
,
rank
=
None
,
offsets
=
None
)
->
Any
:
output
=
self
.
_get_output_all
(
key
,
ref_use
,
rank
)
if
offsets
is
None
:
# get all for non iterable output
return
output
else
:
# get part for iterable output
output
=
[
output
[
i
]
for
i
in
offsets
]
return
output
def
get_numels
(
self
)
->
int
:
numel
=
sum
(
param
.
numel
()
for
param
in
self
.
module_partition
.
parameters
())
return
numel
def
get_parameters
(
self
)
->
List
[
torch
.
Tensor
]:
return
[
p
for
p
in
self
.
module_partition
.
parameters
()]
...
...
@@ -215,8 +260,10 @@ class WorkerBase(ABC):
self
.
partition_condition_lock
.
wait_for
(
lambda
:
hasattr
(
self
,
'module_partition'
))
return
self
.
module_partition
.
state_dict
()
def
_make_args_kwargs
(
self
,
microbatch
):
def
_make_args_kwargs
(
self
,
microbatch
,
merge
=
False
):
if
isinstance
(
microbatch
,
dict
):
if
merge
:
return
list
(
microbatch
.
values
()),
{}
return
[],
microbatch
elif
isinstance
(
microbatch
,
torch
.
Tensor
):
return
[
microbatch
],
{}
...
...
@@ -228,24 +275,58 @@ class WorkerBase(ABC):
kwargs
.
update
(
arg
)
else
:
args
.
append
(
arg
)
if
merge
:
arg_lst
=
args
for
arg
in
kwargs
.
values
():
arg_lst
.
append
(
arg
)
return
arg_lst
,
{}
return
args
,
kwargs
else
:
raise
TypeError
(
f
"Input batch can be only dict, list, tuple or tensor, but receive
{
type
(
microbatch
)
}
"
)
# just for first pp_rank
def
set_input
(
self
,
microbatch_id
:
int
,
microbatch
:
Tuple
[
Any
],
forward_only
:
bool
):
assert
self
.
consumer_stage_ids
is
not
None
key
=
UniqueKey
(
microbatch_id
,
Phase
.
FORWARD
)
output
=
self
.
_get_future_by_device
()
# make args and kwargs
args
,
kwargs
=
self
.
_make_args_kwargs
(
microbatch
)
if
not
self
.
use_middleware
():
# make args and kwargs
args
,
kwargs
=
self
.
_make_args_kwargs
(
microbatch
)
work_item
=
WorkItem
(
self
.
pp_rank
,
Phase
.
FORWARD
,
args
,
kwargs
,
output
,
microbatch_id
,
None
,
self
.
num_microbatches
,
forward_only
)
with
self
.
work_list_condition_lock
:
self
.
work_list
[
key
]
=
work_item
self
.
work_list_condition_lock
.
notify_all
()
work_item
=
WorkItem
(
self
.
pp_rank
,
Phase
.
FORWARD
,
args
,
kwargs
,
output
,
microbatch_id
,
None
,
self
.
num_microbatches
,
forward_only
)
with
self
.
work_list_condition_lock
:
self
.
work_list
[
key
]
=
work_item
self
.
work_list_condition_lock
.
notify_all
()
else
:
# make args and kwargs
arg_lst
,
_
=
self
.
_make_args_kwargs
(
microbatch
,
merge
=
True
)
# first stage assign correct input into other stages
topo
:
Topo
=
self
.
get_topo
()
self_partition_id
=
self
.
pp_rank_to_partition_id
(
self
.
pp_rank
,
topo
)
input_partition
=
topo
.
get_input_partition
()
self_input_offsets
=
input_partition
.
get_output_offsets
(
self_partition_id
)
recv_input_key
=
UniqueKey
(
microbatch_id
,
Phase
.
INPUT
)
# set input for self rank
self_arg_lst
=
[]
for
off
in
self_input_offsets
:
self_arg_lst
.
append
(
arg_lst
[
off
])
work_item
=
WorkItem
(
self
.
pp_rank
,
Phase
.
FORWARD
,
self_arg_lst
,
{},
output
,
microbatch_id
,
None
,
self
.
num_microbatches
,
forward_only
)
with
self
.
work_list_condition_lock
:
self
.
work_list
[
key
]
=
work_item
self
.
work_list_condition_lock
.
notify_all
()
# put input tensor which other nodes need into output_list as Phase.INPUT
work_item_remote
=
WorkItem
(
self
.
pp_rank
,
Phase
.
INPUT
,
[],
{},
arg_lst
,
microbatch_id
,
None
,
self
.
num_microbatches
,
forward_only
)
with
self
.
output_list_condition_lock
:
self
.
output_list
[
recv_input_key
]
=
work_item_remote
self
.
output_list_condition_lock
.
notify_all
()
# just for last pp_rank
def
set_labels
(
self
,
microbatch_id
:
int
,
microlabels
:
Any
):
...
...
@@ -268,62 +349,159 @@ class WorkerBase(ABC):
self
.
work_list
[
key
]
=
work_item
self
.
work_list_condition_lock
.
notify_all
()
def
subscribe_producer
(
self
,
microbatch_id
:
int
,
forward_only
:
bool
):
def
_
subscribe_producer
(
self
,
microbatch_id
:
int
,
forward_only
:
bool
):
"""
You should call this function asynchronously
"""
assert
self
.
producer_stage_ids
is
not
None
producer_num
=
len
(
self
.
producer_stage_ids
)
assert
producer_num
>
0
,
"only stage that has producers can subscribe producers"
stage_id
=
self
.
pp_rank
subscribe_forward_futures
:
List
[
Future
]
=
[
None
]
*
producer_num
output
=
self
.
_get_future_by_device
()
if
not
self
.
use_middleware
():
producer_num
=
len
(
self
.
producer_stage_ids
)
subscribe_forward_futures
:
List
[
Future
]
=
[
None
]
*
producer_num
for
i
in
range
(
producer_num
):
producer_stage_id
=
self
.
producer_stage_ids
[
i
]
producer_output_key
=
UniqueKey
(
microbatch_id
,
Phase
.
FORWARD
)
producer_worker_rref
=
self
.
pp_rank_to_worker_rref
[
producer_stage_id
]
subscribe_forward_futures
[
i
]
=
producer_worker_rref
.
rpc_async
().
get_output_by_key
(
producer_output_key
)
else
:
producer_stage_ids
=
self
.
get_producer_stage_ids
()
producer_num
=
len
(
producer_stage_ids
)
if
self
.
need_model_input
():
producer_num
+=
1
# for input partition
subscribe_forward_futures
:
List
[
Future
]
=
[
None
]
*
producer_num
# TODO(jiangziyue) get single value instead of the whole output
if
self
.
need_model_input
():
producer_stage_id
=
0
producer_output_key
=
UniqueKey
(
microbatch_id
,
Phase
.
INPUT
)
producer_worker_rref
=
self
.
pp_rank_to_worker_rref
[
producer_stage_id
]
offsets
=
self
.
_get_input_offsets_by_index
(
target_index
=
0
)
subscribe_forward_futures
[
0
]
=
producer_worker_rref
.
rpc_async
().
get_output_by_key
(
producer_output_key
,
rank
=
self
.
pp_rank
,
offsets
=
offsets
)
for
i
in
range
(
0
,
producer_num
-
1
):
producer_stage_id
=
producer_stage_ids
[
i
]
producer_output_key
=
UniqueKey
(
microbatch_id
,
Phase
.
FORWARD
)
producer_worker_rref
=
self
.
pp_rank_to_worker_rref
[
producer_stage_id
]
target_index
=
i
+
1
offsets
=
self
.
_get_input_offsets_by_index
(
target_index
=
target_index
)
if
offsets
is
not
None
and
len
(
offsets
)
==
0
:
# no need to do rpc
subscribe_forward_futures
[
target_index
]
=
[]
else
:
subscribe_forward_futures
[
target_index
]
=
producer_worker_rref
.
rpc_async
().
get_output_by_key
(
producer_output_key
,
rank
=
self
.
pp_rank
)
for
i
in
range
(
producer_num
):
producer_stage_id
=
self
.
producer_stage_ids
[
i
]
producer_output_key
=
UniqueKey
(
microbatch_id
,
Phase
.
FORWARD
)
producer_worker_rref
=
self
.
pp_rank_to_worker_rref
[
producer_stage_id
]
subscribe_forward_futures
[
i
]
=
producer_worker_rref
.
rpc_async
().
get_output_by_key
(
producer_output_key
)
else
:
for
i
in
range
(
producer_num
):
producer_stage_id
=
producer_stage_ids
[
i
]
producer_output_key
=
UniqueKey
(
microbatch_id
,
Phase
.
FORWARD
)
producer_worker_rref
=
self
.
pp_rank_to_worker_rref
[
producer_stage_id
]
target_index
=
i
offsets
=
self
.
_get_input_offsets_by_index
(
target_index
=
target_index
)
if
offsets
is
not
None
and
len
(
offsets
)
==
0
:
# no need to do rpc
subscribe_forward_futures
[
target_index
]
=
[]
else
:
subscribe_forward_futures
[
target_index
]
=
producer_worker_rref
.
rpc_async
().
get_output_by_key
(
producer_output_key
,
rank
=
self
.
pp_rank
,
offsets
=
offsets
)
work_item_from_producer
=
WorkItem
(
stage_id
,
Phase
.
FORWARD
,
subscribe_forward_futures
,
{},
output
,
microbatch_id
,
None
,
self
.
num_microbatches
,
forward_only
)
# add work_item to work_list
with
self
.
work_list_condition_lock
:
key
=
UniqueKey
(
microbatch_id
,
Phase
.
FORWARD
)
assert
key
not
in
self
.
work_list
self
.
work_list
[
key
]
=
work_item_from_producer
self
.
work_list_condition_lock
.
notify_all
()
return
work_item_from_producer
def
subscribe_consumer
(
self
,
microbatch_id
:
int
):
# TODO(jiangziyue) Profile the side effect of the lock for lifecycle protection and consider a better one.
def
subscribe_producer
(
self
,
microbatch_id
:
int
,
forward_only
:
bool
):
key
=
UniqueKey
(
microbatch_id
,
Phase
.
FORWARD
)
with
self
.
work_list_condition_lock
:
if
key
not
in
self
.
work_list
:
# On current PP middleware design for DAG, get_output_by_key used by _subscribe_producer
# can only be executed once for every producer-consumer stage pair, which is necessary
# to count the lifecycle of work_item. So, keeping the _subscribe_producer in the same
# lock of work_item queue operation gurantees the consistency of lifecycle counter.
work_item_from_producer
=
self
.
_subscribe_producer
(
microbatch_id
,
forward_only
)
self
.
work_list
[
key
]
=
work_item_from_producer
self
.
work_list_condition_lock
.
notify_all
()
def
_subscribe_consumer
(
self
,
microbatch_id
:
int
):
"""
You should call this function asynchronously
"""
assert
self
.
producer_stage_ids
is
not
None
consumer_num
=
len
(
self
.
consumer_stage_ids
)
assert
consumer_num
>
0
,
"only stage that has consumers can subscribe comsumers"
stage_id
=
self
.
pp_rank
subscribe_backward_futures
:
List
[
Future
]
=
[
None
]
*
consumer_num
output
=
self
.
_get_future_by_device
()
if
not
self
.
use_middleware
():
consumer_stage_ids
=
self
.
consumer_stage_ids
else
:
consumer_stage_ids
=
self
.
get_consumer_stage_ids
()
consumer_num
=
len
(
consumer_stage_ids
)
subscribe_backward_futures
:
List
[
Future
]
=
[
None
]
*
consumer_num
for
i
in
range
(
consumer_num
):
consumer_stage_id
=
self
.
consumer_stage_ids
[
i
]
consumer_stage_id
=
consumer_stage_ids
[
i
]
consumer_output_key
=
UniqueKey
(
microbatch_id
,
Phase
.
BACKWARD
)
consumer_worker_rref
=
self
.
pp_rank_to_worker_rref
[
consumer_stage_id
]
subscribe_backward_futures
[
i
]
=
consumer_worker_rref
.
rpc_async
().
get_output_by_key
(
consumer_output_key
)
target_index
=
i
offsets
=
self
.
_get_output_offsets_by_index
(
target_index
=
target_index
)
if
offsets
is
not
None
and
len
(
offsets
)
==
0
:
# no need to do rpc
subscribe_backward_futures
[
target_index
]
=
[]
else
:
subscribe_backward_futures
[
target_index
]
=
consumer_worker_rref
.
rpc_async
().
get_output_by_key
(
consumer_output_key
,
rank
=
self
.
pp_rank
,
offsets
=
offsets
)
# flatten args
work_item_from_consumer
=
WorkItem
(
stage_id
,
Phase
.
BACKWARD
,
subscribe_backward_futures
,
{},
output
,
microbatch_id
,
None
,
self
.
num_microbatches
,
False
)
# add work_item to work_list
return
work_item_from_consumer
def
subscribe_consumer
(
self
,
microbatch_id
:
int
):
key
=
UniqueKey
(
microbatch_id
,
Phase
.
BACKWARD
)
with
self
.
work_list_condition_lock
:
key
=
UniqueKey
(
microbatch_id
,
Phase
.
BACKWARD
)
assert
key
not
in
self
.
work_list
self
.
work_list
[
key
]
=
work_item_from_consumer
self
.
work_list_condition_lock
.
notify_all
()
if
key
not
in
self
.
work_list
:
# On current PP middleware design for DAG, get_output_by_key used by subscribe_consumer
# can only be executed once for every producer-consumer stage pair, which is necessary
# to count the lifecycle of work_item. So, keeping the subscribe_consumer in the same
# lock of work_item queue operation gurantees the consistency of lifecycle counter.
work_item_from_consumer
=
self
.
_subscribe_consumer
(
microbatch_id
)
self
.
work_list
[
key
]
=
work_item_from_consumer
self
.
work_list_condition_lock
.
notify_all
()
def
get_producer_stage_ids
(
self
):
producer_stage_ids
=
[]
rank
=
self
.
pp_rank
if
not
self
.
use_middleware
():
prev_rank
=
rank
-
1
if
prev_rank
>=
0
:
producer_stage_ids
.
append
(
prev_rank
)
else
:
topo
:
Topo
=
self
.
get_topo
()
self_partition_id
=
self
.
pp_rank_to_partition_id
(
rank
,
topo
)
self_partition
:
Partition
=
topo
.
get_partition_by_id
(
self_partition_id
)
input_partition_ids
=
self_partition
.
get_input_partition_ids
()
model_input_partition_id
=
topo
.
get_input_partition_id
()
for
partition_id
in
input_partition_ids
:
# ignore input partition in current implementation.
# it will be specially tackled.
if
partition_id
!=
model_input_partition_id
:
producer_stage_ids
.
append
(
self
.
partition_id_to_pp_rank
(
partition_id
,
topo
))
return
producer_stage_ids
def
get_consumer_stage_ids
(
self
):
consumer_stage_ids
=
[]
rank
=
self
.
pp_rank
if
not
self
.
use_middleware
():
next_rank
=
rank
+
1
if
next_rank
<=
self
.
actual_stage_num
-
1
:
consumer_stage_ids
.
append
(
next_rank
)
else
:
topo
:
Topo
=
self
.
get_topo
()
self_partition_id
=
self
.
pp_rank_to_partition_id
(
rank
,
topo
)
self_partition
:
Partition
=
topo
.
get_partition_by_id
(
self_partition_id
)
output_partition_ids
=
self_partition
.
get_output_partition_ids
()
model_output_partition_id
=
topo
.
get_output_partition_id
()
for
partition_id
in
output_partition_ids
:
if
model_output_partition_id
!=
partition_id
:
consumer_stage_ids
.
append
(
self
.
partition_id_to_pp_rank
(
partition_id
,
topo
))
return
consumer_stage_ids
def
_get_producer_consumer
(
self
)
->
None
:
rank
=
self
.
pp_rank
...
...
@@ -331,16 +509,212 @@ class WorkerBase(ABC):
assert
self
.
consumer_stage_ids
is
None
,
f
"all the consumers of rank
{
rank
}
has been subscribed"
# should be aranged in order, the order of the input of current forward
self
.
producer_stage_ids
=
[]
self
.
consumer_stage_ids
=
[]
self
.
producer_stage_ids
=
self
.
get_producer_stage_ids
()
self
.
consumer_stage_ids
=
self
.
get_consumer_stage_ids
()
def
pp_rank_to_partition_id
(
self
,
pp_rank
:
int
,
topo
:
Topo
):
partition_ids
=
topo
.
get_mid_partition_ids
()
return
partition_ids
[
pp_rank
]
# Just for demo
prev_rank
=
rank
-
1
next_rank
=
rank
+
1
if
prev_rank
>=
0
:
self
.
producer_stage_ids
.
append
(
prev_rank
)
if
next_rank
<=
self
.
actual_stage_num
-
1
:
self
.
consumer_stage_ids
.
append
(
next_rank
)
def
partition_id_to_pp_rank
(
self
,
partition_id
:
int
,
topo
:
Topo
):
partition_ids
=
topo
.
get_mid_partition_ids
()
for
i
,
id
in
enumerate
(
partition_ids
):
if
id
==
partition_id
:
return
i
def
get_topo
(
self
):
with
self
.
partition_condition_lock
:
self
.
partition_condition_lock
.
wait_for
(
lambda
:
hasattr
(
self
,
'module_partition'
))
if
hasattr
(
self
.
module_partition
,
'_topo'
):
return
self
.
module_partition
.
_topo
else
:
return
None
def
use_middleware
(
self
):
topo
=
self
.
get_topo
()
return
topo
is
not
None
def
_get_input_offsets_by_index
(
self
,
target_index
):
res
=
[]
topo
:
Topo
=
self
.
get_topo
()
self_partition_id
=
self
.
pp_rank_to_partition_id
(
self
.
pp_rank
,
topo
)
self_partition
:
Partition
=
topo
.
get_partition_by_id
(
self_partition_id
)
model_input_partition_id
=
topo
.
get_input_partition_id
()
input_vals
=
self_partition
.
get_input_vals
()
producer_stage_ids
=
self
.
get_producer_stage_ids
()
if
self
.
need_model_input
():
# 0 for data from input batch
# >= 1 for data from prev stages
base
=
1
else
:
# data from prev stages
base
=
0
for
val
in
input_vals
:
val_pos
=
val
.
get
()
src_partition_id
=
val_pos
.
partition_id
src_offset
=
val_pos
.
offset
src_index
=
base
src_partition
=
topo
.
get_partition_by_id
(
src_partition_id
)
output_len
=
len
(
src_partition
.
get_output_vals
())
# data from not-input partition
if
src_partition_id
!=
model_input_partition_id
:
src_stage_id
=
self
.
partition_id_to_pp_rank
(
src_partition_id
,
topo
)
src_index
=
base
for
i
,
stage_id
in
enumerate
(
producer_stage_ids
):
if
stage_id
==
src_stage_id
:
src_index
+=
i
break
else
:
# data from input partition
src_index
=
0
# when output_len = 1, not iterable
if
target_index
==
src_index
:
if
output_len
==
1
:
res
=
None
# offset = None to get all outputs
return
res
else
:
res
.
append
(
src_offset
)
return
res
def
_get_output_offsets_by_index
(
self
,
target_index
):
res
=
[]
topo
:
Topo
=
self
.
get_topo
()
self_partition_id
=
self
.
pp_rank_to_partition_id
(
self
.
pp_rank
,
topo
)
self_partition
:
Partition
=
topo
.
get_partition_by_id
(
self_partition_id
)
output_vals
=
self_partition
.
get_output_vals
()
consumer_stage_ids
=
self
.
get_consumer_stage_ids
()
for
val_list
in
output_vals
:
# An output may be passed to many down stages.
target
=
None
for
val_pos
in
val_list
.
get
():
dst_partition_id
=
val_pos
.
partition_id
dst_offset
=
val_pos
.
offset
dst_partition
=
topo
.
get_partition_by_id
(
dst_partition_id
)
input_len
=
len
(
dst_partition
.
get_input_vals
())
dst_stage_id
=
self
.
partition_id_to_pp_rank
(
dst_partition_id
,
topo
)
for
i
,
stage_id
in
enumerate
(
consumer_stage_ids
):
if
stage_id
==
dst_stage_id
:
dst_index
=
i
break
if
target_index
==
dst_index
:
if
input_len
==
1
:
res
=
None
# offset = None to get all outputs
return
res
else
:
res
.
append
(
dst_offset
)
return
res
# TODO(jiangziyue) get single value instead of the whole output
def
_get_real_args_kwargs_fwd
(
self
,
args_or_kwargs
):
if
not
self
.
use_middleware
():
args_or_kwargs
=
pytree_map
(
args_or_kwargs
,
fn
=
lambda
x
:
x
.
wait
(),
process_types
=
Future
)
if
args_or_kwargs
is
not
None
:
if
isinstance
(
args_or_kwargs
,
dict
):
pass
else
:
flatten_args
=
[]
pytree_map
(
args_or_kwargs
,
fn
=
lambda
x
:
flatten_args
.
append
(
x
),
map_all
=
True
)
args_or_kwargs
=
flatten_args
else
:
args_or_kwargs
=
pytree_map
(
args_or_kwargs
,
fn
=
lambda
x
:
x
.
wait
(),
process_types
=
Future
)
if
args_or_kwargs
is
not
None
:
if
isinstance
(
args_or_kwargs
,
dict
):
pass
else
:
flatten_args
=
[]
if
self
.
is_first_stage
():
pytree_map
(
args_or_kwargs
,
fn
=
lambda
x
:
flatten_args
.
append
(
x
),
map_all
=
True
)
else
:
# get by offset
topo
:
Topo
=
self
.
get_topo
()
self_partition_id
=
self
.
pp_rank_to_partition_id
(
self
.
pp_rank
,
topo
)
self_partition
:
Partition
=
topo
.
get_partition_by_id
(
self_partition_id
)
model_input_partition_id
=
topo
.
get_input_partition_id
()
input_vals
=
self_partition
.
get_input_vals
()
producer_stage_ids
=
self
.
get_producer_stage_ids
()
if
self
.
need_model_input
():
# 0 for data from input batch
# >= 1 for data from prev stages
base
=
1
else
:
# data from prev stages
base
=
0
for
val
in
input_vals
:
val_pos
=
val
.
get
()
src_partition_id
=
val_pos
.
partition_id
src_offset
=
val_pos
.
offset
src_index
=
base
src_partition
=
topo
.
get_partition_by_id
(
src_partition_id
)
output_len
=
len
(
src_partition
.
get_output_vals
())
# data from not-input partition
if
src_partition_id
!=
model_input_partition_id
:
src_stage_id
=
self
.
partition_id_to_pp_rank
(
src_partition_id
,
topo
)
src_index
=
base
for
i
,
stage_id
in
enumerate
(
producer_stage_ids
):
if
stage_id
==
src_stage_id
:
src_index
+=
i
break
else
:
# data from input partition
src_index
=
0
# when output_len = 1, not iterable
if
output_len
==
1
:
target
=
args_or_kwargs
[
src_index
]
else
:
offsets
=
self
.
_get_input_offsets_by_index
(
src_index
)
real_offset
=
offsets
.
index
(
src_offset
)
target
=
args_or_kwargs
[
src_index
][
real_offset
]
flatten_args
.
append
(
target
)
args_or_kwargs
=
flatten_args
return
args_or_kwargs
# TODO(jiangziyue) get single value instead of the whole output
def
_get_real_args_kwargs_bwd
(
self
,
args_or_kwargs
):
if
not
self
.
use_middleware
():
args_or_kwargs
=
pytree_map
(
args_or_kwargs
,
fn
=
lambda
x
:
x
.
wait
(),
process_types
=
Future
)
if
args_or_kwargs
is
not
None
:
if
isinstance
(
args_or_kwargs
,
dict
):
pass
else
:
flatten_args
=
[]
pytree_map
(
args_or_kwargs
,
fn
=
lambda
x
:
flatten_args
.
append
(
x
),
map_all
=
True
)
args_or_kwargs
=
flatten_args
else
:
for
i
,
arg
in
enumerate
(
args_or_kwargs
):
args_or_kwargs
[
i
]
=
arg
.
wait
()
if
args_or_kwargs
is
not
None
:
# get by offset
flatten_args
=
[]
topo
:
Topo
=
self
.
get_topo
()
self_partition_id
=
self
.
pp_rank_to_partition_id
(
self
.
pp_rank
,
topo
)
self_partition
:
Partition
=
topo
.
get_partition_by_id
(
self_partition_id
)
output_vals
=
self_partition
.
get_output_vals
()
consumer_stage_ids
=
self
.
get_consumer_stage_ids
()
for
val_list
in
output_vals
:
# An output may be passed to many down stages.
target
=
None
for
val_pos
in
val_list
.
get
():
dst_partition_id
=
val_pos
.
partition_id
dst_offset
=
val_pos
.
offset
dst_partition
=
topo
.
get_partition_by_id
(
dst_partition_id
)
input_len
=
len
(
dst_partition
.
get_input_vals
())
dst_stage_id
=
self
.
partition_id_to_pp_rank
(
dst_partition_id
,
topo
)
for
i
,
stage_id
in
enumerate
(
consumer_stage_ids
):
if
stage_id
==
dst_stage_id
:
dst_index
=
i
break
if
input_len
==
1
:
part_grad
=
args_or_kwargs
[
dst_index
]
else
:
offsets
=
self
.
_get_output_offsets_by_index
(
dst_index
)
real_offsets
=
offsets
.
index
(
dst_offset
)
part_grad
=
args_or_kwargs
[
dst_index
][
real_offsets
]
if
target
is
None
:
target
=
part_grad
elif
part_grad
is
not
None
:
target
+=
part_grad
else
:
continue
flatten_args
.
append
(
target
)
args_or_kwargs
=
flatten_args
return
args_or_kwargs
@
abstractmethod
def
_get_work_item_key
(
self
)
->
UniqueKey
:
...
...
@@ -354,6 +728,23 @@ class WorkerBase(ABC):
def
is_last_stage
(
self
):
return
self
.
pp_rank
==
self
.
actual_stage_num
-
1
def
need_model_input
(
self
):
need_input
=
False
topo
:
Topo
=
self
.
get_topo
()
self_partition_id
=
self
.
pp_rank_to_partition_id
(
self
.
pp_rank
,
topo
)
self_partition
=
topo
.
get_partition_by_id
(
self_partition_id
)
partition_inputs
=
self_partition
.
get_input_partition_ids
()
model_input_partition_id
=
topo
.
get_input_partition_id
()
if
model_input_partition_id
in
partition_inputs
:
need_input
=
True
return
not
self
.
is_first_stage
()
and
need_input
def
is_model_output
(
self
):
return
self
.
is_last_stage
()
def
is_model_input
(
self
):
return
self
.
is_first_stage
()
def
_default_data_process_func
(
self
,
args_kwargs
):
if
self
.
is_first_stage
():
args
=
args_kwargs
[
0
]
...
...
@@ -390,11 +781,16 @@ class WorkerBase(ABC):
# parse and integrate args and kwargs
if
is_first_stage
:
args
=
get_real_args_kwargs
(
args
)
kwargs
=
get_real_args_kwargs
(
kwargs
)
args
=
self
.
_
get_real_args_kwargs
_fwd
(
args
)
kwargs
=
self
.
_
get_real_args_kwargs
_fwd
(
kwargs
)
args_kwargs
=
(
args
,
kwargs
)
else
:
args_kwargs
=
get_real_args_kwargs
(
args
)
args_kwargs
=
self
.
_get_real_args_kwargs_fwd
(
args
)
args_kwargs
=
pyobj_map
(
args_kwargs
,
fn
=
lambda
x
:
x
.
to
(
self
.
device
).
detach
(),
process_types
=
torch
.
Tensor
)
# torch rpc doesn't support args or rets in GPU
args_kwargs
=
pyobj_map
(
args_kwargs
,
fn
=
lambda
x
:
self
.
device
,
process_types
=
torch
.
device
)
# change devices from last stage to current device
args
,
kwargs
=
data_process_func
(
args_kwargs
)
...
...
@@ -459,6 +855,9 @@ class WorkerBase(ABC):
stage_input_kwargs
,
stage_outputs
,
checkpoint
=
use_checkpoint
)
consume_result
=
pyobj_map
(
consume_result
,
fn
=
lambda
x
:
x
.
to
(
'cpu'
),
process_types
=
torch
.
Tensor
)
# torch rpc doesn't support args or rets in
# if not forward_only, do the backward
if
not
forward_only
:
if
is_last_stage
:
# if it is the last stage, trigger backward automatic
...
...
@@ -486,21 +885,43 @@ class WorkerBase(ABC):
# overlap recompute and future.wait
if
not
is_last_stage
:
grad_tensors
=
get_real_args_kwargs
(
args
)
grad_tensors
=
self
.
_
get_real_args_kwargs
_bwd
(
args
)
else
:
grad_tensors
=
None
# take tensor only (for only tensor can do backward)
stage_outputs
=
pytree_filter
(
lambda
x
:
x
.
requires_grad
,
stage_outputs
,
process_types
=
torch
.
Tensor
)
grad_tensors
=
pytree_filter
(
lambda
x
:
x
is
not
None
,
grad_tensors
,
process_types
=
torch
.
Tensor
)
# TODO(jiangziyue) : All values which should do bp are torch.Tensor?
stage_outputs
=
pytree_filter
(
lambda
x
:
True
,
stage_outputs
,
process_types
=
torch
.
Tensor
)
grad_tensors
=
pytree_filter
(
lambda
x
:
True
,
grad_tensors
,
process_types
=
torch
.
Tensor
)
# output all input's grad to producer, even it has no grad(output None)
# to make the offset aligned to the topo's record.
if
grad_tensors
is
not
None
:
filtered_outputs
=
[]
filtered_grads
=
[]
for
i
,
grad
in
enumerate
(
grad_tensors
):
stage_output
=
stage_outputs
[
i
]
if
stage_output
.
requires_grad
and
grad
is
not
None
:
filtered_outputs
.
append
(
stage_output
)
filtered_grads
.
append
(
grad
)
stage_outputs
=
filtered_outputs
grad_tensors
=
pyobj_map
(
filtered_grads
,
fn
=
lambda
x
:
x
.
to
(
self
.
device
),
process_types
=
torch
.
Tensor
)
# torch rpc doesn't support args or rets in GPU
autograd
.
backward
(
stage_outputs
,
grad_tensors
=
grad_tensors
)
# collect grad of input tensor
consume_result
=
[]
if
not
is_first_stage
:
pytree_map
(
stage_input_args
,
lambda
x
:
consume_result
.
append
(
x
.
grad
),
process_types
=
torch
.
Tensor
)
pytree_map
(
stage_input_kwargs
,
lambda
x
:
consume_result
.
append
(
x
.
grad
),
process_types
=
torch
.
Tensor
)
# In current design, input mush be a flatten args.
for
arg
in
stage_input_args
:
if
isinstance
(
arg
,
torch
.
Tensor
):
consume_result
.
append
(
arg
.
grad
)
else
:
consume_result
.
append
(
None
)
consume_result
=
pyobj_map
(
consume_result
,
fn
=
lambda
x
:
x
.
to
(
'cpu'
),
process_types
=
torch
.
Tensor
)
# torch rpc doesn't support args or rets in GPU
else
:
raise
TypeError
(
f
"Unknown phase appears in _consume_work_item_by_phase
{
phase
}
"
)
...
...
@@ -532,11 +953,11 @@ class WorkerBase(ABC):
def
_hook_before_step
(
self
):
pass
def
_reset_context
(
self
):
self
.
forward_times
=
0
self
.
backward_times
=
0
self
.
outstanding
=
0
self
.
_initialize_outstanding_range
()
# install the main loop to wait for next batch input
def
_wait_for_reset
(
self
):
with
self
.
reset_condition
:
self
.
reset_condition
.
wait_for
(
lambda
:
self
.
reset
)
self
.
reset
=
False
# do the main loop to consume ready_list
def
_work_loop
(
self
):
...
...
@@ -547,10 +968,10 @@ class WorkerBase(ABC):
# main loop
while
True
:
work_item_key
=
self
.
_get_work_item_key
()
# move current work item to output_list to activate subscribe in advance
with
self
.
work_list_condition_lock
:
work_item
=
self
.
work_list
.
pop
(
work_item_key
)
self
.
work_list_condition_lock
.
wait_for
(
lambda
:
work_item_key
in
self
.
work_list
)
work_item
=
self
.
work_list
[
work_item_key
]
with
self
.
output_list_condition_lock
:
# assert work_item_key not in self.output_list
...
...
@@ -559,27 +980,37 @@ class WorkerBase(ABC):
consume_result
=
self
.
_consume_work_item_by_phase
(
work_item
)
with
self
.
work_list_condition_lock
:
self
.
work_list
.
pop
(
work_item_key
)
work_item
.
output
.
set_result
(
consume_result
)
# if is last step in one batch reset context and do step
if
self
.
_is_last_step
(
work_item
):
self
.
_hook_before_step
()
if
hasattr
(
self
,
'optimizer'
)
and
not
work_item
.
forward_only
:
self
.
step
()
self
.
_reset_context
()
self
.
_wait_for_reset
()
# reset context and resume loop
def
reset_context
(
self
):
self
.
forward_times
=
0
self
.
backward_times
=
0
self
.
outstanding
=
0
self
.
_initialize_outstanding_range
()
with
self
.
work_list_condition_lock
:
self
.
work_list
.
clear
()
with
self
.
output_list_condition_lock
:
self
.
output_list
.
clear
()
with
self
.
reset_condition
:
self
.
reset
=
True
self
.
reset_condition
.
notify_all
()
def
initialize_optimizer
(
self
,
optimizer_class
:
type
,
**
kwargs
):
self
.
optimizer
:
optim
.
Optimizer
=
optimizer_class
(
self
.
module_partition
.
parameters
(),
**
kwargs
)
self
.
step_lock
=
threading
.
Lock
()
self
.
step_lock
.
acquire
()
def
wait_for_step
(
self
):
self
.
step_lock
.
acquire
()
def
step
(
self
):
self
.
_hook_before_step
()
self
.
optimizer
.
step
()
self
.
optimizer
.
zero_grad
()
self
.
step_lock
.
release
()
class
PipelineEngineBase
(
ABC
,
nn
.
Module
):
...
...
@@ -611,8 +1042,6 @@ class PipelineEngineBase(ABC, nn.Module):
self
.
pp_rank_to_worker_rref
:
Dict
[
int
,
PyRRef
]
=
dict
()
self
.
step_futs
:
List
[
Future
]
=
[]
self
.
_check_argument
()
self
.
_create_pp_rank_to_rpc_worker_id
()
self
.
_create_pp_rank_to_module_partition_id
()
...
...
@@ -639,7 +1068,7 @@ class PipelineEngineBase(ABC, nn.Module):
def
_create_pp_rank_to_rpc_worker_id
(
self
)
->
None
:
"""create a map from model partition to stage_id, which is useful when use_interleave is True.
e.g. If a model is splited into 4 parts, which means stage_num is 2, chunk is 2, then
e.g. If a model is splited into 4 parts, which means stage_num is 2, chunk is 2, then
pp_rank_to_rpc_worker_id = [0, 1, 0, 1], that means first and third part
of partitions will be moved to device 0 and the others to device 1
"""
...
...
@@ -692,6 +1121,15 @@ class PipelineEngineBase(ABC, nn.Module):
for
fut
in
sync_futs
:
fut
.
wait
()
def
remote_numels
(
self
)
->
Dict
[
int
,
int
]:
numels
=
{}
actual_stage_num
=
self
.
_get_actual_stage_num
()
for
stage_id
in
range
(
actual_stage_num
):
worker_rref
=
self
.
pp_rank_to_worker_rref
[
stage_id
]
numel
=
worker_rref
.
rpc_sync
().
get_numels
()
numels
[
stage_id
]
=
numel
return
numels
def
remote_parameters
(
self
)
->
Dict
[
int
,
List
[
torch
.
Tensor
]]:
parameters
=
{}
actual_stage_num
=
self
.
_get_actual_stage_num
()
...
...
@@ -728,9 +1166,14 @@ class PipelineEngineBase(ABC, nn.Module):
ret_future
[
pp_rank
][
microbatch_id
-
actual_stage_num
].
wait
()
else
:
key
=
UniqueKey
(
microbatch_id
-
actual_stage_num
,
Phase
.
BACKWARD
)
futs
=
[]
for
pp_rank
in
input_pp_ranks
:
worker_rref
=
self
.
pp_rank_to_worker_rref
[
pp_rank
]
worker_rref
.
rpc_sync
().
get_output_by_key
(
key
)
fut
=
worker_rref
.
rpc_async
().
get_output_by_key
(
key
,
ref_use
=
True
,
offsets
=
[])
futs
.
append
(
fut
)
for
fut
in
futs
:
fut
.
wait
()
def
_create_ret_future
(
self
,
output_pp_ranks
:
List
[
int
])
->
Dict
[
int
,
List
[
Future
]]:
num_microbatches
=
self
.
num_microbatches
...
...
@@ -748,6 +1191,7 @@ class PipelineEngineBase(ABC, nn.Module):
# TODO : add relationship between output_pp_ranks and parts of microlabels
worker_rref
.
remote
().
set_labels
(
microbatch_id
,
microlabels
)
# TODO(jiangziyue) : get model output with single value, instead of merging into last stage.
def
_subscribe_forward
(
self
,
microbatch_id
:
int
,
output_pp_ranks
:
List
[
int
],
ret_future
:
Dict
[
int
,
List
[
Future
]]):
key
=
UniqueKey
(
microbatch_id
,
Phase
.
FORWARD
)
for
pp_rank
in
output_pp_ranks
:
...
...
@@ -756,10 +1200,16 @@ class PipelineEngineBase(ABC, nn.Module):
def
_ensure_backward
(
self
,
forward_only
:
bool
,
input_pp_ranks
:
List
[
int
]):
if
not
forward_only
:
backward_result
=
[]
for
pp_rank
in
input_pp_ranks
:
worker_rref
=
self
.
pp_rank_to_worker_rref
[
pp_rank
]
key
=
UniqueKey
(
self
.
num_microbatches
-
1
,
Phase
.
BACKWARD
)
worker_rref
.
rpc_sync
().
get_output_by_key
(
key
)
fut
=
worker_rref
.
rpc_async
().
get_output_by_key
(
key
,
offsets
=
[])
# only ensure the res exists, no need for real data.
backward_result
.
append
(
fut
)
for
fut
in
backward_result
:
fut
.
wait
()
def
_collect_forward_result
(
self
,
output_pp_ranks
:
List
[
int
],
ret_future
:
Dict
[
int
,
List
[
Future
]]):
forward_result
=
[]
...
...
@@ -776,6 +1226,17 @@ class PipelineEngineBase(ABC, nn.Module):
return
forward_result
def
_reset_worker
(
self
):
actual_stage_num
=
self
.
_get_actual_stage_num
()
reset_futs
:
List
[
Future
]
=
[]
for
pp_rank
in
range
(
actual_stage_num
):
worker_rref
=
self
.
pp_rank_to_worker_rref
[
pp_rank
]
fut
=
worker_rref
.
rpc_async
().
reset_context
()
reset_futs
.
append
(
fut
)
for
fut
in
reset_futs
:
fut
.
wait
()
def
forward_backward
(
self
,
batch
:
torch
.
Tensor
,
labels
:
torch
.
Tensor
=
None
,
forward_only
:
bool
=
False
):
batch_lengths
=
get_batch_lengths
(
batch
)
batch_length
=
batch_lengths
[
0
]
...
...
@@ -800,7 +1261,7 @@ class PipelineEngineBase(ABC, nn.Module):
for
microbatch_id
in
range
(
num_microbatches
):
# control data input speed
# to prevent exceed of wait limitations
self
.
_consume_constraint
(
microbatch_id
,
forward_only
,
input_pp_ranks
,
output_pp_ranks
,
ret_future
)
#
self._consume_constraint(microbatch_id, forward_only, input_pp_ranks, output_pp_ranks, ret_future)
batch_start
=
microbatch_size
*
microbatch_id
batch_end
=
min
(
batch_start
+
microbatch_size
,
batch_length
)
...
...
@@ -824,11 +1285,9 @@ class PipelineEngineBase(ABC, nn.Module):
forward_result
=
self
.
_collect_forward_result
(
output_pp_ranks
,
ret_future
)
if
not
forward_only
and
hasattr
(
self
,
'optimizer_class'
):
# wait for all step
for
pp_rank
in
self
.
pp_rank_to_worker_rref
:
worker_rref
=
self
.
pp_rank_to_worker_rref
[
pp_rank
]
worker_rref
.
rpc_sync
().
wait_for_step
()
self
.
step
()
self
.
_reset_worker
()
# reset worker attributes for next batch
return
forward_result
def
initialize_optimizer
(
self
,
optimizer_class
:
type
,
**
kwargs
):
...
...
@@ -839,10 +1298,11 @@ class PipelineEngineBase(ABC, nn.Module):
def
step
(
self
):
actual_stage_num
=
self
.
_get_actual_stage_num
()
step_futs
:
List
[
Future
]
=
[]
for
pp_rank
in
range
(
actual_stage_num
):
worker_rref
=
self
.
pp_rank_to_worker_rref
[
pp_rank
]
fut
=
worker_rref
.
rpc_async
().
step
()
self
.
step_futs
.
append
(
fut
)
step_futs
.
append
(
fut
)
for
fut
in
self
.
step_futs
:
for
fut
in
step_futs
:
fut
.
wait
()
colossalai/pipeline/rpc/_pipeline_schedule.py
View file @
e532679c
...
...
@@ -3,11 +3,12 @@ from typing import Callable, Dict, List
import
torch
import
torch.distributed
as
dist
from
colossalai.pipeline.pipeline_process_group
import
ppg
from
colossalai.pipeline.rpc._pipeline_base
import
(
Phase
,
PipelineEngineBase
,
UniqueKey
,
WorkerBase
,
WorkItem
)
from
torch._C._distributed_rpc
import
PyRRef
from
torch.futures
import
Future
from
colossalai.pipeline.pipeline_process_group
import
ppg
from
colossalai.pipeline.rpc._pipeline_base
import
Phase
,
PipelineEngineBase
,
UniqueKey
,
WorkerBase
,
WorkItem
# Implementation of different Pipeline schedule
# <strategy>Worker defines the worker for each stage
# <strategy>PipelineEngine is the class for use
...
...
@@ -86,12 +87,9 @@ class OneFOneBWorker(WorkerBase):
outstanding_min
=
actual_stage_num
-
pp_rank
-
1
outstanding_max
=
actual_stage_num
-
pp_rank
self
.
outstanding_range
=
(
outstanding_min
,
outstanding_max
)
el
if
target_key
.
microbatch_id
==
num_microbatches
-
1
:
if
target_key
.
microbatch_id
==
num_microbatches
-
1
:
self
.
outstanding_range
=
(
0
,
0
)
with
self
.
work_list_condition_lock
:
self
.
work_list_condition_lock
.
wait_for
(
lambda
:
target_key
in
self
.
work_list
)
return
target_key
...
...
colossalai/pipeline/rpc/utils.py
View file @
e532679c
...
...
@@ -6,11 +6,25 @@ from typing import Any, Callable, Dict, List, Tuple, Type, Union
import
torch
import
torch.distributed.rpc
as
rpc
import
torch.multiprocessing
as
mp
from
colossalai.initialize
import
launch
from
colossalai.pipeline.pipeline_process_group
import
ppg
from
torch._C._distributed_rpc
import
_is_current_rpc_agent_set
from
torch.futures
import
Future
from
colossalai.initialize
import
launch
from
colossalai.pipeline.pipeline_process_group
import
ppg
def
pyobj_map
(
obj
:
Any
,
fn
:
Callable
,
process_types
:
Union
[
Type
,
Tuple
[
Type
]]
=
())
->
Any
:
if
isinstance
(
obj
,
process_types
):
return
fn
(
obj
)
elif
type
(
obj
)
is
dict
:
return
{
k
:
pyobj_map
(
obj
[
k
],
fn
,
process_types
)
for
k
in
obj
}
elif
type
(
obj
)
is
tuple
:
return
tuple
(
pyobj_map
(
o
,
fn
,
process_types
)
for
o
in
obj
)
elif
type
(
obj
)
is
list
:
return
list
(
pyobj_map
(
o
,
fn
,
process_types
)
for
o
in
obj
)
else
:
return
obj
def
pytree_map
(
obj
:
Any
,
fn
:
Callable
,
process_types
:
Union
[
Type
,
Tuple
[
Type
]]
=
(),
map_all
:
bool
=
False
)
->
Any
:
"""process object recursively, like pytree
...
...
@@ -19,10 +33,10 @@ def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] =
obj (:class:`Any`): object to process
fn (:class:`Callable`): a function to process subobject in obj
process_types (:class: `type | tuple[type]`): types to determine the type to process
map_all (:class: `bool`): if map_all is True, then any type of element will use fn
map_all (:class: `bool`): if map_all is True, then any type of element will use fn
Returns:
:class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn`
:class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn`
"""
if
isinstance
(
obj
,
dict
):
return
{
k
:
pytree_map
(
obj
[
k
],
fn
,
process_types
,
map_all
)
for
k
in
obj
}
...
...
@@ -137,5 +151,5 @@ def parse_args():
parser
.
add_argument
(
'--device'
,
type
=
str
,
choices
=
[
'cpu'
,
'cuda'
],
default
=
'cuda'
)
parser
.
add_argument
(
'--master_addr'
,
type
=
str
,
default
=
'localhost'
)
parser
.
add_argument
(
'--master_port'
,
type
=
str
,
default
=
'29020'
)
parser
.
add_argument
(
'--num_worker_threads'
,
type
=
str
,
default
=
128
)
parser
.
add_argument
(
'--num_worker_threads'
,
type
=
int
,
default
=
128
)
return
parser
.
parse_args
()
colossalai/pipeline/utils.py
View file @
e532679c
...
...
@@ -6,6 +6,7 @@ from colossalai.logging import get_dist_logger
from
colossalai.nn.layer.utils
import
CheckpointModule
from
typing
import
List
from
collections
import
OrderedDict
def
_binary_partition
(
weights
:
List
,
start
:
int
,
end
:
int
):
"""Returns the binary partition position of `weights`, given the start
...
...
@@ -159,8 +160,10 @@ def build_kwargs_for_module(function, input_tensor, kw_dict):
kwargs_offset
=
0
elif
isinstance
(
input_tensor
,
torch
.
Tensor
):
kwargs_offset
=
1
else
:
assert
isinstance
(
input_tensor
,
tuple
),
f
'input_tensor should be a torch.Tensor or a tuple object.'
elif
isinstance
(
input_tensor
,
(
tuple
,
OrderedDict
)):
#assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.'
# Huggingface will take their own structures based on OrderedDict as the output
# between layers so we've to close this check.
kwargs_offset
=
len
(
input_tensor
)
args_name_list
=
list
(
sig
.
parameters
.
keys
())
kw_dict
=
{
k
:
v
for
k
,
v
in
kw_dict
.
items
()
if
k
in
args_name_list
[
kwargs_offset
:]}
...
...
colossalai/tensor/__init__.py
View file @
e532679c
from
.process_group
import
ProcessGroup
from
.tensor_spec
import
ColoTensorSpec
from
.distspec
import
ShardSpec
from
.distspec
import
ReplicaSpec
from
.compute_spec
import
ComputeSpec
,
ComputePattern
from
.colo_tensor
import
ColoTensor
from
.
import
distspec
from
.colo_parameter
import
ColoParameter
from
.utils
import
convert_parameter
,
named_params_with_colotensor
from
.dist_spec_mgr
import
DistSpecManager
from
.param_op_hook
import
ParamOpHook
,
ParamOpHookManager
from
.colo_tensor
import
ColoTensor
from
.comm_spec
import
CollectiveCommPattern
,
CommSpec
from
.
import
distspec
from
.compute_spec
import
ComputePattern
,
ComputeSpec
from
.dist_spec_mgr
import
DistSpecManager
from
.distspec
import
ReplicaSpec
,
ShardSpec
from
.param_op_hook
import
ColoParamOpHook
,
ColoParamOpHookManager
from
.process_group
import
ProcessGroup
from
.tensor_spec
import
ColoTensorSpec
from
.utils
import
convert_dim_partition_dict
,
convert_parameter
,
merge_same_dim_mesh_list
,
named_params_with_colotensor
__all__
=
[
'ColoTensor'
,
'convert_parameter'
,
'ComputePattern'
,
'ComputeSpec'
,
'named_params_with_colotensor'
,
'ColoParameter'
,
'distspec'
,
'DistSpecManager'
,
'ParamOpHook'
,
'ParamOpHookManager'
,
'ProcessGroup'
,
'ColoTensorSpec'
,
'ShardSpec'
,
'ReplicaSpec'
,
'CommSpec'
,
'CollectiveCommPattern'
'distspec'
,
'DistSpecManager'
,
'ColoParamOpHook'
,
'ColoParamOpHookManager'
,
'ProcessGroup'
,
'ColoTensorSpec'
,
'ShardSpec'
,
'ReplicaSpec'
,
'CommSpec'
,
'CollectiveCommPattern'
,
'convert_dim_partition_dict'
,
'merge_same_dim_mesh_list'
]
colossalai/tensor/colo_parameter.py
View file @
e532679c
import
torch
from
typing
import
Optional
import
torch
from
colossalai.tensor.colo_tensor
import
ColoTensor
from
colossalai.tensor.const
import
TensorType
from
colossalai.tensor
import
ColoTensorSpec
from
colossalai.tensor.param_op_hook
import
ParamOpHookManager
from
colossalai.tensor.param_op_hook
import
ColoParamOpHookManager
from
colossalai.tensor.tensor_spec
import
ColoTensorSpec
def
filter_colo_parameters
(
*
args
,
**
kwargs
):
param_list
=
[]
def
get_colo_parameters
(
element
)
->
None
:
if
isinstance
(
element
,
list
)
or
isinstance
(
element
,
tuple
):
for
e
in
element
:
get_colo_parameters
(
e
)
elif
isinstance
(
element
,
dict
):
raise
RuntimeError
(
"Found Dict: ColoParameter can't deal with complicated arguments."
)
elif
isinstance
(
element
,
ColoParameter
):
param_list
.
append
(
element
)
return
for
a
in
args
:
get_colo_parameters
(
a
)
for
v
in
kwargs
.
values
():
get_colo_parameters
(
v
)
def
filter_args
(
func
,
*
args
):
return
[
arg
for
arg
in
args
if
func
(
arg
)]
return
param_list
def
replace_args
(
args
,
kwargs
,
new_args
):
...
...
@@ -58,18 +75,18 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
@
classmethod
def
__torch_function__
(
cls
,
func
,
types
,
args
=
...,
kwargs
=
None
):
if
ParamOpHookManager
.
has_hook
():
if
Colo
ParamOpHookManager
.
has_hook
():
if
not
func
.
__name__
.
startswith
(
'__'
):
if
kwargs
is
None
:
kwargs
=
{}
params
=
filter_
args
(
lambda
arg
:
isinstance
(
arg
,
C
olo
P
arameter
),
*
args
,
*
kwargs
.
values
()
)
params
=
filter_
c
olo
_p
arameter
s
(
*
args
,
*
*
kwargs
)
if
len
(
params
)
>
0
:
with
torch
.
_C
.
DisableTorchFunction
():
new_args
=
ParamOpHookManager
.
pre_op
(
params
,
*
args
,
*
kwargs
.
values
())
new_args
=
Colo
ParamOpHookManager
.
pre_op
(
params
,
*
args
,
*
kwargs
.
values
())
args
,
kwargs
=
replace_args
(
args
,
kwargs
,
new_args
)
ret
=
super
().
__torch_function__
(
func
,
types
,
args
,
kwargs
)
with
torch
.
_C
.
DisableTorchFunction
():
ret
=
ParamOpHookManager
.
post_op
(
params
,
ret
)
ret
=
Colo
ParamOpHookManager
.
post_op
(
params
,
ret
)
return
ret
return
super
().
__torch_function__
(
func
,
types
,
args
,
kwargs
)
...
...
colossalai/tensor/colo_tensor.py
View file @
e532679c
from
.op_wrapper
import
_COLOSSAL_OPS
from
.const
import
TensorType
from
copy
import
copy
import
torch
from
functools
import
lru_cache
from
typing
import
Callable
,
Optional
,
Set
import
torch
from
colossalai.tensor
import
ColoTensorSpec
from
colossalai.tensor
import
ProcessGroup
,
ReplicaSpec
from
colossalai.tensor.dist_spec_mgr
import
DistSpecManager
from
colossalai.tensor.distspec
import
_DistSpec
,
DistPlacementPattern
from
typing
import
Optional
,
Set
,
Callable
from
colossalai.tensor.distspec
import
DistPlacementPattern
,
ReplicaSpec
,
_DistSpec
from
colossalai.tensor.process_group
import
ProcessGroup
from
colossalai.tensor.tensor_spec
import
ColoTensorSpec
from
.const
import
TensorType
from
.op_wrapper
import
_COLOSSAL_OPS
@
lru_cache
(
None
)
...
...
@@ -55,27 +57,29 @@ class ColoTensor(torch.Tensor):
The Colotensor can be initialized with a PyTorch tensor in the following ways.
>>> pg = ProcessGroup()
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
)
>>> # The tensor passed in is a tensor after sharding but not a global tensor.
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
>>> dims=[0],
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
>>> dims=[0],
>>> num_partitions=[world_size])
>>> tensor_spec = ColoTensorSpec(pg, shard_spec)
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
"""
torch_major
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
torch_minor
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
def
__new__
(
cls
,
data
:
torch
.
Tensor
,
spec
:
ColoTensorSpec
)
->
'ColoTensor'
:
"""
The signature of the __new__ has to be consistent with the torch.Tensor.
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (TensorSpec, optional): the tensor spec of initialization.
Returns:
ColoTensor: a ColoTensor wrappers the data.
"""
...
...
@@ -100,7 +104,6 @@ class ColoTensor(torch.Tensor):
self
.
process_group
=
spec
.
pg
self
.
_type
=
TensorType
.
NONMODEL
self
.
_graph_node
=
None
def
has_compute_spec
(
self
)
->
bool
:
return
self
.
compute_spec
is
not
None
...
...
@@ -112,9 +115,9 @@ class ColoTensor(torch.Tensor):
return
self
.
process_group
def
set_process_group
(
self
,
pg
:
ProcessGroup
):
"""set_process_group
"""set_process_group
change the pg of the ColoTensor. Note that the valid use cases is limited.
Only existing pg is DP and dist spec is REPLICaTE is valid
.
It works for the target pg is DP and TP only and current dist spec of the Tensor is Replica
.
Args:
pg (ProcessGroup): target pg
...
...
@@ -124,10 +127,10 @@ class ColoTensor(torch.Tensor):
# if the new pg is the same as the old pg, just returns
if
self
.
process_group
==
pg
:
return
assert
self
.
process_group
.
tp_world_size
()
==
1
,
\
"Can not set_process_group on a ColoTensor whose process_group
has tp
world group"
assert
self
.
process_group
.
tp_world_size
()
==
1
or
self
.
process_group
.
dp_world_size
()
==
1
,
\
"Can not set_process_group on a ColoTensor whose process_group
is both tp > 1 and
world group
> 1
"
assert
self
.
dist_spec
.
placement
.
value
==
'r'
,
\
"Can not set_process_group on a ColoTensor whose dist spec is not R
EPLICATE
"
"Can not set_process_group on a ColoTensor whose dist spec is not R
eplica
"
self
.
process_group
=
pg
...
...
@@ -135,7 +138,7 @@ class ColoTensor(torch.Tensor):
return
self
.
process_group
.
tp_world_size
()
def
set_dist_spec
(
self
,
dist_spec
:
_DistSpec
):
"""set_dist_spec
"""set_dist_spec
set dist spec and change the payloads.
Args:
...
...
@@ -166,6 +169,16 @@ class ColoTensor(torch.Tensor):
if
func
in
_COLOSSAL_OPS
:
func
=
_COLOSSAL_OPS
[
func
]
if
cls
.
torch_major
>
1
or
(
cls
.
torch_major
==
1
and
cls
.
torch_minor
>=
12
):
# in order to trigger pre-op hook in the forward of checkpoint module
# we have to capture the `backward` function
# and make sure that it does not in `torch._C.DisableTorchFunction()` context
if
func
is
torch
.
Tensor
.
backward
:
assert
len
(
args
)
==
1
# only has 1 paramter
backward_tensor
=
torch
.
Tensor
(
args
[
0
])
tensor_kwargs
=
{
k
:
torch
.
Tensor
(
v
)
if
torch
.
is_tensor
(
v
)
else
v
for
k
,
v
in
kwargs
.
items
()}
return
backward_tensor
.
backward
(
**
tensor_kwargs
)
with
torch
.
_C
.
DisableTorchFunction
():
ret
=
func
(
*
args
,
**
kwargs
)
if
func
in
_get_my_nowrap_functions
():
...
...
@@ -178,7 +191,7 @@ class ColoTensor(torch.Tensor):
return
f
'ColoTensor:
\n
{
super
().
__repr__
()
}
\n
{
self
.
dist_spec
}
\n
{
self
.
process_group
}
\n
{
self
.
compute_spec
}
'
def
_redistribute
(
self
,
dist_spec
:
_DistSpec
)
->
None
:
"""_redistribute
"""_redistribute
Note the function will not handle the logic of backward propagation!
It is used during model tensor initializations as an internal function.
...
...
@@ -191,12 +204,12 @@ class ColoTensor(torch.Tensor):
self
.
dist_spec
=
dist_spec
def
redistribute
(
self
,
dist_spec
:
_DistSpec
,
pg
:
Optional
[
ProcessGroup
]
=
None
)
->
'ColoTensor'
:
"""redistribute
"""redistribute
Redistribute the tensor among processes. The rule is like this:
1. If the pg is None, then redistribute the tensor payload among the TP process group. Keep the
DP process group not changed.
2. If the pg is not not None and not equal to the current process group.
First, convert the tensor as replicated among the TP process group.
Second, reset the process group to the new pg.
...
...
@@ -220,7 +233,7 @@ class ColoTensor(torch.Tensor):
return
ColoTensor
.
from_torch_tensor
(
ret
,
ColoTensorSpec
(
pg
=
pg
,
dist_attr
=
dist_spec
))
def
to_replicate_
(
self
):
"""to_replicate_
"""to_replicate_
an inline member function, converting dist spec of the tensor to REPLICATE
"""
...
...
colossalai/tensor/comm_spec.py
View file @
e532679c
...
...
@@ -23,9 +23,9 @@ def _all_gather(tensor, comm_spec):
torch
.
zeros
(
tensor
.
shape
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
for
_
in
range
(
comm_spec
.
device_mesh
.
mesh_shape
[
comm_spec
.
logical_process_axis
])
]
tensor
=
tensor
group
=
process_group
dist
.
all_gather
(
tensor_list
,
tensor
,
group
=
group
)
# without this contiguous operation, the all gather may get some unexpected results.
tensor
=
tensor
.
contiguous
()
dist
.
all_gather
(
tensor_list
,
tensor
,
group
=
process_
group
)
output
=
torch
.
cat
(
tuple
(
tensor_list
),
comm_spec
.
gather_dim
).
contiguous
()
return
output
...
...
@@ -37,11 +37,10 @@ def _split(tensor, comm_spec):
process_groups_list
=
comm_spec
.
device_mesh
.
process_groups_dict
[
comm_spec
.
logical_process_axis
]
for
rank_list
,
_
in
process_groups_list
:
if
dist
.
get_rank
()
in
rank_list
:
tensor
=
tensor
dim
=
comm_spec
.
shard_dim
length
=
tensor
.
shape
[
comm_spec
.
shard_dim
]
//
len
(
rank_list
)
start
=
length
*
rank_list
.
index
(
dist
.
get_rank
())
output
=
torch
.
narrow
(
tensor
,
dim
,
start
,
length
)
output
=
torch
.
narrow
(
tensor
,
dim
,
start
,
length
)
.
contiguous
()
return
output
...
...
@@ -69,17 +68,145 @@ def _all_to_all(tensor, comm_spec):
return
output
def
_all_reduce
(
tensor
,
comm_spec
):
def
_all_reduce
(
tensor
,
comm_spec
,
async_op
=
False
):
'''
Implement all reduce operation on device mesh based on information provided by comm_spec.
'''
process_groups_list
=
comm_spec
.
device_mesh
.
process_groups_dict
[
comm_spec
.
logical_process_axis
]
for
rank_list
,
process_group
in
process_groups_list
:
if
dist
.
get_rank
()
in
rank_list
:
dist
.
all_reduce
(
tensor
,
op
=
ReduceOp
.
SUM
,
group
=
process_group
)
if
not
tensor
.
is_contiguous
():
tensor
=
tensor
.
contiguous
()
dist
.
all_reduce
(
tensor
,
op
=
ReduceOp
.
SUM
,
group
=
process_group
,
async_op
=
async_op
)
return
tensor
def
_mix_gather
(
tensor
,
comm_spec
):
'''
Implement mix gather operation on device mesh based on information provided by comm_spec.
Mix gather is the all-gather operation on all devices in the device_mesh(FlattenDeviceMesh) of the comm_spec. It is
different from _all_gather because _mix_gather does all-gather in two dimensions of device mesh, while _all_gather
only does all-gather in one dimension.
Assume index of f and b target pairs are 'f' and 'b'
ShardingSpec => gather_dim, logical_process_axes
S0S1 => [b, f], (1, 0)
S1S0 => [b, f], (0, 1)
S01R => [f], (1, 1)
RS01 => [b], (1, 1)
Example:
mesh_shape = (2,4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7]]
# return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]}
S0S1:
leading_group_dim = 1
process_group = "[0, 1, 2, 3, 4, 5, 6, 7]"
tensor_list = [(0,0),(0,1),(0,2),(0,3),(1,0),(1,1),(1,2),(1,3)] # [(slice_id_f, slice_id_b),...]
mesh_shape = (2,4)
cat_slice = [4,2]
tmp_tensor_list = [(...,shape[f],shape[b]*4,...),(...,shape[f],shape[b]*4,...)]
tmp_tensor_list[0] = torch.cat(((0,0),(0,1),(0,2),(0,3)), dim=b)
tmp_tensor_list[1] = torch.cat(((1,0),(1,1),(1,2),(1,3)), dim=b)
output = torch.cat((tmp_tensor_list[0],tmp_tensor_list[1]), dim=a)
S1S0:
leading_group_dim = 0
process_group = "[0, 4, 1, 5, 2, 6, 3, 7]"
tensor_list = [(0,0),(0,1),(1,0),(1,1),(2,0),(2,1),(3,0),(3,1)]
mesh_shape = (2,4)
cat_slice = [2,4]
tmp_tensor_list = [(...,shape[f],shape[b]*2,...),(...,shape[f],shape[b]*2,...),(...,shape[f],shape[b]*2,...),(...,shape[f],shape[b]*2,...)]
tmp_tensor_list[0] = torch.cat(((0,0),(0,1)), dim=b)
tmp_tensor_list[1] = torch.cat(((1,0),(1,1)), dim=b)
tmp_tensor_list[2] = torch.cat(((2,0),(2,1)), dim=b)
tmp_tensor_list[3] = torch.cat(((3,0),(3,1)), dim=b)
S10R:
leading_group_dim = 0
process_group = "[0, 4, 1, 5, 2, 6, 3, 7]"
tensor_list = [(0,0),(1,0),(2,0),(3,0),(4,0),(5,0),(6,0),(7,0)]
S01R:
leading_group_dim = 1
process_group = "[0, 1, 2, 3, 4, 5, 6, 7]"
tensor_list = [(0,0),(1,0),(2,0),(3,0),(4,0),(5,0),(6,0),(7,0)]
'''
total_slices
=
comm_spec
.
device_mesh
.
mesh_shape
[
0
]
tensor_list
=
[
torch
.
zeros
(
tensor
.
shape
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
for
_
in
range
(
total_slices
)]
leading_group_dim
=
comm_spec
.
logical_process_axes
[
0
]
assert
len
(
comm_spec
.
device_mesh
.
process_groups_dict
)
==
1
_
,
process_group
=
comm_spec
.
device_mesh
.
process_groups_dict
[
0
][
0
]
process_number_list
=
comm_spec
.
device_meshes
.
process_number_dict
[
leading_group_dim
]
# Global all_gather
dist
.
all_gather
(
tensor_list
,
tensor
,
group
=
process_group
)
# This is very ugly. I'm figuring out more elegant methods
tensor_list_sorted
=
[
torch
.
zeros
(
tensor
.
shape
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
for
_
in
range
(
total_slices
)
]
for
i
in
range
(
total_slices
):
tensor_list_sorted
[
i
]
=
tensor_list
[
process_number_list
[
i
]]
tensor_list
=
tensor_list_sorted
if
comm_spec
.
logical_process_axes
[
0
]
==
comm_spec
.
logical_process_axes
[
1
]:
output
=
torch
.
cat
(
tuple
(
tensor_list
),
comm_spec
.
gather_dim
[
0
]).
contiguous
()
else
:
mesh_shape
=
comm_spec
.
device_meshes
.
mesh_shape
cat_slice
=
[
mesh_shape
[
comm_spec
.
logical_process_axes
[
0
]],
mesh_shape
[
comm_spec
.
logical_process_axes
[
1
]]]
tmp_tensor_shape
=
list
(
tensor
.
shape
)
tmp_tensor_shape
[
comm_spec
.
gather_dim
[
0
]]
*=
cat_slice
[
0
]
tmp_tensor_shape
=
torch
.
Size
(
tmp_tensor_shape
)
tmp_tensor_list
=
[
torch
.
zeros
(
tmp_tensor_shape
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
for
_
in
range
(
cat_slice
[
1
])
]
for
i
in
range
(
cat_slice
[
1
]):
tmp_tensor_list
[
i
]
=
torch
.
cat
(
tuple
(
tensor_list
[
i
*
cat_slice
[
0
]:(
i
+
1
)
*
cat_slice
[
0
]]),
comm_spec
.
gather_dim
[
0
]).
contiguous
()
output
=
torch
.
cat
(
tuple
(
tmp_tensor_list
),
comm_spec
.
gather_dim
[
1
]).
contiguous
()
return
output
def
_mix_split
(
tensor
,
comm_spec
):
'''
Implement mix split operation. Mix split is only called for the backward of mix gather (Use ctx to keep consistent)
Mix split shards the tensor on device mesh based on information provided by comm_spec. It is different from split
because _mix_split shards the tensor in two dimensions of device mesh, while _split only shards in one dimension.
Assume index of f and b target pairs are 'f' and 'b'
S0S1 => [b, f], (1, 0)
S1S0 => [b, f], (0, 1)
S01R => [f], (0, 0)
RS01 => [b], (0, 0)
Example:
mesh_shape = (2,4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7]]
# return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]}
'''
mesh_shape
=
comm_spec
.
device_meshes
.
mesh_shape
dim
=
comm_spec
.
gather_dim
total_slices
=
comm_spec
.
device_mesh
.
mesh_shape
[
0
]
# Get global rank
rank
=
dist
.
get_rank
()
leading_group_dim
=
comm_spec
.
logical_process_axes
[
0
]
process_number_list
=
comm_spec
.
device_meshes
.
process_number_dict
[
leading_group_dim
]
rank
=
process_number_list
.
index
(
rank
)
if
comm_spec
.
logical_process_axes
[
0
]
==
comm_spec
.
logical_process_axes
[
1
]:
length
=
tensor
.
shape
[
dim
[
0
]]
//
total_slices
start
=
length
*
rank
output
=
torch
.
narrow
(
tensor
,
dim
[
0
],
start
,
length
).
contiguous
()
else
:
tensor_shape
=
[
tensor
.
shape
[
dim
[
0
]],
tensor
.
shape
[
dim
[
1
]]]
rank_slice
=
[
mesh_shape
[
comm_spec
.
logical_process_axes
[
0
]],
mesh_shape
[
comm_spec
.
logical_process_axes
[
1
]]]
length
=
[
tensor_shape
[
0
]
//
rank_slice
[
0
],
tensor_shape
[
1
]
//
rank_slice
[
1
]]
start
=
[(
rank
%
rank_slice
[
0
])
*
length
[
0
],
(
rank
//
rank_slice
[
0
])
*
length
[
1
]]
tmp_output
=
torch
.
narrow
(
tensor
,
dim
[
0
],
start
[
0
],
length
[
0
]).
contiguous
()
output
=
torch
.
narrow
(
tmp_output
,
dim
[
1
],
start
[
1
],
length
[
1
]).
contiguous
()
return
output
class
_ReduceGrad
(
torch
.
autograd
.
Function
):
"""
A customized communication operation which forward is an identity operation,
...
...
@@ -205,6 +332,22 @@ class _AllToAll(torch.autograd.Function):
return
_all_to_all
(
grad_outputs
,
ctx
.
comm_spec
),
None
class
_MixGatherForwardMixSplitBackward
(
torch
.
autograd
.
Function
):
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_mix_gather
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
,
comm_spec
):
ctx
.
comm_spec
=
comm_spec
return
_mix_gather
(
input_
,
comm_spec
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_mix_split
(
grad_output
,
ctx
.
comm_spec
),
None
def
reduce_grad
(
input_
,
comm_spec
):
return
_ReduceGrad
.
apply
(
input_
,
comm_spec
)
...
...
@@ -225,12 +368,17 @@ def all_to_all(input_, comm_spec):
return
_AllToAll
.
apply
(
input_
,
comm_spec
)
def
mixgather_forward_split_backward
(
input_
,
comm_spec
):
return
_MixGatherForwardMixSplitBackward
.
apply
(
input_
,
comm_spec
)
class
CollectiveCommPattern
(
Enum
):
GATHER_FWD_SPLIT_BWD
=
'gather_fwd_split_bwd'
ALL2ALL_FWD_ALL2ALL_BWD
=
'all2all_fwd_all2all_bwd'
SPLIT_FWD_GATHER_BWD
=
'split_fwd_gather_bwd'
ALLREDUCE_FWD_IDENTITY_BWD
=
'all_reduce_fwd_identity_bwd'
IDENTITY_FWD_ALLREDUCE_BWD
=
'identity_fwd_all_reduce_bwd'
MIXGATHER_FWD_SPLIT_BWD
=
"mixgather_fwd_split_bwd"
class
CommSpec
:
...
...
@@ -256,7 +404,8 @@ class CommSpec:
gather_dim
=
None
,
shard_dim
=
None
,
logical_process_axis
=
None
,
forward_only
=
False
):
forward_only
=
False
,
mix_gather
=
False
):
self
.
comm_pattern
=
comm_pattern
self
.
sharding_spec
=
sharding_spec
self
.
gather_dim
=
gather_dim
...
...
@@ -264,8 +413,14 @@ class CommSpec:
self
.
logical_process_axis
=
logical_process_axis
self
.
forward_only
=
forward_only
if
isinstance
(
self
.
logical_process_axis
,
list
):
self
.
device_mesh
=
self
.
sharding_spec
.
device_mesh
.
flatten_device_mesh
self
.
logical_process_axis
=
0
if
not
mix_gather
:
self
.
device_mesh
=
self
.
sharding_spec
.
device_mesh
.
flatten_device_mesh
self
.
logical_process_axis
=
0
else
:
self
.
device_meshes
=
self
.
sharding_spec
.
device_mesh
.
flatten_device_meshes
self
.
device_mesh
=
self
.
sharding_spec
.
device_mesh
.
flatten_device_mesh
# Create a new member `logical_process_axes` to distinguish from original flatten
self
.
logical_process_axes
=
logical_process_axis
else
:
self
.
device_mesh
=
self
.
sharding_spec
.
device_mesh
...
...
@@ -290,6 +445,10 @@ class CommSpec:
elif
self
.
comm_pattern
==
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
:
res_list
.
append
(
f
"comm_pattern:IDENTITY_FWD_ALLREDUCE_BWD, "
)
res_list
.
append
(
f
"logical_process_axis:
{
self
.
logical_process_axis
}
)"
)
elif
self
.
comm_pattern
==
CollectiveCommPattern
.
MIXGATHER_FWD_SPLIT_BWD
:
res_list
.
append
(
f
"comm_pattern:MIXGATHER_FWD_SPLIT_BWD, "
)
res_list
.
append
(
f
"gather_dim:
{
self
.
gather_dim
}
, "
)
res_list
.
append
(
f
"logical_process_asex:
{
self
.
logical_process_axes
}
)"
)
return
''
.
join
(
res_list
)
...
...
@@ -325,6 +484,11 @@ class CommSpec:
forward_communication_cost
=
10
backward_communication_cost
=
self
.
device_mesh
.
all_gather_cost
(
comm_size
,
self
.
logical_process_axis
)
if
self
.
comm_pattern
==
CollectiveCommPattern
.
MIXGATHER_FWD_SPLIT_BWD
:
# no need for axis because all devices are used in mix_gather
forward_communication_cost
=
self
.
device_mesh
.
mix_gather_cost
(
comm_size
)
backward_communication_cost
=
10
if
self
.
forward_only
:
cost_dict
[
"forward"
]
=
forward_communication_cost
cost_dict
[
"backward"
]
=
0
...
...
@@ -357,4 +521,5 @@ pattern_to_func_dict = {
CollectiveCommPattern
.
SPLIT_FWD_GATHER_BWD
:
split_forward_gather_backward
,
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
:
reduce_input
,
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
:
reduce_grad
,
CollectiveCommPattern
.
MIXGATHER_FWD_SPLIT_BWD
:
mixgather_forward_split_backward
,
}
colossalai/tensor/dist_spec_mgr.py
View file @
e532679c
from
colossalai.tensor.distspec
import
_DistSpec
# from colossalai.nn.layer.utils import divide
from
numpy
import
prod
from
contextlib
import
contextmanager
import
torch
import
torch.distributed
as
dist
# from colossalai.nn.layer.utils import divide
from
numpy
import
prod
from
packaging
import
version
from
colossalai.logging
import
get_dist_logger
from
colossalai.tensor
import
ProcessGroup
from
colossalai.tensor.distspec
import
_DistSpec
from
colossalai.tensor.process_group
import
ProcessGroup
# TODO(jiaruifang) circle import, move the divide to colossalai.commons.
...
...
colossalai/tensor/param_op_hook.py
View file @
e532679c
import
torch
from
contextlib
import
contextmanager
from
abc
import
ABC
,
abstractmethod
from
typing
import
List
,
Tuple
,
Any
from
contextlib
import
contextmanager
from
typing
import
Any
,
List
,
Tuple
import
torch
from
colossalai.tensor.colo_tensor
import
ColoTensor
from
colossalai.tensor
import
ColoTensorSpec
from
colossalai.tensor
.tensor_spec
import
ColoTensorSpec
class
ParamOpHook
(
ABC
):
"""Hook which is triggered by each operation when operands contain ColoParameter.
class
ColoParamOpHook
(
ABC
):
"""
Hook which is triggered by each operation when operands contain ColoParameter.
To customize it, you must inherit this abstract class, and implement ``pre_forward``,
``post_forward``, ``pre_backward`` and ``post_backward``.
These four methods take a list
of ColoParameter.
``post_forward``, ``pre_backward`` and ``post_backward``.
These four methods apply a list
of ColoParameter
as input args
.
"""
@
abstractmethod
...
...
@@ -30,68 +33,79 @@ class ParamOpHook(ABC):
pass
class
ParamOpHookManager
:
"""Manage your param op hooks. It only has static methods.
class
ColoParamOpHookManager
:
"""
Manage your param op hooks. It only has static methods.
The only static method you should call is ``use_hooks(*hooks)``.
"""
hooks
:
Tuple
[
ParamOpHook
,
...]
=
tuple
()
hooks
:
Tuple
[
Colo
ParamOpHook
,
...]
=
tuple
()
@
staticmethod
@
contextmanager
def
use_hooks
(
*
hooks
:
ParamOpHook
):
def
use_hooks
(
*
hooks
:
Colo
ParamOpHook
):
"""Change the param op hooks you use. Nested calling is allowed.
Example:
>>> with ParamOpHookManager.use_hooks(*hooks):
>>> with
Colo
ParamOpHookManager.use_hooks(*hooks):
>>> do_something()
>>> with ParamOpHookManager.use_hooks():
>>> with
Colo
ParamOpHookManager.use_hooks():
>>> // clear hooks
>>> do_something()
"""
try
:
old_param_op_hooks
=
ParamOpHookManager
.
hooks
ParamOpHookManager
.
hooks
=
hooks
old_param_op_hooks
=
Colo
ParamOpHookManager
.
hooks
Colo
ParamOpHookManager
.
hooks
=
hooks
yield
finally
:
ParamOpHookManager
.
hooks
=
old_param_op_hooks
Colo
ParamOpHookManager
.
hooks
=
old_param_op_hooks
@
staticmethod
def
_trigger_pre_forward
(
params
:
List
[
torch
.
Tensor
])
->
None
:
for
hook
in
ParamOpHookManager
.
hooks
:
for
hook
in
Colo
ParamOpHookManager
.
hooks
:
hook
.
pre_forward
(
params
)
@
staticmethod
def
_trigger_post_forward
(
params
:
List
[
torch
.
Tensor
])
->
None
:
for
hook
in
ParamOpHookManager
.
hooks
:
for
hook
in
Colo
ParamOpHookManager
.
hooks
:
hook
.
post_forward
(
params
)
@
staticmethod
def
_trigger_pre_backward
(
params
:
List
[
torch
.
Tensor
])
->
None
:
for
hook
in
ParamOpHookManager
.
hooks
:
for
hook
in
Colo
ParamOpHookManager
.
hooks
:
hook
.
pre_backward
(
params
)
@
staticmethod
def
_trigger_post_backward
(
params
:
List
[
torch
.
Tensor
])
->
None
:
for
hook
in
ParamOpHookManager
.
hooks
:
for
hook
in
Colo
ParamOpHookManager
.
hooks
:
hook
.
post_backward
(
params
)
@
staticmethod
def
pre_op
(
params
:
List
[
torch
.
Tensor
],
*
args
:
Any
)
->
list
:
ParamOpHookManager
.
_trigger_pre_forward
(
params
)
args_info
=
_get_colo_tensors_info
(
*
args
)
rets
=
PreFwdPostBwd
.
apply
(
params
,
*
args
)
return
_update_colo_tensors
(
args_info
,
*
rets
)
ColoParamOpHookManager
.
_trigger_pre_forward
(
params
)
grad_args
,
rear_args
=
_get_grad_args
(
*
args
)
colo_info
=
_get_colo_tensors_info
(
*
grad_args
)
rets
=
PreFwdPostBwd
.
apply
(
params
,
*
grad_args
)
update_args
=
_update_colo_tensors
(
colo_info
,
*
rets
)
if
rear_args
is
None
:
return
update_args
else
:
arg_zero
=
(
tuple
(
update_args
),)
return
arg_zero
+
rear_args
@
staticmethod
def
post_op
(
params
:
List
[
torch
.
Tensor
],
arg
:
Any
)
->
Any
:
ParamOpHookManager
.
_trigger_post_forward
(
params
)
arg
_info
=
_get_colo_tensors_info
(
arg
)
Colo
ParamOpHookManager
.
_trigger_post_forward
(
params
)
colo
_info
=
_get_colo_tensors_info
(
arg
)
ret
=
PostFwdPreBwd
.
apply
(
params
,
arg
)
return
_unpack_args
(
_update_colo_tensors
(
arg_info
,
ret
))
res
=
_update_colo_tensors
(
colo_info
,
ret
)
if
len
(
res
)
==
1
:
return
res
[
0
]
else
:
return
res
@
staticmethod
def
has_hook
()
->
bool
:
return
len
(
ParamOpHookManager
.
hooks
)
>
0
return
len
(
Colo
ParamOpHookManager
.
hooks
)
>
0
class
PreFwdPostBwd
(
torch
.
autograd
.
Function
):
...
...
@@ -99,11 +113,11 @@ class PreFwdPostBwd(torch.autograd.Function):
@
staticmethod
def
forward
(
ctx
,
params
,
*
args
):
ctx
.
params
=
params
return
_unpack_args
(
args
)
return
args
@
staticmethod
def
backward
(
ctx
,
*
grads
):
ParamOpHookManager
.
_trigger_post_backward
(
ctx
.
params
)
Colo
ParamOpHookManager
.
_trigger_post_backward
(
ctx
.
params
)
return
(
None
,)
+
grads
...
...
@@ -116,14 +130,51 @@ class PostFwdPreBwd(torch.autograd.Function):
@
staticmethod
def
backward
(
ctx
,
*
grads
):
ParamOpHookManager
.
_trigger_pre_backward
(
ctx
.
params
)
Colo
ParamOpHookManager
.
_trigger_pre_backward
(
ctx
.
params
)
return
(
None
,)
+
grads
def
_unpack_args
(
args
):
if
len
(
args
)
==
1
:
return
args
[
0
]
return
args
def
_is_grad_tensor
(
obj
)
->
bool
:
if
torch
.
is_tensor
(
obj
):
if
obj
.
grad_fn
is
not
None
or
obj
.
requires_grad
:
return
True
return
False
def
_has_grad_tensor
(
obj
)
->
bool
:
if
isinstance
(
obj
,
tuple
)
or
isinstance
(
obj
,
list
):
for
x
in
obj
:
if
_has_grad_tensor
(
x
):
return
True
return
False
elif
isinstance
(
obj
,
dict
):
for
x
in
obj
.
values
():
if
_has_grad_tensor
(
x
):
return
True
return
False
else
:
return
_is_grad_tensor
(
obj
)
def
_get_grad_args
(
*
args
):
# if there is no grad tensors, do nothing
if
not
_has_grad_tensor
(
args
):
return
args
,
None
# returns the identical args if there is a grad tensor
for
obj
in
args
:
if
_is_grad_tensor
(
obj
):
return
args
,
None
# otherwise, the first arguement should be a tuple of grad tensors
# if there is no grad tensor, the backward of PreFwdPostBwd can't be triggered
arg_zero
=
args
[
0
]
if
not
isinstance
(
arg_zero
,
tuple
):
raise
NotImplementedError
(
"Some torch function is incompatible because of its complcated inputs."
)
check_grad_flag
=
False
for
obj
in
arg_zero
:
check_grad_flag
|=
_is_grad_tensor
(
obj
)
if
not
check_grad_flag
:
raise
NotImplementedError
(
"Some torch function is incompatible because of its complcated inputs."
)
return
arg_zero
,
args
[
1
:]
def
_get_colo_tensors_info
(
*
args
)
->
list
:
...
...
colossalai/tensor/shape_consistency.py
View file @
e532679c
import
math
import
operator
from
copy
import
deepcopy
from
dataclasses
import
dataclass
from
enum
import
Enum
from
functools
import
reduce
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Tuple
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ReduceOp
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
MemoryCost
,
TrainCycleItem
from
colossalai.context.singleton_meta
import
SingletonMeta
from
colossalai.tensor.sharding_spec
import
ShardingSpec
,
ShardingSpecException
,
_DimSpec
from
colossalai.tensor.utils
import
all_gather_simulator
,
all_to_all_simulator
,
shard_simulator
from
colossalai.tensor.sharding_spec
import
ShardingSpec
,
ShardingSpecException
from
colossalai.tensor.utils
import
all_gather_simulator
,
all_to_all_simulator
,
mix_gather_simulator
,
shard_simulator
from
.comm_spec
import
*
...
...
@@ -28,6 +25,15 @@ class ShapeConsistencyOptions:
pass
def
to_global
(
distributed_tensor
:
torch
.
Tensor
,
sharding_spec
:
ShardingSpec
)
->
torch
.
Tensor
:
shape_consistency_manager
=
ShapeConsistencyManager
()
global_sharding_spec
=
ShardingSpec
(
sharding_spec
.
device_mesh
,
sharding_spec
.
entire_shape
,
{})
with
torch
.
no_grad
():
global_tensor
=
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
distributed_tensor
,
sharding_spec
,
global_sharding_spec
)
return
global_tensor
def
set_shape_consistency_options
(
options
:
ShapeConsistencyOptions
):
"""
Configure the shape consistency manager via function call.
...
...
@@ -63,7 +69,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
assert
isinstance
(
value
,
bool
)
self
.
_forward_only
=
value
def
get_all_all_gather_spec
(
self
,
source_spec
,
orig_cost_dict
):
def
get_all_all_gather_spec
(
self
,
source_spec
:
ShardingSpec
,
orig_cost_dict
:
Dict
[
str
,
float
])
->
Dict
[
ShardingSpec
,
float
]:
'''
Get all valid sharding specs from source_spec with single all-gather operation, and
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
...
...
@@ -71,7 +78,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
Argument:
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
orig_cost(float): the original communication cost before this operation.
orig_cost(
Dict[str,
float
]
): the original communication cost before this operation.
Return:
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-gather operation.
...
...
@@ -83,7 +90,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
# device_mesh_shape: (4, 4)
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
shape_consistency_manager = ShapeConsistencyManager()
rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec,
0
)
rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec,
{'forward': 0, 'backward': 0, 'total': 0}
)
print(rst_dict)
Output:
...
...
@@ -134,7 +141,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
pass
return
valid_spec_dict
def
get_all_all_to_all_spec
(
self
,
source_spec
,
orig_cost_dict
):
def
get_all_all_to_all_spec
(
self
,
source_spec
:
ShardingSpec
,
orig_cost_dict
:
Dict
[
str
,
float
])
->
Dict
[
ShardingSpec
,
float
]:
'''
Get all valid sharding specs from source_spec with single all-to-all operation, and
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
...
...
@@ -142,7 +150,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
Argument:
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
orig_cost(float): the original communication cost before this operation.
orig_cost(
Dict[str,
float
]
): the original communication cost before this operation.
Return:
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation.
...
...
@@ -154,7 +162,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
# device_mesh_shape: (4, 4)
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
shape_consistency_manager = ShapeConsistencyManager()
rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec,
0
)
rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec,
{'forward': 0, 'backward': 0, 'total': 0}
)
print(rst_dict)
Output:
...
...
@@ -241,7 +249,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
return
valid_spec_dict
def
get_all_shard_spec
(
self
,
source_spec
,
orig_cost_dict
):
def
get_all_shard_spec
(
self
,
source_spec
:
ShardingSpec
,
orig_cost_dict
):
'''
Get all valid sharding specs from source_spec with single shard operation, and
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
...
...
@@ -261,7 +269,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
# device_mesh_shape: (4, 4)
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
shape_consistency_manager = ShapeConsistencyManager()
rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec,
0
)
rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec,
{'forward': 0, 'backward': 0, 'total': 0}
)
print(rst_dict)
Output:
...
...
@@ -322,7 +330,60 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
pass
return
valid_spec_dict
def
get_all_one_step_transform_spec
(
self
,
source_spec
,
orig_cost_dict
):
def
get_all_mix_gather_spec
(
self
,
source_spec
:
ShardingSpec
,
orig_cost_dict
:
Dict
[
str
,
float
])
->
Dict
[
ShardingSpec
,
float
]:
'''
S0S1 -> RR
S1S0 -> RR
S01R -> RR
RS01 -> RR
'''
valid_spec_dict
=
{}
comm_pathern
=
CollectiveCommPattern
.
MIXGATHER_FWD_SPLIT_BWD
tensor_dims
=
len
(
source_spec
.
entire_shape
)
for
f_index
in
range
(
tensor_dims
-
1
):
for
b_index
in
range
(
f_index
+
1
,
tensor_dims
):
if
(
f_index
not
in
source_spec
.
dim_partition_dict
)
and
(
b_index
not
in
source_spec
.
dim_partition_dict
):
continue
else
:
if
f_index
in
source_spec
.
dim_partition_dict
:
# skip (S10, R) -> (R, R)
if
len
(
f_target_pair
[
1
])
==
2
and
f_target_pair
[
1
][
0
]
>=
f_target_pair
[
1
][
1
]:
continue
f_target_pair
=
(
f_index
,
deepcopy
(
source_spec
.
dim_partition_dict
[
f_index
]))
else
:
f_target_pair
=
(
f_index
,
[])
if
b_index
in
source_spec
.
dim_partition_dict
:
# skip (R, S10) -> (R, R)
if
len
(
b_target_pair
[
1
])
==
2
and
b_target_pair
[
1
][
0
]
>=
b_target_pair
[
1
][
1
]:
continue
b_target_pair
=
(
b_index
,
deepcopy
(
source_spec
.
dim_partition_dict
[
b_index
]))
else
:
b_target_pair
=
(
b_index
,
[])
gather_dim
,
logical_process_axes
=
mix_gather_simulator
(
f_target_pair
,
b_target_pair
)
comm_spec
=
CommSpec
(
comm_pathern
,
sharding_spec
=
source_spec
,
gather_dim
=
gather_dim
,
logical_process_axis
=
logical_process_axes
,
forward_only
=
self
.
forward_only
,
mix_gather
=
True
)
cost_dict
=
comm_spec
.
get_comm_cost
()
new_dim_partition_dict
=
{}
# generate new sharding spec
try
:
new_sharding_spec
=
ShardingSpec
(
source_spec
.
device_mesh
,
source_spec
.
entire_shape
,
dim_partition_dict
=
new_dim_partition_dict
)
for
phase
,
cost
in
cost_dict
.
items
():
cost_dict
[
phase
]
=
cost
+
orig_cost_dict
[
phase
]
valid_spec_dict
[
new_sharding_spec
]
=
(
comm_spec
,
cost_dict
)
except
ShardingSpecException
:
pass
return
valid_spec_dict
def
get_all_one_step_transform_spec
(
self
,
source_spec
:
ShardingSpec
,
orig_cost_dict
)
->
Dict
[
ShardingSpec
,
float
]:
'''
Get all valid sharding specs from source_spec with one step transform, and
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
...
...
@@ -344,7 +405,167 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
valid_spec_dict
.
update
(
self
.
get_all_shard_spec
(
source_spec
,
orig_cost_dict
))
return
valid_spec_dict
def
shape_consistency
(
self
,
source_spec
,
target_spec
):
def
mem_cost
(
self
,
comm_action_sequence
:
List
[
CommSpec
])
->
TrainCycleItem
:
"""memory cost of the communication action sequence
Args:
comm_action_sequence (List[CommSpec]): list of communication actions
Returns:
TrainCycleItem: memory (numel) cost of such comm_action_sequence
"""
def
compute_shape
(
sharding_spec
:
ShardingSpec
):
shape
=
sharding_spec
.
entire_shape
new_shape
=
[]
for
dim
,
shard
in
sharding_spec
.
dim_partition_dict
.
items
():
new_shape
.
append
(
shape
[
dim
]
//
len
(
shard
))
return
new_shape
def
gather_analysis
(
comm_spec
:
CommSpec
,
discard_input
:
bool
,
alloc_numel
:
int
,
peak_numel
:
int
):
"""analyze all_gather memory footprint
all_gather will allocate memory for the output tensor, and there will be temp memory for
all_gather operation, which is twice the size of output tensor
Args:
comm_spec (CommSpec): input CommSpec
discard_input (bool): whether to discard the input tensor
alloc_numel (int): current allocated numel
peak_numel (int): current peak numel
"""
input_shape
=
compute_shape
(
comm_spec
.
sharding_spec
)
input_numel
=
np
.
prod
(
input_shape
)
output_numel
=
input_numel
*
comm_spec
.
device_mesh
.
mesh_shape
[
comm_spec
.
logical_process_axis
]
peak_numel
=
max
(
peak_numel
,
alloc_numel
+
output_numel
*
2
)
alloc_numel
+=
output_numel
if
discard_input
:
alloc_numel
-=
input_numel
return
alloc_numel
,
peak_numel
def
split_analysis
(
comm_spec
:
CommSpec
,
discard_input
:
bool
,
alloc_numel
:
int
,
peak_numel
:
int
):
"""analyze split memory footprint
split will allocate memory for the output tensor if we don't apply shard on the first dimension of
the input tensor. If we apply shard on the first dimension, the `torch.tensor.contiguous()` will not
generate new tensor in this case, so no memory will be allocated.
Args:
comm_spec (CommSpec): input CommSpec
discard_input (bool): whether to discard the input tensor
alloc_numel (int): current allocated numel
peak_numel (int): current peak numel
"""
shard_dim
=
comm_spec
.
shard_dim
if
shard_dim
!=
0
:
# if we don't shard the tensor on the first dimension, the split action will
# generate a new tensor
input_shape
=
compute_shape
(
comm_spec
.
sharding_spec
)
input_numel
=
np
.
prod
(
input_shape
)
output_numel
=
input_numel
//
comm_spec
.
device_mesh
.
mesh_shape
[
comm_spec
.
logical_process_axis
]
alloc_numel
+=
output_numel
peak_numel
=
max
(
peak_numel
,
alloc_numel
)
if
discard_input
:
alloc_numel
-=
input_numel
else
:
# if we shard the tensor on the first dimension, the split action will not generate
# a new tensor, and as it will preserve a reference to the input tensor, we could
# override the discard_input option here
# NOTE: this special case might fail in some weird cases, e.g. if we have three split
# actions in the comm actions sequence, the first split action operate on the second dimension,
# the second split action operate on the first dimension, and the third split action operate, again,
# on the second dimension. Therefore, after the first two actions in the sequence, we will allocate
# memory the same size as the output of first split action. However, the third split action will discard
# the input tensor, and it actually should discard the tensor generated by the first split action, so in
# the current memory estimation framework, we will overestimate the memory usage. But the above case is
# kind of weird, and I think we could ignore it for now.
pass
return
alloc_numel
,
peak_numel
def
reduce_analysis
(
comm_spec
:
CommSpec
,
discard_input
:
bool
,
alloc_numel
:
int
,
peak_numel
:
int
):
"""
a dummy function for reduce memory footprint analysis, as the reduce action doesn't allocate extra memory
"""
return
alloc_numel
,
peak_numel
def
all2all_analysis
(
comm_spec
:
CommSpec
,
discard_input
:
bool
,
alloc_numel
:
int
,
peak_numel
:
int
):
"""analyze all_to_all memory footprint
all_to_all will allocate memory for the output tensor, and temp memory of all_to_all action
is twice the size of output tensor if we shard input tensor on the first dimension, otherwise
the temp memory is three times the size of output tensor
Args:
comm_spec (CommSpec): input CommSpec
discard_input (bool): whether to discard the input tensor
alloc_numel (int): current allocated numel
peak_numel (int): current peak numel
"""
input_shape
=
compute_shape
(
comm_spec
.
sharding_spec
)
input_numel
=
np
.
prod
(
input_shape
)
output_numel
=
input_numel
shard_dim
=
comm_spec
.
shard_dim
if
shard_dim
!=
0
:
peak_numel
=
max
(
peak_numel
,
alloc_numel
+
output_numel
*
3
)
else
:
peak_numel
=
max
(
peak_numel
,
alloc_numel
+
output_numel
*
2
)
alloc_numel
+=
output_numel
if
discard_input
:
alloc_numel
-=
input_numel
return
alloc_numel
,
peak_numel
def
identity_analysis
(
comm_spec
:
CommSpec
,
discard_input
:
bool
,
alloc_numel
:
int
,
peak_numel
:
int
):
"""
a dummy function for identity memory footprint analysis, as the identity action doesn't allocate extra memory
"""
return
alloc_numel
,
peak_numel
pattern_to_func_dict
=
{
CollectiveCommPattern
.
GATHER_FWD_SPLIT_BWD
:
[
gather_analysis
,
split_analysis
],
CollectiveCommPattern
.
ALL2ALL_FWD_ALL2ALL_BWD
:
[
all2all_analysis
,
all2all_analysis
],
CollectiveCommPattern
.
SPLIT_FWD_GATHER_BWD
:
[
split_analysis
,
gather_analysis
],
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
:
[
reduce_analysis
,
identity_analysis
],
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
:
[
identity_analysis
,
reduce_analysis
],
CollectiveCommPattern
.
MIXGATHER_FWD_SPLIT_BWD
:
[],
}
fwd_actions
=
[]
bwd_actions
=
[]
# construct forward and backward comm actions sequence
for
comm_spec
in
comm_action_sequence
:
comm_spec
:
CommSpec
fwd_action
,
bwd_action
=
pattern_to_func_dict
[
comm_spec
.
comm_pattern
]
fwd_actions
.
append
(
fwd_action
)
bwd_actions
.
append
(
bwd_action
)
# analyze memory footprint of forward comm actions sequence
fwd_alloc_numel
=
0
fwd_peak_numel
=
0
for
idx
,
action_spec_pair
in
enumerate
(
zip
(
fwd_actions
,
comm_action_sequence
)):
# the first forward comm action will not discard input
fwd_action
,
comm_spec
=
action_spec_pair
fwd_alloc_numel
,
fwd_peak_numel
=
fwd_action
(
comm_spec
,
False
,
fwd_alloc_numel
,
fwd_peak_numel
)
if
idx
==
0
else
fwd_action
(
comm_spec
,
True
,
fwd_alloc_numel
,
fwd_peak_numel
)
# analyze memory footprint for backward comm actions sequence
bwd_alloc_numel
=
0
bwd_peak_numel
=
0
for
idx
,
action_spec_pair
in
enumerate
(
zip
(
reversed
(
bwd_actions
),
reversed
(
comm_action_sequence
))):
bwd_action
,
comm_spec
=
action_spec_pair
bwd_alloc_numel
,
bwd_peak_numel
=
bwd_action
(
comm_spec
,
False
,
bwd_alloc_numel
,
bwd_peak_numel
)
if
idx
==
0
else
bwd_action
(
comm_spec
,
True
,
bwd_alloc_numel
,
bwd_peak_numel
)
fwd_mem
=
MemoryCost
(
activation
=
fwd_alloc_numel
,
temp
=
fwd_peak_numel
-
fwd_alloc_numel
)
bwd_mem
=
MemoryCost
(
activation
=
bwd_alloc_numel
,
temp
=
bwd_peak_numel
-
bwd_alloc_numel
)
total_mem
=
MemoryCost
(
activation
=
fwd_alloc_numel
+
bwd_alloc_numel
)
return
TrainCycleItem
(
fwd_mem
,
bwd_mem
,
total_mem
)
def
shape_consistency
(
self
,
source_spec
:
ShardingSpec
,
target_spec
:
ShardingSpec
)
->
Tuple
[
List
[
ShardingSpec
],
List
[
CommSpec
],
float
]:
'''
This method will find a path to transform source_spec to target_spec with
a greedy algorithm.
...
...
@@ -450,7 +671,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
raise
RuntimeError
(
f
"Could not find a valid transform path with in
{
MAX_TRANSFORM_STEPS
}
steps."
)
def
apply
(
self
,
tensor_with_sharding_spec
,
target_spec
)
:
def
apply
(
self
,
tensor_with_sharding_spec
:
torch
.
Tensor
,
target_spec
:
ShardingSpec
)
->
torch
.
Tensor
:
'''
Apply target_spec to tensor with source sharding spec, the transform path is generated by the
shape_consistency method.
...
...
Prev
1
…
10
11
12
13
14
15
16
17
18
…
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment