Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
4f21c9e8
Unverified
Commit
4f21c9e8
authored
Dec 05, 2022
by
Jiarui Fang
Committed by
GitHub
Dec 05, 2022
Browse files
[Gemini] polish runtime tracer tests (#2077)
parent
677e1e20
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
13 deletions
+2
-13
tests/test_gemini/test_runtime_mem_tracer.py
tests/test_gemini/test_runtime_mem_tracer.py
+2
-13
No files found.
tests/test_gemini/test_runtime_mem_tracer.py
View file @
4f21c9e8
...
@@ -10,17 +10,6 @@ from tests.components_to_test import run_fwd_bwd
...
@@ -10,17 +10,6 @@ from tests.components_to_test import run_fwd_bwd
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
def
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
enable_autocast
=
False
,
dtype
=
torch
.
half
):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
if
criterion
:
y
=
model
(
data
)
loss
=
criterion
(
y
,
label
)
else
:
loss
=
model
(
data
,
label
)
loss
=
loss
.
to
(
dtype
)
model
.
backward
(
loss
)
def
test_runtime_mem_tracer
():
def
test_runtime_mem_tracer
():
test_models
=
[
'gpt2'
,
'bert'
,
'simple_net'
,
'repeated_computed_layers'
,
'nested_model'
,
'albert'
]
test_models
=
[
'gpt2'
,
'bert'
,
'simple_net'
,
'repeated_computed_layers'
,
'nested_model'
,
'albert'
]
...
@@ -28,7 +17,7 @@ def test_runtime_mem_tracer():
...
@@ -28,7 +17,7 @@ def test_runtime_mem_tracer():
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
_
,
_
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
_
,
_
,
criterion
=
get_components_func
()
with
ColoInitContext
(
device
=
torch
.
device
(
'cpu'
)
)
:
with
ColoInitContext
(
device
=
'cpu'
):
model
=
model_builder
(
checkpoint
=
False
)
model
=
model_builder
(
checkpoint
=
False
)
model_bk
=
deepcopy
(
model
)
model_bk
=
deepcopy
(
model
)
...
@@ -40,7 +29,7 @@ def test_runtime_mem_tracer():
...
@@ -40,7 +29,7 @@ def test_runtime_mem_tracer():
data
=
data
.
cuda
()
data
=
data
.
cuda
()
label
=
label
.
cuda
()
label
=
label
.
cuda
()
run_fwd_bwd
(
runtime_mem_tracer
,
data
,
label
,
criterion
,
False
)
run_fwd_bwd
(
runtime_mem_tracer
,
data
,
label
,
criterion
,
optimizer
=
runtime_mem_tracer
)
for
p1
,
p2
in
zip
(
model_bk
.
parameters
(),
model
.
parameters
()):
for
p1
,
p2
in
zip
(
model_bk
.
parameters
(),
model
.
parameters
()):
torch
.
allclose
(
p1
.
to
(
torch
.
half
),
p2
)
torch
.
allclose
(
p1
.
to
(
torch
.
half
),
p2
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment