Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
280a8124
Unverified
Commit
280a8124
authored
Jul 07, 2022
by
HELSON
Committed by
GitHub
Jul 07, 2022
Browse files
[tensor] improve robustness of class 'ProcessGroup' (#1223)
parent
15d988f9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
42 additions
and
35 deletions
+42
-35
colossalai/tensor/process_group.py
colossalai/tensor/process_group.py
+42
-35
No files found.
colossalai/tensor/process_group.py
View file @
280a8124
...
...
@@ -10,29 +10,17 @@ class PyTorchProcessGroupDict(metaclass=SingletonMeta):
# distributed settings
self
.
dict
=
{}
def
get
(
self
,
rank
:
int
,
world_size
:
int
,
tp_degree
:
int
,
dp_degree
:
int
,
backend
:
str
=
'nccl'
):
key
=
(
tp_degree
,
dp_degree
,
backend
)
if
key
in
self
.
dict
:
return
self
.
dict
[
key
]
else
:
self
.
logger
=
get_dist_logger
(
'PyTorchProcessGroupDict'
)
_tp_rank_list
=
[]
_dp_rank_list
=
[]
for
rank_id
in
range
(
world_size
):
# rank_id and self._rank in the same tp group
if
rank_id
%
tp_degree
==
rank
%
tp_degree
:
_dp_rank_list
.
append
(
rank_id
)
if
rank_id
//
tp_degree
==
rank
//
tp_degree
:
_tp_rank_list
.
append
(
rank_id
)
_tp_process_group
=
torch
.
distributed
.
new_group
(
ranks
=
_tp_rank_list
,
backend
=
backend
)
_dp_process_group
=
torch
.
distributed
.
new_group
(
ranks
=
_dp_rank_list
,
backend
=
backend
)
self
.
logger
.
info
(
f
'rank
{
rank
}
initialize process group on
{
backend
}
, dp ranks:
{
_dp_rank_list
}
tp ranks:
{
_tp_rank_list
}
'
)
self
.
dict
[
key
]
=
_tp_rank_list
,
_tp_process_group
,
_dp_rank_list
,
_dp_process_group
return
_tp_rank_list
,
_tp_process_group
,
_dp_rank_list
,
_dp_process_group
def
get
(
self
,
rank_list
:
List
[
int
],
backend
:
str
=
'nccl'
):
"""Reuse Pytorch ProcessGroup when such a group is initialized
"""
rank_tuple
=
tuple
(
rank_list
)
# we need to convert the passed list to a tuple
# since List is unhashable
pg_key
=
(
backend
,
rank_tuple
)
if
pg_key
not
in
self
.
dict
:
self
.
dict
[
pg_key
]
=
torch
.
distributed
.
new_group
(
ranks
=
rank_list
,
backend
=
backend
)
return
self
.
dict
[
pg_key
]
PYTORCHPGDICT_
=
PyTorchProcessGroupDict
()
...
...
@@ -50,7 +38,6 @@ class ProcessGroup:
dp_degree: Optional[int], data parallelism degree, default None means len(ranks)
"""
#TODO(haichen) fix me! ranks now must start from 0,1,2,3...
def
__init__
(
self
,
rank
:
Optional
[
int
]
=
None
,
ranks
:
Optional
[
List
[
int
]]
=
None
,
...
...
@@ -69,37 +56,57 @@ class ProcessGroup:
self
.
_rank_list
=
list
(
range
(
torch
.
distributed
.
get_world_size
()))
else
:
self
.
_rank_list
=
ranks
self
.
_rank_list
.
sort
()
# ensure that the list is in order
self
.
_rank_idx
=
self
.
_rank_list
.
index
(
self
.
_rank
)
self
.
_world_size
=
len
(
self
.
_rank_list
)
if
dp_degree
is
None
and
tp_degree
is
None
:
self
.
_dp_degree
=
self
.
_world_size
self
.
_tp_degree
=
1
if
dp_degree
and
not
tp_degree
:
elif
dp_degree
and
not
tp_degree
:
self
.
_dp_degree
=
dp_degree
assert
self
.
_world_size
%
self
.
_dp_degree
==
0
,
f
"DP degree
{
dp_degree
}
should be divisible by
{
self
.
_world_size
}
hen DP degree is None"
self
.
_tp_degree
=
self
.
_world_size
//
dp_degree
if
not
dp_degree
and
tp_degree
:
elif
not
dp_degree
and
tp_degree
:
self
.
_tp_degree
=
tp_degree
assert
self
.
_world_size
%
self
.
_tp_degree
==
0
,
f
"TP degree
{
tp_degree
}
should be divisible by
{
self
.
_world_size
}
when DP degree is None"
self
.
_dp_degree
=
self
.
_world_size
//
tp_degree
else
:
self
.
_dp_degree
=
dp_degree
self
.
_tp_degree
=
tp_degree
assert
self
.
_dp_degree
*
self
.
_tp_degree
==
self
.
_world_size
,
\
f
"the world size
{
self
.
_world_size
}
should equals to the product of DP degree
{
self
.
_dp_degree
}
"
\
f
"and TP degree
{
self
.
_tp_degree
}
"
self
.
_tp_rank_list
=
[]
self
.
_dp_rank_list
=
[]
for
idx
,
rank_id
in
enumerate
(
self
.
_rank_list
):
# idx and self._rank_idx in the same tp group
if
idx
%
self
.
_tp_degree
==
self
.
_rank_idx
%
self
.
_tp_degree
:
self
.
_dp_rank_list
.
append
(
rank_id
)
if
idx
//
self
.
_tp_degree
==
self
.
_rank_idx
//
self
.
_tp_degree
:
self
.
_tp_rank_list
.
append
(
rank_id
)
self
.
_tp_process_group
=
PYTORCHPGDICT_
.
get
(
self
.
_tp_rank_list
,
'nccl'
)
self
.
_dp_process_group
=
PYTORCHPGDICT_
.
get
(
self
.
_dp_rank_list
,
'nccl'
)
self
.
logger
=
get_dist_logger
(
'ProcessGroup'
)
self
.
logger
.
info
(
f
'
{
self
.
_rank
}
NCCL initialize TP group on
{
self
.
_tp_rank_list
}
, DP group on
{
self
.
_dp_rank_list
}
'
)
self
.
_tp_rank_list
,
self
.
_tp_process_group
,
self
.
_dp_rank_list
,
self
.
_dp_process_group
=
PYTORCHPGDICT_
.
get
(
self
.
_rank
,
self
.
_world_size
,
self
.
_tp_degree
,
self
.
_dp_degree
,
'nccl'
)
self
.
_has_cpu_groups
=
False
self
.
_cpu_dp_process_group
=
None
self
.
_cpu_tp_process_group
=
None
def
set_cpu_groups
(
self
):
if
self
.
has_cpu_groups
:
return
self
.
logger
.
info
(
f
'
{
self
.
_rank
}
Gloo initialize TP group on
{
self
.
_tp_rank_list
}
, DP group on
{
self
.
_dp_rank_list
}
'
)
self
.
_cpu_tp_process_group
=
torch
.
distributed
.
new_group
(
ranks
=
self
.
_tp_rank_list
,
backend
=
'gloo'
)
self
.
_cpu_dp_process_group
=
torch
.
distributed
.
new_group
(
ranks
=
self
.
_dp_rank_list
,
backend
=
'gloo'
)
_
,
self
.
_cpu_tp_process_group
,
_
,
self
.
_cpu_dp_process_group
=
PYTORCHPGDICT_
.
get
(
self
.
_rank
,
self
.
_world_size
,
self
.
_tp_degree
,
self
.
_dp_degree
,
'gloo'
)
self
.
_cpu_tp_process_group
=
PYTORCHPGDICT_
.
get
(
self
.
_tp_rank_list
,
'gloo'
)
self
.
_cpu_dp_process_group
=
PYTORCHPGDICT_
.
get
(
self
.
_dp_rank_list
,
'gloo'
)
@
property
def
has_cpu_groups
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment