Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a77d73f2
Commit
a77d73f2
authored
Mar 10, 2022
by
ziyu huang
Committed by
Frank Lee
Mar 11, 2022
Browse files
fix format parallel_context.py (#359)
Co-authored-by:
huangziyu
<
202476410arsmart@gmail.com
>
parent
c695369a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
34 deletions
+23
-34
colossalai/context/parallel_context.py
colossalai/context/parallel_context.py
+23
-34
No files found.
colossalai/context/parallel_context.py
View file @
a77d73f2
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
os
import
random
from
typing
import
Union
...
...
@@ -218,7 +217,8 @@ class ParallelContext:
def
is_pipeline_last_stage
(
self
,
ignore_virtual
=
False
):
if
not
ignore_virtual
:
if
self
.
virtual_pipeline_parallel_size
is
not
None
and
self
.
virtual_pipeline_parallel_rank
!=
self
.
virtual_pipeline_parallel_size
-
1
:
if
self
.
virtual_pipeline_parallel_size
\
is
not
None
and
self
.
virtual_pipeline_parallel_rank
!=
self
.
virtual_pipeline_parallel_size
-
1
:
return
False
return
self
.
is_last_rank
(
ParallelMode
.
PIPELINE
)
...
...
@@ -300,13 +300,7 @@ class ParallelContext:
self
.
_check_parallel_mode
(
parallel_mode
)
self
.
_ranks_in_group
[
parallel_mode
]
=
ranks
def
init_global_dist
(
self
,
rank
:
int
,
world_size
:
int
,
backend
:
str
,
host
:
str
,
port
:
int
):
def
init_global_dist
(
self
,
rank
:
int
,
world_size
:
int
,
backend
:
str
,
host
:
str
,
port
:
int
):
"""Initializes the global distributed environment
:param rank: rank for the default process group
:type rank: int
...
...
@@ -321,18 +315,13 @@ class ParallelContext:
"""
# initialize the default process group
init_method
=
f
'tcp://
{
host
}
:
{
port
}
'
dist
.
init_process_group
(
rank
=
rank
,
world_size
=
world_size
,
backend
=
backend
,
init_method
=
init_method
)
dist
.
init_process_group
(
rank
=
rank
,
world_size
=
world_size
,
backend
=
backend
,
init_method
=
init_method
)
# None will give the default global process group for pytorch dist operations
self
.
_register_dist
(
rank
,
world_size
,
None
,
list
(
range
(
world_size
)),
ParallelMode
.
GLOBAL
)
self
.
_register_dist
(
rank
,
world_size
,
None
,
list
(
range
(
world_size
)),
ParallelMode
.
GLOBAL
)
self
.
add_global_rank
(
ParallelMode
.
GLOBAL
,
rank
)
def
_register_dist
(
self
,
local_rank
,
world_size
,
process_group
,
ranks_in_group
,
mode
):
def
_register_dist
(
self
,
local_rank
,
world_size
,
process_group
,
ranks_in_group
,
mode
):
self
.
add_local_rank
(
mode
,
local_rank
)
self
.
add_world_size
(
mode
,
world_size
)
self
.
add_group
(
mode
,
process_group
)
...
...
@@ -349,7 +338,9 @@ class ParallelContext:
tps
=
self
.
tensor_parallel_size
ws
=
self
.
world_size
assert
ws
==
dps
*
pps
*
\
tps
,
f
"Expected the world size
{
ws
}
to be equal to data parallel size (
{
dps
}
) * pipeline parallel size (
{
pps
}
) * tensor parallel size (
{
tps
}
)"
tps
,
f
"Expected the world size
{
ws
}
to be equal to data"
\
f
" parallel size (
{
dps
}
) * pipeline parallel size "
\
f
"(
{
pps
}
) * tensor parallel size (
{
tps
}
)"
def
_set_parallel_size_from_config
(
self
,
config
:
dict
,
key
:
str
,
attr_name
:
str
):
if
key
in
config
:
...
...
@@ -360,8 +351,7 @@ class ParallelContext:
setattr
(
self
,
attr_name
,
ele
[
'size'
])
else
:
raise
NotImplementedError
(
f
"Parallel configuration does not support this kind of argument, please use int or dict"
)
f
'
{
"Parallel configuration does not support this kind of argument, please use int or dict"
}
'
)
def
init_parallel_groups
(
self
):
"""Initializes the parallel groups.
...
...
@@ -386,9 +376,11 @@ class ParallelContext:
# get the tensor parallel mode and check
tensor_parallel_mode
=
None
if
parallel_config
is
not
None
and
'tensor'
in
parallel_config
and
'mode'
in
parallel_config
[
'tensor'
]:
if
parallel_config
is
not
None
and
'tensor'
in
\
parallel_config
and
'mode'
in
parallel_config
[
'tensor'
]:
tensor_parallel_mode
=
parallel_config
[
'tensor'
][
'mode'
]
assert
tensor_parallel_mode
in
ALLOWED_MODES
,
f
"mode in the parallel config must be set to one of
{
ALLOWED_MODES
}
"
assert
tensor_parallel_mode
in
ALLOWED_MODES
,
\
f
"mode in the parallel config must be set to one of
{
ALLOWED_MODES
}
"
env
.
mode
=
tensor_parallel_mode
self
.
check_sanity
()
...
...
@@ -426,12 +418,10 @@ class ParallelContext:
for
initializer_cfg
in
pg_init
:
cfg
=
initializer_cfg
.
copy
()
initializer_type
=
cfg
.
pop
(
'type'
)
initializer
=
DIST_GROUP_INITIALIZER
.
get_module
(
initializer_type
)(
rank
,
world_size
,
self
.
config
,
initializer
=
DIST_GROUP_INITIALIZER
.
get_module
(
initializer_type
)(
rank
,
world_size
,
self
.
config
,
self
.
data_parallel_size
,
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
,
**
cfg
)
self
.
tensor_parallel_size
,
**
cfg
)
parallel_setting
=
initializer
.
init_dist_group
()
if
isinstance
(
parallel_setting
,
list
):
for
args
in
parallel_setting
:
...
...
@@ -509,8 +499,7 @@ class ParallelContext:
seed_str
=
', '
.
join
([
f
'
{
k
}
:
{
v
}
'
for
k
,
v
in
seeds
.
items
()])
if
self
.
_verbose
:
self
.
_logger
.
info
(
f
"initialized seed on rank
{
global_rank
}
, "
self
.
_logger
.
info
(
f
"initialized seed on rank
{
global_rank
}
, "
f
"numpy:
{
seed
}
, python random:
{
seed
}
,
{
seed_str
}
,"
f
"the default parallel seed is
{
ParallelMode
.
DATA
}
."
)
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment