"...git@developer.sourcefind.cn:modelzoo/qwen_lmdeploy.git" did not exist on "0d21f366adeea29ef816ff137f4febc71c2416a7"
Commit 74c50358 authored by hlums's avatar hlums
Browse files

Fix token order in xlnet preprocessing.

parent 80889a02
...@@ -302,7 +302,11 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal ...@@ -302,7 +302,11 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride, doc_stride=args.doc_stride,
max_query_length=args.max_query_length, max_query_length=args.max_query_length,
is_training=not evaluate) is_training=not evaluate,
cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0,
pad_token_segment_id=3 if args.model_type in ['xlnet'] else 0,
cls_token_at_end=True if args.model_type in ['xlnet'] else False,
sequence_a_is_doc=True if args.model_type in ['xlnet'] else False)
if args.local_rank in [-1, 0]: if args.local_rank in [-1, 0]:
logger.info("Saving features into cached file %s", cached_features_file) logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file) torch.save(features, cached_features_file)
......
...@@ -192,7 +192,8 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -192,7 +192,8 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
cls_token='[CLS]', sep_token='[SEP]', pad_token=0, cls_token='[CLS]', sep_token='[SEP]', pad_token=0,
sequence_a_segment_id=0, sequence_b_segment_id=1, sequence_a_segment_id=0, sequence_b_segment_id=1,
cls_token_segment_id=0, pad_token_segment_id=0, cls_token_segment_id=0, pad_token_segment_id=0,
mask_padding_with_zero=True): mask_padding_with_zero=True,
sequence_a_is_doc=False):
"""Loads a data file into a list of `InputBatch`s.""" """Loads a data file into a list of `InputBatch`s."""
unique_id = 1000000000 unique_id = 1000000000
...@@ -272,11 +273,13 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -272,11 +273,13 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
p_mask.append(0) p_mask.append(0)
cls_index = 0 cls_index = 0
# XLNet: P SEP Q SEP CLS
# Others: CLS Q SEP P SEP
if not sequence_a_is_doc:
# Query # Query
for token in query_tokens: tokens += query_tokens
tokens.append(token) segment_ids += [sequence_a_segment_id] * len(query_tokens)
segment_ids.append(sequence_a_segment_id) p_mask += [1] * len(query_tokens)
p_mask.append(1)
# SEP token # SEP token
tokens.append(sep_token) tokens.append(sep_token)
...@@ -292,10 +295,23 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -292,10 +295,23 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
split_token_index) split_token_index)
token_is_max_context[len(tokens)] = is_max_context token_is_max_context[len(tokens)] = is_max_context
tokens.append(all_doc_tokens[split_token_index]) tokens.append(all_doc_tokens[split_token_index])
if not sequence_a_is_doc:
segment_ids.append(sequence_b_segment_id) segment_ids.append(sequence_b_segment_id)
else:
segment_ids.append(sequence_a_segment_id)
p_mask.append(0) p_mask.append(0)
paragraph_len = doc_span.length paragraph_len = doc_span.length
if sequence_a_is_doc:
# SEP token
tokens.append(sep_token)
segment_ids.append(sequence_a_segment_id)
p_mask.append(1)
tokens += query_tokens
segment_ids += [sequence_b_segment_id] * len(query_tokens)
p_mask += [1] * len(query_tokens)
# SEP token # SEP token
tokens.append(sep_token) tokens.append(sep_token)
segment_ids.append(sequence_b_segment_id) segment_ids.append(sequence_b_segment_id)
...@@ -341,6 +357,9 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -341,6 +357,9 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
start_position = 0 start_position = 0
end_position = 0 end_position = 0
span_is_impossible = True span_is_impossible = True
else:
if sequence_a_is_doc:
doc_offset = 0
else: else:
doc_offset = len(query_tokens) + 2 doc_offset = len(query_tokens) + 2
start_position = tok_start_position - doc_start + doc_offset start_position = tok_start_position - doc_start + doc_offset
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment