Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
f477a14f
Unverified
Commit
f477a14f
authored
Jan 31, 2023
by
YuliangLiu0306
Committed by
GitHub
Jan 31, 2023
Browse files
[hotfix] fix autoparallel demo (#2533)
parent
63199c66
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
2 deletions
+9
-2
examples/tutorial/auto_parallel/auto_parallel_with_resnet.py
examples/tutorial/auto_parallel/auto_parallel_with_resnet.py
+9
-2
No files found.
examples/tutorial/auto_parallel/auto_parallel_with_resnet.py
View file @
f477a14f
...
...
@@ -3,8 +3,9 @@ from torchvision.models import resnet50
from
tqdm
import
tqdm
import
colossalai
from
colossalai.auto_parallel.tensor_shard.initialize
import
autoparallelize
from
colossalai.auto_parallel.tensor_shard.initialize
import
initialize_model
from
colossalai.core
import
global_context
as
gpc
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.logging
import
get_dist_logger
from
colossalai.nn.lr_scheduler
import
CosineAnnealingLR
...
...
@@ -22,9 +23,14 @@ def main():
# trace the model with meta data
model
=
resnet50
(
num_classes
=
10
).
cuda
()
input_sample
=
{
'x'
:
torch
.
rand
([
gpc
.
config
.
BATCH_SIZE
*
torch
.
distributed
.
get_world_size
(),
3
,
32
,
32
]).
to
(
'meta'
)}
device_mesh
=
DeviceMesh
(
physical_mesh_id
=
torch
.
tensor
([
0
,
1
,
2
,
3
]),
mesh_shape
=
[
2
,
2
],
init_process_group
=
True
)
model
,
solution
=
initialize_model
(
model
,
input_sample
,
device_mesh
=
device_mesh
,
return_solution
=
True
)
model
=
autoparallelize
(
model
,
input_sample
)
if
gpc
.
get_global_rank
()
==
0
:
for
node_strategy
in
solution
:
print
(
node_strategy
)
# build criterion
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
...
...
@@ -52,6 +58,7 @@ def main():
output
=
model
(
img
)
train_loss
=
criterion
(
output
,
label
)
train_loss
.
backward
(
train_loss
)
torch
.
cuda
.
synchronize
()
optimizer
.
step
()
lr_scheduler
.
step
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment