Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7334bf6c
Commit
7334bf6c
authored
Jun 24, 2019
by
thomwolf
Browse files
pad on left for xlnet
parent
c888663f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
15 deletions
+29
-15
examples/run_xlnet_classifier.py
examples/run_xlnet_classifier.py
+12
-6
examples/utils_glue.py
examples/utils_glue.py
+17
-9
No files found.
examples/run_xlnet_classifier.py
View file @
7334bf6c
...
@@ -198,14 +198,17 @@ def main():
...
@@ -198,14 +198,17 @@ def main():
list
(
filter
(
None
,
args
.
xlnet_model
.
split
(
'/'
))).
pop
(),
list
(
filter
(
None
,
args
.
xlnet_model
.
split
(
'/'
))).
pop
(),
str
(
args
.
max_seq_length
),
str
(
args
.
max_seq_length
),
str
(
task_name
)))
str
(
task_name
)))
try
:
if
os
.
path
.
exists
(
cached_train_features_file
):
logger
.
info
(
"Loading train features for cache file %s"
,
cached_train_features_file
)
with
open
(
cached_train_features_file
,
"rb"
)
as
reader
:
with
open
(
cached_train_features_file
,
"rb"
)
as
reader
:
train_features
=
pickle
.
load
(
reader
)
train_features
=
pickle
.
load
(
reader
)
except
:
else
:
logger
.
info
(
"No cache file at %s, preparing train features"
,
cached_train_features_file
)
train_features
=
convert_examples_to_features
(
train_features
=
convert_examples_to_features
(
train_examples
,
label_list
,
args
.
max_seq_length
,
tokenizer
,
output_mode
,
train_examples
,
label_list
,
args
.
max_seq_length
,
tokenizer
,
output_mode
,
cls_token_at_end
=
True
,
cls_token
=
tokenizer
.
CLS_TOKEN
,
cls_token_at_end
=
True
,
cls_token
=
tokenizer
.
CLS_TOKEN
,
sep_token
=
tokenizer
.
SEP_TOKEN
,
cls_token_segment_id
=
2
)
sep_token
=
tokenizer
.
SEP_TOKEN
,
cls_token_segment_id
=
2
,
pad_on_left
=
True
,
pad_token_segment_id
=
4
)
if
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
:
if
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
:
logger
.
info
(
" Saving train features into cached file %s"
,
cached_train_features_file
)
logger
.
info
(
" Saving train features into cached file %s"
,
cached_train_features_file
)
with
open
(
cached_train_features_file
,
"wb"
)
as
writer
:
with
open
(
cached_train_features_file
,
"wb"
)
as
writer
:
...
@@ -344,14 +347,17 @@ def main():
...
@@ -344,14 +347,17 @@ def main():
list
(
filter
(
None
,
args
.
xlnet_model
.
split
(
'/'
))).
pop
(),
list
(
filter
(
None
,
args
.
xlnet_model
.
split
(
'/'
))).
pop
(),
str
(
args
.
max_seq_length
),
str
(
args
.
max_seq_length
),
str
(
task_name
)))
str
(
task_name
)))
try
:
if
os
.
path
.
exists
(
cached_eval_features_file
):
logger
.
info
(
"Loading eval features for cache file %s"
,
cached_eval_features_file
)
with
open
(
cached_eval_features_file
,
"rb"
)
as
reader
:
with
open
(
cached_eval_features_file
,
"rb"
)
as
reader
:
eval_features
=
pickle
.
load
(
reader
)
eval_features
=
pickle
.
load
(
reader
)
except
:
else
:
logger
.
info
(
"No cache file at %s, preparing eval features"
,
cached_eval_features_file
)
eval_features
=
convert_examples_to_features
(
eval_features
=
convert_examples_to_features
(
eval_examples
,
label_list
,
args
.
max_seq_length
,
tokenizer
,
output_mode
,
eval_examples
,
label_list
,
args
.
max_seq_length
,
tokenizer
,
output_mode
,
cls_token_at_end
=
True
,
cls_token
=
tokenizer
.
CLS_TOKEN
,
cls_token_at_end
=
True
,
cls_token
=
tokenizer
.
CLS_TOKEN
,
sep_token
=
tokenizer
.
SEP_TOKEN
,
cls_token_segment_id
=
2
)
sep_token
=
tokenizer
.
SEP_TOKEN
,
cls_token_segment_id
=
2
,
pad_on_left
=
True
,
pad_token_segment_id
=
4
)
if
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
:
if
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
:
logger
.
info
(
" Saving eval features into cached file %s"
,
cached_eval_features_file
)
logger
.
info
(
" Saving eval features into cached file %s"
,
cached_eval_features_file
)
with
open
(
cached_eval_features_file
,
"wb"
)
as
writer
:
with
open
(
cached_eval_features_file
,
"wb"
)
as
writer
:
...
...
examples/utils_glue.py
View file @
7334bf6c
...
@@ -389,8 +389,11 @@ class WnliProcessor(DataProcessor):
...
@@ -389,8 +389,11 @@ class WnliProcessor(DataProcessor):
def
convert_examples_to_features
(
examples
,
label_list
,
max_seq_length
,
def
convert_examples_to_features
(
examples
,
label_list
,
max_seq_length
,
tokenizer
,
output_mode
,
tokenizer
,
output_mode
,
cls_token_at_end
=
False
,
cls_token
=
'[CLS]'
,
cls_token_at_end
=
False
,
pad_on_left
=
False
,
sep_token
=
'[SEP]'
,
cls_token_segment_id
=
0
):
cls_token
=
'[CLS]'
,
sep_token
=
'[SEP]'
,
pad_token
=
0
,
sequence_a_segment_id
=
0
,
sequence_b_segment_id
=
1
,
cls_token_segment_id
=
1
,
pad_token_segment_id
=
0
,
mask_padding_with_zero
=
True
):
""" Loads a data file into a list of `InputBatch`s
""" Loads a data file into a list of `InputBatch`s
`cls_token_at_end` define the location of the CLS token:
`cls_token_at_end` define the location of the CLS token:
- False (BERT pattern): [CLS] + A + [SEP] + B + [SEP]
- False (BERT pattern): [CLS] + A + [SEP] + B + [SEP]
...
@@ -438,11 +441,11 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
...
@@ -438,11 +441,11 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
# used as as the "sentence vector". Note that this only makes sense because
# used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
# the entire model is fine-tuned.
tokens
=
tokens_a
+
[
sep_token
]
tokens
=
tokens_a
+
[
sep_token
]
segment_ids
=
[
0
]
*
len
(
tokens
)
segment_ids
=
[
sequence_a_segment_id
]
*
len
(
tokens
)
if
tokens_b
:
if
tokens_b
:
tokens
+=
tokens_b
+
[
sep_token
]
tokens
+=
tokens_b
+
[
sep_token
]
segment_ids
+=
[
1
]
*
(
len
(
tokens_b
)
+
1
)
segment_ids
+=
[
sequence_b_segment_id
]
*
(
len
(
tokens_b
)
+
1
)
if
cls_token_at_end
:
if
cls_token_at_end
:
tokens
=
tokens
+
[
cls_token
]
tokens
=
tokens
+
[
cls_token
]
...
@@ -455,13 +458,18 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
...
@@ -455,13 +458,18 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
# tokens are attended to.
input_mask
=
[
1
]
*
len
(
input_ids
)
input_mask
=
[
1
if
mask_padding_with_zero
else
0
]
*
len
(
input_ids
)
# Zero-pad up to the sequence length.
# Zero-pad up to the sequence length.
padding
=
[
0
]
*
(
max_seq_length
-
len
(
input_ids
))
padding_length
=
max_seq_length
-
len
(
input_ids
)
input_ids
+=
padding
if
pad_on_left
:
input_mask
+=
padding
input_ids
=
([
pad_token
]
*
padding_length
)
+
input_ids
segment_ids
+=
padding
input_mask
=
([
0
if
mask_padding_with_zero
else
1
]
*
padding_length
)
+
input_mask
segment_ids
=
([
pad_token_segment_id
]
*
padding_length
)
+
segment_ids
else
:
input_ids
=
input_ids
+
([
pad_token
]
*
padding_length
)
input_mask
=
input_mask
+
([
0
if
mask_padding_with_zero
else
1
]
*
padding_length
)
segment_ids
=
segment_ids
+
([
pad_token_segment_id
]
*
padding_length
)
assert
len
(
input_ids
)
==
max_seq_length
assert
len
(
input_ids
)
==
max_seq_length
assert
len
(
input_mask
)
==
max_seq_length
assert
len
(
input_mask
)
==
max_seq_length
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment