Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
d271f259
Commit
d271f259
authored
Mar 09, 2022
by
jiaruifang
Committed by
Frank Lee
Mar 11, 2022
Browse files
polish engine unitest
parent
354c0f90
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
8 deletions
+11
-8
tests/test_engine/test_engine.py
tests/test_engine/test_engine.py
+10
-8
tests/test_zero_data_parallel/test_shard_model_v2.py
tests/test_zero_data_parallel/test_shard_model_v2.py
+1
-0
No files found.
tests/test_engine/test_engine.py
View file @
d271f259
...
@@ -15,7 +15,6 @@ CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None
...
@@ -15,7 +15,6 @@ CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None
def
run_train
():
def
run_train
():
assert
non_distributed_component_funcs
.
get_callable
(
'bert'
)
for
get_components_func
in
non_distributed_component_funcs
:
for
get_components_func
in
non_distributed_component_funcs
:
model_builder
,
train_dataloader
,
_
,
optimizer_builder
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
_
,
optimizer_builder
,
criterion
=
get_components_func
()
...
@@ -27,12 +26,15 @@ def run_train():
...
@@ -27,12 +26,15 @@ def run_train():
try
:
try
:
engine
.
train
()
engine
.
train
()
for
img
,
label
in
train_dataloader
:
for
data
,
label
in
train_dataloader
:
engine
.
zero_grad
()
engine
.
zero_grad
()
img
=
img
.
cuda
()
data
=
data
.
cuda
()
label
=
label
.
cuda
()
label
=
label
.
cuda
()
output
=
engine
(
img
)
if
criterion
:
loss
=
engine
.
criterion
(
output
,
label
)
output
=
engine
(
data
)
loss
=
engine
.
criterion
(
output
,
label
)
else
:
loss
=
engine
(
data
,
label
)
engine
.
backward
(
loss
)
engine
.
backward
(
loss
)
engine
.
step
()
engine
.
step
()
break
break
...
@@ -72,9 +74,9 @@ def run_engine(rank, world_size, port):
...
@@ -72,9 +74,9 @@ def run_engine(rank, world_size, port):
# init dist env
# init dist env
colossalai
.
launch
(
config
=
dict
(),
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
dict
(),
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_with_no_amp
()
run_with_no_amp
()
#
run_with_torch_amp()
run_with_torch_amp
()
#
run_with_apex_amp()
run_with_apex_amp
()
#
run_with_naive_amp()
run_with_naive_amp
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
...
...
tests/test_zero_data_parallel/test_shard_model_v2.py
View file @
d271f259
...
@@ -76,6 +76,7 @@ def run_dist(rank, world_size, port):
...
@@ -76,6 +76,7 @@ def run_dist(rank, world_size, port):
check_grads
(
model
,
zero_model
,
loose
=
True
)
check_grads
(
model
,
zero_model
,
loose
=
True
)
@
pytest
.
mark
.
skip
(
reason
=
"Under development"
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
,
4
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
,
4
])
def
test_shard_model_v2
(
world_size
):
def
test_shard_model_v2
(
world_size
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment