Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
616ed91e
Unverified
Commit
616ed91e
authored
Dec 05, 2022
by
Jiarui Fang
Committed by
GitHub
Dec 05, 2022
Browse files
[test] bert test in non-distributed way (#2074)
parent
223332ff
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
6 deletions
+8
-6
tests/components_to_test/bert.py
tests/components_to_test/bert.py
+3
-2
tests/test_gemini/test_runtime_mem_tracer.py
tests/test_gemini/test_runtime_mem_tracer.py
+5
-4
No files found.
tests/components_to_test/bert.py
View file @
616ed91e
...
...
@@ -68,16 +68,17 @@ def get_training_components():
return
model
is_distrbuted
=
torch
.
distributed
.
is_initialized
()
trainloader
=
get_bert_data_loader
(
n_class
=
vocab_size
,
batch_size
=
2
,
total_samples
=
10000
,
sequence_length
=
sequence_length
,
is_distrbuted
=
True
)
is_distrbuted
=
is_distrbuted
)
testloader
=
get_bert_data_loader
(
n_class
=
vocab_size
,
batch_size
=
2
,
total_samples
=
10000
,
sequence_length
=
sequence_length
,
is_distrbuted
=
True
)
is_distrbuted
=
is_distrbuted
)
criterion
=
None
return
bert_model_builder
,
trainloader
,
testloader
,
torch
.
optim
.
Adam
,
criterion
tests/test_gemini/test_runtime_mem_tracer.py
View file @
616ed91e
...
...
@@ -21,14 +21,15 @@ def run_fwd_bwd(model, data, label, criterion, enable_autocast=False, dtype=torc
model
.
backward
(
loss
)
def
run_param_wrapper_testing
():
test_models
=
[
'simple_net'
,
'repeated_computed_layers'
,
'nested_model'
]
def
test_runtime_mem_tracer
():
test_models
=
[
'gpt2'
,
'bert'
,
'simple_net'
,
'repeated_computed_layers'
,
'nested_model'
]
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
_
,
_
,
criterion
=
get_components_func
()
with
ColoInitContext
(
device
=
torch
.
device
(
'cpu'
)):
model
=
model_builder
(
checkpoint
=
Fals
e
)
model
=
model_builder
(
checkpoint
=
Tru
e
)
model_bk
=
deepcopy
(
model
)
runtime_mem_tracer
=
RuntimeMemTracer
(
model
)
...
...
@@ -52,4 +53,4 @@ def run_param_wrapper_testing():
if
__name__
==
'__main__'
:
run_para
m_
w
ra
pper_testing
()
test_runtime_me
m_
t
ra
cer
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment