Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e691fc09
Commit
e691fc09
authored
Jul 15, 2019
by
thomwolf
Browse files
update QA models tests + run_generation
parent
15d8b126
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
41 additions
and
27 deletions
+41
-27
examples/run_generation.py
examples/run_generation.py
+7
-10
examples/test_examples.py
examples/test_examples.py
+2
-1
pytorch_transformers/tests/modeling_xlm_test.py
pytorch_transformers/tests/modeling_xlm_test.py
+16
-8
pytorch_transformers/tests/modeling_xlnet_test.py
pytorch_transformers/tests/modeling_xlnet_test.py
+16
-8
No files found.
examples/run_generation.py
View file @
e691fc09
...
@@ -131,8 +131,10 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
...
@@ -131,8 +131,10 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
def
main
():
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--model_name'
,
type
=
str
,
default
=
None
,
required
=
True
,
parser
.
add_argument
(
"--model_type"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"GPT, GPT-2, Transformer-XL or XLNet pre-trained model selected in the list: "
+
", "
.
join
(
ALL_MODELS
))
help
=
"Model type selected in the list: "
+
", "
.
join
(
MODEL_CLASSES
.
keys
()))
parser
.
add_argument
(
"--model_name_or_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to pre-trained model or shortcut name selected in the list: "
+
", "
.
join
(
ALL_MODELS
))
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--padding_text"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--padding_text"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--length"
,
type
=
int
,
default
=
20
)
parser
.
add_argument
(
"--length"
,
type
=
int
,
default
=
20
)
...
@@ -150,15 +152,10 @@ def main():
...
@@ -150,15 +152,10 @@ def main():
set_seed
(
args
)
set_seed
(
args
)
args
.
model_type
=
""
args
.
model_type
=
args
.
model_type
.
lower
()
for
key
in
MODEL_CLASSES
:
if
key
in
args
.
model_name
.
lower
():
args
.
model_type
=
key
# take the first match in model types
break
model_class
,
tokenizer_class
=
MODEL_CLASSES
[
args
.
model_type
]
model_class
,
tokenizer_class
=
MODEL_CLASSES
[
args
.
model_type
]
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
model_name
)
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
model_name
_or_path
)
model
=
model_class
.
from_pretrained
(
args
.
model_name
)
model
=
model_class
.
from_pretrained
(
args
.
model_name
_or_path
)
model
.
to
(
args
.
device
)
model
.
to
(
args
.
device
)
model
.
eval
()
model
.
eval
()
...
...
examples/test_examples.py
View file @
e691fc09
...
@@ -101,7 +101,8 @@ class ExamplesTests(unittest.TestCase):
...
@@ -101,7 +101,8 @@ class ExamplesTests(unittest.TestCase):
"--prompt=Hello"
,
"--prompt=Hello"
,
"--length=10"
,
"--length=10"
,
"--seed=42"
]
"--seed=42"
]
model_name
=
"--model_name=openai-gpt"
model_type
,
model_name
=
(
"--model_type=openai-gpt"
,
"--model_name_or_path=openai-gpt"
)
with
patch
.
object
(
sys
,
'argv'
,
testargs
+
[
model_name
]):
with
patch
.
object
(
sys
,
'argv'
,
testargs
+
[
model_name
]):
result
=
run_generation
.
main
()
result
=
run_generation
.
main
()
self
.
assertGreaterEqual
(
len
(
result
),
10
)
self
.
assertGreaterEqual
(
len
(
result
),
10
)
...
...
pytorch_transformers/tests/modeling_xlm_test.py
View file @
e691fc09
...
@@ -191,17 +191,19 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
...
@@ -191,17 +191,19 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
cls_index
=
sequence_labels
,
cls_index
=
sequence_labels
,
is_impossible
=
is_impossible_labels
)
is_impossible
=
is_impossible_labels
)
total_loss
,
start_logits
,
end_logits
,
cls_logits
=
outputs
(
total_loss
,
)
=
outputs
outputs
=
model
(
input_ids
,
start_positions
=
sequence_labels
,
outputs
=
model
(
input_ids
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
)
end_positions
=
sequence_labels
)
total_loss
,
start_logits
,
end_logits
=
outputs
(
total_loss
,
)
=
outputs
result
=
{
result
=
{
"loss"
:
total_loss
,
"loss"
:
total_loss
,
"start_logits"
:
start_logits
,
"start_top_log_probs"
:
start_top_log_probs
,
"end_logits"
:
end_logits
,
"start_top_index"
:
start_top_index
,
"end_top_log_probs"
:
end_top_log_probs
,
"end_top_index"
:
end_top_index
,
"cls_logits"
:
cls_logits
,
"cls_logits"
:
cls_logits
,
}
}
...
@@ -209,11 +211,17 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
...
@@ -209,11 +211,17 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
list
(
result
[
"loss"
].
size
()),
list
(
result
[
"loss"
].
size
()),
[])
[])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_
logit
s"
].
size
()),
list
(
result
[
"start_
top_log_prob
s"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
[
self
.
batch_size
,
model
.
config
.
start_n_top
])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
size
()),
list
(
result
[
"start_top_index"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
[
self
.
batch_size
,
model
.
config
.
start_n_top
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_top_log_probs"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
*
model
.
config
.
end_n_top
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_top_index"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
*
model
.
config
.
end_n_top
])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"cls_logits"
].
size
()),
list
(
result
[
"cls_logits"
].
size
()),
[
self
.
batch_size
])
[
self
.
batch_size
])
...
...
pytorch_transformers/tests/modeling_xlnet_test.py
View file @
e691fc09
...
@@ -210,17 +210,19 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
...
@@ -210,17 +210,19 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
cls_index
=
sequence_labels
,
cls_index
=
sequence_labels
,
is_impossible
=
is_impossible_labels
)
is_impossible
=
is_impossible_labels
)
total_loss
,
start_logits
,
end_logits
,
cls_logits
,
mems
=
outputs
total_loss
,
mems
=
outputs
outputs
=
model
(
input_ids_1
,
start_positions
=
sequence_labels
,
outputs
=
model
(
input_ids_1
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
)
end_positions
=
sequence_labels
)
total_loss
,
start_logits
,
end_logits
,
mems
=
outputs
total_loss
,
mems
=
outputs
result
=
{
result
=
{
"loss"
:
total_loss
,
"loss"
:
total_loss
,
"start_logits"
:
start_logits
,
"start_top_log_probs"
:
start_top_log_probs
,
"end_logits"
:
end_logits
,
"start_top_index"
:
start_top_index
,
"end_top_log_probs"
:
end_top_log_probs
,
"end_top_index"
:
end_top_index
,
"cls_logits"
:
cls_logits
,
"cls_logits"
:
cls_logits
,
"mems"
:
mems
,
"mems"
:
mems
,
}
}
...
@@ -229,11 +231,17 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
...
@@ -229,11 +231,17 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
list
(
result
[
"loss"
].
size
()),
list
(
result
[
"loss"
].
size
()),
[])
[])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_
logit
s"
].
size
()),
list
(
result
[
"start_
top_log_prob
s"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
[
self
.
batch_size
,
model
.
config
.
start_n_top
])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
size
()),
list
(
result
[
"start_top_index"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
[
self
.
batch_size
,
model
.
config
.
start_n_top
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_top_log_probs"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
*
model
.
config
.
end_n_top
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_top_index"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
*
model
.
config
.
end_n_top
])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"cls_logits"
].
size
()),
list
(
result
[
"cls_logits"
].
size
()),
[
self
.
batch_size
])
[
self
.
batch_size
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment