Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c6d9d539
Commit
c6d9d539
authored
Dec 05, 2018
by
Grégory Châtel
Browse files
Simplifying code for easier understanding.
parent
793262e8
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
24 deletions
+10
-24
examples/run_classifier.py
examples/run_classifier.py
+10
-24
No files found.
examples/run_classifier.py
View file @
c6d9d539
...
...
@@ -196,9 +196,7 @@ class ColaProcessor(DataProcessor):
def
convert_examples_to_features
(
examples
,
label_list
,
max_seq_length
,
tokenizer
):
"""Loads a data file into a list of `InputBatch`s."""
label_map
=
{}
for
(
i
,
label
)
in
enumerate
(
label_list
):
label_map
[
label
]
=
i
label_map
=
{
label
:
i
for
i
,
label
in
enumerate
(
label_list
)}
features
=
[]
for
(
ex_index
,
example
)
in
enumerate
(
examples
):
...
...
@@ -207,8 +205,6 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer
tokens_b
=
None
if
example
.
text_b
:
tokens_b
=
tokenizer
.
tokenize
(
example
.
text_b
)
if
tokens_b
:
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3"
...
...
@@ -216,7 +212,7 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer
else
:
# Account for [CLS] and [SEP] with "- 2"
if
len
(
tokens_a
)
>
max_seq_length
-
2
:
tokens_a
=
tokens_a
[
0
:(
max_seq_length
-
2
)]
tokens_a
=
tokens_a
[:(
max_seq_length
-
2
)]
# The convention in BERT is:
# (a) For sequence pairs:
...
...
@@ -236,22 +232,12 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer
# For classification tasks, the first vector (corresponding to [CLS]) is
# used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens
=
[]
segment_ids
=
[]
tokens
.
append
(
"[CLS]"
)
segment_ids
.
append
(
0
)
for
token
in
tokens_a
:
tokens
.
append
(
token
)
segment_ids
.
append
(
0
)
tokens
.
append
(
"[SEP]"
)
segment_ids
.
append
(
0
)
tokens
=
[
"[CLS]"
]
+
tokens_a
+
[
"[SEP]"
]
segment_ids
=
[
0
]
*
len
(
tokens
)
if
tokens_b
:
for
token
in
tokens_b
:
tokens
.
append
(
token
)
segment_ids
.
append
(
1
)
tokens
.
append
(
"[SEP]"
)
segment_ids
.
append
(
1
)
tokens
+=
tokens_b
+
[
"[SEP]"
]
segment_ids
+=
[
1
]
*
(
len
(
tokens_b
)
+
1
)
input_ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
...
...
@@ -260,10 +246,10 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer
input_mask
=
[
1
]
*
len
(
input_ids
)
# Zero-pad up to the sequence length.
while
len
(
input_ids
)
<
max_seq_length
:
input_ids
.
append
(
0
)
input_mask
.
append
(
0
)
segment_ids
.
append
(
0
)
padding
=
[
0
]
*
(
max_seq_length
-
len
(
input_ids
))
input_ids
+=
padding
input_mask
+=
padding
segment_ids
+=
padding
assert
len
(
input_ids
)
==
max_seq_length
assert
len
(
input_mask
)
==
max_seq_length
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment