Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
36824a30
Unverified
Commit
36824a30
authored
Aug 16, 2022
by
Jiarui Fang
Committed by
GitHub
Aug 16, 2022
Browse files
[Doc] add more doc for ColoTensor. (#1458)
parent
a1476ea8
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
46 additions
and
18 deletions
+46
-18
colossalai/fx/passes/shard_1d_pass.py
colossalai/fx/passes/shard_1d_pass.py
+4
-4
colossalai/tensor/__init__.py
colossalai/tensor/__init__.py
+4
-4
colossalai/tensor/distspec.py
colossalai/tensor/distspec.py
+37
-9
tests/test_utils/test_norm_gradient_clipping.py
tests/test_utils/test_norm_gradient_clipping.py
+1
-1
No files found.
colossalai/fx/passes/shard_1d_pass.py
View file @
36824a30
...
...
@@ -2,7 +2,7 @@ import torch
import
torch.nn
as
nn
import
operator
from
colossalai.tensor
import
ProcessGroup
from
colossalai.tensor.distspec
import
s
hard
from
colossalai.tensor.distspec
import
S
hard
Spec
from
colossalai.tensor.compute_spec
import
ComputePattern
,
ComputeSpec
ELEMENTWISE_MODULE_OP
=
[
torch
.
nn
.
Dropout
,
torch
.
nn
.
ReLU
]
...
...
@@ -85,13 +85,13 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc
for
shard_type
,
module
in
annotation_record
.
items
():
# add row sharding spec
if
shard_type
==
'row'
:
dist_spec
=
s
hard
(
dims
=
[
-
1
],
num_partitions
=
[
world_size
])
dist_spec
=
S
hard
Spec
(
dims
=
[
-
1
],
num_partitions
=
[
world_size
])
comp_spec
=
ComputeSpec
(
ComputePattern
.
TP1D
)
setattr
(
module
.
weight
,
'pg'
,
process_group
)
setattr
(
module
.
weight
,
'dist_spec'
,
dist_spec
)
setattr
(
module
.
weight
,
'comp_spec'
,
comp_spec
)
elif
shard_type
==
'col'
:
weight_dist_spec
=
s
hard
(
dims
=
[
0
],
num_partitions
=
[
world_size
])
weight_dist_spec
=
S
hard
Spec
(
dims
=
[
0
],
num_partitions
=
[
world_size
])
weight_comp_spec
=
ComputeSpec
(
ComputePattern
.
TP1D
)
weight_comp_spec
.
output_replicate
=
False
setattr
(
module
.
weight
,
'pg'
,
process_group
)
...
...
@@ -99,7 +99,7 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc
setattr
(
module
.
weight
,
'comp_spec'
,
weight_comp_spec
)
if
module
.
bias
is
not
None
:
bias_dist_spec
=
s
hard
(
dims
=
[
0
],
num_partitions
=
[
world_size
])
bias_dist_spec
=
S
hard
Spec
(
dims
=
[
0
],
num_partitions
=
[
world_size
])
bias_comp_spec
=
ComputeSpec
(
ComputePattern
.
TP1D
)
bias_comp_spec
.
output_replicate
=
False
setattr
(
module
.
bias
,
'pg'
,
process_group
)
...
...
colossalai/tensor/__init__.py
View file @
36824a30
from
.process_group
import
ProcessGroup
from
.tensor_spec
import
ColoTensorSpec
from
.distspec
import
shard
as
ShardSpec
from
.distspec
import
replicate
as
ReplicaSpec
from
.distspec
import
ShardSpec
from
.distspec
import
ReplicaSpec
from
.compute_spec
import
ComputeSpec
,
ComputePattern
from
.colo_tensor
import
ColoTensor
...
...
@@ -13,6 +13,6 @@ from . import distspec
__all__
=
[
'ColoTensor'
,
'convert_parameter'
,
'ComputePattern'
,
'ComputeSpec'
,
'named_params_with_colotensor'
,
'ColoParameter'
,
'distspec'
,
'DistSpecManager'
,
'ParamOpHook'
,
'ParamOpHookManager'
,
'ProcessGroup'
,
'ColoTensorSpec'
,
'ShardSpec'
,
'ReplicaSpec'
'distspec'
,
'DistSpecManager'
,
'ParamOpHook'
,
'ParamOpHookManager'
,
'ProcessGroup'
,
'ColoTensorSpec'
,
'ShardSpec'
,
'ReplicaSpec'
]
colossalai/tensor/distspec.py
View file @
36824a30
from
enum
import
Enum
from
typing
import
List
__all__
=
[
'
r
eplica
te
'
,
'
s
hard'
]
__all__
=
[
'
R
eplica
Spec
'
,
'
S
hard
Spec
'
]
class
DistPlacementPattern
(
Enum
):
...
...
@@ -10,15 +10,22 @@ class DistPlacementPattern(Enum):
class
_DistSpec
:
"""_DistSpec
A class indicates Distributed Specification.
The DistSpec is only works for the tensor parallel process groups.
Because the dist spec of data parallel process group can be automatically deduced.
This is an internal data structrue.
The API for users should be `ShardSpec` and `ReplicaSpec`.
Args:
dist_placement_pattern (DistPlacementPattern): the pattern describing how tensors are distributed among processes.
The dist_placement_pattern is picked from a limited set, now including two patterns: replicate and shard.
process_group (Optional[ProcessGroup], optional): the process group contains processes. Defaults to None.
"""
def
__init__
(
self
,
dist_placement_pattern
:
DistPlacementPattern
,
**
meta_info
):
"""_DistSpec, Distributed Specification
Args:
dist_placement_pattern (DistPlacementPattern): the pattern describing how tensors are distributed among processes.
The dist_placement_pattern is picked from a limited set, now including two patterns: replicate and shard.
process_group (Optional[ProcessGroup], optional): the process group contains processes. Defaults to None.
"""
self
.
placement
=
dist_placement_pattern
for
k
,
v
in
meta_info
.
items
():
setattr
(
self
,
k
,
v
)
...
...
@@ -39,11 +46,32 @@ class _DistSpec:
return
''
.
join
(
res_list
)
def
replicate
()
->
_DistSpec
:
def
ReplicaSpec
()
->
_DistSpec
:
"""ReplicaSpec
A distributed specification represents the tensor is replicated among the tensor parallel process group.
Returns:
_DistSpec: an replicated dist spec instance.
"""
return
_DistSpec
(
DistPlacementPattern
.
REPLICATE
)
def
shard
(
dims
:
List
[
int
],
num_partitions
:
List
[
int
])
->
_DistSpec
:
def
ShardSpec
(
dims
:
List
[
int
],
num_partitions
:
List
[
int
])
->
_DistSpec
:
"""ShardSpec
A distributed specification represents the tensor is sharded among the tensor parallel process group.
Note:
Currently, only shard on one dimension is valid. In another word, dims should be of size 1.
Args:
dims (List[int]): a list of dimensions
num_partitions (List[int]): a list of partition number of each dimensions.
Returns:
_DistSpec: an shard dist spec instance.
"""
assert
isinstance
(
dims
,
list
)
and
isinstance
(
num_partitions
,
list
)
assert
len
(
dims
)
==
len
(
num_partitions
)
return
_DistSpec
(
DistPlacementPattern
.
SHARD
,
dims
=
tuple
(
dims
),
num_partitions
=
tuple
(
num_partitions
))
tests/test_utils/test_norm_gradient_clipping.py
View file @
36824a30
...
...
@@ -19,7 +19,7 @@ def close(num: float, other: float, rtol: float = 1e-5, atol: float = 1e-8):
def
shard_param
(
p
:
ColoParameter
)
->
None
:
pg
=
p
.
get_process_group
()
p
.
_redistribute
(
distspec
.
s
hard
([
0
],
[
pg
.
tp_world_size
()]))
p
.
_redistribute
(
distspec
.
S
hard
Spec
([
0
],
[
pg
.
tp_world_size
()]))
p
.
grad
=
p
.
grad
.
chunk
(
pg
.
tp_world_size
(),
0
)[
pg
.
tp_local_rank
()].
clone
().
detach
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment