Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0812aee2
Commit
0812aee2
authored
Dec 06, 2018
by
Grégory Châtel
Browse files
Fixing problems in convert_examples_to_features.
parent
f2b873e9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
13 deletions
+14
-13
examples/run_swag.py
examples/run_swag.py
+14
-13
No files found.
examples/run_swag.py
View file @
0812aee2
...
...
@@ -70,20 +70,13 @@ class SwagExample(object):
class
InputFeatures
(
object
):
def
__init__
(
self
,
unique_id
,
example_id
,
input_ids
,
input_mask
,
segment_ids
,
label_id
choices_features
,
label
):
self
.
unique_id
=
unique_id
self
.
example_id
=
example_id
self
.
input_ids
=
input_ids
self
.
input_mask
=
input_mask
self
.
segment_ids
=
segment_ids
self
.
label_id
=
label_id
self
.
choices_features
=
choices_features
self
.
label
=
label
def
read_swag_examples
(
input_file
,
is_training
):
input_df
=
pd
.
read_csv
(
input_file
)
...
...
@@ -145,7 +138,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
# place so that the total length is less than the
# specified length. Account for [CLS], [SEP], [SEP] with
# "- 3"
_truncate_seq_pair
(
context_tokens
,
ending_tokens
,
max_seq_length
-
3
)
_truncate_seq_pair
(
context_tokens
_choice
,
ending_tokens
,
max_seq_length
-
3
)
tokens
=
[
"[CLS]"
]
+
context_tokens_choice
+
[
"[SEP]"
]
+
ending_tokens
+
[
"[SEP]"
]
segment_ids
=
[
0
]
*
(
len
(
context_tokens_choice
)
+
2
)
+
[
1
]
*
(
len
(
ending_tokens
)
+
1
)
...
...
@@ -178,7 +171,15 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
if
is_training
:
logger
.
info
(
f
"label:
{
label
}
"
)
features
.
append
(
InputFeatures
(
example_id
=
example
.
swag_id
,
choices_features
=
choices_features
,
label
=
label
)
)
return
features
def
_truncate_seq_pair
(
tokens_a
,
tokens_b
,
max_length
):
"""Truncates a sequence pair in place to the maximum length."""
...
...
@@ -206,4 +207,4 @@ if __name__ == "__main__":
print
(
"###########################"
)
print
(
example
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
"bert-base-uncased"
)
convert_examples_to_features
(
examples
,
tokenizer
,
max_seq_length
,
is_training
)
features
=
convert_examples_to_features
(
examples
,
tokenizer
,
max_seq_length
,
is_training
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment