Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
b72b8445
Unverified
Commit
b72b8445
authored
Mar 17, 2022
by
Frank Lee
Committed by
GitHub
Mar 17, 2022
Browse files
optimized context test time consumption (#446)
parent
496cbb07
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
169 additions
and
357 deletions
+169
-357
colossalai/context/parallel_context.py
colossalai/context/parallel_context.py
+1
-0
colossalai/testing/comparison.py
colossalai/testing/comparison.py
+1
-1
colossalai/utils/common.py
colossalai/utils/common.py
+1
-0
tests/test_amp/test_naive_fp16.py
tests/test_amp/test_naive_fp16.py
+4
-3
tests/test_context/test_2d_init.py
tests/test_context/test_2d_init.py
+0
-105
tests/test_context/test_2p5d_init.py
tests/test_context/test_2p5d_init.py
+0
-128
tests/test_context/test_3d_init.py
tests/test_context/test_3d_init.py
+0
-120
tests/test_context/test_hybrid_parallel.py
tests/test_context/test_hybrid_parallel.py
+162
-0
No files found.
colossalai/context/parallel_context.py
View file @
b72b8445
...
...
@@ -449,6 +449,7 @@ class ParallelContext:
dist
.
destroy_process_group
(
group
)
# destroy global process group
dist
.
destroy_process_group
()
self
.
_groups
.
clear
()
def
set_device
(
self
,
device_ordinal
:
int
=
None
):
"""Sets distributed processes to be bound to devices.
...
...
colossalai/testing/comparison.py
View file @
b72b8445
...
...
@@ -13,7 +13,7 @@ def assert_not_equal(a: Tensor, b: Tensor):
def
assert_close
(
a
:
Tensor
,
b
:
Tensor
,
rtol
:
float
=
1e-5
,
atol
:
float
=
1e-8
):
assert
torch
.
allclose
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
),
f
'expected a and b to be close but they are not,
{
a
}
vs
{
b
}
'
def
assert_close_loose
(
a
:
Tensor
,
b
:
Tensor
,
rtol
:
float
=
1e-
2
,
atol
:
float
=
1e-3
):
def
assert_close_loose
(
a
:
Tensor
,
b
:
Tensor
,
rtol
:
float
=
1e-
3
,
atol
:
float
=
1e-3
):
assert_close
(
a
,
b
,
rtol
,
atol
)
def
assert_equal_in_group
(
tensor
:
Tensor
,
process_group
:
ProcessGroup
=
None
):
...
...
colossalai/utils/common.py
View file @
b72b8445
...
...
@@ -46,6 +46,7 @@ def free_port():
while
True
:
try
:
sock
=
socket
.
socket
()
sock
.
setsockopt
(
socket
.
SOL_SOCKET
,
socket
.
SO_REUSEADDR
,
1
)
port
=
random
.
randint
(
20000
,
65000
)
sock
.
bind
((
'localhost'
,
port
))
sock
.
close
()
...
...
tests/test_amp/test_naive_fp16.py
View file @
b72b8445
...
...
@@ -5,6 +5,7 @@ import pytest
import
torch.multiprocessing
as
mp
from
colossalai.amp
import
convert_to_naive_amp
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
colossalai.testing
import
assert_close_loose
from
colossalai.utils
import
free_port
from
functools
import
partial
...
...
@@ -48,7 +49,7 @@ def run_naive_amp():
# forward pass
amp_output
=
amp_model
(
data
)
torch_output
=
torch_model
(
data
)
assert
torch
.
allclose
(
amp_output
,
torch_output
,
rtol
=
1e-3
,
atol
=
1e-3
),
f
'
{
amp_output
}
vs
{
torch_output
}
'
assert
_close_loose
(
amp_output
,
torch_output
)
# backward
amp_optimizer
.
backward
(
amp_output
.
mean
())
...
...
@@ -56,7 +57,7 @@ def run_naive_amp():
# check grad
for
amp_param
,
torch_param
in
zip
(
amp_model
.
parameters
(),
torch_model
.
parameters
()):
torch
.
allcl
ose
(
amp_param
.
grad
,
torch_param
.
grad
.
half
()
,
rtol
=
1e-3
,
atol
=
1e-3
)
assert_close_lo
ose
(
amp_param
.
grad
,
torch_param
.
grad
.
half
())
# step
amp_optimizer
.
step
()
...
...
@@ -64,7 +65,7 @@ def run_naive_amp():
# check updated param
for
amp_param
,
torch_param
in
zip
(
amp_model
.
parameters
(),
torch_model
.
parameters
()):
torch
.
allcl
ose
(
amp_param
,
torch_param
.
half
()
,
rtol
=
1e-3
,
atol
=
1e-3
)
assert_close_lo
ose
(
amp_param
,
torch_param
.
half
())
def
run_dist
(
rank
,
world_size
,
port
):
...
...
tests/test_context/test_2d_init.py
deleted
100644 → 0
View file @
496cbb07
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from
functools
import
partial
from
pathlib
import
Path
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
colossalai
import
launch
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.utils
import
free_port
CONFIG_PATH
=
Path
(
__file__
).
parent
.
joinpath
(
'configs/parallel_2d_init.py'
).
absolute
()
def
check_data_parallel_rank
(
rank
):
if
rank
in
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
]:
assert
gpc
.
get_local_rank
(
ParallelMode
.
DATA
)
==
0
elif
rank
in
[
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
]:
assert
gpc
.
get_local_rank
(
ParallelMode
.
DATA
)
==
1
def
check_pipeline_parallel_rank
(
rank
):
if
rank
in
[
0
,
1
,
2
,
3
]:
assert
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
==
0
elif
rank
in
[
4
,
5
,
6
,
7
]:
assert
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
==
1
elif
rank
in
[
8
,
9
,
10
,
11
]:
assert
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
==
0
elif
rank
in
[
12
,
13
,
14
,
15
]:
assert
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
==
1
def
check_model_parallel_rank
(
rank
):
for
i
in
range
(
8
):
if
rank
in
[
i
,
i
+
8
]:
assert
gpc
.
get_local_rank
(
ParallelMode
.
MODEL
)
==
i
def
check_tensor_parallel_rank
(
rank
):
if
rank
in
[
0
,
4
,
8
,
12
]:
assert
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
elif
rank
in
[
1
,
5
,
9
,
13
]:
assert
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
1
elif
rank
in
[
2
,
6
,
10
,
14
]:
assert
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
2
elif
rank
in
[
3
,
7
,
11
,
15
]:
assert
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
3
def
check_2d_parallel_rank
(
rank
):
if
rank
in
[
0
,
4
,
8
,
12
]:
assert
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_COL
)
==
0
assert
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_ROW
)
==
0
elif
rank
in
[
1
,
5
,
9
,
13
]:
assert
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_COL
)
==
0
assert
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_ROW
)
==
1
elif
rank
in
[
2
,
6
,
10
,
14
]:
assert
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_COL
)
==
1
assert
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_ROW
)
==
0
elif
rank
in
[
3
,
7
,
11
,
15
]:
assert
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_COL
)
==
1
assert
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_ROW
)
==
1
def
init_2d
(
rank
,
world_size
,
backend
,
port
,
host
):
dist_args
=
dict
(
config
=
CONFIG_PATH
,
rank
=
rank
,
world_size
=
world_size
,
backend
=
backend
,
port
=
port
,
host
=
host
,
verbose
=
True
)
launch
(
**
dist_args
)
check_tensor_parallel_rank
(
rank
)
check_data_parallel_rank
(
rank
)
check_2d_parallel_rank
(
rank
)
check_pipeline_parallel_rank
(
rank
)
check_model_parallel_rank
(
rank
)
gpc
.
destroy
()
torch
.
cuda
.
empty_cache
()
@
pytest
.
mark
.
cpu
def
test_2d_init
():
"""
As no computation or communication is done, we can run this test on CPU.
"""
world_size
=
16
test_fn
=
partial
(
init_2d
,
world_size
=
world_size
,
backend
=
'gloo'
,
port
=
free_port
(),
host
=
'localhost'
)
mp
.
spawn
(
test_fn
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_2d_init
()
tests/test_context/test_2p5d_init.py
deleted
100644 → 0
View file @
496cbb07
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from
functools
import
partial
from
pathlib
import
Path
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.initialize
import
launch
from
colossalai.utils
import
free_port
CONFIG_PATH
=
Path
(
__file__
).
parent
.
joinpath
(
'configs/parallel_2p5d_init.py'
).
absolute
()
def
check_data_parallel_rank
(
rank
):
dp_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
DATA
)
if
rank
in
list
(
range
(
16
)):
assert
dp_rank
==
0
elif
rank
in
list
(
range
(
16
,
32
)):
assert
dp_rank
==
1
def
check_pipeline_parallel_rank
(
rank
):
ppr
=
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
if
rank
in
list
(
range
(
8
)):
assert
ppr
==
0
elif
rank
in
list
(
range
(
8
,
16
)):
assert
ppr
==
1
elif
rank
in
list
(
range
(
16
,
24
)):
assert
ppr
==
0
elif
rank
in
list
(
range
(
24
,
32
)):
assert
ppr
==
1
def
check_model_parallel_rank
(
rank
):
for
i
in
range
(
16
):
if
rank
in
[
i
,
i
+
16
]:
assert
gpc
.
get_local_rank
(
ParallelMode
.
MODEL
)
==
i
def
check_tensor_parallel_rank
(
rank
):
tp_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
for
i
in
range
(
8
):
ranks
=
list
(
range
(
i
,
32
,
8
))
if
rank
in
ranks
:
assert
tp_rank
==
i
,
f
'
{
rank
}
:
{
tp_rank
}
'
def
check_2p5d_parallel_rank
(
rank
):
rp_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_ROW
)
cp_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_COL
)
dp_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_DEP
)
xp_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_XZ
)
# check for row parallel group
for
i
in
range
(
2
):
ranks
=
list
(
range
(
i
,
32
,
2
))
if
rank
in
ranks
:
assert
rp_rank
==
i
# check for col parallel group
for
i
in
range
(
2
):
ranks
=
list
(
range
(
i
*
2
,
32
,
4
))
ranks_plus_ones
=
[
val
+
1
for
val
in
ranks
]
ranks
.
extend
(
ranks_plus_ones
)
if
rank
in
ranks
:
assert
cp_rank
==
i
# check for depth parallel group
for
i
in
range
(
2
):
ranks
=
[]
for
j
in
range
(
i
*
4
,
32
,
8
):
ranks
.
extend
([
j
+
k
for
k
in
range
(
4
)])
if
rank
in
ranks
:
assert
dp_rank
==
i
# check for xz parallel group
for
i
in
range
(
2
):
ranks
=
list
(
range
(
i
*
2
,
32
,
8
))
ranks_plus_one
=
[
val
+
1
for
val
in
ranks
]
ranks
.
extend
(
ranks_plus_one
)
if
rank
in
ranks
:
assert
xp_rank
==
i
def
init_2halfd
(
rank
,
world_size
,
backend
,
port
,
host
):
dist_args
=
dict
(
config
=
CONFIG_PATH
,
rank
=
rank
,
world_size
=
world_size
,
backend
=
backend
,
port
=
port
,
host
=
host
,
verbose
=
True
)
launch
(
**
dist_args
)
check_data_parallel_rank
(
rank
)
check_pipeline_parallel_rank
(
rank
)
check_tensor_parallel_rank
(
rank
)
check_2p5d_parallel_rank
(
rank
)
check_model_parallel_rank
(
rank
)
gpc
.
destroy
()
torch
.
cuda
.
empty_cache
()
@
pytest
.
mark
.
cpu
def
test_2halfd_init
():
"""
As no computation or communication is done, we can run this test on CPU.
"""
world_size
=
32
test_fn
=
partial
(
init_2halfd
,
world_size
=
world_size
,
backend
=
'gloo'
,
port
=
free_port
(),
host
=
'localhost'
)
mp
.
spawn
(
test_fn
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_2halfd_init
()
tests/test_context/test_3d_init.py
deleted
100644 → 0
View file @
496cbb07
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from
functools
import
partial
from
pathlib
import
Path
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.initialize
import
launch
from
colossalai.utils
import
free_port
CONFIG_PATH
=
Path
(
__file__
).
parent
.
joinpath
(
'configs/parallel_3d_init.py'
).
absolute
()
def
check_data_parallel_rank
(
rank
):
dp_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
DATA
)
if
rank
in
list
(
range
(
16
)):
assert
dp_rank
==
0
elif
rank
in
list
(
range
(
16
,
32
)):
assert
dp_rank
==
1
def
check_pipeline_parallel_rank
(
rank
):
ppr
=
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
if
rank
in
list
(
range
(
8
)):
assert
ppr
==
0
elif
rank
in
list
(
range
(
8
,
16
)):
assert
ppr
==
1
elif
rank
in
list
(
range
(
16
,
24
)):
assert
ppr
==
0
elif
rank
in
list
(
range
(
24
,
32
)):
assert
ppr
==
1
def
check_model_parallel_rank
(
rank
):
for
i
in
range
(
16
):
if
rank
in
[
i
,
i
+
16
]:
assert
gpc
.
get_local_rank
(
ParallelMode
.
MODEL
)
==
i
def
check_tensor_parallel_rank
(
rank
):
tp_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
for
i
in
range
(
8
):
ranks
=
list
(
range
(
i
,
32
,
8
))
if
rank
in
ranks
:
assert
tp_rank
==
i
def
check_3d_parallel_rank
(
rank
):
ip_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_3D_INPUT
)
wp_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_3D_WEIGHT
)
op_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_3D_OUTPUT
)
# check for input parallel group
for
i
in
range
(
2
):
_ranks
=
list
(
range
(
i
*
2
,
32
,
4
))
_ranks_plus_one
=
[
val
+
1
for
val
in
_ranks
]
input_ranks
=
_ranks
+
_ranks_plus_one
if
rank
in
input_ranks
:
assert
ip_rank
==
i
# check for weight parallel group
for
i
in
range
(
2
):
ranks
=
list
(
range
(
i
,
32
,
2
))
if
rank
in
ranks
:
assert
wp_rank
==
i
# check for output parallel group
for
i
in
range
(
2
):
ranks
=
[]
for
j
in
range
(
i
*
4
,
32
,
8
):
ranks
.
extend
([
j
+
k
for
k
in
range
(
4
)])
if
rank
in
ranks
:
assert
op_rank
==
i
def
init_3d
(
rank
,
world_size
,
backend
,
port
,
host
):
dist_args
=
dict
(
config
=
CONFIG_PATH
,
rank
=
rank
,
world_size
=
world_size
,
backend
=
backend
,
port
=
port
,
host
=
host
,
verbose
=
True
)
launch
(
**
dist_args
)
check_tensor_parallel_rank
(
rank
)
check_3d_parallel_rank
(
rank
)
check_data_parallel_rank
(
rank
)
check_pipeline_parallel_rank
(
rank
)
check_model_parallel_rank
(
rank
)
gpc
.
destroy
()
torch
.
cuda
.
empty_cache
()
@
pytest
.
mark
.
cpu
def
test_3d_init
():
"""
As no computation or communication is done, we can run this test on CPU.
"""
world_size
=
32
test_fn
=
partial
(
init_3d
,
world_size
=
world_size
,
backend
=
'gloo'
,
port
=
free_port
(),
host
=
'localhost'
)
mp
.
spawn
(
test_fn
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_3d_init
()
tests/test_context/test_hybrid_parallel.py
0 → 100644
View file @
b72b8445
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from
functools
import
partial
from
pathlib
import
Path
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
colossalai
import
launch
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.utils
import
free_port
from
colossalai.context
import
reset_seeds
from
colossalai.global_variables
import
tensor_parallel_env
as
tp_env
CONFIG_PATH_LIST
=
list
(
Path
(
__file__
).
parent
.
glob
(
'configs/*.py'
))
def
check_data_parallel_rank
(
rank
):
global_world_size
=
gpc
.
get_world_size
(
ParallelMode
.
GLOBAL
)
mp_size
=
gpc
.
get_world_size
(
ParallelMode
.
MODEL
)
num_dp_groups
=
global_world_size
//
mp_size
dp_local_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
DATA
)
assert
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
==
num_dp_groups
for
group_idx
in
range
(
num_dp_groups
):
ranks_in_dp_group
=
range
(
group_idx
*
mp_size
,
(
group_idx
+
1
)
*
mp_size
)
if
rank
in
ranks_in_dp_group
:
assert
dp_local_rank
==
group_idx
def
check_pipeline_parallel_rank
(
rank
):
mp_world_size
=
gpc
.
get_world_size
(
ParallelMode
.
MODEL
)
tp_world_size
=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
num_pipeline_stage
=
mp_world_size
//
tp_world_size
pipeline_local_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
for
stage_idx
in
range
(
num_pipeline_stage
):
ranks_in_current_stage
=
range
(
stage_idx
*
tp_world_size
,
(
stage_idx
+
1
)
*
tp_world_size
)
if
rank
in
ranks_in_current_stage
:
assert
stage_idx
==
pipeline_local_rank
def
check_model_parallel_rank
(
rank
):
mp_size
=
gpc
.
get_world_size
(
ParallelMode
.
MODEL
)
rank_within_mp_group
=
rank
%
mp_size
mp_local_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
MODEL
)
assert
rank_within_mp_group
==
mp_local_rank
def
check_tensor_parallel_rank
(
rank
):
if
tp_env
.
mode
==
'2d'
:
check_2d_tensor_parallel_rank
(
rank
)
elif
tp_env
==
'2.5d'
:
check_2p5d_tensor_parallel_rank
(
rank
)
elif
tp_env
==
'3d'
:
check_3d_tensor_parallel_rank
(
rank
)
def
get_tp_info
():
global_world_size
=
gpc
.
get_world_size
(
ParallelMode
.
GLOBAL
)
tp_world_size
=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
num_tp_groups
=
global_world_size
//
tp_world_size
tp_local_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
return
tp_local_rank
,
tp_world_size
,
num_tp_groups
def
check_2d_tensor_parallel_rank
(
rank
):
tp_local_rank
,
tp_world_size
,
num_tp_groups
=
get_tp_info
()
for
group_id
in
range
(
num_tp_groups
):
ranks_in_current_tp_group
=
range
(
group_id
*
tp_world_size
,
(
group_id
+
1
)
*
tp_world_size
)
if
rank
in
ranks_in_current_tp_group
:
col_local_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_COL
)
row_local_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_ROW
)
assert
col_local_rank
==
tp_local_rank
//
tp_env
.
summa_dim
assert
row_local_rank
==
tp_local_rank
%
tp_env
.
summa_dim
def
check_2p5d_tensor_parallel_rank
(
rank
):
tp_local_rank
,
tp_world_size
,
num_tp_groups
=
get_tp_info
()
for
group_id
in
range
(
num_tp_groups
):
ranks_in_current_tp_group
=
range
(
group_id
*
tp_world_size
,
(
group_id
+
1
)
*
tp_world_size
)
if
rank
in
ranks_in_current_tp_group
:
rp_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_ROW
)
cp_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_COL
)
dp_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_DEP
)
xp_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_XZ
)
assert
rp_rank
==
tp_local_rank
%
tp_env
.
summa_dim
assert
cp_rank
==
tp_local_rank
//
tp_env
.
tesseract_dim
assert
dp_rank
==
tp_local_rank
//
(
tp_env
.
summa_dim
**
2
)
assert
xp_rank
==
tp_local_rank
//
tp_env
.
summa_dim
def
check_3d_tensor_parallel_rank
(
rank
):
tp_local_rank
,
tp_world_size
,
num_tp_groups
=
get_tp_info
()
for
group_id
in
range
(
num_tp_groups
):
ranks_in_current_tp_group
=
range
(
group_id
*
tp_world_size
,
(
group_id
+
1
)
*
tp_world_size
)
if
rank
in
ranks_in_current_tp_group
:
ip_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_3D_INPUT
)
wp_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_3D_WEIGHT
)
op_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_3D_OUTPUT
)
assert
ip_rank
==
tp_local_rank
%
tp_env
.
depth_3d
assert
wp_rank
==
tp_local_rank
//
tp_env
.
depth_3d
assert
op_rank
==
tp_local_rank
//
(
tp_env
.
depth_3d
**
2
)
def
init_context
(
config_path
,
rank
,
world_size
,
backend
,
port
,
host
):
dist_args
=
dict
(
config
=
config_path
,
rank
=
rank
,
world_size
=
world_size
,
backend
=
backend
,
port
=
port
,
host
=
host
,
verbose
=
True
)
launch
(
**
dist_args
)
check_tensor_parallel_rank
(
rank
)
check_data_parallel_rank
(
rank
)
check_pipeline_parallel_rank
(
rank
)
check_model_parallel_rank
(
rank
)
gpc
.
destroy
()
torch
.
cuda
.
empty_cache
()
def
run_dist
(
rank
,
world_size
,
backend
,
port_list
,
host
):
for
config_path
,
port
in
zip
(
CONFIG_PATH_LIST
,
port_list
):
init_context
(
config_path
=
config_path
,
rank
=
rank
,
world_size
=
world_size
,
backend
=
backend
,
port
=
port
,
host
=
host
)
reset_seeds
()
@
pytest
.
mark
.
cpu
def
test_context
():
"""
As no computation or communication is done, we can run this test on CPU.
"""
world_size
=
32
port_list
=
[]
for
_
in
range
(
len
(
CONFIG_PATH_LIST
)):
while
True
:
port
=
free_port
()
if
port
not
in
port_list
:
port_list
.
append
(
port
)
break
test_fn
=
partial
(
run_dist
,
world_size
=
world_size
,
backend
=
'gloo'
,
port_list
=
port_list
,
host
=
'localhost'
)
mp
.
spawn
(
test_fn
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_context
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment