Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
f4f42d4c
Unverified
Commit
f4f42d4c
authored
Apr 13, 2022
by
Frank Lee
Committed by
GitHub
Apr 13, 2022
Browse files
[bug] fixed DDP compatibility with torch 1.8 (#739)
parent
a4e91bc8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
3 additions
and
3 deletions
+3
-3
tests/test_zero/test_shard_model_v2.py
tests/test_zero/test_shard_model_v2.py
+1
-1
tests/test_zero/test_sharded_optim_v2.py
tests/test_zero/test_sharded_optim_v2.py
+1
-1
tests/test_zero/test_zero_engine.py
tests/test_zero/test_zero_engine.py
+1
-1
No files found.
tests/test_zero/test_shard_model_v2.py
View file @
f4f42d4c
...
@@ -39,7 +39,7 @@ def run_model_test(enable_autocast, shard_strategy_class):
...
@@ -39,7 +39,7 @@ def run_model_test(enable_autocast, shard_strategy_class):
col_model_deepcopy
(
zero_model
,
model
)
col_model_deepcopy
(
zero_model
,
model
)
model
=
model
.
cuda
()
model
=
model
.
cuda
()
model
=
DDP
(
model
)
model
=
DDP
(
model
,
device_ids
=
[
torch
.
cuda
.
current_device
()]
)
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
5
:
if
i
>
5
:
...
...
tests/test_zero/test_sharded_optim_v2.py
View file @
f4f42d4c
...
@@ -86,7 +86,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
...
@@ -86,7 +86,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
amp_config
=
dict
(
opt_level
=
'O2'
,
keep_batchnorm_fp32
=
False
)
amp_config
=
dict
(
opt_level
=
'O2'
,
keep_batchnorm_fp32
=
False
)
apex_model
,
apex_optimizer
=
convert_to_apex_amp
(
model
,
optim
,
amp_config
)
apex_model
,
apex_optimizer
=
convert_to_apex_amp
(
model
,
optim
,
amp_config
)
if
dist
.
get_world_size
()
>
1
:
if
dist
.
get_world_size
()
>
1
:
apex_model
=
DDP
(
apex_model
)
apex_model
=
DDP
(
apex_model
,
device_ids
=
[
torch
.
cuda
.
current_device
()]
)
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
5
:
if
i
>
5
:
...
...
tests/test_zero/test_zero_engine.py
View file @
f4f42d4c
...
@@ -50,7 +50,7 @@ def run_dist(rank, world_size, port, parallel_config):
...
@@ -50,7 +50,7 @@ def run_dist(rank, world_size, port, parallel_config):
torch_optimizer
=
optimizer_class
(
torch_model
.
parameters
(),
lr
=
1e-3
)
torch_optimizer
=
optimizer_class
(
torch_model
.
parameters
(),
lr
=
1e-3
)
if
dist
.
get_world_size
()
>
1
:
if
dist
.
get_world_size
()
>
1
:
torch_model
=
DDP
(
torch_model
)
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
torch
.
cuda
.
current_device
()]
)
i
=
0
i
=
0
for
data
,
label
in
train_dataloader
:
for
data
,
label
in
train_dataloader
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment