Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
5da87ce3
Unverified
Commit
5da87ce3
authored
Jul 06, 2022
by
Frank Lee
Committed by
GitHub
Jul 06, 2022
Browse files
[fx] added testing for all albert variants (#1211)
parent
2d13a45a
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
193 additions
and
4 deletions
+193
-4
colossalai/fx/proxy.py
colossalai/fx/proxy.py
+2
-2
colossalai/fx/tracer/meta_patch/patched_function.py
colossalai/fx/tracer/meta_patch/patched_function.py
+54
-1
colossalai/fx/tracer/meta_patch/patched_module.py
colossalai/fx/tracer/meta_patch/patched_module.py
+6
-0
tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py
tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py
+65
-0
tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py
tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py
+31
-0
tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py
tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py
+32
-0
tests/test_fx/test_tracer/test_hf_model/utils.py
tests/test_fx/test_tracer/test_hf_model/utils.py
+3
-1
No files found.
colossalai/fx/proxy.py
View file @
5da87ce3
...
...
@@ -39,7 +39,7 @@ class ColoProxy(Proxy):
self
.
_meta_data
)
and
self
.
_meta_data
.
is_meta
,
f
'Meta data is not a meta tensor for
{
self
.
node
.
name
}
'
def
_assert_has_meta_data
(
self
):
assert
self
.
_meta_data
,
f
'Meta data is not set for
{
self
.
node
.
name
}
'
assert
self
.
_meta_data
is
not
None
,
f
'Meta data is not set for
{
self
.
node
.
name
}
'
@
property
def
device
(
self
):
...
...
@@ -63,7 +63,7 @@ class ColoProxy(Proxy):
def
size
(
self
,
dim
:
int
=
None
):
self
.
_assert_meta_data_is_tensor
()
if
dim
:
if
dim
is
not
None
:
return
self
.
meta_data
.
size
(
dim
=
dim
)
else
:
# size(dim=None) will trigger runtime error for meta tensor
...
...
colossalai/fx/tracer/meta_patch/patched_function.py
View file @
5da87ce3
from
curses
import
meta
import
operator
import
torch
from
.registry
import
meta_patched_function
...
...
@@ -89,3 +90,55 @@ def torch_where(condition, x, y):
# torch.where returns the broadcasted tensor of condition, x, and y,
# so hack it by using addition
return
condition
.
to
(
device
=
"meta"
)
+
x
.
to
(
device
=
"meta"
)
+
y
.
to
(
device
=
"meta"
)
@
meta_patched_function
.
register
(
torch
.
abs
)
def
torch_abs
(
input
,
*
,
out
=
None
):
assert
out
is
None
,
'out is not supported yet'
return
torch
.
empty
(
input
.
shape
,
device
=
'meta'
)
@
meta_patched_function
.
register
(
torch
.
nn
.
functional
.
relu
)
def
torch_nn_func_relu
(
input
,
inplace
=
False
):
assert
not
inplace
,
'inplace is not supported yet'
return
torch
.
empty
(
input
.
shape
,
device
=
'meta'
)
@
meta_patched_function
.
register
(
torch
.
Tensor
.
repeat
)
def
torch_tensor_repeat
(
self
,
*
sizes
):
shape
=
list
(
self
.
shape
)
for
i
,
x
in
enumerate
(
sizes
):
shape
[
i
]
*=
x
return
torch
.
empty
(
shape
,
device
=
"meta"
)
@
meta_patched_function
.
register
(
torch
.
index_select
)
def
torch_index_select
(
input
,
dim
,
index
,
*
,
out
=
None
):
shape
=
list
(
input
.
shape
)
shape
[
dim
]
=
len
(
index
)
return
torch
.
empty
(
*
shape
,
device
=
"meta"
)
@
meta_patched_function
.
register
(
torch
.
Tensor
.
index_select
)
def
torch_tensor_index_select
(
self
,
dim
,
index
):
return
torch_index_select
(
self
,
dim
,
index
)
@
meta_patched_function
.
register
(
torch
.
nn
.
functional
.
embedding
)
def
torch_nn_functional_embedding
(
input
,
weight
,
padding_idx
=
None
,
max_norm
=
None
,
norm_type
=
2.0
,
scale_grad_by_freq
=
False
,
sparse
=
False
):
return
torch
.
empty
(
*
input
.
shape
,
weight
.
shape
[
-
1
],
device
=
"meta"
)
@
meta_patched_function
.
register
(
torch
.
bmm
)
def
torch_bmm
(
input
,
mat2
,
*
,
out
=
None
):
if
out
is
not
None
:
raise
ValueError
(
"Don't support in-place abs for MetaTensor analysis"
)
batch_size
,
n
,
m
=
input
.
shape
_
,
_
,
p
=
mat2
.
shape
return
torch
.
empty
(
batch_size
,
n
,
p
,
device
=
"meta"
)
colossalai/fx/tracer/meta_patch/patched_module.py
View file @
5da87ce3
...
...
@@ -116,3 +116,9 @@ def torch_nn_maxpool3d(self, input):
w_out
,
)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_module
.
register
(
torch
.
nn
.
ReLU
)
def
torch_nn_func_relu
(
self
,
input
):
assert
not
self
.
inplace
,
'inplace is not supported yet'
return
input
.
clone
()
tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py
0 → 100644
View file @
5da87ce3
import
transformers
import
torch
from
utils
import
trace_model_and_compare_output
BATCH_SIZE
=
2
SEQ_LENGHT
=
16
def
test_single_sentence_albert
():
MODEL_LIST
=
[
transformers
.
AlbertModel
,
transformers
.
AlbertForPreTraining
,
transformers
.
AlbertForMaskedLM
,
transformers
.
AlbertForSequenceClassification
,
transformers
.
AlbertForTokenClassification
,
]
config
=
transformers
.
AlbertConfig
(
embedding_size
=
128
,
hidden_size
=
128
,
num_hidden_layers
=
2
,
num_attention_heads
=
4
,
intermediate_size
=
256
)
def
data_gen
():
input_ids
=
torch
.
zeros
((
BATCH_SIZE
,
SEQ_LENGHT
),
dtype
=
torch
.
int64
)
token_type_ids
=
torch
.
zeros
((
BATCH_SIZE
,
SEQ_LENGHT
),
dtype
=
torch
.
int64
)
attention_mask
=
torch
.
zeros
((
BATCH_SIZE
,
SEQ_LENGHT
),
dtype
=
torch
.
int64
)
meta_args
=
dict
(
input_ids
=
input_ids
,
token_type_ids
=
token_type_ids
,
attention_mask
=
attention_mask
)
return
meta_args
for
model_cls
in
MODEL_LIST
:
model
=
model_cls
(
config
=
config
)
trace_model_and_compare_output
(
model
,
data_gen
)
def
test_multi_sentence_albert
():
config
=
transformers
.
AlbertConfig
(
hidden_size
=
128
,
num_hidden_layers
=
2
,
num_attention_heads
=
4
,
intermediate_size
=
256
)
tokenizer
=
transformers
.
BertTokenizer
.
from_pretrained
(
"bert-base-uncased"
)
def
data_gen_for_qa
():
question
,
text
=
"Who was Jim Henson?"
,
"Jim Henson was a nice puppet"
inputs
=
tokenizer
(
question
,
text
,
return_tensors
=
"pt"
)
return
inputs
model
=
transformers
.
AlbertForQuestionAnswering
(
config
)
trace_model_and_compare_output
(
model
,
data_gen_for_qa
)
def
data_gen_for_mcq
():
prompt
=
"In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
choice0
=
"It is eaten with a fork and a knife."
choice1
=
"It is eaten while held in the hand."
encoding
=
tokenizer
([
prompt
,
prompt
],
[
choice0
,
choice1
],
return_tensors
=
"pt"
,
padding
=
True
)
encoding
=
{
k
:
v
.
unsqueeze
(
0
)
for
k
,
v
in
encoding
.
items
()}
return
encoding
model
=
transformers
.
AlbertForMultipleChoice
(
config
)
trace_model_and_compare_output
(
model
,
data_gen_for_mcq
)
if
__name__
==
'__main__'
:
test_single_sentence_albert
()
test_multi_sentence_albert
()
tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py
0 → 100644
View file @
5da87ce3
import
pytest
import
transformers
import
torch
from
utils
import
trace_model_and_compare_output
BATCH_SIZE
=
1
SEQ_LENGHT
=
16
@
pytest
.
mark
.
skip
(
'value is not aligned yet'
)
def
test_opt
():
MODEL_LIST
=
[
transformers
.
OPTModel
,
transformers
.
OPTForCausalLM
,
]
config
=
transformers
.
OPTConfig
(
hidden_size
=
128
,
num_hidden_layers
=
2
,
num_attention_heads
=
4
)
def
data_gen
():
input_ids
=
torch
.
zeros
((
BATCH_SIZE
,
SEQ_LENGHT
),
dtype
=
torch
.
int64
)
attention_mask
=
torch
.
zeros
((
BATCH_SIZE
,
SEQ_LENGHT
),
dtype
=
torch
.
int64
)
kwargs
=
dict
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
)
return
kwargs
for
model_cls
in
MODEL_LIST
:
model
=
model_cls
(
config
=
config
)
trace_model_and_compare_output
(
model
,
data_gen
)
if
__name__
==
'__main__'
:
test_opt
()
tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py
0 → 100644
View file @
5da87ce3
import
pytest
import
transformers
import
torch
from
utils
import
trace_model_and_compare_output
BATCH_SIZE
=
1
SEQ_LENGHT
=
16
@
pytest
.
mark
.
skip
(
'value is not aligned yet'
)
def
test_t5
():
MODEL_LIST
=
[
transformers
.
T5Model
,
transformers
.
T5ForConditionalGeneration
,
transformers
.
T5EncoderModel
,
]
config
=
transformers
.
T5Config
(
d_model
=
128
,
num_layers
=
2
)
def
data_gen
():
input_ids
=
torch
.
zeros
((
BATCH_SIZE
,
SEQ_LENGHT
),
dtype
=
torch
.
int64
)
decoder_input_ids
=
torch
.
zeros
((
BATCH_SIZE
,
SEQ_LENGHT
),
dtype
=
torch
.
int64
)
kwargs
=
dict
(
input_ids
=
input_ids
,
decoder_input_ids
=
decoder_input_ids
)
return
kwargs
for
model_cls
in
MODEL_LIST
:
model
=
model_cls
(
config
=
config
)
trace_model_and_compare_output
(
model
,
data_gen
)
if
__name__
==
'__main__'
:
test_t5
()
tests/test_fx/test_tracer/test_hf_model/utils.py
View file @
5da87ce3
...
...
@@ -30,4 +30,6 @@ def trace_model_and_compare_output(model, data_gen):
for
k
in
non_fx_out
.
keys
():
if
torch
.
is_tensor
(
fx_out
[
k
]):
assert
torch
.
equal
(
fx_out
[
k
],
non_fx_out
[
k
]),
f
'
{
model
.
__class__
.
__name__
}
has incorrect output
{
k
}
'
assert
torch
.
equal
(
fx_out
[
k
],
non_fx_out
[
k
]
),
f
'
{
model
.
__class__
.
__name__
}
has incorrect output
{
k
}
, expect
{
non_fx_out
[
k
]
}
, but got
{
fx_out
[
k
]
}
'
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment