Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
426a279c
Unverified
Commit
426a279c
authored
Jul 06, 2022
by
Frank Lee
Committed by
GitHub
Jul 06, 2022
Browse files
[fx] added testing for all bert variants (#1207)
* [fx] added testing for all bert variants * polish code
parent
b5f25eb3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
88 additions
and
20 deletions
+88
-20
colossalai/fx/proxy.py
colossalai/fx/proxy.py
+16
-2
tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py
tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py
+72
-18
No files found.
colossalai/fx/proxy.py
View file @
426a279c
import
operator
import
torch
from
torch.fx.proxy
import
Proxy
,
Attribute
from
typing
import
List
,
Union
from
torch.utils._pytree
import
PyTree
__all__
=
[
'ColoProxy'
]
...
...
@@ -26,8 +28,12 @@ class ColoProxy(Proxy):
return
self
.
_meta_tensor
@
meta_tensor
.
setter
def
meta_tensor
(
self
,
tensor
:
torch
.
Tensor
):
assert
tensor
is
None
or
tensor
.
is_meta
,
'Expected to receive a meta tensor, but got a non-meta tensor'
def
meta_tensor
(
self
,
tensor
:
Union
[
List
[
torch
.
Tensor
],
torch
.
Tensor
]):
def
_is_meta
(
item
):
assert
torch
.
is_tensor
(
item
)
and
item
.
is_meta
torch
.
fx
.
node
.
map_aggregate
(
tensor
,
_is_meta
)
self
.
_meta_tensor
=
tensor
@
property
...
...
@@ -83,6 +89,14 @@ class ColoProxy(Proxy):
def
__setitem__
(
self
,
indices
,
values
):
return
self
.
tracer
.
create_proxy
(
"call_function"
,
operator
.
setitem
,
(
self
,
indices
,
values
),
{})
def
__contains__
(
self
,
key
):
if
self
.
node
.
op
==
"placeholder"
:
# this is used to handle like
# if x in kwargs
# we don't handle this case for now
return
False
return
super
().
__contains__
(
key
)
class
ColoAttribute
(
ColoProxy
):
...
...
tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py
View file @
426a279c
...
...
@@ -7,36 +7,90 @@ BATCH_SIZE = 2
SEQ_LENGHT
=
16
def
t
est
_bert
(
):
def
t
race
_bert
_and_compare_output
(
model
,
data_gen
):
tracer
=
ColoTracer
()
config
=
transformers
.
BertConfig
()
model
=
transformers
.
BertModel
(
config
=
config
)
input_ids
=
torch
.
zeros
((
BATCH_SIZE
,
SEQ_LENGHT
),
dtype
=
torch
.
int64
,
device
=
'meta'
)
token_type_ids
=
torch
.
zeros
((
BATCH_SIZE
,
SEQ_LENGHT
),
dtype
=
torch
.
int64
,
device
=
'meta'
)
attention_mask
=
torch
.
zeros
((
BATCH_SIZE
,
SEQ_LENGHT
),
dtype
=
torch
.
int64
,
device
=
'meta'
)
meta_args
=
dict
(
input_ids
=
input_ids
,
token_type_ids
=
token_type_ids
,
attention_mask
=
attention_mask
)
# make sure that the model is traceable
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
meta_args
)
try
:
kwargs
=
data_gen
()
meta_args
=
{
k
:
v
.
to
(
'meta'
)
for
k
,
v
in
kwargs
.
items
()}
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
meta_args
)
except
Exception
as
e
:
raise
RuntimeError
(
f
"Failed to trace
{
model
.
__class__
.
__name__
}
, error:
{
e
}
"
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
# check output
input_ids
=
torch
.
zeros
((
BATCH_SIZE
,
SEQ_LENGHT
),
dtype
=
torch
.
int64
)
token_type_ids
=
torch
.
zeros
((
BATCH_SIZE
,
SEQ_LENGHT
),
dtype
=
torch
.
int64
)
attention_mask
=
torch
.
zeros
((
BATCH_SIZE
,
SEQ_LENGHT
),
dtype
=
torch
.
int64
)
inputs
=
data_gen
()
# must turn on eval mode to ensure the output is consistent
gm
.
eval
()
model
.
eval
()
# run forward
fx_out
=
gm
(
input_ids
=
input_ids
,
token_type_ids
=
token_type_ids
,
attention_mask
=
attention_mask
)
non_fx_out
=
model
(
input_ids
=
input_ids
,
token_type_ids
=
token_type_ids
,
attention_mask
=
attention_mask
)
assert
fx_out
[
'last_hidden_state'
].
shape
==
non_fx_out
[
'last_hidden_state'
].
shape
assert
torch
.
equal
(
fx_out
[
'last_hidden_state'
],
non_fx_out
[
'last_hidden_state'
])
non_fx_out
=
model
(
**
inputs
)
fx_out
=
gm
(
**
inputs
)
for
k
in
non_fx_out
.
keys
():
assert
torch
.
equal
(
fx_out
[
k
],
non_fx_out
[
k
]),
f
'
{
model
.
__class__
.
__name__
}
has incorrect output
{
k
}
'
def
test_single_sentence_bert
():
MODEL_LIST
=
[
transformers
.
BertModel
,
transformers
.
BertForPreTraining
,
transformers
.
BertLMHeadModel
,
transformers
.
BertForMaskedLM
,
transformers
.
BertForSequenceClassification
,
transformers
.
BertForTokenClassification
,
]
config
=
transformers
.
BertConfig
(
hidden_size
=
128
,
num_hidden_layers
=
2
,
num_attention_heads
=
4
,
intermediate_size
=
256
)
def
data_gen
():
input_ids
=
torch
.
zeros
((
BATCH_SIZE
,
SEQ_LENGHT
),
dtype
=
torch
.
int64
)
token_type_ids
=
torch
.
zeros
((
BATCH_SIZE
,
SEQ_LENGHT
),
dtype
=
torch
.
int64
)
attention_mask
=
torch
.
zeros
((
BATCH_SIZE
,
SEQ_LENGHT
),
dtype
=
torch
.
int64
)
meta_args
=
dict
(
input_ids
=
input_ids
,
token_type_ids
=
token_type_ids
,
attention_mask
=
attention_mask
)
return
meta_args
for
model_cls
in
MODEL_LIST
:
model
=
model_cls
(
config
=
config
)
trace_bert_and_compare_output
(
model
,
data_gen
)
def
test_multi_sentence_bert
():
config
=
transformers
.
BertConfig
(
hidden_size
=
128
,
num_hidden_layers
=
2
,
num_attention_heads
=
4
,
intermediate_size
=
256
)
tokenizer
=
transformers
.
BertTokenizer
.
from_pretrained
(
"bert-base-uncased"
)
def
data_gen_for_next_sentence
():
prompt
=
"In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
next_sentence
=
"The sky is blue due to the shorter wavelength of blue light."
encoding
=
tokenizer
(
prompt
,
next_sentence
,
return_tensors
=
"pt"
)
return
encoding
model
=
transformers
.
BertForNextSentencePrediction
(
config
)
trace_bert_and_compare_output
(
model
,
data_gen_for_next_sentence
)
def
data_gen_for_qa
():
question
,
text
=
"Who was Jim Henson?"
,
"Jim Henson was a nice puppet"
inputs
=
tokenizer
(
question
,
text
,
return_tensors
=
"pt"
)
return
inputs
model
=
transformers
.
BertForQuestionAnswering
(
config
)
trace_bert_and_compare_output
(
model
,
data_gen_for_qa
)
def
data_gen_for_mcq
():
prompt
=
"In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
choice0
=
"It is eaten with a fork and a knife."
choice1
=
"It is eaten while held in the hand."
encoding
=
tokenizer
([
prompt
,
prompt
],
[
choice0
,
choice1
],
return_tensors
=
"pt"
,
padding
=
True
)
encoding
=
{
k
:
v
.
unsqueeze
(
0
)
for
k
,
v
in
encoding
.
items
()}
return
encoding
model
=
transformers
.
BertForMultipleChoice
(
config
)
trace_bert_and_compare_output
(
model
,
data_gen_for_mcq
)
if
__name__
==
'__main__'
:
test_bert
()
test_single_sentence_bert
()
test_multi_sentence_bert
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment