Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
edf4cd46
Unverified
Commit
edf4cd46
authored
Dec 01, 2022
by
YuliangLiu0306
Committed by
GitHub
Dec 01, 2022
Browse files
[examples] update autoparallel demo (#2061)
parent
1c1fe443
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
14 deletions
+14
-14
examples/tutorial/auto_parallel/auto_parallel_with_resnet.py
examples/tutorial/auto_parallel/auto_parallel_with_resnet.py
+12
-14
examples/tutorial/auto_parallel/config.py
examples/tutorial/auto_parallel/config.py
+2
-0
No files found.
examples/tutorial/auto_parallel/auto_parallel_with_resnet.py
View file @
edf4cd46
...
@@ -15,7 +15,7 @@ from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pas
...
@@ -15,7 +15,7 @@ from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pas
from
colossalai.auto_parallel.passes.runtime_preparation_pass
import
runtime_preparation_pass
from
colossalai.auto_parallel.passes.runtime_preparation_pass
import
runtime_preparation_pass
from
colossalai.auto_parallel.tensor_shard.solver.cost_graph
import
CostGraph
from
colossalai.auto_parallel.tensor_shard.solver.cost_graph
import
CostGraph
from
colossalai.auto_parallel.tensor_shard.solver.graph_analysis
import
GraphAnalyser
from
colossalai.auto_parallel.tensor_shard.solver.graph_analysis
import
GraphAnalyser
from
colossalai.auto_parallel.tensor_shard.solver.options
import
SolverOptions
from
colossalai.auto_parallel.tensor_shard.solver.options
import
DataloaderOption
,
SolverOptions
from
colossalai.auto_parallel.tensor_shard.solver.solver
import
Solver
from
colossalai.auto_parallel.tensor_shard.solver.solver
import
Solver
from
colossalai.auto_parallel.tensor_shard.solver.strategies_constructor
import
StrategiesConstructor
from
colossalai.auto_parallel.tensor_shard.solver.strategies_constructor
import
StrategiesConstructor
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
...
@@ -26,8 +26,6 @@ from colossalai.nn.lr_scheduler import CosineAnnealingLR
...
@@ -26,8 +26,6 @@ from colossalai.nn.lr_scheduler import CosineAnnealingLR
from
colossalai.utils
import
get_dataloader
from
colossalai.utils
import
get_dataloader
DATA_ROOT
=
Path
(
os
.
environ
.
get
(
'DATA'
,
'../data'
)).
absolute
()
DATA_ROOT
=
Path
(
os
.
environ
.
get
(
'DATA'
,
'../data'
)).
absolute
()
BATCH_SIZE
=
1024
NUM_EPOCHS
=
10
def
parse_args
():
def
parse_args
():
...
@@ -37,14 +35,14 @@ def parse_args():
...
@@ -37,14 +35,14 @@ def parse_args():
def
synthesize_data
():
def
synthesize_data
():
img
=
torch
.
rand
(
BATCH_SIZE
,
3
,
32
,
32
)
img
=
torch
.
rand
(
gpc
.
config
.
BATCH_SIZE
,
3
,
32
,
32
)
label
=
torch
.
randint
(
low
=
0
,
high
=
10
,
size
=
(
BATCH_SIZE
,))
label
=
torch
.
randint
(
low
=
0
,
high
=
10
,
size
=
(
gpc
.
config
.
BATCH_SIZE
,))
return
img
,
label
return
img
,
label
def
main
():
def
main
():
args
=
parse_args
()
args
=
parse_args
()
colossalai
.
launch_from_torch
(
config
=
{}
)
colossalai
.
launch_from_torch
(
config
=
'./config.py'
)
logger
=
get_dist_logger
()
logger
=
get_dist_logger
()
...
@@ -70,16 +68,16 @@ def main():
...
@@ -70,16 +68,16 @@ def main():
train_dataloader
=
get_dataloader
(
train_dataloader
=
get_dataloader
(
dataset
=
train_dataset
,
dataset
=
train_dataset
,
add_sampler
=
Fals
e
,
add_sampler
=
Tru
e
,
shuffle
=
True
,
shuffle
=
True
,
batch_size
=
BATCH_SIZE
,
batch_size
=
gpc
.
config
.
BATCH_SIZE
,
pin_memory
=
True
,
pin_memory
=
True
,
)
)
test_dataloader
=
get_dataloader
(
test_dataloader
=
get_dataloader
(
dataset
=
test_dataset
,
dataset
=
test_dataset
,
add_sampler
=
Fals
e
,
add_sampler
=
Tru
e
,
batch_size
=
BATCH_SIZE
,
batch_size
=
gpc
.
config
.
BATCH_SIZE
,
pin_memory
=
True
,
pin_memory
=
True
,
)
)
else
:
else
:
...
@@ -93,13 +91,13 @@ def main():
...
@@ -93,13 +91,13 @@ def main():
# trace the model with meta data
# trace the model with meta data
tracer
=
ColoTracer
()
tracer
=
ColoTracer
()
model
=
resnet50
(
num_classes
=
10
).
cuda
()
model
=
resnet50
(
num_classes
=
10
).
cuda
()
input_sample
=
{
'x'
:
torch
.
rand
([
1024
,
3
,
32
,
32
]).
to
(
'meta'
)}
input_sample
=
{
'x'
:
torch
.
rand
([
gpc
.
config
.
BATCH_SIZE
*
torch
.
distributed
.
get_world_size
()
,
3
,
32
,
32
]).
to
(
'meta'
)}
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
input_sample
)
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
input_sample
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
gm
.
recompile
()
# prepare info for solver
# prepare info for solver
solver_options
=
SolverOptions
(
fast
=
True
)
solver_options
=
SolverOptions
(
dataloader_option
=
DataloaderOption
.
DISTRIBUTED
)
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
solver_options
)
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
solver_options
)
strategies_constructor
.
build_strategies_and_cost
()
strategies_constructor
.
build_strategies_and_cost
()
cost_graph
=
CostGraph
(
strategies_constructor
.
leaf_strategies
)
cost_graph
=
CostGraph
(
strategies_constructor
.
leaf_strategies
)
...
@@ -126,9 +124,9 @@ def main():
...
@@ -126,9 +124,9 @@ def main():
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
# lr_scheduler
# lr_scheduler
lr_scheduler
=
CosineAnnealingLR
(
optimizer
,
total_steps
=
NUM_EPOCHS
)
lr_scheduler
=
CosineAnnealingLR
(
optimizer
,
total_steps
=
gpc
.
config
.
NUM_EPOCHS
)
for
epoch
in
range
(
NUM_EPOCHS
):
for
epoch
in
range
(
gpc
.
config
.
NUM_EPOCHS
):
gm
.
train
()
gm
.
train
()
if
args
.
synthetic
:
if
args
.
synthetic
:
...
...
examples/tutorial/auto_parallel/config.py
0 → 100644
View file @
edf4cd46
BATCH_SIZE
=
128
NUM_EPOCHS
=
10
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment