Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
552183bb
Unverified
Commit
552183bb
authored
Feb 03, 2023
by
HELSON
Committed by
GitHub
Feb 03, 2023
Browse files
[polish] polish ColoTensor and its submodules (#2537)
parent
51d4d6e7
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
75 additions
and
65 deletions
+75
-65
colossalai/tensor/colo_parameter.py
colossalai/tensor/colo_parameter.py
+1
-1
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+6
-1
colossalai/tensor/compute_spec.py
colossalai/tensor/compute_spec.py
+3
-3
colossalai/tensor/distspec.py
colossalai/tensor/distspec.py
+5
-4
colossalai/tensor/process_group.py
colossalai/tensor/process_group.py
+37
-29
colossalai/utils/model/colo_init_context.py
colossalai/utils/model/colo_init_context.py
+23
-27
No files found.
colossalai/tensor/colo_parameter.py
View file @
552183bb
...
@@ -71,7 +71,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
...
@@ -71,7 +71,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
return
tensor
return
tensor
def
__repr__
(
self
):
def
__repr__
(
self
):
return
f
'
ColoParameter
:
{
ColoTensor
.
__repr__
(
self
)
}
'
return
super
(
ColoParameter
,
self
)
.
__repr__
(
)
@
classmethod
@
classmethod
def
__torch_function__
(
cls
,
func
,
types
,
args
=
...,
kwargs
=
None
):
def
__torch_function__
(
cls
,
func
,
types
,
args
=
...,
kwargs
=
None
):
...
...
colossalai/tensor/colo_tensor.py
View file @
552183bb
...
@@ -189,7 +189,12 @@ class ColoTensor(torch.Tensor):
...
@@ -189,7 +189,12 @@ class ColoTensor(torch.Tensor):
return
_convert_output
(
ret
,
colo_spec
)
return
_convert_output
(
ret
,
colo_spec
)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
f
'ColoTensor:
\n
{
super
().
__repr__
()
}
\n
{
self
.
dist_spec
}
\n
{
self
.
process_group
}
\n
{
self
.
compute_spec
}
'
output_list
=
[
super
(
ColoTensor
,
self
).
__repr__
()]
output_list
.
append
(
str
(
self
.
process_group
))
output_list
.
append
(
str
(
self
.
dist_spec
))
if
self
.
compute_spec
is
not
None
:
output_list
.
append
(
str
(
self
.
compute_spec
))
return
"
\n
"
.
join
(
output_list
)
def
_redistribute
(
self
,
dist_spec
:
_DistSpec
)
->
None
:
def
_redistribute
(
self
,
dist_spec
:
_DistSpec
)
->
None
:
"""_redistribute
"""_redistribute
...
...
colossalai/tensor/compute_spec.py
View file @
552183bb
...
@@ -9,9 +9,9 @@ class ComputePattern(Enum):
...
@@ -9,9 +9,9 @@ class ComputePattern(Enum):
class
ComputeSpec
(
object
):
class
ComputeSpec
(
object
):
"""ComputeSpec
"""ComputeSpec
The Specification for compuattion pattern
The Specification for compuattion pattern
Args:
Args:
compute_pattern (ComputePattern): an Enum instance for compute pattern.
compute_pattern (ComputePattern): an Enum instance for compute pattern.
"""
"""
...
@@ -23,7 +23,7 @@ class ComputeSpec(object):
...
@@ -23,7 +23,7 @@ class ComputeSpec(object):
self
.
output_replicate
=
True
self
.
output_replicate
=
True
def
__repr__
(
self
):
def
__repr__
(
self
):
return
f
'Compute
pattern
:
{
self
.
compute_pattern
}
'
return
f
'Compute
Spec(
pattern
=
{
self
.
compute_pattern
}
, replicate_output=
{
self
.
output_replicate
}
)
'
def
set_output_replicate
(
self
,
flag
:
bool
=
True
):
def
set_output_replicate
(
self
,
flag
:
bool
=
True
):
self
.
output_replicate
=
flag
self
.
output_replicate
=
flag
colossalai/tensor/distspec.py
View file @
552183bb
...
@@ -11,7 +11,7 @@ class DistPlacementPattern(Enum):
...
@@ -11,7 +11,7 @@ class DistPlacementPattern(Enum):
class
_DistSpec
:
class
_DistSpec
:
"""_DistSpec
"""_DistSpec
A class indicates Distributed Specification.
A class indicates Distributed Specification.
The DistSpec is only works for the tensor parallel process groups.
The DistSpec is only works for the tensor parallel process groups.
Because the dist spec of data parallel process group can be automatically deduced.
Because the dist spec of data parallel process group can be automatically deduced.
...
@@ -39,11 +39,12 @@ class _DistSpec:
...
@@ -39,11 +39,12 @@ class _DistSpec:
return
True
return
True
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
res
_list
=
[
"DistSpec:"
]
attr
_list
=
[]
for
attr
in
dir
(
self
):
for
attr
in
dir
(
self
):
if
not
attr
.
startswith
(
'__'
):
if
not
attr
.
startswith
(
'__'
):
res_list
.
append
(
f
'
\n\t
{
attr
}
:
{
str
(
getattr
(
self
,
attr
))
}
'
)
attr_list
.
append
(
f
'
{
attr
}
=
{
str
(
getattr
(
self
,
attr
))
}
'
)
return
''
.
join
(
res_list
)
attr_str
=
", "
.
join
(
attr_list
)
return
"DistSpec("
+
attr_str
+
")"
def
ReplicaSpec
()
->
_DistSpec
:
def
ReplicaSpec
()
->
_DistSpec
:
...
...
colossalai/tensor/process_group.py
View file @
552183bb
import
torch
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
from
colossalai.logging
import
get_dist_logger
import
torch
from
colossalai.context.singleton_meta
import
SingletonMeta
from
colossalai.context.singleton_meta
import
SingletonMeta
from
colossalai.logging
import
get_dist_logger
class
PyTorchProcessGroupDict
(
metaclass
=
SingletonMeta
):
class
PyTorchProcessGroupDict
(
metaclass
=
SingletonMeta
):
def
__init__
(
self
):
def
__init__
(
self
):
# distributed settings
# distributed settings
# use this dict to record all Pytorch ProcessGroups
self
.
dict
=
{}
self
.
dict
=
{}
# set a distributed logger
self
.
logger
=
get_dist_logger
(
'ProcessGroup'
)
def
log_pg_init
(
self
,
rank_list
:
List
[
int
],
backend
:
str
):
str_list
=
[
"Pytorch ProcessGroup Init:"
]
str_list
.
append
(
f
"backend:
{
backend
}
"
)
str_list
.
append
(
f
"ranks:
{
rank_list
}
"
)
self
.
logger
.
info
(
"
\n\t
"
.
join
(
str_list
),
ranks
=
[
0
])
def
get
(
self
,
rank_list
:
List
[
int
],
backend
:
str
=
'nccl'
):
def
get
(
self
,
rank_list
:
List
[
int
],
backend
:
str
=
'nccl'
):
"""Reuse Pytorch ProcessGroup when such a group is initialized
"""Reuse Pytorch ProcessGroup when such a group is initialized
"""
"""
rank_tuple
=
tuple
(
rank_list
)
# we need to convert the passed list to a tuple
# we need to convert the passed list to a tuple
# since List is unhashable
# since List is unhashable
pg_key
=
(
backend
,
rank_tuple
)
processgroup_key
=
(
backend
,
tuple
(
rank_list
))
if
processgroup_key
not
in
self
.
dict
:
if
pg_key
not
in
self
.
dict
:
self
.
log_pg_init
(
rank_list
=
rank_list
,
backend
=
backend
)
self
.
dict
[
processgroup_key
]
=
torch
.
distributed
.
new_group
(
ranks
=
rank_list
,
backend
=
backend
)
self
.
logger
=
get_dist_logger
(
'ProcessGroup'
)
return
self
.
dict
[
processgroup_key
]
self
.
logger
.
info
(
f
'NCCL initialize ProcessGroup on
{
rank_list
}
'
,
ranks
=
[
0
])
self
.
dict
[
pg_key
]
=
torch
.
distributed
.
new_group
(
ranks
=
rank_list
,
backend
=
backend
)
return
self
.
dict
[
pg_key
]
PYTORCHPGDICT_
=
PyTorchProcessGroupDict
()
PYTORCHPGDICT_
=
PyTorchProcessGroupDict
()
...
@@ -40,7 +47,7 @@ class ProcessGroup:
...
@@ -40,7 +47,7 @@ class ProcessGroup:
rank: the global rank of the current process.
rank: the global rank of the current process.
ranks: List[int], a list of rank id belongings to this process group.
ranks: List[int], a list of rank id belongings to this process group.
backend: str, the backend of the process group.
backend: str, the backend of the process group.
tp_degree: Optional[int], tensor parallelism degree. How many processes are inside a tp process group. default None means 1.
tp_degree: Optional[int], tensor parallelism degree. How many processes are inside a tp process group. default None means 1.
dp_degree: Optional[int], data parallelism degree. How many processes are inside a dp process group. . default None means len(ranks).
dp_degree: Optional[int], data parallelism degree. How many processes are inside a dp process group. . default None means len(ranks).
"""
"""
...
@@ -54,10 +61,10 @@ class ProcessGroup:
...
@@ -54,10 +61,10 @@ class ProcessGroup:
return
return
assert
torch
.
distributed
.
is_initialized
(),
f
"ProcessGroup must be used after distributed initialized"
assert
torch
.
distributed
.
is_initialized
(),
f
"ProcessGroup must be used after distributed initialized"
if
rank
is
None
:
self
.
_rank
=
torch
.
distributed
.
get_rank
()
self
.
_rank
=
torch
.
distributed
.
get_rank
()
els
e
:
if
rank
is
not
Non
e
:
self
.
_rank
=
rank
assert
self
.
_rank
=
=
rank
# make sure that the global rank is correct
if
ranks
is
None
:
if
ranks
is
None
:
self
.
_rank_list
=
list
(
range
(
torch
.
distributed
.
get_world_size
()))
self
.
_rank_list
=
list
(
range
(
torch
.
distributed
.
get_world_size
()))
...
@@ -104,7 +111,7 @@ class ProcessGroup:
...
@@ -104,7 +111,7 @@ class ProcessGroup:
self
.
is_init
=
True
self
.
is_init
=
True
def
set_cpu_groups
(
self
):
def
set_cpu_groups
(
self
):
"""set_cpu_groups
"""set_cpu_groups
Initialize Pytorch process groups for cpu communications.
Initialize Pytorch process groups for cpu communications.
"""
"""
if
self
.
has_cpu_groups
:
if
self
.
has_cpu_groups
:
...
@@ -122,7 +129,7 @@ class ProcessGroup:
...
@@ -122,7 +129,7 @@ class ProcessGroup:
@
property
@
property
def
has_cpu_groups
(
self
)
->
bool
:
def
has_cpu_groups
(
self
)
->
bool
:
"""has_cpu_groups
"""has_cpu_groups
If cpu groups have been initailized.
If cpu groups have been initailized.
Returns:
Returns:
...
@@ -132,8 +139,9 @@ class ProcessGroup:
...
@@ -132,8 +139,9 @@ class ProcessGroup:
def
__repr__
(
self
):
def
__repr__
(
self
):
if
self
.
is_init
:
if
self
.
is_init
:
return
"ProcessGroup:
\n\t
Rank: {}, World size: {}, DP degree: {}, TP degree: {}
\n\t
Ranks in group: {}"
.
\
ranks_str
=
f
"ProcessGroup(ranks=
{
self
.
_rank_list
}
,
\n
"
format
(
self
.
_rank
,
self
.
_world_size
,
self
.
_dp_degree
,
self
.
_tp_degree
,
self
.
_rank_list
)
personal_str
=
f
" rank=
{
self
.
_rank
}
, dp=
{
self
.
_dp_degree
}
, tp=
{
self
.
_tp_degree
}
)"
return
ranks_str
+
personal_str
else
:
else
:
return
"ProcessGroup not initialized"
return
"ProcessGroup not initialized"
...
@@ -155,7 +163,7 @@ class ProcessGroup:
...
@@ -155,7 +163,7 @@ class ProcessGroup:
return
True
return
True
def
rank
(
self
)
->
int
:
def
rank
(
self
)
->
int
:
"""rank
"""rank
The current rank in the global process group.
The current rank in the global process group.
...
@@ -165,9 +173,9 @@ class ProcessGroup:
...
@@ -165,9 +173,9 @@ class ProcessGroup:
return
self
.
_rank
return
self
.
_rank
def
ranks_in_group
(
self
)
->
List
[
int
]:
def
ranks_in_group
(
self
)
->
List
[
int
]:
"""ranks_in_group
"""ranks_in_group
a list of rank number in in the global process group.
a list of rank number in in the global process group.
Returns:
Returns:
List[int]: a list of rank number.
List[int]: a list of rank number.
...
@@ -177,7 +185,7 @@ class ProcessGroup:
...
@@ -177,7 +185,7 @@ class ProcessGroup:
def
world_size
(
self
)
->
int
:
def
world_size
(
self
)
->
int
:
"""world_size
"""world_size
The world size of the global process group.
The world size of the global process group.
Returns:
Returns:
int: world size
int: world size
...
@@ -185,7 +193,7 @@ class ProcessGroup:
...
@@ -185,7 +193,7 @@ class ProcessGroup:
return
self
.
_world_size
return
self
.
_world_size
def
tp_rank_list
(
self
)
->
List
[
int
]:
def
tp_rank_list
(
self
)
->
List
[
int
]:
"""tp_rank_list
"""tp_rank_list
the rank list in the TP process group containing the current rank.
the rank list in the TP process group containing the current rank.
...
@@ -195,7 +203,7 @@ class ProcessGroup:
...
@@ -195,7 +203,7 @@ class ProcessGroup:
return
self
.
_tp_rank_list
return
self
.
_tp_rank_list
def
dp_rank_list
(
self
)
->
List
[
int
]:
def
dp_rank_list
(
self
)
->
List
[
int
]:
"""dp_rank_list
"""dp_rank_list
the rank list in the DP process group containing the current rank.
the rank list in the DP process group containing the current rank.
...
@@ -205,7 +213,7 @@ class ProcessGroup:
...
@@ -205,7 +213,7 @@ class ProcessGroup:
return
self
.
_dp_rank_list
return
self
.
_dp_rank_list
def
tp_local_rank
(
self
)
->
int
:
def
tp_local_rank
(
self
)
->
int
:
"""tp_local_rank
"""tp_local_rank
The local rank number in the current TP process group.
The local rank number in the current TP process group.
...
@@ -268,7 +276,7 @@ class ProcessGroup:
...
@@ -268,7 +276,7 @@ class ProcessGroup:
"""cpu_dp_process_group
"""cpu_dp_process_group
the pytorch CPU DP process group containing the current rank.
the pytorch CPU DP process group containing the current rank.
assert failed if cpu process group is not initialized.
assert failed if cpu process group is not initialized.
Returns:
Returns:
...
@@ -281,7 +289,7 @@ class ProcessGroup:
...
@@ -281,7 +289,7 @@ class ProcessGroup:
"""cpu_tp_process_group
"""cpu_tp_process_group
the pytorch CPU TP process group containing the current rank.
the pytorch CPU TP process group containing the current rank.
assert failed if cpu process group is not initialized.
assert failed if cpu process group is not initialized.
Returns:
Returns:
...
...
colossalai/utils/model/colo_init_context.py
View file @
552183bb
...
@@ -37,12 +37,11 @@ def _convert_to_coloparam(param: torch.nn.Parameter,
...
@@ -37,12 +37,11 @@ def _convert_to_coloparam(param: torch.nn.Parameter,
# detaching tensor is necessary for optimizers.
# detaching tensor is necessary for optimizers.
requires_grad
=
param
.
requires_grad
requires_grad
=
param
.
requires_grad
# param is the global tensor.
# param is the global tensor.
if
param
.
device
.
type
==
"meta"
:
if
param
.
device
.
type
==
"meta"
:
colo_param
=
ColoParameter
(
param
,
requires_grad
=
requires_grad
)
colo_param
=
ColoParameter
(
param
,
requires_grad
=
requires_grad
)
else
:
else
:
colo_param
=
ColoParameter
(
param
.
to
(
device
=
device
,
dtype
=
dtype
),
requires_grad
=
requires_grad
)
colo_param
=
ColoParameter
(
param
.
to
(
device
=
device
,
dtype
=
dtype
),
requires_grad
=
requires_grad
)
# if default_shard_plan exists, shard the param during initialization.
# if default_shard_plan exists, shard the param during initialization.
# This can reduce the model size after initialization.
# This can reduce the model size after initialization.
...
@@ -129,32 +128,29 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -129,32 +128,29 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
delattr
(
submodule
,
param_name
)
delattr
(
submodule
,
param_name
)
setattr
(
submodule
,
param_name
,
colo_param
)
setattr
(
submodule
,
param_name
,
colo_param
)
colo_param
.
shared_param_modules
.
append
(
submodule
)
colo_param
.
shared_param_modules
.
append
(
submodule
)
meta_param_flag
=
0
param_number
=
0
meta_buffer_flag
=
0
meta_param_number
=
0
buffer_number
=
0
meta_buffer_number
=
0
for
param
in
module
.
parameters
():
for
param
in
module
.
parameters
():
if
param
.
device
.
type
==
"meta"
:
param_number
+=
1
meta_param_flag
=
1
meta_param_number
+=
(
param
.
device
.
type
==
'meta'
)
if
meta_param_flag
==
1
and
param
.
device
.
type
!=
"meta"
:
raise
ValueError
(
"Meta parameters and valued parameters can not be in the same model"
)
for
buffer
in
module
.
buffers
():
for
buffer
in
module
.
buffers
():
if
buffer
.
device
.
type
==
"meta"
:
buffer_number
+=
1
meta_buffer_flag
=
1
meta_buffer_number
+=
(
buffer
.
device
.
type
==
'meta'
)
if
meta_buffer_flag
==
1
and
buffer
.
device
.
type
!=
"meta"
:
raise
ValueError
(
"Meta buffers and valued buffers can not be in the same model"
)
if
meta_param_number
>
0
and
meta_param_number
!=
param_number
:
raise
ValueError
(
"Meta parameters and valued parameters can not be in the same model"
)
if
meta_param_flag
==
1
and
meta_buffer_flag
==
1
:
if
meta_buffer_number
>
0
and
meta_buffer_number
!=
buffer_number
:
pass
raise
ValueError
(
"Meta buffers and valued buffers can not be in the same model"
)
elif
meta_buffer_flag
==
0
and
meta_param_flag
==
1
:
for
name
,
buf
in
module
.
named_buffers
():
if
meta_buffer_number
==
0
:
module
.
_buffers
[
name
]
=
module
.
_buffers
[
name
].
to
(
device
=
self
.
_device
)
for
buffer
in
module
.
buffers
():
elif
meta_param_flag
==
0
and
meta_buffer_flag
==
1
:
buffer
.
data
=
buffer
.
data
.
to
(
device
=
self
.
_device
)
for
name
,
param
in
module
.
named_parameters
():
module
.
_parameters
[
name
]
=
module
.
_parameters
[
name
].
to
(
device
=
self
.
_device
)
else
:
module
.
to
(
self
.
_device
)
def
post_process_colo_init_ctx
(
model
:
torch
.
nn
.
Module
,
def
post_process_colo_init_ctx
(
model
:
torch
.
nn
.
Module
,
device
:
torch
.
device
=
torch
.
device
(
'cpu'
),
device
:
torch
.
device
=
torch
.
device
(
'cpu'
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment