Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a77d73f2
Commit
a77d73f2
authored
Mar 10, 2022
by
ziyu huang
Committed by
Frank Lee
Mar 11, 2022
Browse files
fix format parallel_context.py (#359)
Co-authored-by:
huangziyu
<
202476410arsmart@gmail.com
>
parent
c695369a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
34 deletions
+23
-34
colossalai/context/parallel_context.py
colossalai/context/parallel_context.py
+23
-34
No files found.
colossalai/context/parallel_context.py
View file @
a77d73f2
#!/usr/bin/env python
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# -*- encoding: utf-8 -*-
import
os
import
random
import
random
from
typing
import
Union
from
typing
import
Union
...
@@ -20,7 +19,7 @@ from .random import add_seed, get_seeds, set_mode
...
@@ -20,7 +19,7 @@ from .random import add_seed, get_seeds, set_mode
class
ParallelContext
:
class
ParallelContext
:
"""This class provides interface functions for users to get the parallel context,
"""This class provides interface functions for users to get the parallel context,
such as the global rank, the local rank, the world size, etc. of each device.
such as the global rank, the local rank, the world size, etc. of each device.
"""
"""
...
@@ -218,7 +217,8 @@ class ParallelContext:
...
@@ -218,7 +217,8 @@ class ParallelContext:
def
is_pipeline_last_stage
(
self
,
ignore_virtual
=
False
):
def
is_pipeline_last_stage
(
self
,
ignore_virtual
=
False
):
if
not
ignore_virtual
:
if
not
ignore_virtual
:
if
self
.
virtual_pipeline_parallel_size
is
not
None
and
self
.
virtual_pipeline_parallel_rank
!=
self
.
virtual_pipeline_parallel_size
-
1
:
if
self
.
virtual_pipeline_parallel_size
\
is
not
None
and
self
.
virtual_pipeline_parallel_rank
!=
self
.
virtual_pipeline_parallel_size
-
1
:
return
False
return
False
return
self
.
is_last_rank
(
ParallelMode
.
PIPELINE
)
return
self
.
is_last_rank
(
ParallelMode
.
PIPELINE
)
...
@@ -300,13 +300,7 @@ class ParallelContext:
...
@@ -300,13 +300,7 @@ class ParallelContext:
self
.
_check_parallel_mode
(
parallel_mode
)
self
.
_check_parallel_mode
(
parallel_mode
)
self
.
_ranks_in_group
[
parallel_mode
]
=
ranks
self
.
_ranks_in_group
[
parallel_mode
]
=
ranks
def
init_global_dist
(
self
,
def
init_global_dist
(
self
,
rank
:
int
,
world_size
:
int
,
backend
:
str
,
host
:
str
,
port
:
int
):
rank
:
int
,
world_size
:
int
,
backend
:
str
,
host
:
str
,
port
:
int
):
"""Initializes the global distributed environment
"""Initializes the global distributed environment
:param rank: rank for the default process group
:param rank: rank for the default process group
:type rank: int
:type rank: int
...
@@ -321,18 +315,13 @@ class ParallelContext:
...
@@ -321,18 +315,13 @@ class ParallelContext:
"""
"""
# initialize the default process group
# initialize the default process group
init_method
=
f
'tcp://
{
host
}
:
{
port
}
'
init_method
=
f
'tcp://
{
host
}
:
{
port
}
'
dist
.
init_process_group
(
rank
=
rank
,
dist
.
init_process_group
(
rank
=
rank
,
world_size
=
world_size
,
backend
=
backend
,
init_method
=
init_method
)
world_size
=
world_size
,
backend
=
backend
,
init_method
=
init_method
)
# None will give the default global process group for pytorch dist operations
# None will give the default global process group for pytorch dist operations
self
.
_register_dist
(
rank
,
world_size
,
None
,
self
.
_register_dist
(
rank
,
world_size
,
None
,
list
(
range
(
world_size
)),
ParallelMode
.
GLOBAL
)
list
(
range
(
world_size
)),
ParallelMode
.
GLOBAL
)
self
.
add_global_rank
(
ParallelMode
.
GLOBAL
,
rank
)
self
.
add_global_rank
(
ParallelMode
.
GLOBAL
,
rank
)
def
_register_dist
(
self
,
local_rank
,
world_size
,
def
_register_dist
(
self
,
local_rank
,
world_size
,
process_group
,
ranks_in_group
,
mode
):
process_group
,
ranks_in_group
,
mode
):
self
.
add_local_rank
(
mode
,
local_rank
)
self
.
add_local_rank
(
mode
,
local_rank
)
self
.
add_world_size
(
mode
,
world_size
)
self
.
add_world_size
(
mode
,
world_size
)
self
.
add_group
(
mode
,
process_group
)
self
.
add_group
(
mode
,
process_group
)
...
@@ -349,7 +338,9 @@ class ParallelContext:
...
@@ -349,7 +338,9 @@ class ParallelContext:
tps
=
self
.
tensor_parallel_size
tps
=
self
.
tensor_parallel_size
ws
=
self
.
world_size
ws
=
self
.
world_size
assert
ws
==
dps
*
pps
*
\
assert
ws
==
dps
*
pps
*
\
tps
,
f
"Expected the world size
{
ws
}
to be equal to data parallel size (
{
dps
}
) * pipeline parallel size (
{
pps
}
) * tensor parallel size (
{
tps
}
)"
tps
,
f
"Expected the world size
{
ws
}
to be equal to data"
\
f
" parallel size (
{
dps
}
) * pipeline parallel size "
\
f
"(
{
pps
}
) * tensor parallel size (
{
tps
}
)"
def
_set_parallel_size_from_config
(
self
,
config
:
dict
,
key
:
str
,
attr_name
:
str
):
def
_set_parallel_size_from_config
(
self
,
config
:
dict
,
key
:
str
,
attr_name
:
str
):
if
key
in
config
:
if
key
in
config
:
...
@@ -360,8 +351,7 @@ class ParallelContext:
...
@@ -360,8 +351,7 @@ class ParallelContext:
setattr
(
self
,
attr_name
,
ele
[
'size'
])
setattr
(
self
,
attr_name
,
ele
[
'size'
])
else
:
else
:
raise
NotImplementedError
(
raise
NotImplementedError
(
f
"Parallel configuration does not support this kind of argument, please use int or dict"
f
'
{
"Parallel configuration does not support this kind of argument, please use int or dict"
}
'
)
)
def
init_parallel_groups
(
self
):
def
init_parallel_groups
(
self
):
"""Initializes the parallel groups.
"""Initializes the parallel groups.
...
@@ -386,11 +376,13 @@ class ParallelContext:
...
@@ -386,11 +376,13 @@ class ParallelContext:
# get the tensor parallel mode and check
# get the tensor parallel mode and check
tensor_parallel_mode
=
None
tensor_parallel_mode
=
None
if
parallel_config
is
not
None
and
'tensor'
in
parallel_config
and
'mode'
in
parallel_config
[
'tensor'
]:
if
parallel_config
is
not
None
and
'tensor'
in
\
parallel_config
and
'mode'
in
parallel_config
[
'tensor'
]:
tensor_parallel_mode
=
parallel_config
[
'tensor'
][
'mode'
]
tensor_parallel_mode
=
parallel_config
[
'tensor'
][
'mode'
]
assert
tensor_parallel_mode
in
ALLOWED_MODES
,
f
"mode in the parallel config must be set to one of
{
ALLOWED_MODES
}
"
assert
tensor_parallel_mode
in
ALLOWED_MODES
,
\
f
"mode in the parallel config must be set to one of
{
ALLOWED_MODES
}
"
env
.
mode
=
tensor_parallel_mode
env
.
mode
=
tensor_parallel_mode
self
.
check_sanity
()
self
.
check_sanity
()
pg_init
=
[]
pg_init
=
[]
...
@@ -426,12 +418,10 @@ class ParallelContext:
...
@@ -426,12 +418,10 @@ class ParallelContext:
for
initializer_cfg
in
pg_init
:
for
initializer_cfg
in
pg_init
:
cfg
=
initializer_cfg
.
copy
()
cfg
=
initializer_cfg
.
copy
()
initializer_type
=
cfg
.
pop
(
'type'
)
initializer_type
=
cfg
.
pop
(
'type'
)
initializer
=
DIST_GROUP_INITIALIZER
.
get_module
(
initializer_type
)(
initializer
=
DIST_GROUP_INITIALIZER
.
get_module
(
initializer_type
)(
rank
,
world_size
,
self
.
config
,
rank
,
world_size
,
self
.
config
,
self
.
data_parallel_size
,
self
.
data_parallel_size
,
self
.
pipeline_parallel_size
,
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
,
**
cfg
)
self
.
tensor_parallel_size
,
**
cfg
)
parallel_setting
=
initializer
.
init_dist_group
()
parallel_setting
=
initializer
.
init_dist_group
()
if
isinstance
(
parallel_setting
,
list
):
if
isinstance
(
parallel_setting
,
list
):
for
args
in
parallel_setting
:
for
args
in
parallel_setting
:
...
@@ -509,10 +499,9 @@ class ParallelContext:
...
@@ -509,10 +499,9 @@ class ParallelContext:
seed_str
=
', '
.
join
([
f
'
{
k
}
:
{
v
}
'
for
k
,
v
in
seeds
.
items
()])
seed_str
=
', '
.
join
([
f
'
{
k
}
:
{
v
}
'
for
k
,
v
in
seeds
.
items
()])
if
self
.
_verbose
:
if
self
.
_verbose
:
self
.
_logger
.
info
(
self
.
_logger
.
info
(
f
"initialized seed on rank
{
global_rank
}
, "
f
"initialized seed on rank
{
global_rank
}
, "
f
"numpy:
{
seed
}
, python random:
{
seed
}
,
{
seed_str
}
,"
f
"numpy:
{
seed
}
, python random:
{
seed
}
,
{
seed_str
}
,"
f
"the default parallel seed is
{
ParallelMode
.
DATA
}
."
)
f
"the default parallel seed is
{
ParallelMode
.
DATA
}
."
)
else
:
else
:
if
self
.
_verbose
:
if
self
.
_verbose
:
self
.
_logger
.
info
(
self
.
_logger
.
info
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment