Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
15d988f9
Unverified
Commit
15d988f9
authored
Jul 07, 2022
by
Jiarui Fang
Committed by
GitHub
Jul 07, 2022
Browse files
[tensor] sharded global process group (#1219)
parent
db1bef90
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
42 additions
and
18 deletions
+42
-18
colossalai/tensor/process_group.py
colossalai/tensor/process_group.py
+41
-17
tests/test_tensor/test_tensor.py
tests/test_tensor/test_tensor.py
+1
-1
No files found.
colossalai/tensor/process_group.py
View file @
15d988f9
import
torch
import
torch
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.context.singleton_meta
import
SingletonMeta
class
PyTorchProcessGroupDict
(
metaclass
=
SingletonMeta
):
def
__init__
(
self
):
# distributed settings
self
.
dict
=
{}
def
get
(
self
,
rank
:
int
,
world_size
:
int
,
tp_degree
:
int
,
dp_degree
:
int
,
backend
:
str
=
'nccl'
):
key
=
(
tp_degree
,
dp_degree
,
backend
)
if
key
in
self
.
dict
:
return
self
.
dict
[
key
]
else
:
self
.
logger
=
get_dist_logger
(
'PyTorchProcessGroupDict'
)
_tp_rank_list
=
[]
_dp_rank_list
=
[]
for
rank_id
in
range
(
world_size
):
# rank_id and self._rank in the same tp group
if
rank_id
%
tp_degree
==
rank
%
tp_degree
:
_dp_rank_list
.
append
(
rank_id
)
if
rank_id
//
tp_degree
==
rank
//
tp_degree
:
_tp_rank_list
.
append
(
rank_id
)
_tp_process_group
=
torch
.
distributed
.
new_group
(
ranks
=
_tp_rank_list
,
backend
=
backend
)
_dp_process_group
=
torch
.
distributed
.
new_group
(
ranks
=
_dp_rank_list
,
backend
=
backend
)
self
.
logger
.
info
(
f
'rank
{
rank
}
initialize process group on
{
backend
}
, dp ranks:
{
_dp_rank_list
}
tp ranks:
{
_tp_rank_list
}
'
)
self
.
dict
[
key
]
=
_tp_rank_list
,
_tp_process_group
,
_dp_rank_list
,
_dp_process_group
return
_tp_rank_list
,
_tp_process_group
,
_dp_rank_list
,
_dp_process_group
PYTORCHPGDICT_
=
PyTorchProcessGroupDict
()
class
ProcessGroup
:
class
ProcessGroup
:
...
@@ -15,6 +50,7 @@ class ProcessGroup:
...
@@ -15,6 +50,7 @@ class ProcessGroup:
dp_degree: Optional[int], data parallelism degree, default None means len(ranks)
dp_degree: Optional[int], data parallelism degree, default None means len(ranks)
"""
"""
#TODO(haichen) fix me! ranks now must start from 0,1,2,3...
def
__init__
(
self
,
def
__init__
(
self
,
rank
:
Optional
[
int
]
=
None
,
rank
:
Optional
[
int
]
=
None
,
ranks
:
Optional
[
List
[
int
]]
=
None
,
ranks
:
Optional
[
List
[
int
]]
=
None
,
...
@@ -50,23 +86,8 @@ class ProcessGroup:
...
@@ -50,23 +86,8 @@ class ProcessGroup:
assert
self
.
_world_size
%
self
.
_tp_degree
==
0
,
f
"TP degree
{
tp_degree
}
should be divisible by
{
self
.
_world_size
}
when DP degree is None"
assert
self
.
_world_size
%
self
.
_tp_degree
==
0
,
f
"TP degree
{
tp_degree
}
should be divisible by
{
self
.
_world_size
}
when DP degree is None"
self
.
_dp_degree
=
self
.
_world_size
//
tp_degree
self
.
_dp_degree
=
self
.
_world_size
//
tp_degree
self
.
_tp_rank_list
=
[]
self
.
_tp_rank_list
,
self
.
_tp_process_group
,
self
.
_dp_rank_list
,
self
.
_dp_process_group
=
PYTORCHPGDICT_
.
get
(
self
.
_dp_rank_list
=
[]
self
.
_rank
,
self
.
_world_size
,
self
.
_tp_degree
,
self
.
_dp_degree
,
'nccl'
)
for
rank_id
in
range
(
self
.
_world_size
):
# rank_id and self._rank in the same tp group
if
rank_id
%
self
.
_tp_degree
==
self
.
_rank
%
self
.
_tp_degree
:
self
.
_dp_rank_list
.
append
(
rank_id
)
if
rank_id
//
self
.
_tp_degree
==
self
.
_rank
//
self
.
_tp_degree
:
self
.
_tp_rank_list
.
append
(
rank_id
)
self
.
_tp_process_group
=
torch
.
distributed
.
new_group
(
ranks
=
self
.
_tp_rank_list
,
backend
=
'nccl'
)
self
.
_dp_process_group
=
torch
.
distributed
.
new_group
(
ranks
=
self
.
_dp_rank_list
,
backend
=
'nccl'
)
self
.
logger
=
get_dist_logger
(
'ProcessGroup'
)
self
.
logger
.
info
(
f
'
{
self
.
_rank
}
NCCL initialize TP group on
{
self
.
_tp_rank_list
}
, DP group on
{
self
.
_dp_rank_list
}
'
)
self
.
_has_cpu_groups
=
False
self
.
_has_cpu_groups
=
False
def
set_cpu_groups
(
self
):
def
set_cpu_groups
(
self
):
...
@@ -77,6 +98,9 @@ class ProcessGroup:
...
@@ -77,6 +98,9 @@ class ProcessGroup:
self
.
_cpu_tp_process_group
=
torch
.
distributed
.
new_group
(
ranks
=
self
.
_tp_rank_list
,
backend
=
'gloo'
)
self
.
_cpu_tp_process_group
=
torch
.
distributed
.
new_group
(
ranks
=
self
.
_tp_rank_list
,
backend
=
'gloo'
)
self
.
_cpu_dp_process_group
=
torch
.
distributed
.
new_group
(
ranks
=
self
.
_dp_rank_list
,
backend
=
'gloo'
)
self
.
_cpu_dp_process_group
=
torch
.
distributed
.
new_group
(
ranks
=
self
.
_dp_rank_list
,
backend
=
'gloo'
)
_
,
self
.
_cpu_tp_process_group
,
_
,
self
.
_cpu_dp_process_group
=
PYTORCHPGDICT_
.
get
(
self
.
_rank
,
self
.
_world_size
,
self
.
_tp_degree
,
self
.
_dp_degree
,
'gloo'
)
@
property
@
property
def
has_cpu_groups
(
self
):
def
has_cpu_groups
(
self
):
return
self
.
_has_cpu_groups
return
self
.
_has_cpu_groups
...
...
tests/test_tensor/test_tensor.py
View file @
15d988f9
...
@@ -103,9 +103,9 @@ def run_dist_tests(rank, world_size, port):
...
@@ -103,9 +103,9 @@ def run_dist_tests(rank, world_size, port):
_run_view
(
world_size
)
_run_view
(
world_size
)
_run_process_group
(
world_size
)
_run_process_group
(
world_size
)
_run_tensor_indexing
()
_run_tensor_indexing
()
_run_operand
()
# TODO not passed
# TODO not passed
# _run_wrapped_tensor_func()
# _run_wrapped_tensor_func()
_run_operand
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment