Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c6d9d539
Commit
c6d9d539
authored
Dec 05, 2018
by
Grégory Châtel
Browse files
Simplifying code for easier understanding.
parent
793262e8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
24 deletions
+10
-24
examples/run_classifier.py
examples/run_classifier.py
+10
-24
No files found.
examples/run_classifier.py
View file @
c6d9d539
...
...
@@ -196,9 +196,7 @@ class ColaProcessor(DataProcessor):
def
convert_examples_to_features
(
examples
,
label_list
,
max_seq_length
,
tokenizer
):
"""Loads a data file into a list of `InputBatch`s."""
label_map
=
{}
for
(
i
,
label
)
in
enumerate
(
label_list
):
label_map
[
label
]
=
i
label_map
=
{
label
:
i
for
i
,
label
in
enumerate
(
label_list
)}
features
=
[]
for
(
ex_index
,
example
)
in
enumerate
(
examples
):
...
...
@@ -207,8 +205,6 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer
tokens_b
=
None
if
example
.
text_b
:
tokens_b
=
tokenizer
.
tokenize
(
example
.
text_b
)
if
tokens_b
:
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3"
...
...
@@ -216,7 +212,7 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer
else
:
# Account for [CLS] and [SEP] with "- 2"
if
len
(
tokens_a
)
>
max_seq_length
-
2
:
tokens_a
=
tokens_a
[
0
:(
max_seq_length
-
2
)]
tokens_a
=
tokens_a
[:(
max_seq_length
-
2
)]
# The convention in BERT is:
# (a) For sequence pairs:
...
...
@@ -236,22 +232,12 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer
# For classification tasks, the first vector (corresponding to [CLS]) is
# used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens
=
[]
segment_ids
=
[]
tokens
.
append
(
"[CLS]"
)
segment_ids
.
append
(
0
)
for
token
in
tokens_a
:
tokens
.
append
(
token
)
segment_ids
.
append
(
0
)
tokens
.
append
(
"[SEP]"
)
segment_ids
.
append
(
0
)
tokens
=
[
"[CLS]"
]
+
tokens_a
+
[
"[SEP]"
]
segment_ids
=
[
0
]
*
len
(
tokens
)
if
tokens_b
:
for
token
in
tokens_b
:
tokens
.
append
(
token
)
segment_ids
.
append
(
1
)
tokens
.
append
(
"[SEP]"
)
segment_ids
.
append
(
1
)
tokens
+=
tokens_b
+
[
"[SEP]"
]
segment_ids
+=
[
1
]
*
(
len
(
tokens_b
)
+
1
)
input_ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
...
...
@@ -260,10 +246,10 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer
input_mask
=
[
1
]
*
len
(
input_ids
)
# Zero-pad up to the sequence length.
while
len
(
input_ids
)
<
max_seq_length
:
input_ids
.
append
(
0
)
input_mask
.
append
(
0
)
segment_ids
.
append
(
0
)
padding
=
[
0
]
*
(
max_seq_length
-
len
(
input_ids
))
input_ids
+=
padding
input_mask
+=
padding
segment_ids
+=
padding
assert
len
(
input_ids
)
==
max_seq_length
assert
len
(
input_mask
)
==
max_seq_length
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment