Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
4d94cd51
Commit
4d94cd51
authored
Mar 09, 2022
by
jiaruifang
Committed by
Frank Lee
Mar 11, 2022
Browse files
adapting bert unitest interface
parent
7977422a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
6 deletions
+18
-6
tests/components_to_test/bert.py
tests/components_to_test/bert.py
+13
-1
tests/test_zero_data_parallel/test_shard_model_v2.py
tests/test_zero_data_parallel/test_shard_model_v2.py
+5
-5
No files found.
tests/components_to_test/bert.py
View file @
4d94cd51
...
...
@@ -48,9 +48,21 @@ def get_training_components():
num_hidden_layers
=
num_layer
,
)
print
(
'building BertForSequenceClassification model'
)
model
=
BertForSequenceClassification
(
config
)
# adapting huggingface BertForSequenceClassification for single unitest calling interface
class
ModelAaptor
(
BertForSequenceClassification
):
def
forward
(
self
,
input_ids
,
labels
):
"""
inputs: data, label
outputs: loss
"""
return
super
().
forward
(
input_ids
=
input_ids
,
labels
=
labels
)[
0
]
model
=
ModelAaptor
(
config
)
if
checkpoint
and
version
.
parse
(
transformers
.
__version__
)
>=
version
.
parse
(
"4.11.0"
):
model
.
gradient_checkpointing_enable
()
return
model
trainloader
=
get_bert_data_loader
(
batch_size
=
2
,
...
...
tests/test_zero_data_parallel/test_shard_model_v2.py
View file @
4d94cd51
...
...
@@ -31,11 +31,11 @@ def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
loss
.
backward
()
def
run_bert_fwd_bwd
(
model
,
data
,
label
,
enable_autocast
=
False
):
# with no criterion
def
run_fwd_bwd_no_criterion
(
model
,
data
,
label
,
enable_autocast
=
False
):
model
.
train
()
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
output
=
model
(
input_ids
=
data
,
labels
=
label
)
loss
=
output
[
0
]
loss
=
model
(
data
,
label
)
if
isinstance
(
model
,
ShardedModelV2
):
model
.
backward
(
loss
)
else
:
...
...
@@ -60,8 +60,8 @@ def run_dist(rank, world_size, port):
if
model_name
==
'bert'
:
data
,
label
=
data
.
cuda
(),
label
.
cuda
()
run_
bert_
fwd_bwd
(
model
,
data
,
label
,
False
)
run_
bert_
fwd_bwd
(
zero_model
,
data
,
label
,
False
)
run_fwd_bwd
_no_criterion
(
model
,
data
,
label
,
False
)
run_fwd_bwd
_no_criterion
(
zero_model
,
data
,
label
,
False
)
else
:
data
,
label
=
data
.
half
().
cuda
(),
label
.
cuda
()
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment