Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
1b0dd059
Unverified
Commit
1b0dd059
authored
Nov 12, 2022
by
Frank Lee
Committed by
GitHub
Nov 12, 2022
Browse files
[tutorial] added synthetic dataset for auto parallel demo (#1918)
parent
acd9abc5
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
90 additions
and
37 deletions
+90
-37
examples/tutorial/auto_parallel/auto_parallel_with_resnet.py
examples/tutorial/auto_parallel/auto_parallel_with_resnet.py
+90
-37
No files found.
examples/tutorial/auto_parallel/auto_parallel_with_resnet.py
View file @
1b0dd059
import
argparse
import
os
import
os
from
pathlib
import
Path
from
pathlib
import
Path
...
@@ -29,11 +30,25 @@ BATCH_SIZE = 1024
...
@@ -29,11 +30,25 @@ BATCH_SIZE = 1024
NUM_EPOCHS
=
10
NUM_EPOCHS
=
10
def
parse_args
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'-s'
,
'--synthetic'
,
action
=
"store_true"
,
help
=
"use synthetic dataset instead of CIFAR10"
)
return
parser
.
parse_args
()
def
synthesize_data
():
img
=
torch
.
rand
(
BATCH_SIZE
,
3
,
32
,
32
)
label
=
torch
.
randint
(
low
=
0
,
high
=
10
,
size
=
(
BATCH_SIZE
,))
return
img
,
label
def
main
():
def
main
():
args
=
parse_args
()
colossalai
.
launch_from_torch
(
config
=
{})
colossalai
.
launch_from_torch
(
config
=
{})
logger
=
get_dist_logger
()
logger
=
get_dist_logger
()
if
not
args
.
synthetic
:
with
barrier_context
():
with
barrier_context
():
# build dataloaders
# build dataloaders
train_dataset
=
CIFAR10
(
root
=
DATA_ROOT
,
train_dataset
=
CIFAR10
(
root
=
DATA_ROOT
,
...
@@ -42,7 +57,8 @@ def main():
...
@@ -42,7 +57,8 @@ def main():
transforms
.
RandomCrop
(
size
=
32
,
padding
=
4
),
transforms
.
RandomCrop
(
size
=
32
,
padding
=
4
),
transforms
.
RandomHorizontalFlip
(),
transforms
.
RandomHorizontalFlip
(),
transforms
.
ToTensor
(),
transforms
.
ToTensor
(),
transforms
.
Normalize
(
mean
=
[
0.4914
,
0.4822
,
0.4465
],
std
=
[
0.2023
,
0.1994
,
0.2010
]),
transforms
.
Normalize
(
mean
=
[
0.4914
,
0.4822
,
0.4465
],
std
=
[
0.2023
,
0.1994
,
0.2010
]),
]))
]))
test_dataset
=
CIFAR10
(
root
=
DATA_ROOT
,
test_dataset
=
CIFAR10
(
root
=
DATA_ROOT
,
...
@@ -66,6 +82,8 @@ def main():
...
@@ -66,6 +82,8 @@ def main():
batch_size
=
BATCH_SIZE
,
batch_size
=
BATCH_SIZE
,
pin_memory
=
True
,
pin_memory
=
True
,
)
)
else
:
train_dataloader
,
test_dataloader
=
None
,
None
# initialize device mesh
# initialize device mesh
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
...
@@ -112,11 +130,26 @@ def main():
...
@@ -112,11 +130,26 @@ def main():
for
epoch
in
range
(
NUM_EPOCHS
):
for
epoch
in
range
(
NUM_EPOCHS
):
gm
.
train
()
gm
.
train
()
if
gpc
.
get_global_rank
()
==
0
:
train_dl
=
tqdm
(
train_dataloader
)
if
args
.
synthetic
:
# if we use synthetic data
# we assume it only has 30 steps per epoch
num_steps
=
range
(
30
)
else
:
# we use the actual number of steps for training
num_steps
=
range
(
len
(
train_dataloader
))
data_iter
=
iter
(
train_dataloader
)
progress
=
tqdm
(
num_steps
)
for
_
in
progress
:
if
args
.
synthetic
:
# generate fake data
img
,
label
=
synthesize_data
()
else
:
else
:
train_dl
=
train_dataloader
# get the real data
for
img
,
label
in
train_dl
:
img
,
label
=
next
(
data_iter
)
img
=
img
.
cuda
()
img
=
img
.
cuda
()
label
=
label
.
cuda
()
label
=
label
.
cuda
()
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
...
@@ -126,10 +159,30 @@ def main():
...
@@ -126,10 +159,30 @@ def main():
optimizer
.
step
()
optimizer
.
step
()
lr_scheduler
.
step
()
lr_scheduler
.
step
()
# run evaluation
gm
.
eval
()
gm
.
eval
()
correct
=
0
correct
=
0
total
=
0
total
=
0
for
img
,
label
in
test_dataloader
:
if
args
.
synthetic
:
# if we use synthetic data
# we assume it only has 10 steps for evaluation
num_steps
=
range
(
30
)
else
:
# we use the actual number of steps for training
num_steps
=
range
(
len
(
test_dataloader
))
data_iter
=
iter
(
test_dataloader
)
progress
=
tqdm
(
num_steps
)
for
_
in
progress
:
if
args
.
synthetic
:
# generate fake data
img
,
label
=
synthesize_data
()
else
:
# get the real data
img
,
label
=
next
(
data_iter
)
img
=
img
.
cuda
()
img
=
img
.
cuda
()
label
=
label
.
cuda
()
label
=
label
.
cuda
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment