Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
74c50358
Commit
74c50358
authored
Oct 14, 2019
by
hlums
Browse files
Fix token order in xlnet preprocessing.
parent
80889a02
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
12 deletions
+35
-12
examples/run_squad.py
examples/run_squad.py
+5
-1
examples/utils_squad.py
examples/utils_squad.py
+30
-11
No files found.
examples/run_squad.py
View file @
74c50358
...
@@ -302,7 +302,11 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
...
@@ -302,7 +302,11 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
max_seq_length
=
args
.
max_seq_length
,
max_seq_length
=
args
.
max_seq_length
,
doc_stride
=
args
.
doc_stride
,
doc_stride
=
args
.
doc_stride
,
max_query_length
=
args
.
max_query_length
,
max_query_length
=
args
.
max_query_length
,
is_training
=
not
evaluate
)
is_training
=
not
evaluate
,
cls_token_segment_id
=
2
if
args
.
model_type
in
[
'xlnet'
]
else
0
,
pad_token_segment_id
=
3
if
args
.
model_type
in
[
'xlnet'
]
else
0
,
cls_token_at_end
=
True
if
args
.
model_type
in
[
'xlnet'
]
else
False
,
sequence_a_is_doc
=
True
if
args
.
model_type
in
[
'xlnet'
]
else
False
)
if
args
.
local_rank
in
[
-
1
,
0
]:
if
args
.
local_rank
in
[
-
1
,
0
]:
logger
.
info
(
"Saving features into cached file %s"
,
cached_features_file
)
logger
.
info
(
"Saving features into cached file %s"
,
cached_features_file
)
torch
.
save
(
features
,
cached_features_file
)
torch
.
save
(
features
,
cached_features_file
)
...
...
examples/utils_squad.py
View file @
74c50358
...
@@ -192,7 +192,8 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -192,7 +192,8 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
cls_token
=
'[CLS]'
,
sep_token
=
'[SEP]'
,
pad_token
=
0
,
cls_token
=
'[CLS]'
,
sep_token
=
'[SEP]'
,
pad_token
=
0
,
sequence_a_segment_id
=
0
,
sequence_b_segment_id
=
1
,
sequence_a_segment_id
=
0
,
sequence_b_segment_id
=
1
,
cls_token_segment_id
=
0
,
pad_token_segment_id
=
0
,
cls_token_segment_id
=
0
,
pad_token_segment_id
=
0
,
mask_padding_with_zero
=
True
):
mask_padding_with_zero
=
True
,
sequence_a_is_doc
=
False
):
"""Loads a data file into a list of `InputBatch`s."""
"""Loads a data file into a list of `InputBatch`s."""
unique_id
=
1000000000
unique_id
=
1000000000
...
@@ -272,17 +273,19 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -272,17 +273,19 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
p_mask
.
append
(
0
)
p_mask
.
append
(
0
)
cls_index
=
0
cls_index
=
0
# Query
# XLNet: P SEP Q SEP CLS
for
token
in
query_tokens
:
# Others: CLS Q SEP P SEP
tokens
.
append
(
token
)
if
not
sequence_a_is_doc
:
# Query
tokens
+=
query_tokens
segment_ids
+=
[
sequence_a_segment_id
]
*
len
(
query_tokens
)
p_mask
+=
[
1
]
*
len
(
query_tokens
)
# SEP token
tokens
.
append
(
sep_token
)
segment_ids
.
append
(
sequence_a_segment_id
)
segment_ids
.
append
(
sequence_a_segment_id
)
p_mask
.
append
(
1
)
p_mask
.
append
(
1
)
# SEP token
tokens
.
append
(
sep_token
)
segment_ids
.
append
(
sequence_a_segment_id
)
p_mask
.
append
(
1
)
# Paragraph
# Paragraph
for
i
in
range
(
doc_span
.
length
):
for
i
in
range
(
doc_span
.
length
):
split_token_index
=
doc_span
.
start
+
i
split_token_index
=
doc_span
.
start
+
i
...
@@ -292,10 +295,23 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -292,10 +295,23 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
split_token_index
)
split_token_index
)
token_is_max_context
[
len
(
tokens
)]
=
is_max_context
token_is_max_context
[
len
(
tokens
)]
=
is_max_context
tokens
.
append
(
all_doc_tokens
[
split_token_index
])
tokens
.
append
(
all_doc_tokens
[
split_token_index
])
segment_ids
.
append
(
sequence_b_segment_id
)
if
not
sequence_a_is_doc
:
segment_ids
.
append
(
sequence_b_segment_id
)
else
:
segment_ids
.
append
(
sequence_a_segment_id
)
p_mask
.
append
(
0
)
p_mask
.
append
(
0
)
paragraph_len
=
doc_span
.
length
paragraph_len
=
doc_span
.
length
if
sequence_a_is_doc
:
# SEP token
tokens
.
append
(
sep_token
)
segment_ids
.
append
(
sequence_a_segment_id
)
p_mask
.
append
(
1
)
tokens
+=
query_tokens
segment_ids
+=
[
sequence_b_segment_id
]
*
len
(
query_tokens
)
p_mask
+=
[
1
]
*
len
(
query_tokens
)
# SEP token
# SEP token
tokens
.
append
(
sep_token
)
tokens
.
append
(
sep_token
)
segment_ids
.
append
(
sequence_b_segment_id
)
segment_ids
.
append
(
sequence_b_segment_id
)
...
@@ -342,7 +358,10 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -342,7 +358,10 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
end_position
=
0
end_position
=
0
span_is_impossible
=
True
span_is_impossible
=
True
else
:
else
:
doc_offset
=
len
(
query_tokens
)
+
2
if
sequence_a_is_doc
:
doc_offset
=
0
else
:
doc_offset
=
len
(
query_tokens
)
+
2
start_position
=
tok_start_position
-
doc_start
+
doc_offset
start_position
=
tok_start_position
-
doc_start
+
doc_offset
end_position
=
tok_end_position
-
doc_start
+
doc_offset
end_position
=
tok_end_position
-
doc_start
+
doc_offset
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment