Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
552183bb
"examples/vscode:/vscode.git/clone" did not exist on "31dc302017ff491a36088dd27ed4c76e11d5b5b7"
Unverified
Commit
552183bb
authored
Feb 03, 2023
by
HELSON
Committed by
GitHub
Feb 03, 2023
Browse files
[polish] polish ColoTensor and its submodules (#2537)
parent
51d4d6e7
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
75 additions
and
65 deletions
+75
-65
colossalai/tensor/colo_parameter.py
colossalai/tensor/colo_parameter.py
+1
-1
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+6
-1
colossalai/tensor/compute_spec.py
colossalai/tensor/compute_spec.py
+3
-3
colossalai/tensor/distspec.py
colossalai/tensor/distspec.py
+5
-4
colossalai/tensor/process_group.py
colossalai/tensor/process_group.py
+37
-29
colossalai/utils/model/colo_init_context.py
colossalai/utils/model/colo_init_context.py
+23
-27
No files found.
colossalai/tensor/colo_parameter.py
View file @
552183bb
...
...
@@ -71,7 +71,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
return
tensor
def
__repr__
(
self
):
return
f
'
ColoParameter
:
{
ColoTensor
.
__repr__
(
self
)
}
'
return
super
(
ColoParameter
,
self
)
.
__repr__
(
)
@
classmethod
def
__torch_function__
(
cls
,
func
,
types
,
args
=
...,
kwargs
=
None
):
...
...
colossalai/tensor/colo_tensor.py
View file @
552183bb
...
...
@@ -189,7 +189,12 @@ class ColoTensor(torch.Tensor):
return
_convert_output
(
ret
,
colo_spec
)
def
__repr__
(
self
):
return
f
'ColoTensor:
\n
{
super
().
__repr__
()
}
\n
{
self
.
dist_spec
}
\n
{
self
.
process_group
}
\n
{
self
.
compute_spec
}
'
output_list
=
[
super
(
ColoTensor
,
self
).
__repr__
()]
output_list
.
append
(
str
(
self
.
process_group
))
output_list
.
append
(
str
(
self
.
dist_spec
))
if
self
.
compute_spec
is
not
None
:
output_list
.
append
(
str
(
self
.
compute_spec
))
return
"
\n
"
.
join
(
output_list
)
def
_redistribute
(
self
,
dist_spec
:
_DistSpec
)
->
None
:
"""_redistribute
...
...
colossalai/tensor/compute_spec.py
View file @
552183bb
...
...
@@ -23,7 +23,7 @@ class ComputeSpec(object):
self
.
output_replicate
=
True
def
__repr__
(
self
):
return
f
'Compute
pattern
:
{
self
.
compute_pattern
}
'
return
f
'Compute
Spec(
pattern
=
{
self
.
compute_pattern
}
, replicate_output=
{
self
.
output_replicate
}
)
'
def
set_output_replicate
(
self
,
flag
:
bool
=
True
):
self
.
output_replicate
=
flag
colossalai/tensor/distspec.py
View file @
552183bb
...
...
@@ -39,11 +39,12 @@ class _DistSpec:
return
True
def
__repr__
(
self
)
->
str
:
res
_list
=
[
"DistSpec:"
]
attr
_list
=
[]
for
attr
in
dir
(
self
):
if
not
attr
.
startswith
(
'__'
):
res_list
.
append
(
f
'
\n\t
{
attr
}
:
{
str
(
getattr
(
self
,
attr
))
}
'
)
return
''
.
join
(
res_list
)
attr_list
.
append
(
f
'
{
attr
}
=
{
str
(
getattr
(
self
,
attr
))
}
'
)
attr_str
=
", "
.
join
(
attr_list
)
return
"DistSpec("
+
attr_str
+
")"
def
ReplicaSpec
()
->
_DistSpec
:
...
...
colossalai/tensor/process_group.py
View file @
552183bb
import
torch
from
typing
import
List
,
Optional
from
colossalai.logging
import
get_dist_logger
import
torch
from
colossalai.context.singleton_meta
import
SingletonMeta
from
colossalai.logging
import
get_dist_logger
class
PyTorchProcessGroupDict
(
metaclass
=
SingletonMeta
):
def
__init__
(
self
):
# distributed settings
# use this dict to record all Pytorch ProcessGroups
self
.
dict
=
{}
# set a distributed logger
self
.
logger
=
get_dist_logger
(
'ProcessGroup'
)
def
log_pg_init
(
self
,
rank_list
:
List
[
int
],
backend
:
str
):
str_list
=
[
"Pytorch ProcessGroup Init:"
]
str_list
.
append
(
f
"backend:
{
backend
}
"
)
str_list
.
append
(
f
"ranks:
{
rank_list
}
"
)
self
.
logger
.
info
(
"
\n\t
"
.
join
(
str_list
),
ranks
=
[
0
])
def
get
(
self
,
rank_list
:
List
[
int
],
backend
:
str
=
'nccl'
):
"""Reuse Pytorch ProcessGroup when such a group is initialized
"""
rank_tuple
=
tuple
(
rank_list
)
# we need to convert the passed list to a tuple
# since List is unhashable
pg_key
=
(
backend
,
rank_tuple
)
if
pg_key
not
in
self
.
dict
:
self
.
logger
=
get_dist_logger
(
'ProcessGroup'
)
self
.
logger
.
info
(
f
'NCCL initialize ProcessGroup on
{
rank_list
}
'
,
ranks
=
[
0
])
self
.
dict
[
pg_key
]
=
torch
.
distributed
.
new_group
(
ranks
=
rank_list
,
backend
=
backend
)
return
self
.
dict
[
pg_key
]
processgroup_key
=
(
backend
,
tuple
(
rank_list
))
if
processgroup_key
not
in
self
.
dict
:
self
.
log_pg_init
(
rank_list
=
rank_list
,
backend
=
backend
)
self
.
dict
[
processgroup_key
]
=
torch
.
distributed
.
new_group
(
ranks
=
rank_list
,
backend
=
backend
)
return
self
.
dict
[
processgroup_key
]
PYTORCHPGDICT_
=
PyTorchProcessGroupDict
()
...
...
@@ -54,10 +61,10 @@ class ProcessGroup:
return
assert
torch
.
distributed
.
is_initialized
(),
f
"ProcessGroup must be used after distributed initialized"
if
rank
is
None
:
self
.
_rank
=
torch
.
distributed
.
get_rank
()
els
e
:
self
.
_rank
=
rank
if
rank
is
not
Non
e
:
assert
self
.
_rank
=
=
rank
# make sure that the global rank is correct
if
ranks
is
None
:
self
.
_rank_list
=
list
(
range
(
torch
.
distributed
.
get_world_size
()))
...
...
@@ -132,8 +139,9 @@ class ProcessGroup:
def
__repr__
(
self
):
if
self
.
is_init
:
return
"ProcessGroup:
\n\t
Rank: {}, World size: {}, DP degree: {}, TP degree: {}
\n\t
Ranks in group: {}"
.
\
format
(
self
.
_rank
,
self
.
_world_size
,
self
.
_dp_degree
,
self
.
_tp_degree
,
self
.
_rank_list
)
ranks_str
=
f
"ProcessGroup(ranks=
{
self
.
_rank_list
}
,
\n
"
personal_str
=
f
" rank=
{
self
.
_rank
}
, dp=
{
self
.
_dp_degree
}
, tp=
{
self
.
_tp_degree
}
)"
return
ranks_str
+
personal_str
else
:
return
"ProcessGroup not initialized"
...
...
colossalai/utils/model/colo_init_context.py
View file @
552183bb
...
...
@@ -43,7 +43,6 @@ def _convert_to_coloparam(param: torch.nn.Parameter,
else
:
colo_param
=
ColoParameter
(
param
.
to
(
device
=
device
,
dtype
=
dtype
),
requires_grad
=
requires_grad
)
# if default_shard_plan exists, shard the param during initialization.
# This can reduce the model size after initialization.
# NOTE() embedding usually can not be correctly sharded. So I use except to handle
...
...
@@ -130,30 +129,27 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
setattr
(
submodule
,
param_name
,
colo_param
)
colo_param
.
shared_param_modules
.
append
(
submodule
)
meta_param_flag
=
0
meta_buffer_flag
=
0
param_number
=
0
meta_param_number
=
0
buffer_number
=
0
meta_buffer_number
=
0
for
param
in
module
.
parameters
():
if
param
.
device
.
type
==
"meta"
:
meta_param_flag
=
1
if
meta_param_flag
==
1
and
param
.
device
.
type
!=
"meta"
:
raise
ValueError
(
"Meta parameters and valued parameters can not be in the same model"
)
param_number
+=
1
meta_param_number
+=
(
param
.
device
.
type
==
'meta'
)
for
buffer
in
module
.
buffers
():
if
buffer
.
device
.
type
==
"meta"
:
meta_buffer_flag
=
1
if
meta_buffer_flag
==
1
and
buffer
.
device
.
type
!=
"meta"
:
buffer_number
+=
1
meta_buffer_number
+=
(
buffer
.
device
.
type
==
'meta'
)
if
meta_param_number
>
0
and
meta_param_number
!=
param_number
:
raise
ValueError
(
"Meta parameters and valued parameters can not be in the same model"
)
if
meta_buffer_number
>
0
and
meta_buffer_number
!=
buffer_number
:
raise
ValueError
(
"Meta buffers and valued buffers can not be in the same model"
)
if
meta_param_flag
==
1
and
meta_buffer_flag
==
1
:
pass
elif
meta_buffer_flag
==
0
and
meta_param_flag
==
1
:
for
name
,
buf
in
module
.
named_buffers
():
module
.
_buffers
[
name
]
=
module
.
_buffers
[
name
].
to
(
device
=
self
.
_device
)
elif
meta_param_flag
==
0
and
meta_buffer_flag
==
1
:
for
name
,
param
in
module
.
named_parameters
():
module
.
_parameters
[
name
]
=
module
.
_parameters
[
name
].
to
(
device
=
self
.
_device
)
else
:
module
.
to
(
self
.
_device
)
if
meta_buffer_number
==
0
:
for
buffer
in
module
.
buffers
():
buffer
.
data
=
buffer
.
data
.
to
(
device
=
self
.
_device
)
def
post_process_colo_init_ctx
(
model
:
torch
.
nn
.
Module
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment