Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7334bf6c
Commit
7334bf6c
authored
Jun 24, 2019
by
thomwolf
Browse files
pad on left for xlnet
parent
c888663f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
15 deletions
+29
-15
examples/run_xlnet_classifier.py
examples/run_xlnet_classifier.py
+12
-6
examples/utils_glue.py
examples/utils_glue.py
+17
-9
No files found.
examples/run_xlnet_classifier.py
View file @
7334bf6c
...
@@ -198,14 +198,17 @@ def main():
...
@@ -198,14 +198,17 @@ def main():
list
(
filter
(
None
,
args
.
xlnet_model
.
split
(
'/'
))).
pop
(),
list
(
filter
(
None
,
args
.
xlnet_model
.
split
(
'/'
))).
pop
(),
str
(
args
.
max_seq_length
),
str
(
args
.
max_seq_length
),
str
(
task_name
)))
str
(
task_name
)))
try
:
if
os
.
path
.
exists
(
cached_train_features_file
):
logger
.
info
(
"Loading train features for cache file %s"
,
cached_train_features_file
)
with
open
(
cached_train_features_file
,
"rb"
)
as
reader
:
with
open
(
cached_train_features_file
,
"rb"
)
as
reader
:
train_features
=
pickle
.
load
(
reader
)
train_features
=
pickle
.
load
(
reader
)
except
:
else
:
logger
.
info
(
"No cache file at %s, preparing train features"
,
cached_train_features_file
)
train_features
=
convert_examples_to_features
(
train_features
=
convert_examples_to_features
(
train_examples
,
label_list
,
args
.
max_seq_length
,
tokenizer
,
output_mode
,
train_examples
,
label_list
,
args
.
max_seq_length
,
tokenizer
,
output_mode
,
cls_token_at_end
=
True
,
cls_token
=
tokenizer
.
CLS_TOKEN
,
cls_token_at_end
=
True
,
cls_token
=
tokenizer
.
CLS_TOKEN
,
sep_token
=
tokenizer
.
SEP_TOKEN
,
cls_token_segment_id
=
2
)
sep_token
=
tokenizer
.
SEP_TOKEN
,
cls_token_segment_id
=
2
,
pad_on_left
=
True
,
pad_token_segment_id
=
4
)
if
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
:
if
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
:
logger
.
info
(
" Saving train features into cached file %s"
,
cached_train_features_file
)
logger
.
info
(
" Saving train features into cached file %s"
,
cached_train_features_file
)
with
open
(
cached_train_features_file
,
"wb"
)
as
writer
:
with
open
(
cached_train_features_file
,
"wb"
)
as
writer
:
...
@@ -344,14 +347,17 @@ def main():
...
@@ -344,14 +347,17 @@ def main():
list
(
filter
(
None
,
args
.
xlnet_model
.
split
(
'/'
))).
pop
(),
list
(
filter
(
None
,
args
.
xlnet_model
.
split
(
'/'
))).
pop
(),
str
(
args
.
max_seq_length
),
str
(
args
.
max_seq_length
),
str
(
task_name
)))
str
(
task_name
)))
try
:
if
os
.
path
.
exists
(
cached_eval_features_file
):
logger
.
info
(
"Loading eval features for cache file %s"
,
cached_eval_features_file
)
with
open
(
cached_eval_features_file
,
"rb"
)
as
reader
:
with
open
(
cached_eval_features_file
,
"rb"
)
as
reader
:
eval_features
=
pickle
.
load
(
reader
)
eval_features
=
pickle
.
load
(
reader
)
except
:
else
:
logger
.
info
(
"No cache file at %s, preparing eval features"
,
cached_eval_features_file
)
eval_features
=
convert_examples_to_features
(
eval_features
=
convert_examples_to_features
(
eval_examples
,
label_list
,
args
.
max_seq_length
,
tokenizer
,
output_mode
,
eval_examples
,
label_list
,
args
.
max_seq_length
,
tokenizer
,
output_mode
,
cls_token_at_end
=
True
,
cls_token
=
tokenizer
.
CLS_TOKEN
,
cls_token_at_end
=
True
,
cls_token
=
tokenizer
.
CLS_TOKEN
,
sep_token
=
tokenizer
.
SEP_TOKEN
,
cls_token_segment_id
=
2
)
sep_token
=
tokenizer
.
SEP_TOKEN
,
cls_token_segment_id
=
2
,
pad_on_left
=
True
,
pad_token_segment_id
=
4
)
if
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
:
if
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
:
logger
.
info
(
" Saving eval features into cached file %s"
,
cached_eval_features_file
)
logger
.
info
(
" Saving eval features into cached file %s"
,
cached_eval_features_file
)
with
open
(
cached_eval_features_file
,
"wb"
)
as
writer
:
with
open
(
cached_eval_features_file
,
"wb"
)
as
writer
:
...
...
examples/utils_glue.py
View file @
7334bf6c
...
@@ -389,8 +389,11 @@ class WnliProcessor(DataProcessor):
...
@@ -389,8 +389,11 @@ class WnliProcessor(DataProcessor):
def
convert_examples_to_features
(
examples
,
label_list
,
max_seq_length
,
def
convert_examples_to_features
(
examples
,
label_list
,
max_seq_length
,
tokenizer
,
output_mode
,
tokenizer
,
output_mode
,
cls_token_at_end
=
False
,
cls_token
=
'[CLS]'
,
cls_token_at_end
=
False
,
pad_on_left
=
False
,
sep_token
=
'[SEP]'
,
cls_token_segment_id
=
0
):
cls_token
=
'[CLS]'
,
sep_token
=
'[SEP]'
,
pad_token
=
0
,
sequence_a_segment_id
=
0
,
sequence_b_segment_id
=
1
,
cls_token_segment_id
=
1
,
pad_token_segment_id
=
0
,
mask_padding_with_zero
=
True
):
""" Loads a data file into a list of `InputBatch`s
""" Loads a data file into a list of `InputBatch`s
`cls_token_at_end` define the location of the CLS token:
`cls_token_at_end` define the location of the CLS token:
- False (BERT pattern): [CLS] + A + [SEP] + B + [SEP]
- False (BERT pattern): [CLS] + A + [SEP] + B + [SEP]
...
@@ -438,11 +441,11 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
...
@@ -438,11 +441,11 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
# used as as the "sentence vector". Note that this only makes sense because
# used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
# the entire model is fine-tuned.
tokens
=
tokens_a
+
[
sep_token
]
tokens
=
tokens_a
+
[
sep_token
]
segment_ids
=
[
0
]
*
len
(
tokens
)
segment_ids
=
[
sequence_a_segment_id
]
*
len
(
tokens
)
if
tokens_b
:
if
tokens_b
:
tokens
+=
tokens_b
+
[
sep_token
]
tokens
+=
tokens_b
+
[
sep_token
]
segment_ids
+=
[
1
]
*
(
len
(
tokens_b
)
+
1
)
segment_ids
+=
[
sequence_b_segment_id
]
*
(
len
(
tokens_b
)
+
1
)
if
cls_token_at_end
:
if
cls_token_at_end
:
tokens
=
tokens
+
[
cls_token
]
tokens
=
tokens
+
[
cls_token
]
...
@@ -455,13 +458,18 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
...
@@ -455,13 +458,18 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
# tokens are attended to.
input_mask
=
[
1
]
*
len
(
input_ids
)
input_mask
=
[
1
if
mask_padding_with_zero
else
0
]
*
len
(
input_ids
)
# Zero-pad up to the sequence length.
# Zero-pad up to the sequence length.
padding
=
[
0
]
*
(
max_seq_length
-
len
(
input_ids
))
padding_length
=
max_seq_length
-
len
(
input_ids
)
input_ids
+=
padding
if
pad_on_left
:
input_mask
+=
padding
input_ids
=
([
pad_token
]
*
padding_length
)
+
input_ids
segment_ids
+=
padding
input_mask
=
([
0
if
mask_padding_with_zero
else
1
]
*
padding_length
)
+
input_mask
segment_ids
=
([
pad_token_segment_id
]
*
padding_length
)
+
segment_ids
else
:
input_ids
=
input_ids
+
([
pad_token
]
*
padding_length
)
input_mask
=
input_mask
+
([
0
if
mask_padding_with_zero
else
1
]
*
padding_length
)
segment_ids
=
segment_ids
+
([
pad_token_segment_id
]
*
padding_length
)
assert
len
(
input_ids
)
==
max_seq_length
assert
len
(
input_ids
)
==
max_seq_length
assert
len
(
input_mask
)
==
max_seq_length
assert
len
(
input_mask
)
==
max_seq_length
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment