Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
943982d2
Unverified
Commit
943982d2
authored
Apr 22, 2022
by
Frank Lee
Committed by
GitHub
Apr 22, 2022
Browse files
[unittest] refactored unit tests for change in dependency (#838)
parent
f271f347
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
25 additions
and
58 deletions
+25
-58
tests/test_data/test_cifar10_dataset.py
tests/test_data/test_cifar10_dataset.py
+4
-17
tests/test_data/test_data_parallel_sampler.py
tests/test_data/test_data_parallel_sampler.py
+15
-25
tests/test_data/test_deterministic_dataloader.py
tests/test_data/test_deterministic_dataloader.py
+6
-16
No files found.
tests/test_data/test_cifar10_dataset.py
View file @
943982d2
...
...
@@ -5,34 +5,21 @@ import os
from
pathlib
import
Path
import
pytest
from
torchvision
import
transforms
from
torchvision
import
transforms
,
datasets
from
torch.utils.data
import
DataLoader
from
colossalai.builder
import
build_dataset
,
build_transform
from
colossalai.context
import
Config
from
torchvision.transforms
import
ToTensor
TRAIN_DATA
=
dict
(
dataset
=
dict
(
type
=
'CIFAR10'
,
root
=
Path
(
os
.
environ
[
'DATA'
]),
train
=
True
,
download
=
True
),
dataloader
=
dict
(
batch_size
=
4
,
shuffle
=
True
,
num_workers
=
2
))
@
pytest
.
mark
.
cpu
def
test_cifar10_dataset
():
config
=
Config
(
TRAIN_DATA
)
dataset_cfg
=
config
.
dataset
dataloader_cfg
=
config
.
dataloader
transform_cfg
=
config
.
transform_pipeline
# build transform
transform_pipeline
=
[
ToTensor
()]
transform_pipeline
=
[
transforms
.
ToTensor
()]
transform_pipeline
=
transforms
.
Compose
(
transform_pipeline
)
dataset_cfg
[
'transform'
]
=
transform_pipeline
# build dataset
dataset
=
build_
dataset
(
dataset_cfg
)
dataset
=
dataset
s
.
CIFAR10
(
root
=
Path
(
os
.
environ
[
'DATA'
]),
train
=
True
,
download
=
True
,
transform
=
transform_pipeline
)
# build dataloader
dataloader
=
DataLoader
(
dataset
=
dataset
,
**
dataloader_cfg
)
dataloader
=
DataLoader
(
dataset
=
dataset
,
batch_size
=
4
,
shuffle
=
True
,
num_workers
=
2
)
data_iter
=
iter
(
dataloader
)
img
,
label
=
data_iter
.
next
()
...
...
tests/test_data/test_data_parallel_sampler.py
View file @
943982d2
...
...
@@ -9,34 +9,21 @@ import pytest
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
torch.utils.data
import
DataLoader
import
colossalai
from
colossalai.builder
import
build_dataset
from
torchvision
import
transforms
from
torchvision
import
transforms
,
datasets
from
colossalai.context
import
ParallelMode
,
Config
from
colossalai.core
import
global_context
as
gpc
from
colossalai.utils
import
get_dataloader
,
free_port
from
colossalai.testing
import
rerun_if_address_is_in_use
from
torchvision.transforms
import
ToTensor
CONFIG
=
Config
(
dict
(
train_data
=
dict
(
dataset
=
dict
(
type
=
'CIFAR10'
,
root
=
Path
(
os
.
environ
[
'DATA'
]),
train
=
True
,
download
=
True
,
),
dataloader
=
dict
(
batch_size
=
8
,),
),
parallel
=
dict
(
pipeline
=
dict
(
size
=
1
),
tensor
=
dict
(
size
=
1
,
mode
=
None
),
),
seed
=
1024
,
))
CONFIG
=
Config
(
dict
(
parallel
=
dict
(
pipeline
=
dict
(
size
=
1
),
tensor
=
dict
(
size
=
1
,
mode
=
None
),
),
seed
=
1024
,
))
def
run_data_sampler
(
rank
,
world_size
,
port
):
...
...
@@ -44,11 +31,14 @@ def run_data_sampler(rank, world_size, port):
colossalai
.
launch
(
**
dist_args
)
print
(
'finished initialization'
)
transform_pipeline
=
[
ToTensor
()]
# build dataset
transform_pipeline
=
[
transforms
.
ToTensor
()]
transform_pipeline
=
transforms
.
Compose
(
transform_pipeline
)
gpc
.
config
.
train_data
.
dataset
[
'transform'
]
=
transform_pipeline
dataset
=
build_dataset
(
gpc
.
config
.
train_data
.
dataset
)
dataloader
=
get_dataloader
(
dataset
,
**
gpc
.
config
.
train_data
.
dataloader
)
dataset
=
datasets
.
CIFAR10
(
root
=
Path
(
os
.
environ
[
'DATA'
]),
train
=
True
,
download
=
True
,
transform
=
transform_pipeline
)
# build dataloader
dataloader
=
get_dataloader
(
dataset
,
batch_size
=
8
,
add_sampler
=
True
)
data_iter
=
iter
(
dataloader
)
img
,
label
=
data_iter
.
next
()
img
=
img
[
0
]
...
...
tests/test_data/test_deterministic_dataloader.py
View file @
943982d2
...
...
@@ -9,14 +9,12 @@ import pytest
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
torchvision
import
transforms
from
torch.utils.data
import
DataLoader
from
torchvision
import
transforms
,
datasets
import
colossalai
from
colossalai.builder
import
build_dataset
from
colossalai.context
import
ParallelMode
,
Config
from
colossalai.core
import
global_context
as
gpc
from
colossalai.utils
import
free_port
from
colossalai.utils
import
get_dataloader
,
free_port
from
colossalai.testing
import
rerun_if_address_is_in_use
from
torchvision
import
transforms
...
...
@@ -43,20 +41,13 @@ def run_data_sampler(rank, world_size, port):
dist_args
=
dict
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
backend
=
'gloo'
,
port
=
port
,
host
=
'localhost'
)
colossalai
.
launch
(
**
dist_args
)
dataset_cfg
=
gpc
.
config
.
train_data
.
dataset
dataloader_cfg
=
gpc
.
config
.
train_data
.
dataloader
transform_cfg
=
gpc
.
config
.
train_data
.
transform_pipeline
# build transform
transform_pipeline
=
[
transforms
.
ToTensor
(),
transforms
.
RandomCrop
(
size
=
32
)]
transform_pipeline
=
transforms
.
Compose
(
transform_pipeline
)
dataset_cfg
[
'transform'
]
=
transform_pipeline
# build dataset
dataset
=
build_dataset
(
dataset_cfg
)
transform_pipeline
=
[
transforms
.
ToTensor
(),
transforms
.
RandomCrop
(
size
=
32
,
padding
=
4
)]
transform_pipeline
=
transforms
.
Compose
(
transform_pipeline
)
dataset
=
datasets
.
CIFAR10
(
root
=
Path
(
os
.
environ
[
'DATA'
]),
train
=
True
,
download
=
True
,
transform
=
transform_pipeline
)
# build dataloader
dataloader
=
D
ata
L
oader
(
dataset
=
dataset
,
**
dataloader_cfg
)
dataloader
=
get_d
ata
l
oader
(
dataset
,
batch_size
=
8
,
add_sampler
=
False
)
data_iter
=
iter
(
dataloader
)
img
,
label
=
data_iter
.
next
()
...
...
@@ -76,7 +67,6 @@ def run_data_sampler(rank, world_size, port):
torch
.
cuda
.
empty_cache
()
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
cpu
@
rerun_if_address_is_in_use
()
def
test_data_sampler
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment