Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
943982d2
Unverified
Commit
943982d2
authored
Apr 22, 2022
by
Frank Lee
Committed by
GitHub
Apr 22, 2022
Browse files
[unittest] refactored unit tests for change in dependency (#838)
parent
f271f347
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
25 additions
and
58 deletions
+25
-58
tests/test_data/test_cifar10_dataset.py
tests/test_data/test_cifar10_dataset.py
+4
-17
tests/test_data/test_data_parallel_sampler.py
tests/test_data/test_data_parallel_sampler.py
+15
-25
tests/test_data/test_deterministic_dataloader.py
tests/test_data/test_deterministic_dataloader.py
+6
-16
No files found.
tests/test_data/test_cifar10_dataset.py
View file @
943982d2
...
@@ -5,34 +5,21 @@ import os
...
@@ -5,34 +5,21 @@ import os
from
pathlib
import
Path
from
pathlib
import
Path
import
pytest
import
pytest
from
torchvision
import
transforms
from
torchvision
import
transforms
,
datasets
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
colossalai.builder
import
build_dataset
,
build_transform
from
colossalai.context
import
Config
from
torchvision.transforms
import
ToTensor
TRAIN_DATA
=
dict
(
dataset
=
dict
(
type
=
'CIFAR10'
,
root
=
Path
(
os
.
environ
[
'DATA'
]),
train
=
True
,
download
=
True
),
dataloader
=
dict
(
batch_size
=
4
,
shuffle
=
True
,
num_workers
=
2
))
@
pytest
.
mark
.
cpu
@
pytest
.
mark
.
cpu
def
test_cifar10_dataset
():
def
test_cifar10_dataset
():
config
=
Config
(
TRAIN_DATA
)
dataset_cfg
=
config
.
dataset
dataloader_cfg
=
config
.
dataloader
transform_cfg
=
config
.
transform_pipeline
# build transform
# build transform
transform_pipeline
=
[
ToTensor
()]
transform_pipeline
=
[
transforms
.
ToTensor
()]
transform_pipeline
=
transforms
.
Compose
(
transform_pipeline
)
transform_pipeline
=
transforms
.
Compose
(
transform_pipeline
)
dataset_cfg
[
'transform'
]
=
transform_pipeline
# build dataset
# build dataset
dataset
=
build_
dataset
(
dataset_cfg
)
dataset
=
dataset
s
.
CIFAR10
(
root
=
Path
(
os
.
environ
[
'DATA'
]),
train
=
True
,
download
=
True
,
transform
=
transform_pipeline
)
# build dataloader
# build dataloader
dataloader
=
DataLoader
(
dataset
=
dataset
,
**
dataloader_cfg
)
dataloader
=
DataLoader
(
dataset
=
dataset
,
batch_size
=
4
,
shuffle
=
True
,
num_workers
=
2
)
data_iter
=
iter
(
dataloader
)
data_iter
=
iter
(
dataloader
)
img
,
label
=
data_iter
.
next
()
img
,
label
=
data_iter
.
next
()
...
...
tests/test_data/test_data_parallel_sampler.py
View file @
943982d2
...
@@ -9,34 +9,21 @@ import pytest
...
@@ -9,34 +9,21 @@ import pytest
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
torch.utils.data
import
DataLoader
import
colossalai
import
colossalai
from
colossalai.builder
import
build_dataset
from
torchvision
import
transforms
,
datasets
from
torchvision
import
transforms
from
colossalai.context
import
ParallelMode
,
Config
from
colossalai.context
import
ParallelMode
,
Config
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.utils
import
get_dataloader
,
free_port
from
colossalai.utils
import
get_dataloader
,
free_port
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing
import
rerun_if_address_is_in_use
from
torchvision.transforms
import
ToTensor
CONFIG
=
Config
(
CONFIG
=
Config
(
dict
(
dict
(
train_data
=
dict
(
dataset
=
dict
(
type
=
'CIFAR10'
,
root
=
Path
(
os
.
environ
[
'DATA'
]),
train
=
True
,
download
=
True
,
),
dataloader
=
dict
(
batch_size
=
8
,),
),
parallel
=
dict
(
parallel
=
dict
(
pipeline
=
dict
(
size
=
1
),
pipeline
=
dict
(
size
=
1
),
tensor
=
dict
(
size
=
1
,
mode
=
None
),
tensor
=
dict
(
size
=
1
,
mode
=
None
),
),
),
seed
=
1024
,
seed
=
1024
,
))
))
def
run_data_sampler
(
rank
,
world_size
,
port
):
def
run_data_sampler
(
rank
,
world_size
,
port
):
...
@@ -44,11 +31,14 @@ def run_data_sampler(rank, world_size, port):
...
@@ -44,11 +31,14 @@ def run_data_sampler(rank, world_size, port):
colossalai
.
launch
(
**
dist_args
)
colossalai
.
launch
(
**
dist_args
)
print
(
'finished initialization'
)
print
(
'finished initialization'
)
transform_pipeline
=
[
ToTensor
()]
# build dataset
transform_pipeline
=
[
transforms
.
ToTensor
()]
transform_pipeline
=
transforms
.
Compose
(
transform_pipeline
)
transform_pipeline
=
transforms
.
Compose
(
transform_pipeline
)
gpc
.
config
.
train_data
.
dataset
[
'transform'
]
=
transform_pipeline
dataset
=
datasets
.
CIFAR10
(
root
=
Path
(
os
.
environ
[
'DATA'
]),
train
=
True
,
download
=
True
,
transform
=
transform_pipeline
)
dataset
=
build_dataset
(
gpc
.
config
.
train_data
.
dataset
)
dataloader
=
get_dataloader
(
dataset
,
**
gpc
.
config
.
train_data
.
dataloader
)
# build dataloader
dataloader
=
get_dataloader
(
dataset
,
batch_size
=
8
,
add_sampler
=
True
)
data_iter
=
iter
(
dataloader
)
data_iter
=
iter
(
dataloader
)
img
,
label
=
data_iter
.
next
()
img
,
label
=
data_iter
.
next
()
img
=
img
[
0
]
img
=
img
[
0
]
...
...
tests/test_data/test_deterministic_dataloader.py
View file @
943982d2
...
@@ -9,14 +9,12 @@ import pytest
...
@@ -9,14 +9,12 @@ import pytest
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
torchvision
import
transforms
from
torchvision
import
transforms
,
datasets
from
torch.utils.data
import
DataLoader
import
colossalai
import
colossalai
from
colossalai.builder
import
build_dataset
from
colossalai.context
import
ParallelMode
,
Config
from
colossalai.context
import
ParallelMode
,
Config
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.utils
import
free_port
from
colossalai.utils
import
get_dataloader
,
free_port
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing
import
rerun_if_address_is_in_use
from
torchvision
import
transforms
from
torchvision
import
transforms
...
@@ -43,20 +41,13 @@ def run_data_sampler(rank, world_size, port):
...
@@ -43,20 +41,13 @@ def run_data_sampler(rank, world_size, port):
dist_args
=
dict
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
backend
=
'gloo'
,
port
=
port
,
host
=
'localhost'
)
dist_args
=
dict
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
backend
=
'gloo'
,
port
=
port
,
host
=
'localhost'
)
colossalai
.
launch
(
**
dist_args
)
colossalai
.
launch
(
**
dist_args
)
dataset_cfg
=
gpc
.
config
.
train_data
.
dataset
dataloader_cfg
=
gpc
.
config
.
train_data
.
dataloader
transform_cfg
=
gpc
.
config
.
train_data
.
transform_pipeline
# build transform
transform_pipeline
=
[
transforms
.
ToTensor
(),
transforms
.
RandomCrop
(
size
=
32
)]
transform_pipeline
=
transforms
.
Compose
(
transform_pipeline
)
dataset_cfg
[
'transform'
]
=
transform_pipeline
# build dataset
# build dataset
dataset
=
build_dataset
(
dataset_cfg
)
transform_pipeline
=
[
transforms
.
ToTensor
(),
transforms
.
RandomCrop
(
size
=
32
,
padding
=
4
)]
transform_pipeline
=
transforms
.
Compose
(
transform_pipeline
)
dataset
=
datasets
.
CIFAR10
(
root
=
Path
(
os
.
environ
[
'DATA'
]),
train
=
True
,
download
=
True
,
transform
=
transform_pipeline
)
# build dataloader
# build dataloader
dataloader
=
D
ata
L
oader
(
dataset
=
dataset
,
**
dataloader_cfg
)
dataloader
=
get_d
ata
l
oader
(
dataset
,
batch_size
=
8
,
add_sampler
=
False
)
data_iter
=
iter
(
dataloader
)
data_iter
=
iter
(
dataloader
)
img
,
label
=
data_iter
.
next
()
img
,
label
=
data_iter
.
next
()
...
@@ -76,7 +67,6 @@ def run_data_sampler(rank, world_size, port):
...
@@ -76,7 +67,6 @@ def run_data_sampler(rank, world_size, port):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
cpu
@
pytest
.
mark
.
cpu
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
def
test_data_sampler
():
def
test_data_sampler
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment