Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
552183bb
Unverified
Commit
552183bb
authored
Feb 03, 2023
by
HELSON
Committed by
GitHub
Feb 03, 2023
Browse files
[polish] polish ColoTensor and its submodules (#2537)
parent
51d4d6e7
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
75 additions
and
65 deletions
+75
-65
colossalai/tensor/colo_parameter.py
colossalai/tensor/colo_parameter.py
+1
-1
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+6
-1
colossalai/tensor/compute_spec.py
colossalai/tensor/compute_spec.py
+3
-3
colossalai/tensor/distspec.py
colossalai/tensor/distspec.py
+5
-4
colossalai/tensor/process_group.py
colossalai/tensor/process_group.py
+37
-29
colossalai/utils/model/colo_init_context.py
colossalai/utils/model/colo_init_context.py
+23
-27
No files found.
colossalai/tensor/colo_parameter.py
View file @
552183bb
...
...
@@ -71,7 +71,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
return
tensor
def
__repr__
(
self
):
return
f
'
ColoParameter
:
{
ColoTensor
.
__repr__
(
self
)
}
'
return
super
(
ColoParameter
,
self
)
.
__repr__
(
)
@
classmethod
def
__torch_function__
(
cls
,
func
,
types
,
args
=
...,
kwargs
=
None
):
...
...
colossalai/tensor/colo_tensor.py
View file @
552183bb
...
...
@@ -189,7 +189,12 @@ class ColoTensor(torch.Tensor):
return
_convert_output
(
ret
,
colo_spec
)
def
__repr__
(
self
):
return
f
'ColoTensor:
\n
{
super
().
__repr__
()
}
\n
{
self
.
dist_spec
}
\n
{
self
.
process_group
}
\n
{
self
.
compute_spec
}
'
output_list
=
[
super
(
ColoTensor
,
self
).
__repr__
()]
output_list
.
append
(
str
(
self
.
process_group
))
output_list
.
append
(
str
(
self
.
dist_spec
))
if
self
.
compute_spec
is
not
None
:
output_list
.
append
(
str
(
self
.
compute_spec
))
return
"
\n
"
.
join
(
output_list
)
def
_redistribute
(
self
,
dist_spec
:
_DistSpec
)
->
None
:
"""_redistribute
...
...
colossalai/tensor/compute_spec.py
View file @
552183bb
...
...
@@ -9,9 +9,9 @@ class ComputePattern(Enum):
class
ComputeSpec
(
object
):
"""ComputeSpec
"""ComputeSpec
The Specification for compuattion pattern
Args:
compute_pattern (ComputePattern): an Enum instance for compute pattern.
"""
...
...
@@ -23,7 +23,7 @@ class ComputeSpec(object):
self
.
output_replicate
=
True
def
__repr__
(
self
):
return
f
'Compute
pattern
:
{
self
.
compute_pattern
}
'
return
f
'Compute
Spec(
pattern
=
{
self
.
compute_pattern
}
, replicate_output=
{
self
.
output_replicate
}
)
'
def
set_output_replicate
(
self
,
flag
:
bool
=
True
):
self
.
output_replicate
=
flag
colossalai/tensor/distspec.py
View file @
552183bb
...
...
@@ -11,7 +11,7 @@ class DistPlacementPattern(Enum):
class
_DistSpec
:
"""_DistSpec
A class indicates Distributed Specification.
The DistSpec is only works for the tensor parallel process groups.
Because the dist spec of data parallel process group can be automatically deduced.
...
...
@@ -39,11 +39,12 @@ class _DistSpec:
return
True
def
__repr__
(
self
)
->
str
:
res
_list
=
[
"DistSpec:"
]
attr
_list
=
[]
for
attr
in
dir
(
self
):
if
not
attr
.
startswith
(
'__'
):
res_list
.
append
(
f
'
\n\t
{
attr
}
:
{
str
(
getattr
(
self
,
attr
))
}
'
)
return
''
.
join
(
res_list
)
attr_list
.
append
(
f
'
{
attr
}
=
{
str
(
getattr
(
self
,
attr
))
}
'
)
attr_str
=
", "
.
join
(
attr_list
)
return
"DistSpec("
+
attr_str
+
")"
def
ReplicaSpec
()
->
_DistSpec
:
...
...
colossalai/tensor/process_group.py
View file @
552183bb
import
torch
from
typing
import
List
,
Optional
from
colossalai.logging
import
get_dist_logger
import
torch
from
colossalai.context.singleton_meta
import
SingletonMeta
from
colossalai.logging
import
get_dist_logger
class
PyTorchProcessGroupDict
(
metaclass
=
SingletonMeta
):
def
__init__
(
self
):
# distributed settings
# use this dict to record all Pytorch ProcessGroups
self
.
dict
=
{}
# set a distributed logger
self
.
logger
=
get_dist_logger
(
'ProcessGroup'
)
def
log_pg_init
(
self
,
rank_list
:
List
[
int
],
backend
:
str
):
str_list
=
[
"Pytorch ProcessGroup Init:"
]
str_list
.
append
(
f
"backend:
{
backend
}
"
)
str_list
.
append
(
f
"ranks:
{
rank_list
}
"
)
self
.
logger
.
info
(
"
\n\t
"
.
join
(
str_list
),
ranks
=
[
0
])
def
get
(
self
,
rank_list
:
List
[
int
],
backend
:
str
=
'nccl'
):
"""Reuse Pytorch ProcessGroup when such a group is initialized
"""
rank_tuple
=
tuple
(
rank_list
)
# we need to convert the passed list to a tuple
# since List is unhashable
pg_key
=
(
backend
,
rank_tuple
)
if
pg_key
not
in
self
.
dict
:
self
.
logger
=
get_dist_logger
(
'ProcessGroup'
)
self
.
logger
.
info
(
f
'NCCL initialize ProcessGroup on
{
rank_list
}
'
,
ranks
=
[
0
])
self
.
dict
[
pg_key
]
=
torch
.
distributed
.
new_group
(
ranks
=
rank_list
,
backend
=
backend
)
return
self
.
dict
[
pg_key
]
processgroup_key
=
(
backend
,
tuple
(
rank_list
))
if
processgroup_key
not
in
self
.
dict
:
self
.
log_pg_init
(
rank_list
=
rank_list
,
backend
=
backend
)
self
.
dict
[
processgroup_key
]
=
torch
.
distributed
.
new_group
(
ranks
=
rank_list
,
backend
=
backend
)
return
self
.
dict
[
processgroup_key
]
PYTORCHPGDICT_
=
PyTorchProcessGroupDict
()
...
...
@@ -40,7 +47,7 @@ class ProcessGroup:
rank: the global rank of the current process.
ranks: List[int], a list of rank id belongings to this process group.
backend: str, the backend of the process group.
tp_degree: Optional[int], tensor parallelism degree. How many processes are inside a tp process group. default None means 1.
tp_degree: Optional[int], tensor parallelism degree. How many processes are inside a tp process group. default None means 1.
dp_degree: Optional[int], data parallelism degree. How many processes are inside a dp process group. . default None means len(ranks).
"""
...
...
@@ -54,10 +61,10 @@ class ProcessGroup:
return
assert
torch
.
distributed
.
is_initialized
(),
f
"ProcessGroup must be used after distributed initialized"
if
rank
is
None
:
self
.
_rank
=
torch
.
distributed
.
get_rank
()
els
e
:
self
.
_rank
=
rank
self
.
_rank
=
torch
.
distributed
.
get_rank
()
if
rank
is
not
Non
e
:
assert
self
.
_rank
=
=
rank
# make sure that the global rank is correct
if
ranks
is
None
:
self
.
_rank_list
=
list
(
range
(
torch
.
distributed
.
get_world_size
()))
...
...
@@ -104,7 +111,7 @@ class ProcessGroup:
self
.
is_init
=
True
def
set_cpu_groups
(
self
):
"""set_cpu_groups
"""set_cpu_groups
Initialize Pytorch process groups for cpu communications.
"""
if
self
.
has_cpu_groups
:
...
...
@@ -122,7 +129,7 @@ class ProcessGroup:
@
property
def
has_cpu_groups
(
self
)
->
bool
:
"""has_cpu_groups
"""has_cpu_groups
If cpu groups have been initailized.
Returns:
...
...
@@ -132,8 +139,9 @@ class ProcessGroup:
def
__repr__
(
self
):
if
self
.
is_init
:
return
"ProcessGroup:
\n\t
Rank: {}, World size: {}, DP degree: {}, TP degree: {}
\n\t
Ranks in group: {}"
.
\
format
(
self
.
_rank
,
self
.
_world_size
,
self
.
_dp_degree
,
self
.
_tp_degree
,
self
.
_rank_list
)
ranks_str
=
f
"ProcessGroup(ranks=
{
self
.
_rank_list
}
,
\n
"
personal_str
=
f
" rank=
{
self
.
_rank
}
, dp=
{
self
.
_dp_degree
}
, tp=
{
self
.
_tp_degree
}
)"
return
ranks_str
+
personal_str
else
:
return
"ProcessGroup not initialized"
...
...
@@ -155,7 +163,7 @@ class ProcessGroup:
return
True
def
rank
(
self
)
->
int
:
"""rank
"""rank
The current rank in the global process group.
...
...
@@ -165,9 +173,9 @@ class ProcessGroup:
return
self
.
_rank
def
ranks_in_group
(
self
)
->
List
[
int
]:
"""ranks_in_group
"""ranks_in_group
a list of rank number in in the global process group.
a list of rank number in in the global process group.
Returns:
List[int]: a list of rank number.
...
...
@@ -177,7 +185,7 @@ class ProcessGroup:
def
world_size
(
self
)
->
int
:
"""world_size
The world size of the global process group.
The world size of the global process group.
Returns:
int: world size
...
...
@@ -185,7 +193,7 @@ class ProcessGroup:
return
self
.
_world_size
def
tp_rank_list
(
self
)
->
List
[
int
]:
"""tp_rank_list
"""tp_rank_list
the rank list in the TP process group containing the current rank.
...
...
@@ -195,7 +203,7 @@ class ProcessGroup:
return
self
.
_tp_rank_list
def
dp_rank_list
(
self
)
->
List
[
int
]:
"""dp_rank_list
"""dp_rank_list
the rank list in the DP process group containing the current rank.
...
...
@@ -205,7 +213,7 @@ class ProcessGroup:
return
self
.
_dp_rank_list
def
tp_local_rank
(
self
)
->
int
:
"""tp_local_rank
"""tp_local_rank
The local rank number in the current TP process group.
...
...
@@ -268,7 +276,7 @@ class ProcessGroup:
"""cpu_dp_process_group
the pytorch CPU DP process group containing the current rank.
assert failed if cpu process group is not initialized.
Returns:
...
...
@@ -281,7 +289,7 @@ class ProcessGroup:
"""cpu_tp_process_group
the pytorch CPU TP process group containing the current rank.
assert failed if cpu process group is not initialized.
Returns:
...
...
colossalai/utils/model/colo_init_context.py
View file @
552183bb
...
...
@@ -37,12 +37,11 @@ def _convert_to_coloparam(param: torch.nn.Parameter,
# detaching tensor is necessary for optimizers.
requires_grad
=
param
.
requires_grad
# param is the global tensor.
if
param
.
device
.
type
==
"meta"
:
colo_param
=
ColoParameter
(
param
,
requires_grad
=
requires_grad
)
else
:
else
:
colo_param
=
ColoParameter
(
param
.
to
(
device
=
device
,
dtype
=
dtype
),
requires_grad
=
requires_grad
)
# if default_shard_plan exists, shard the param during initialization.
# This can reduce the model size after initialization.
...
...
@@ -129,32 +128,29 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
delattr
(
submodule
,
param_name
)
setattr
(
submodule
,
param_name
,
colo_param
)
colo_param
.
shared_param_modules
.
append
(
submodule
)
meta_param_flag
=
0
meta_buffer_flag
=
0
param_number
=
0
meta_param_number
=
0
buffer_number
=
0
meta_buffer_number
=
0
for
param
in
module
.
parameters
():
if
param
.
device
.
type
==
"meta"
:
meta_param_flag
=
1
if
meta_param_flag
==
1
and
param
.
device
.
type
!=
"meta"
:
raise
ValueError
(
"Meta parameters and valued parameters can not be in the same model"
)
param_number
+=
1
meta_param_number
+=
(
param
.
device
.
type
==
'meta'
)
for
buffer
in
module
.
buffers
():
if
buffer
.
device
.
type
==
"meta"
:
meta_buffer_flag
=
1
if
meta_buffer_flag
==
1
and
buffer
.
device
.
type
!=
"meta"
:
raise
ValueError
(
"Meta buffers and valued buffers can not be in the same model"
)
if
meta_param_flag
==
1
and
meta_buffer_flag
==
1
:
pass
elif
meta_buffer_flag
==
0
and
meta_param_flag
==
1
:
for
name
,
buf
in
module
.
named_buffers
():
module
.
_buffers
[
name
]
=
module
.
_buffers
[
name
].
to
(
device
=
self
.
_device
)
elif
meta_param_flag
==
0
and
meta_buffer_flag
==
1
:
for
name
,
param
in
module
.
named_parameters
():
module
.
_parameters
[
name
]
=
module
.
_parameters
[
name
].
to
(
device
=
self
.
_device
)
else
:
module
.
to
(
self
.
_device
)
buffer_number
+=
1
meta_buffer_number
+=
(
buffer
.
device
.
type
==
'meta'
)
if
meta_param_number
>
0
and
meta_param_number
!=
param_number
:
raise
ValueError
(
"Meta parameters and valued parameters can not be in the same model"
)
if
meta_buffer_number
>
0
and
meta_buffer_number
!=
buffer_number
:
raise
ValueError
(
"Meta buffers and valued buffers can not be in the same model"
)
if
meta_buffer_number
==
0
:
for
buffer
in
module
.
buffers
():
buffer
.
data
=
buffer
.
data
.
to
(
device
=
self
.
_device
)
def
post_process_colo_init_ctx
(
model
:
torch
.
nn
.
Module
,
device
:
torch
.
device
=
torch
.
device
(
'cpu'
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment