Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
3601b2ba
Unverified
Commit
3601b2ba
authored
Mar 25, 2022
by
Frank Lee
Committed by
GitHub
Mar 25, 2022
Browse files
[test] fixed rerun_on_exception and adapted test cases (#487)
parent
4d322b79
Changes
31
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
121 additions
and
135 deletions
+121
-135
colossalai/testing/utils.py
colossalai/testing/utils.py
+19
-2
tests/test_amp/test_naive_fp16.py
tests/test_amp/test_naive_fp16.py
+5
-3
tests/test_comm/test_comm.py
tests/test_comm/test_comm.py
+2
-0
tests/test_context/test_hybrid_parallel.py
tests/test_context/test_hybrid_parallel.py
+2
-0
tests/test_data/test_data_parallel_sampler.py
tests/test_data/test_data_parallel_sampler.py
+19
-28
tests/test_data/test_deterministic_dataloader.py
tests/test_data/test_deterministic_dataloader.py
+20
-31
tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py
...e_tensor_parallel/test_cifar_with_data_pipeline_tensor.py
+3
-1
tests/test_engine/test_engine.py
tests/test_engine/test_engine.py
+2
-2
tests/test_layers/test_1d/test_1d.py
tests/test_layers/test_1d/test_1d.py
+4
-16
tests/test_layers/test_2d/test_2d.py
tests/test_layers/test_2d/test_2d.py
+3
-2
tests/test_layers/test_2p5d/test_2p5d.py
tests/test_layers/test_2p5d/test_2p5d.py
+7
-13
tests/test_layers/test_3d/test_3d.py
tests/test_layers/test_3d/test_3d.py
+2
-1
tests/test_layers/test_sequence/test_sequence.py
tests/test_layers/test_sequence/test_sequence.py
+18
-27
tests/test_moe/test_grad_handler.py
tests/test_moe/test_grad_handler.py
+2
-0
tests/test_moe/test_kernel.py
tests/test_moe/test_kernel.py
+2
-0
tests/test_moe/test_moe_group.py
tests/test_moe/test_moe_group.py
+2
-1
tests/test_trainer/test_pipeline/test_p2p.py
tests/test_trainer/test_pipeline/test_p2p.py
+2
-0
tests/test_trainer/test_pipeline/test_partition.py
tests/test_trainer/test_pipeline/test_partition.py
+3
-7
tests/test_trainer/test_pipeline/test_pipeline_schedule.py
tests/test_trainer/test_pipeline/test_pipeline_schedule.py
+2
-0
tests/test_trainer/test_trainer_with_non_pipe_schedule.py
tests/test_trainer/test_trainer_with_non_pipe_schedule.py
+2
-1
No files found.
colossalai/testing/utils.py
View file @
3601b2ba
import
re
import
re
from
typing
import
Callable
,
List
,
Any
from
typing
import
Callable
,
List
,
Any
from
functools
import
partial
from
functools
import
partial
from
inspect
import
signature
def
parameterize
(
argument
:
str
,
values
:
List
[
Any
])
->
Callable
:
def
parameterize
(
argument
:
str
,
values
:
List
[
Any
])
->
Callable
:
...
@@ -105,6 +106,12 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non
...
@@ -105,6 +106,12 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non
If max_try is None, it will rerun foreven if exception keeps occurings
If max_try is None, it will rerun foreven if exception keeps occurings
"""
"""
def
_match_lines
(
lines
,
pattern
):
for
line
in
lines
:
if
re
.
match
(
pattern
,
line
):
return
True
return
False
def
_wrapper
(
func
):
def
_wrapper
(
func
):
def
_run_until_success
(
*
args
,
**
kwargs
):
def
_run_until_success
(
*
args
,
**
kwargs
):
...
@@ -115,15 +122,25 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non
...
@@ -115,15 +122,25 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non
while
max_try
is
None
or
try_count
<
max_try
:
while
max_try
is
None
or
try_count
<
max_try
:
try
:
try
:
try_count
+=
1
try_count
+=
1
func
(
*
args
,
**
kwargs
)
ret
=
func
(
*
args
,
**
kwargs
)
return
ret
except
exception_type
as
e
:
except
exception_type
as
e
:
if
pattern
is
None
or
re
.
match
(
pattern
,
str
(
e
)):
error_lines
=
str
(
e
).
split
(
'
\n
'
)
if
try_count
<
max_try
and
(
pattern
is
None
or
_match_lines
(
error_lines
,
pattern
)):
print
(
'Exception is caught, retrying...'
)
# when pattern is not specified, we always skip the exception
# when pattern is not specified, we always skip the exception
# when pattern is specified, we only skip when pattern is matched
# when pattern is specified, we only skip when pattern is matched
continue
continue
else
:
else
:
print
(
'Maximum number of attempts is reached or pattern is not matched, no more retrying...'
)
raise
e
raise
e
# Override signature
# otherwise pytest.mark.parameterize will raise the following error:
# function does not use argumetn xxx
sig
=
signature
(
func
)
_run_until_success
.
__signature__
=
sig
return
_run_until_success
return
_run_until_success
return
_wrapper
return
_wrapper
tests/test_amp/test_naive_fp16.py
View file @
3601b2ba
import
torch
import
torch
import
torch.multiprocessing
as
mp
import
colossalai
import
colossalai
from
colossalai.testing
import
assert_close_loose
import
torch.multiprocessing
as
mp
from
colossalai.amp
import
convert_to_naive_amp
,
convert_to_apex_amp
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
colossalai.testing
import
assert_close_loose
,
rerun_on_exception
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.amp
import
convert_to_naive_amp
,
convert_to_apex_amp
from
colossalai.amp
import
convert_to_naive_amp
,
convert_to_apex_amp
...
@@ -83,6 +84,7 @@ def run_dist(rank, world_size, port):
...
@@ -83,6 +84,7 @@ def run_dist(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_naive_amp
():
def
test_naive_amp
():
world_size
=
1
world_size
=
1
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
...
...
tests/test_comm/test_comm.py
View file @
3601b2ba
...
@@ -9,6 +9,7 @@ from colossalai.context import ParallelMode
...
@@ -9,6 +9,7 @@ from colossalai.context import ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.initialize
import
launch
from
colossalai.initialize
import
launch
from
colossalai.utils
import
free_port
,
get_current_device
from
colossalai.utils
import
free_port
,
get_current_device
from
colossalai.testing
import
rerun_on_exception
CONFIG
=
dict
(
parallel
=
dict
(
data
=
8
,
pipeline
=
1
,
tensor
=
dict
(
mode
=
None
,
size
=
1
)))
CONFIG
=
dict
(
parallel
=
dict
(
data
=
8
,
pipeline
=
1
,
tensor
=
dict
(
mode
=
None
,
size
=
1
)))
...
@@ -63,6 +64,7 @@ def check_layer(rank, world_size, port):
...
@@ -63,6 +64,7 @@ def check_layer(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_comm
():
def
test_comm
():
world_size
=
4
world_size
=
4
run_func
=
partial
(
check_layer
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
check_layer
,
world_size
=
world_size
,
port
=
free_port
())
...
...
tests/test_context/test_hybrid_parallel.py
View file @
3601b2ba
...
@@ -13,6 +13,7 @@ from colossalai.core import global_context as gpc
...
@@ -13,6 +13,7 @@ from colossalai.core import global_context as gpc
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.context
import
reset_seeds
from
colossalai.context
import
reset_seeds
from
colossalai.global_variables
import
tensor_parallel_env
as
tp_env
from
colossalai.global_variables
import
tensor_parallel_env
as
tp_env
from
colossalai.testing
import
rerun_on_exception
CONFIG_PATH_LIST
=
list
(
Path
(
__file__
).
parent
.
glob
(
'configs/*.py'
))
CONFIG_PATH_LIST
=
list
(
Path
(
__file__
).
parent
.
glob
(
'configs/*.py'
))
...
@@ -140,6 +141,7 @@ def run_dist(rank, world_size, backend, port_list, host):
...
@@ -140,6 +141,7 @@ def run_dist(rank, world_size, backend, port_list, host):
@
pytest
.
mark
.
cpu
@
pytest
.
mark
.
cpu
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_context
():
def
test_context
():
"""
"""
As no computation or communication is done, we can run this test on CPU.
As no computation or communication is done, we can run this test on CPU.
...
...
tests/test_data/test_data_parallel_sampler.py
View file @
3601b2ba
...
@@ -12,29 +12,26 @@ import torch.multiprocessing as mp
...
@@ -12,29 +12,26 @@ import torch.multiprocessing as mp
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
import
colossalai
import
colossalai
from
colossalai.builder
import
build_dataset
,
build_data_sampler
,
build_transform
from
colossalai.builder
import
build_dataset
,
build_transform
from
torchvision
import
transforms
from
torchvision
import
transforms
from
colossalai.context
import
ParallelMode
,
Config
from
colossalai.context
import
ParallelMode
,
Config
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.utils
import
get_dataloader
from
colossalai.utils
import
get_dataloader
,
free_port
from
colossalai.testing
import
rerun_on_exception
CONFIG
=
Config
(
CONFIG
=
Config
(
dict
(
dict
(
train_data
=
dict
(
train_data
=
dict
(
dataset
=
dict
(
dataset
=
dict
(
type
=
'CIFAR10'
,
type
=
'CIFAR10'
,
root
=
Path
(
os
.
environ
[
'DATA'
]),
root
=
Path
(
os
.
environ
[
'DATA'
]),
train
=
True
,
train
=
True
,
download
=
True
,
download
=
True
,
),
dataloader
=
dict
(
batch_size
=
8
,
),
transform_pipeline
=
[
dict
(
type
=
'ToTensor'
),
dict
(
type
=
'Normalize'
,
mean
=
(
0.5
,
0.5
,
0.5
),
std
=
(
0.5
,
0.5
,
0.5
))
]
),
),
dataloader
=
dict
(
batch_size
=
8
,),
transform_pipeline
=
[
dict
(
type
=
'ToTensor'
),
dict
(
type
=
'Normalize'
,
mean
=
(
0.5
,
0.5
,
0.5
),
std
=
(
0.5
,
0.5
,
0.5
))
]),
parallel
=
dict
(
parallel
=
dict
(
pipeline
=
dict
(
size
=
1
),
pipeline
=
dict
(
size
=
1
),
tensor
=
dict
(
size
=
1
,
mode
=
None
),
tensor
=
dict
(
size
=
1
,
mode
=
None
),
...
@@ -43,15 +40,8 @@ CONFIG = Config(
...
@@ -43,15 +40,8 @@ CONFIG = Config(
))
))
def
run_data_sampler
(
rank
,
world_size
):
def
run_data_sampler
(
rank
,
world_size
,
port
):
dist_args
=
dict
(
dist_args
=
dict
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
backend
=
'gloo'
,
port
=
port
,
host
=
'localhost'
)
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
backend
=
'gloo'
,
port
=
'29903'
,
host
=
'localhost'
)
colossalai
.
launch
(
**
dist_args
)
colossalai
.
launch
(
**
dist_args
)
print
(
'finished initialization'
)
print
(
'finished initialization'
)
...
@@ -71,15 +61,16 @@ def run_data_sampler(rank, world_size):
...
@@ -71,15 +61,16 @@ def run_data_sampler(rank, world_size):
dist
.
broadcast
(
img_to_compare
,
src
=
0
,
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
dist
.
broadcast
(
img_to_compare
,
src
=
0
,
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
if
gpc
.
get_local_rank
(
ParallelMode
.
DATA
)
!=
0
:
if
gpc
.
get_local_rank
(
ParallelMode
.
DATA
)
!=
0
:
assert
not
torch
.
equal
(
img
,
assert
not
torch
.
equal
(
img_to_compare
),
'Same image was distributed across ranks but expected it to be different'
img
,
img_to_compare
),
'Same image was distributed across ranks but expected it to be different'
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
@
pytest
.
mark
.
cpu
@
pytest
.
mark
.
cpu
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_data_sampler
():
def
test_data_sampler
():
world_size
=
4
world_size
=
4
test_func
=
partial
(
run_data_sampler
,
world_size
=
world_size
)
test_func
=
partial
(
run_data_sampler
,
world_size
=
world_size
,
port
=
free_port
()
)
mp
.
spawn
(
test_func
,
nprocs
=
world_size
)
mp
.
spawn
(
test_func
,
nprocs
=
world_size
)
...
...
tests/test_data/test_deterministic_dataloader.py
View file @
3601b2ba
...
@@ -16,45 +16,33 @@ import colossalai
...
@@ -16,45 +16,33 @@ import colossalai
from
colossalai.builder
import
build_dataset
,
build_transform
from
colossalai.builder
import
build_dataset
,
build_transform
from
colossalai.context
import
ParallelMode
,
Config
from
colossalai.context
import
ParallelMode
,
Config
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.utils
import
free_port
from
colossalai.testing
import
rerun_on_exception
CONFIG
=
Config
(
CONFIG
=
Config
(
dict
(
dict
(
train_data
=
dict
(
train_data
=
dict
(
dataset
=
dict
(
dataset
=
dict
(
type
=
'CIFAR10'
,
type
=
'CIFAR10'
,
root
=
Path
(
os
.
environ
[
'DATA'
]),
root
=
Path
(
os
.
environ
[
'DATA'
]),
train
=
True
,
train
=
True
,
download
=
True
,
download
=
True
,
),
dataloader
=
dict
(
num_workers
=
2
,
batch_size
=
2
,
shuffle
=
True
),
transform_pipeline
=
[
dict
(
type
=
'ToTensor'
),
dict
(
type
=
'RandomCrop'
,
size
=
32
),
dict
(
type
=
'Normalize'
,
mean
=
(
0.5
,
0.5
,
0.5
),
std
=
(
0.5
,
0.5
,
0.5
))
]
),
),
dataloader
=
dict
(
num_workers
=
2
,
batch_size
=
2
,
shuffle
=
True
),
transform_pipeline
=
[
dict
(
type
=
'ToTensor'
),
dict
(
type
=
'RandomCrop'
,
size
=
32
),
dict
(
type
=
'Normalize'
,
mean
=
(
0.5
,
0.5
,
0.5
),
std
=
(
0.5
,
0.5
,
0.5
))
]),
parallel
=
dict
(
parallel
=
dict
(
pipeline
=
dict
(
size
=
1
),
pipeline
=
dict
(
size
=
1
),
tensor
=
dict
(
size
=
1
,
mode
=
None
),
tensor
=
dict
(
size
=
1
,
mode
=
None
),
),
),
seed
=
1024
,
seed
=
1024
,
)
))
)
def
run_data_sampler
(
rank
,
world_size
,
port
):
def
run_data_sampler
(
rank
,
world_size
):
dist_args
=
dict
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
backend
=
'gloo'
,
port
=
port
,
host
=
'localhost'
)
dist_args
=
dict
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
backend
=
'gloo'
,
port
=
'29904'
,
host
=
'localhost'
)
colossalai
.
launch
(
**
dist_args
)
colossalai
.
launch
(
**
dist_args
)
dataset_cfg
=
gpc
.
config
.
train_data
.
dataset
dataset_cfg
=
gpc
.
config
.
train_data
.
dataset
...
@@ -91,9 +79,10 @@ def run_data_sampler(rank, world_size):
...
@@ -91,9 +79,10 @@ def run_data_sampler(rank, world_size):
@
pytest
.
mark
.
cpu
@
pytest
.
mark
.
cpu
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_data_sampler
():
def
test_data_sampler
():
world_size
=
4
world_size
=
4
test_func
=
partial
(
run_data_sampler
,
world_size
=
world_size
)
test_func
=
partial
(
run_data_sampler
,
world_size
=
world_size
,
port
=
free_port
()
)
mp
.
spawn
(
test_func
,
nprocs
=
world_size
)
mp
.
spawn
(
test_func
,
nprocs
=
world_size
)
...
...
tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py
View file @
3601b2ba
...
@@ -13,8 +13,9 @@ from colossalai.logging import get_dist_logger
...
@@ -13,8 +13,9 @@ from colossalai.logging import get_dist_logger
from
colossalai.nn
import
LinearWarmupLR
from
colossalai.nn
import
LinearWarmupLR
from
colossalai.nn.loss
import
CrossEntropyLoss
from
colossalai.nn.loss
import
CrossEntropyLoss
from
colossalai.trainer
import
Trainer
,
hooks
from
colossalai.trainer
import
Trainer
,
hooks
from
colossalai.utils
import
MultiTimer
,
free_port
,
get_dataloader
from
colossalai.utils
import
free_port
,
get_dataloader
from
colossalai.utils.gradient_accumulation
import
GradAccumLrSchedulerByStep
from
colossalai.utils.gradient_accumulation
import
GradAccumLrSchedulerByStep
from
colossalai.testing
import
rerun_on_exception
from
model_zoo.vit
import
vit_tiny_patch4_32
from
model_zoo.vit
import
vit_tiny_patch4_32
from
torchvision
import
transforms
from
torchvision
import
transforms
from
torchvision.datasets
import
CIFAR10
from
torchvision.datasets
import
CIFAR10
...
@@ -79,6 +80,7 @@ def run_trainer(rank, world_size, port):
...
@@ -79,6 +80,7 @@ def run_trainer(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_hybrid_parallel
():
def
test_hybrid_parallel
():
world_size
=
8
world_size
=
8
run_func
=
partial
(
run_trainer
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
run_trainer
,
world_size
=
world_size
,
port
=
free_port
())
...
...
tests/test_engine/test_engine.py
View file @
3601b2ba
...
@@ -4,11 +4,10 @@ import colossalai
...
@@ -4,11 +4,10 @@ import colossalai
import
pytest
import
pytest
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
colossalai.amp
import
AMP_TYPE
from
colossalai.amp
import
AMP_TYPE
from
colossalai.context
import
Config
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
colossalai.testing
import
parameterize
from
colossalai.testing
import
parameterize
,
rerun_on_exception
CONFIG
=
dict
(
parallel
=
dict
(
pipeline
=
dict
(
size
=
1
),
tensor
=
dict
(
size
=
1
,
mode
=
None
)),
CONFIG
=
dict
(
parallel
=
dict
(
pipeline
=
dict
(
size
=
1
),
tensor
=
dict
(
size
=
1
,
mode
=
None
)),
fp16
=
dict
(
mode
=
None
),
fp16
=
dict
(
mode
=
None
),
...
@@ -57,6 +56,7 @@ def run_engine(rank, world_size, port):
...
@@ -57,6 +56,7 @@ def run_engine(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_engine
():
def
test_engine
():
world_size
=
2
world_size
=
2
run_func
=
partial
(
run_engine
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
run_engine
,
world_size
=
world_size
,
port
=
free_port
())
...
...
tests/test_layers/test_1d/test_1d.py
View file @
3601b2ba
...
@@ -10,28 +10,15 @@ from colossalai.core import global_context as gpc
...
@@ -10,28 +10,15 @@ from colossalai.core import global_context as gpc
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.initialize
import
launch
from
colossalai.initialize
import
launch
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.testing
import
rerun_on_exception
from
checks_1d.check_layer_1d
import
*
from
checks_1d.check_layer_1d
import
*
CONFIG
=
dict
(
CONFIG
=
dict
(
parallel
=
dict
(
pipeline
=
dict
(
size
=
1
),
tensor
=
dict
(
size
=
4
,
mode
=
'1d'
)),)
parallel
=
dict
(
pipeline
=
dict
(
size
=
1
),
tensor
=
dict
(
size
=
4
,
mode
=
'1d'
)
),
)
def
check_layer
(
rank
,
world_size
,
port
):
def
check_layer
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
disable_existing_loggers
()
launch
(
config
=
CONFIG
,
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
check_linear_col
()
check_linear_col
()
check_linear_row
()
check_linear_row
()
...
@@ -48,6 +35,7 @@ def check_layer(rank, world_size, port):
...
@@ -48,6 +35,7 @@ def check_layer(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_1d
():
def
test_1d
():
world_size
=
4
world_size
=
4
run_func
=
partial
(
check_layer
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
check_layer
,
world_size
=
world_size
,
port
=
free_port
())
...
...
tests/test_layers/test_2d/test_2d.py
View file @
3601b2ba
...
@@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc
...
@@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc
from
colossalai.initialize
import
launch
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.testing
import
rerun_on_exception
from
checks_2d.check_layer_2d
import
(
check_classifier_given_embed_weight
,
check_classifier_no_given_weight
,
from
checks_2d.check_layer_2d
import
(
check_classifier_given_embed_weight
,
check_classifier_no_given_weight
,
check_embed
,
check_layernorm
,
check_linear
,
check_loss
,
check_patch_embed
,
check_embed
,
check_layernorm
,
check_linear
,
check_loss
,
check_patch_embed
,
check_vocab_parallel_classifier_given_embed_weight
,
check_vocab_parallel_classifier_given_embed_weight
,
...
@@ -18,7 +18,7 @@ from checks_2d.check_layer_2d import (check_classifier_given_embed_weight, check
...
@@ -18,7 +18,7 @@ from checks_2d.check_layer_2d import (check_classifier_given_embed_weight, check
check_vocab_parallel_loss
)
check_vocab_parallel_loss
)
from
checks_2d.check_operation_2d
import
check_AB
,
check_ABT
,
check_ATB
from
checks_2d.check_operation_2d
import
check_AB
,
check_ABT
,
check_ATB
CONFIG
=
dict
(
parallel
=
dict
(
pipeline
=
dict
(
size
=
1
),
tensor
=
dict
(
size
=
4
,
mode
=
'2d'
)),
)
CONFIG
=
dict
(
parallel
=
dict
(
pipeline
=
dict
(
size
=
1
),
tensor
=
dict
(
size
=
4
,
mode
=
'2d'
)),)
def
check_operations
():
def
check_operations
():
...
@@ -55,6 +55,7 @@ def check_layer_and_operation(rank, world_size, port):
...
@@ -55,6 +55,7 @@ def check_layer_and_operation(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_2d
():
def
test_2d
():
world_size
=
4
world_size
=
4
run_func
=
partial
(
check_layer_and_operation
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
check_layer_and_operation
,
world_size
=
world_size
,
port
=
free_port
())
...
...
tests/test_layers/test_2p5d/test_2p5d.py
View file @
3601b2ba
...
@@ -7,16 +7,14 @@ from colossalai.core import global_context as gpc
...
@@ -7,16 +7,14 @@ from colossalai.core import global_context as gpc
from
colossalai.initialize
import
launch
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.testing
import
rerun_on_exception
from
checks_2p5d.check_layer_2p5d
import
*
from
checks_2p5d.check_layer_2p5d
import
*
from
checks_2p5d.check_operation_2p5d
import
check_AB
,
check_ABT
,
check_ATB
from
checks_2p5d.check_operation_2p5d
import
check_AB
,
check_ABT
,
check_ATB
CONFIG
=
dict
(
CONFIG
=
dict
(
parallel
=
dict
(
parallel
=
dict
(
pipeline
=
dict
(
size
=
1
),
pipeline
=
dict
(
size
=
1
),
tensor
=
dict
(
size
=
4
,
mode
=
'2.5d'
,
depth
=
1
),
tensor
=
dict
(
size
=
4
,
mode
=
'2.5d'
,
depth
=
1
),
),)
),
)
def
check_operations
():
def
check_operations
():
...
@@ -41,12 +39,7 @@ def check_layer():
...
@@ -41,12 +39,7 @@ def check_layer():
def
check_layer_and_operation
(
rank
,
world_size
,
port
):
def
check_layer_and_operation
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
disable_existing_loggers
()
launch
(
config
=
CONFIG
,
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
backends
.
cudnn
.
allow_tf32
=
False
torch
.
backends
.
cudnn
.
allow_tf32
=
False
...
@@ -58,6 +51,7 @@ def check_layer_and_operation(rank, world_size, port):
...
@@ -58,6 +51,7 @@ def check_layer_and_operation(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_2p5d
():
def
test_2p5d
():
world_size
=
4
world_size
=
4
run_func
=
partial
(
check_layer_and_operation
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
check_layer_and_operation
,
world_size
=
world_size
,
port
=
free_port
())
...
...
tests/test_layers/test_3d/test_3d.py
View file @
3601b2ba
...
@@ -9,7 +9,7 @@ from colossalai.core import global_context as gpc
...
@@ -9,7 +9,7 @@ from colossalai.core import global_context as gpc
from
colossalai.initialize
import
launch
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.testing
import
rerun_on_exception
from
checks_3d.check_layer_3d
import
(
check_classifier_given_embed_weight
,
check_classifier_no_given_weight
,
from
checks_3d.check_layer_3d
import
(
check_classifier_given_embed_weight
,
check_classifier_no_given_weight
,
check_embed
,
check_layernorm
,
check_linear
,
check_loss
,
check_patch_embed
,
check_embed
,
check_layernorm
,
check_linear
,
check_loss
,
check_patch_embed
,
check_vocab_parallel_classifier_given_embed_weight
,
check_vocab_parallel_classifier_given_embed_weight
,
...
@@ -51,6 +51,7 @@ def check_layer_and_operation(rank, world_size, port):
...
@@ -51,6 +51,7 @@ def check_layer_and_operation(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_3d
():
def
test_3d
():
world_size
=
8
world_size
=
8
run_func
=
partial
(
check_layer_and_operation
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
check_layer_and_operation
,
world_size
=
world_size
,
port
=
free_port
())
...
...
tests/test_layers/test_sequence/test_sequence.py
View file @
3601b2ba
...
@@ -7,14 +7,10 @@ import pytest
...
@@ -7,14 +7,10 @@ import pytest
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context
import
ParallelMode
from
colossalai.context
import
ParallelMode
from
colossalai.testing
import
rerun_on_exception
from
functools
import
partial
from
functools
import
partial
CONFIG
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
size
=
4
,
mode
=
'sequence'
)))
CONFIG
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
size
=
4
,
mode
=
'sequence'
)
)
)
def
check_ring_qk
(
rank
,
world_size
):
def
check_ring_qk
(
rank
,
world_size
):
...
@@ -26,14 +22,14 @@ def check_ring_qk(rank, world_size):
...
@@ -26,14 +22,14 @@ def check_ring_qk(rank, world_size):
sub_seq_length
=
seq_length
//
world_size
sub_seq_length
=
seq_length
//
world_size
# create master tensors
# create master tensors
q
=
torch
.
rand
(
batch_size
*
num_heads
,
seq_length
,
attention_head_size
).
cuda
()
q
=
torch
.
rand
(
batch_size
*
num_heads
,
seq_length
,
attention_head_size
).
cuda
()
k
=
torch
.
rand
(
batch_size
*
num_heads
,
seq_length
,
attention_head_size
).
cuda
()
k
=
torch
.
rand
(
batch_size
*
num_heads
,
seq_length
,
attention_head_size
).
cuda
()
dist
.
broadcast
(
q
,
src
=
0
,
group
=
gpc
.
get_group
(
ParallelMode
.
SEQUENCE
))
dist
.
broadcast
(
q
,
src
=
0
,
group
=
gpc
.
get_group
(
ParallelMode
.
SEQUENCE
))
dist
.
broadcast
(
k
,
src
=
0
,
group
=
gpc
.
get_group
(
ParallelMode
.
SEQUENCE
))
dist
.
broadcast
(
k
,
src
=
0
,
group
=
gpc
.
get_group
(
ParallelMode
.
SEQUENCE
))
# create distributed tensors
# create distributed tensors
sub_q
=
q
.
clone
()[:,
rank
*
sub_seq_length
:(
rank
+
1
)
*
sub_seq_length
].
contiguous
()
sub_q
=
q
.
clone
()[:,
rank
*
sub_seq_length
:(
rank
+
1
)
*
sub_seq_length
].
contiguous
()
sub_k
=
k
.
clone
()[:,
rank
*
sub_seq_length
:(
rank
+
1
)
*
sub_seq_length
].
contiguous
()
sub_k
=
k
.
clone
()[:,
rank
*
sub_seq_length
:(
rank
+
1
)
*
sub_seq_length
].
contiguous
()
# set autograd attributes
# set autograd attributes
q
.
requires_grad
=
True
q
.
requires_grad
=
True
...
@@ -53,7 +49,7 @@ def check_ring_qk(rank, world_size):
...
@@ -53,7 +49,7 @@ def check_ring_qk(rank, world_size):
sub_a
=
ring_qk
(
sub_q
,
sub_k
,
batch_size
,
num_heads
,
sub_seq_length
)
sub_a
=
ring_qk
(
sub_q
,
sub_k
,
batch_size
,
num_heads
,
sub_seq_length
)
# check master and distributed attetion scores
# check master and distributed attetion scores
sub_master_a
=
a
[:,
rank
*
sub_seq_length
:(
rank
+
1
)
*
sub_seq_length
]
sub_master_a
=
a
[:,
rank
*
sub_seq_length
:(
rank
+
1
)
*
sub_seq_length
]
assert
torch
.
allclose
(
sub_a
,
sub_master_a
,
rtol
=
1e-5
,
atol
=
1e-2
)
assert
torch
.
allclose
(
sub_a
,
sub_master_a
,
rtol
=
1e-5
,
atol
=
1e-2
)
# run master backward
# run master backward
...
@@ -61,11 +57,11 @@ def check_ring_qk(rank, world_size):
...
@@ -61,11 +57,11 @@ def check_ring_qk(rank, world_size):
a
.
mean
().
backward
()
a
.
mean
().
backward
()
# run distributed backward
# run distributed backward
partial_master_a_grad
=
a
.
grad
[:,
rank
*
sub_seq_length
:(
rank
+
1
)
*
sub_seq_length
]
partial_master_a_grad
=
a
.
grad
[:,
rank
*
sub_seq_length
:(
rank
+
1
)
*
sub_seq_length
]
torch
.
autograd
.
backward
(
sub_a
,
partial_master_a_grad
)
torch
.
autograd
.
backward
(
sub_a
,
partial_master_a_grad
)
# check master and distributed grads
# check master and distributed grads
partial_master_q_grad
=
q
.
grad
[:,
rank
*
sub_seq_length
:(
rank
+
1
)
*
sub_seq_length
]
partial_master_q_grad
=
q
.
grad
[:,
rank
*
sub_seq_length
:(
rank
+
1
)
*
sub_seq_length
]
assert
torch
.
allclose
(
sub_q
.
grad
,
partial_master_q_grad
,
rtol
=
1e-5
,
atol
=
1e-2
),
\
assert
torch
.
allclose
(
sub_q
.
grad
,
partial_master_q_grad
,
rtol
=
1e-5
,
atol
=
1e-2
),
\
'attention score cannot match'
'attention score cannot match'
...
@@ -79,14 +75,14 @@ def check_ring_av(rank, world_size):
...
@@ -79,14 +75,14 @@ def check_ring_av(rank, world_size):
sub_seq_length
=
seq_length
//
world_size
sub_seq_length
=
seq_length
//
world_size
# create master tensors
# create master tensors
a
=
torch
.
rand
(
batch_size
*
num_heads
,
seq_length
,
seq_length
).
cuda
()
a
=
torch
.
rand
(
batch_size
*
num_heads
,
seq_length
,
seq_length
).
cuda
()
v
=
torch
.
rand
(
batch_size
*
num_heads
,
seq_length
,
attention_head_size
).
cuda
()
v
=
torch
.
rand
(
batch_size
*
num_heads
,
seq_length
,
attention_head_size
).
cuda
()
dist
.
broadcast
(
a
,
src
=
0
,
group
=
gpc
.
get_group
(
ParallelMode
.
SEQUENCE
))
dist
.
broadcast
(
a
,
src
=
0
,
group
=
gpc
.
get_group
(
ParallelMode
.
SEQUENCE
))
dist
.
broadcast
(
v
,
src
=
0
,
group
=
gpc
.
get_group
(
ParallelMode
.
SEQUENCE
))
dist
.
broadcast
(
v
,
src
=
0
,
group
=
gpc
.
get_group
(
ParallelMode
.
SEQUENCE
))
# create distributed tensors
# create distributed tensors
sub_a
=
a
.
clone
()[:,
rank
*
sub_seq_length
:(
rank
+
1
)
*
sub_seq_length
].
contiguous
()
sub_a
=
a
.
clone
()[:,
rank
*
sub_seq_length
:(
rank
+
1
)
*
sub_seq_length
].
contiguous
()
sub_v
=
v
.
clone
()[:,
rank
*
sub_seq_length
:(
rank
+
1
)
*
sub_seq_length
].
contiguous
()
sub_v
=
v
.
clone
()[:,
rank
*
sub_seq_length
:(
rank
+
1
)
*
sub_seq_length
].
contiguous
()
# set autograd attributes
# set autograd attributes
a
.
requires_grad
=
True
a
.
requires_grad
=
True
...
@@ -108,7 +104,7 @@ def check_ring_av(rank, world_size):
...
@@ -108,7 +104,7 @@ def check_ring_av(rank, world_size):
# print(f'master output shape: {out.shape}, partial output shape: {sub_out.shape}')
# print(f'master output shape: {out.shape}, partial output shape: {sub_out.shape}')
# check master and distributed output
# check master and distributed output
sub_master_out
=
out
[:,
rank
*
sub_seq_length
:(
rank
+
1
)
*
sub_seq_length
]
sub_master_out
=
out
[:,
rank
*
sub_seq_length
:(
rank
+
1
)
*
sub_seq_length
]
assert
torch
.
allclose
(
sub_out
,
sub_master_out
,
rtol
=
1e-5
,
atol
=
1e-2
)
assert
torch
.
allclose
(
sub_out
,
sub_master_out
,
rtol
=
1e-5
,
atol
=
1e-2
)
# # run master backward
# # run master backward
...
@@ -116,23 +112,17 @@ def check_ring_av(rank, world_size):
...
@@ -116,23 +112,17 @@ def check_ring_av(rank, world_size):
out
.
mean
().
backward
()
out
.
mean
().
backward
()
# # run distributed backward
# # run distributed backward
partial_master_out_grad
=
out
.
grad
[:,
rank
*
sub_seq_length
:(
rank
+
1
)
*
sub_seq_length
]
partial_master_out_grad
=
out
.
grad
[:,
rank
*
sub_seq_length
:(
rank
+
1
)
*
sub_seq_length
]
torch
.
autograd
.
backward
(
sub_out
,
partial_master_out_grad
)
torch
.
autograd
.
backward
(
sub_out
,
partial_master_out_grad
)
# # check master and distributed grads
# # check master and distributed grads
partial_master_a_grad
=
a
.
grad
[:,
rank
*
sub_seq_length
:(
rank
+
1
)
*
sub_seq_length
]
partial_master_a_grad
=
a
.
grad
[:,
rank
*
sub_seq_length
:(
rank
+
1
)
*
sub_seq_length
]
assert
torch
.
allclose
(
sub_a
.
grad
,
partial_master_a_grad
,
rtol
=
1e-5
,
atol
=
1e-2
),
\
assert
torch
.
allclose
(
sub_a
.
grad
,
partial_master_a_grad
,
rtol
=
1e-5
,
atol
=
1e-2
),
\
'attention output cannot match'
'attention output cannot match'
def
run_test
(
rank
,
world_size
):
def
run_test
(
rank
,
world_size
):
colossalai
.
launch
(
colossalai
.
launch
(
rank
=
rank
,
world_size
=
world_size
,
config
=
CONFIG
,
host
=
'localhost'
,
port
=
29500
)
rank
=
rank
,
world_size
=
world_size
,
config
=
CONFIG
,
host
=
'localhost'
,
port
=
29500
)
# check_ring_qk(rank, world_size)
# check_ring_qk(rank, world_size)
check_ring_av
(
rank
,
world_size
)
check_ring_av
(
rank
,
world_size
)
...
@@ -142,6 +132,7 @@ def run_test(rank, world_size):
...
@@ -142,6 +132,7 @@ def run_test(rank, world_size):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_sequence
():
def
test_sequence
():
world_size
=
4
world_size
=
4
run_func
=
partial
(
run_test
,
world_size
=
world_size
)
run_func
=
partial
(
run_test
,
world_size
=
world_size
)
...
...
tests/test_moe/test_grad_handler.py
View file @
3601b2ba
...
@@ -11,6 +11,7 @@ from colossalai.context.moe_context import MOE_CONTEXT
...
@@ -11,6 +11,7 @@ from colossalai.context.moe_context import MOE_CONTEXT
from
colossalai.utils.moe
import
sync_moe_model_param
from
colossalai.utils.moe
import
sync_moe_model_param
from
colossalai.engine.gradient_handler
import
MoeGradientHandler
from
colossalai.engine.gradient_handler
import
MoeGradientHandler
from
colossalai.testing
import
assert_equal_in_group
from
colossalai.testing
import
assert_equal_in_group
from
colossalai.testing
import
rerun_on_exception
BATCH_SIZE
=
4
BATCH_SIZE
=
4
DIM
=
16
DIM
=
16
...
@@ -62,6 +63,7 @@ def run_test(rank, world_size, port):
...
@@ -62,6 +63,7 @@ def run_test(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_grad_handler
():
def
test_grad_handler
():
world_size
=
4
world_size
=
4
run_func
=
partial
(
run_test
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
run_test
,
world_size
=
world_size
,
port
=
free_port
())
...
...
tests/test_moe/test_kernel.py
View file @
3601b2ba
...
@@ -9,6 +9,7 @@ from colossalai.core import global_context as gpc
...
@@ -9,6 +9,7 @@ from colossalai.core import global_context as gpc
from
colossalai.utils
import
free_port
,
get_current_device
from
colossalai.utils
import
free_port
,
get_current_device
from
colossalai.nn.layer.moe
import
Top1Router
,
Top2Router
,
MoeLayer
,
Experts
from
colossalai.nn.layer.moe
import
Top1Router
,
Top2Router
,
MoeLayer
,
Experts
from
colossalai.context.moe_context
import
MOE_CONTEXT
from
colossalai.context.moe_context
import
MOE_CONTEXT
from
colossalai.testing
import
rerun_on_exception
BATCH_SIZE
=
16
BATCH_SIZE
=
16
NUM_EXPERTS
=
4
NUM_EXPERTS
=
4
...
@@ -86,6 +87,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
...
@@ -86,6 +87,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
32
,
144
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
32
,
144
])
@
pytest
.
mark
.
parametrize
(
"data_type"
,
[
torch
.
float32
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"data_type"
,
[
torch
.
float32
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"router"
,
[
Top1Router
,
Top2Router
])
@
pytest
.
mark
.
parametrize
(
"router"
,
[
Top1Router
,
Top2Router
])
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_moe_kernel
(
rs
,
hidden_size
,
data_type
,
router
):
def
test_moe_kernel
(
rs
,
hidden_size
,
data_type
,
router
):
world_size
=
4
world_size
=
4
run_func
=
partial
(
run_routing
,
run_func
=
partial
(
run_routing
,
...
...
tests/test_moe/test_moe_group.py
View file @
3601b2ba
...
@@ -8,7 +8,7 @@ from colossalai.utils import free_port, get_current_device
...
@@ -8,7 +8,7 @@ from colossalai.utils import free_port, get_current_device
from
colossalai.nn.layer.moe
import
Experts
from
colossalai.nn.layer.moe
import
Experts
from
colossalai.context.moe_context
import
MOE_CONTEXT
from
colossalai.context.moe_context
import
MOE_CONTEXT
from
colossalai.utils.moe
import
sync_moe_model_param
from
colossalai.utils.moe
import
sync_moe_model_param
from
colossalai.testing
import
assert_equal_in_group
from
colossalai.testing
import
assert_equal_in_group
,
rerun_on_exception
D_MODEL
=
4
D_MODEL
=
4
D_FF
=
8
D_FF
=
8
...
@@ -60,6 +60,7 @@ def run_test(rank, port):
...
@@ -60,6 +60,7 @@ def run_test(rank, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_moe_initialization
():
def
test_moe_initialization
():
world_size
=
4
world_size
=
4
run_func
=
partial
(
run_test
,
port
=
free_port
())
run_func
=
partial
(
run_test
,
port
=
free_port
())
...
...
tests/test_trainer/test_pipeline/test_p2p.py
View file @
3601b2ba
...
@@ -15,6 +15,7 @@ from colossalai.core import global_context as gpc
...
@@ -15,6 +15,7 @@ from colossalai.core import global_context as gpc
from
colossalai.initialize
import
launch
from
colossalai.initialize
import
launch
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
free_port
,
get_current_device
from
colossalai.utils
import
free_port
,
get_current_device
from
colossalai.testing
import
rerun_on_exception
BATCH_SIZE
=
4
BATCH_SIZE
=
4
SEQ_LENGTH
=
2
SEQ_LENGTH
=
2
...
@@ -92,6 +93,7 @@ def run_check(rank, world_size, port):
...
@@ -92,6 +93,7 @@ def run_check(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_p2p
():
def
test_p2p
():
world_size
=
4
world_size
=
4
run_func
=
partial
(
run_check
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
run_check
,
world_size
=
world_size
,
port
=
free_port
())
...
...
tests/test_trainer/test_pipeline/test_partition.py
View file @
3601b2ba
...
@@ -10,19 +10,14 @@ from colossalai.initialize import launch
...
@@ -10,19 +10,14 @@ from colossalai.initialize import launch
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
functools
import
partial
from
functools
import
partial
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.testing
import
rerun_on_exception
DIR_PATH
=
osp
.
dirname
(
osp
.
realpath
(
__file__
))
DIR_PATH
=
osp
.
dirname
(
osp
.
realpath
(
__file__
))
CONFIG_PATH
=
osp
.
join
(
DIR_PATH
,
'resnet_config.py'
)
CONFIG_PATH
=
osp
.
join
(
DIR_PATH
,
'resnet_config.py'
)
def
run_partition
(
rank
,
world_size
,
port
):
def
run_partition
(
rank
,
world_size
,
port
):
launch
(
config
=
CONFIG_PATH
,
launch
(
config
=
CONFIG_PATH
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
logger
=
get_dist_logger
()
logger
=
get_dist_logger
()
logger
.
info
(
'finished initialization'
)
logger
.
info
(
'finished initialization'
)
...
@@ -37,6 +32,7 @@ def run_partition(rank, world_size, port):
...
@@ -37,6 +32,7 @@ def run_partition(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_partition
():
def
test_partition
():
world_size
=
4
world_size
=
4
run_func
=
partial
(
run_partition
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
run_partition
,
world_size
=
world_size
,
port
=
free_port
())
...
...
tests/test_trainer/test_pipeline/test_pipeline_schedule.py
View file @
3601b2ba
...
@@ -14,6 +14,7 @@ from colossalai.core import global_context as gpc
...
@@ -14,6 +14,7 @@ from colossalai.core import global_context as gpc
from
colossalai.engine.schedule
import
PipelineSchedule
from
colossalai.engine.schedule
import
PipelineSchedule
from
colossalai.initialize
import
launch
from
colossalai.initialize
import
launch
from
colossalai.utils
import
free_port
,
get_dataloader
,
print_rank_0
from
colossalai.utils
import
free_port
,
get_dataloader
,
print_rank_0
from
colossalai.testing
import
rerun_on_exception
from
torchvision
import
transforms
from
torchvision
import
transforms
from
torchvision.datasets
import
CIFAR10
from
torchvision.datasets
import
CIFAR10
...
@@ -67,6 +68,7 @@ def run_schedule(rank, world_size, port):
...
@@ -67,6 +68,7 @@ def run_schedule(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_pipeline_schedule
():
def
test_pipeline_schedule
():
world_size
=
4
world_size
=
4
run_func
=
partial
(
run_schedule
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
run_schedule
,
world_size
=
world_size
,
port
=
free_port
())
...
...
tests/test_trainer/test_trainer_with_non_pipe_schedule.py
View file @
3601b2ba
...
@@ -9,7 +9,7 @@ from colossalai.logging import get_dist_logger
...
@@ -9,7 +9,7 @@ from colossalai.logging import get_dist_logger
from
colossalai.trainer
import
Trainer
from
colossalai.trainer
import
Trainer
from
colossalai.utils
import
MultiTimer
,
free_port
from
colossalai.utils
import
MultiTimer
,
free_port
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
colossalai.testing
import
parameterize
from
colossalai.testing
import
parameterize
,
rerun_on_exception
BATCH_SIZE
=
4
BATCH_SIZE
=
4
IMG_SIZE
=
32
IMG_SIZE
=
32
...
@@ -51,6 +51,7 @@ def run_dist(rank, world_size, port):
...
@@ -51,6 +51,7 @@ def run_dist(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_trainer_no_pipeline
():
def
test_trainer_no_pipeline
():
world_size
=
4
world_size
=
4
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment