Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f2b873e9
Commit
f2b873e9
authored
Dec 06, 2018
by
Grégory Châtel
Browse files
convert_examples_to_features code and small improvements.
parent
83fdbd60
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
138 additions
and
16 deletions
+138
-16
examples/run_swag.py
examples/run_swag.py
+138
-16
No files found.
examples/run_swag.py
View file @
f2b873e9
...
@@ -16,6 +16,15 @@
...
@@ -16,6 +16,15 @@
import
pandas
as
pd
import
pandas
as
pd
import
logging
from
pytorch_pretrained_bert.tokenization
import
BertTokenizer
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
level
=
logging
.
INFO
)
logger
=
logging
.
getLogger
(
__name__
)
class
SwagExample
(
object
):
class
SwagExample
(
object
):
"""A single training/test example for the SWAG dataset."""
"""A single training/test example for the SWAG dataset."""
...
@@ -31,10 +40,12 @@ class SwagExample(object):
...
@@ -31,10 +40,12 @@ class SwagExample(object):
self
.
swag_id
=
swag_id
self
.
swag_id
=
swag_id
self
.
context_sentence
=
context_sentence
self
.
context_sentence
=
context_sentence
self
.
start_ending
=
start_ending
self
.
start_ending
=
start_ending
self
.
ending_0
=
ending_0
self
.
endings
=
[
self
.
ending_1
=
ending_1
ending_0
,
self
.
ending_2
=
ending_2
ending_1
,
self
.
ending_3
=
ending_3
ending_2
,
ending_3
,
]
self
.
label
=
label
self
.
label
=
label
def
__str__
(
self
):
def
__str__
(
self
):
...
@@ -42,19 +53,37 @@ class SwagExample(object):
...
@@ -42,19 +53,37 @@ class SwagExample(object):
def
__repr__
(
self
):
def
__repr__
(
self
):
l
=
[
l
=
[
f
'
swag_id:
{
self
.
swag_id
}
'
,
f
"
swag_id:
{
self
.
swag_id
}
"
,
f
'
context_sentence:
{
self
.
context_sentence
}
'
,
f
"
context_sentence:
{
self
.
context_sentence
}
"
,
f
'
start_ending:
{
self
.
start_ending
}
'
,
f
"
start_ending:
{
self
.
start_ending
}
"
,
f
'
ending_0:
{
self
.
ending
_0
}
'
,
f
"
ending_0:
{
self
.
ending
s
[
0
]
}
"
,
f
'
ending_1:
{
self
.
ending
_1
}
'
,
f
"
ending_1:
{
self
.
ending
s
[
1
]
}
"
,
f
'
ending_2:
{
self
.
ending
_2
}
'
,
f
"
ending_2:
{
self
.
ending
s
[
2
]
}
"
,
f
'
ending_3:
{
self
.
ending
_3
}
'
,
f
"
ending_3:
{
self
.
ending
s
[
3
]
}
"
,
]
]
if
self
.
label
is
not
None
:
if
self
.
label
is
not
None
:
l
.
append
(
f
'label:
{
self
.
label
}
'
)
l
.
append
(
f
"label:
{
self
.
label
}
"
)
return
", "
.
join
(
l
)
class
InputFeatures
(
object
):
def
__init__
(
self
,
unique_id
,
example_id
,
input_ids
,
input_mask
,
segment_ids
,
label_id
):
self
.
unique_id
=
unique_id
self
.
example_id
=
example_id
self
.
input_ids
=
input_ids
self
.
input_mask
=
input_mask
self
.
segment_ids
=
segment_ids
self
.
label_id
=
label_id
return
', '
.
join
(
l
)
def
read_swag_examples
(
input_file
,
is_training
):
def
read_swag_examples
(
input_file
,
is_training
):
input_df
=
pd
.
read_csv
(
input_file
)
input_df
=
pd
.
read_csv
(
input_file
)
...
@@ -67,7 +96,9 @@ def read_swag_examples(input_file, is_training):
...
@@ -67,7 +96,9 @@ def read_swag_examples(input_file, is_training):
SwagExample
(
SwagExample
(
swag_id
=
row
[
'fold-ind'
],
swag_id
=
row
[
'fold-ind'
],
context_sentence
=
row
[
'sent1'
],
context_sentence
=
row
[
'sent1'
],
start_ending
=
row
[
'sent2'
],
start_ending
=
row
[
'sent2'
],
# in the swag dataset, the
# common beginning of each
# choice is stored in "sent2".
ending_0
=
row
[
'ending0'
],
ending_0
=
row
[
'ending0'
],
ending_1
=
row
[
'ending1'
],
ending_1
=
row
[
'ending1'
],
ending_2
=
row
[
'ending2'
],
ending_2
=
row
[
'ending2'
],
...
@@ -79,9 +110,100 @@ def read_swag_examples(input_file, is_training):
...
@@ -79,9 +110,100 @@ def read_swag_examples(input_file, is_training):
return
examples
return
examples
def
convert_examples_to_features
(
examples
,
tokenizer
,
max_seq_length
,
is_training
):
"""Loads a data file into a list of `InputBatch`s."""
# Swag is a multiple choice task. To perform this task using Bert,
# we will use the formatting proposed in "Improving Language
# Understanding by Generative Pre-Training" and suggested by
# @jacobdevlin-google in this issue
# https://github.com/google-research/bert/issues/38.
#
# Each choice will correspond to a sample on which we run the
# inference. For a given Swag example, we will create the 4
# following inputs:
# - [CLS] context [SEP] choice_1 [SEP]
# - [CLS] context [SEP] choice_2 [SEP]
# - [CLS] context [SEP] choice_3 [SEP]
# - [CLS] context [SEP] choice_4 [SEP]
# The model will output a single value for each input. To get the
# final decision of the model, we will run a softmax over these 4
# outputs.
features
=
[]
for
example_index
,
example
in
enumerate
(
examples
):
context_tokens
=
tokenizer
.
tokenize
(
example
.
context_sentence
)
start_ending_tokens
=
tokenizer
.
tokenize
(
example
.
start_ending
)
choices_features
=
[]
for
ending_index
,
ending
in
enumerate
(
example
.
endings
):
# We create a copy of the context tokens in order to be
# able to shrink it according to ending_tokens
context_tokens_choice
=
context_tokens
[:]
ending_tokens
=
start_ending_tokens
+
tokenizer
.
tokenize
(
ending
)
# Modifies `context_tokens_choice` and `ending_tokens` in
# place so that the total length is less than the
# specified length. Account for [CLS], [SEP], [SEP] with
# "- 3"
_truncate_seq_pair
(
context_tokens
,
ending_tokens
,
max_seq_length
-
3
)
tokens
=
[
"[CLS]"
]
+
context_tokens_choice
+
[
"[SEP]"
]
+
ending_tokens
+
[
"[SEP]"
]
segment_ids
=
[
0
]
*
(
len
(
context_tokens_choice
)
+
2
)
+
[
1
]
*
(
len
(
ending_tokens
)
+
1
)
input_ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
input_mask
=
[
1
]
*
len
(
input_ids
)
# Zero-pad up to the sequence length.
padding
=
[
0
]
*
(
max_seq_length
-
len
(
input_ids
))
input_ids
+=
padding
input_mask
+=
padding
segment_ids
+=
padding
assert
len
(
input_ids
)
==
max_seq_length
assert
len
(
input_mask
)
==
max_seq_length
assert
len
(
segment_ids
)
==
max_seq_length
choices_features
.
append
((
tokens
,
input_ids
,
input_mask
,
segment_ids
))
label
=
example
.
label
if
example_index
<
5
:
logger
.
info
(
"*** Example ***"
)
logger
.
info
(
f
"swag_id:
{
example
.
swag_id
}
"
)
for
choice_idx
,
(
tokens
,
input_ids
,
input_mask
,
segment_ids
)
in
enumerate
(
choices_features
):
logger
.
info
(
f
"choice:
{
choice_idx
}
"
)
logger
.
info
(
f
"tokens:
{
' '
.
join
(
tokens
)
}
"
)
logger
.
info
(
f
"input_ids:
{
' '
.
join
(
map
(
str
,
input_ids
))
}
"
)
logger
.
info
(
f
"input_mask:
{
' '
.
join
(
map
(
str
,
input_mask
))
}
"
)
logger
.
info
(
f
"segment_ids:
{
' '
.
join
(
map
(
str
,
segment_ids
))
}
"
)
if
is_training
:
logger
.
info
(
f
"label:
{
label
}
"
)
def
_truncate_seq_pair
(
tokens_a
,
tokens_b
,
max_length
):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while
True
:
total_length
=
len
(
tokens_a
)
+
len
(
tokens_b
)
if
total_length
<=
max_length
:
break
if
len
(
tokens_a
)
>
len
(
tokens_b
):
tokens_a
.
pop
()
else
:
tokens_b
.
pop
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
examples
=
read_swag_examples
(
'data/train.csv'
,
True
)
is_training
=
True
max_seq_length
=
80
examples
=
read_swag_examples
(
'data/train.csv'
,
is_training
)
print
(
len
(
examples
))
print
(
len
(
examples
))
for
example
in
examples
[:
5
]:
for
example
in
examples
[:
5
]:
print
(
'
###########################
'
)
print
(
"
###########################
"
)
print
(
example
)
print
(
example
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
"bert-base-uncased"
)
convert_examples_to_features
(
examples
,
tokenizer
,
max_seq_length
,
is_training
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment