Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7334bf6c
Commit
7334bf6c
authored
Jun 24, 2019
by
thomwolf
Browse files
pad on left for xlnet
parent
c888663f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
15 deletions
+29
-15
examples/run_xlnet_classifier.py
examples/run_xlnet_classifier.py
+12
-6
examples/utils_glue.py
examples/utils_glue.py
+17
-9
No files found.
examples/run_xlnet_classifier.py
View file @
7334bf6c
...
...
@@ -198,14 +198,17 @@ def main():
list
(
filter
(
None
,
args
.
xlnet_model
.
split
(
'/'
))).
pop
(),
str
(
args
.
max_seq_length
),
str
(
task_name
)))
try
:
if
os
.
path
.
exists
(
cached_train_features_file
):
logger
.
info
(
"Loading train features for cache file %s"
,
cached_train_features_file
)
with
open
(
cached_train_features_file
,
"rb"
)
as
reader
:
train_features
=
pickle
.
load
(
reader
)
except
:
else
:
logger
.
info
(
"No cache file at %s, preparing train features"
,
cached_train_features_file
)
train_features
=
convert_examples_to_features
(
train_examples
,
label_list
,
args
.
max_seq_length
,
tokenizer
,
output_mode
,
cls_token_at_end
=
True
,
cls_token
=
tokenizer
.
CLS_TOKEN
,
sep_token
=
tokenizer
.
SEP_TOKEN
,
cls_token_segment_id
=
2
)
sep_token
=
tokenizer
.
SEP_TOKEN
,
cls_token_segment_id
=
2
,
pad_on_left
=
True
,
pad_token_segment_id
=
4
)
if
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
:
logger
.
info
(
" Saving train features into cached file %s"
,
cached_train_features_file
)
with
open
(
cached_train_features_file
,
"wb"
)
as
writer
:
...
...
@@ -344,14 +347,17 @@ def main():
list
(
filter
(
None
,
args
.
xlnet_model
.
split
(
'/'
))).
pop
(),
str
(
args
.
max_seq_length
),
str
(
task_name
)))
try
:
if
os
.
path
.
exists
(
cached_eval_features_file
):
logger
.
info
(
"Loading eval features for cache file %s"
,
cached_eval_features_file
)
with
open
(
cached_eval_features_file
,
"rb"
)
as
reader
:
eval_features
=
pickle
.
load
(
reader
)
except
:
else
:
logger
.
info
(
"No cache file at %s, preparing eval features"
,
cached_eval_features_file
)
eval_features
=
convert_examples_to_features
(
eval_examples
,
label_list
,
args
.
max_seq_length
,
tokenizer
,
output_mode
,
cls_token_at_end
=
True
,
cls_token
=
tokenizer
.
CLS_TOKEN
,
sep_token
=
tokenizer
.
SEP_TOKEN
,
cls_token_segment_id
=
2
)
sep_token
=
tokenizer
.
SEP_TOKEN
,
cls_token_segment_id
=
2
,
pad_on_left
=
True
,
pad_token_segment_id
=
4
)
if
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
:
logger
.
info
(
" Saving eval features into cached file %s"
,
cached_eval_features_file
)
with
open
(
cached_eval_features_file
,
"wb"
)
as
writer
:
...
...
examples/utils_glue.py
View file @
7334bf6c
...
...
@@ -389,8 +389,11 @@ class WnliProcessor(DataProcessor):
def
convert_examples_to_features
(
examples
,
label_list
,
max_seq_length
,
tokenizer
,
output_mode
,
cls_token_at_end
=
False
,
cls_token
=
'[CLS]'
,
sep_token
=
'[SEP]'
,
cls_token_segment_id
=
0
):
cls_token_at_end
=
False
,
pad_on_left
=
False
,
cls_token
=
'[CLS]'
,
sep_token
=
'[SEP]'
,
pad_token
=
0
,
sequence_a_segment_id
=
0
,
sequence_b_segment_id
=
1
,
cls_token_segment_id
=
1
,
pad_token_segment_id
=
0
,
mask_padding_with_zero
=
True
):
""" Loads a data file into a list of `InputBatch`s
`cls_token_at_end` define the location of the CLS token:
- False (BERT pattern): [CLS] + A + [SEP] + B + [SEP]
...
...
@@ -438,11 +441,11 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
# used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens
=
tokens_a
+
[
sep_token
]
segment_ids
=
[
0
]
*
len
(
tokens
)
segment_ids
=
[
sequence_a_segment_id
]
*
len
(
tokens
)
if
tokens_b
:
tokens
+=
tokens_b
+
[
sep_token
]
segment_ids
+=
[
1
]
*
(
len
(
tokens_b
)
+
1
)
segment_ids
+=
[
sequence_b_segment_id
]
*
(
len
(
tokens_b
)
+
1
)
if
cls_token_at_end
:
tokens
=
tokens
+
[
cls_token
]
...
...
@@ -455,13 +458,18 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask
=
[
1
]
*
len
(
input_ids
)
input_mask
=
[
1
if
mask_padding_with_zero
else
0
]
*
len
(
input_ids
)
# Zero-pad up to the sequence length.
padding
=
[
0
]
*
(
max_seq_length
-
len
(
input_ids
))
input_ids
+=
padding
input_mask
+=
padding
segment_ids
+=
padding
padding_length
=
max_seq_length
-
len
(
input_ids
)
if
pad_on_left
:
input_ids
=
([
pad_token
]
*
padding_length
)
+
input_ids
input_mask
=
([
0
if
mask_padding_with_zero
else
1
]
*
padding_length
)
+
input_mask
segment_ids
=
([
pad_token_segment_id
]
*
padding_length
)
+
segment_ids
else
:
input_ids
=
input_ids
+
([
pad_token
]
*
padding_length
)
input_mask
=
input_mask
+
([
0
if
mask_padding_with_zero
else
1
]
*
padding_length
)
segment_ids
=
segment_ids
+
([
pad_token_segment_id
]
*
padding_length
)
assert
len
(
input_ids
)
==
max_seq_length
assert
len
(
input_mask
)
==
max_seq_length
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment