Commit 7334bf6c authored by thomwolf's avatar thomwolf
Browse files

pad on left for xlnet

parent c888663f
...@@ -198,14 +198,17 @@ def main(): ...@@ -198,14 +198,17 @@ def main():
list(filter(None, args.xlnet_model.split('/'))).pop(), list(filter(None, args.xlnet_model.split('/'))).pop(),
str(args.max_seq_length), str(args.max_seq_length),
str(task_name))) str(task_name)))
try: if os.path.exists(cached_train_features_file):
logger.info("Loading train features for cache file %s", cached_train_features_file)
with open(cached_train_features_file, "rb") as reader: with open(cached_train_features_file, "rb") as reader:
train_features = pickle.load(reader) train_features = pickle.load(reader)
except: else:
logger.info("No cache file at %s, preparing train features", cached_train_features_file)
train_features = convert_examples_to_features( train_features = convert_examples_to_features(
train_examples, label_list, args.max_seq_length, tokenizer, output_mode, train_examples, label_list, args.max_seq_length, tokenizer, output_mode,
cls_token_at_end=True, cls_token=tokenizer.CLS_TOKEN, cls_token_at_end=True, cls_token=tokenizer.CLS_TOKEN,
sep_token=tokenizer.SEP_TOKEN, cls_token_segment_id=2) sep_token=tokenizer.SEP_TOKEN, cls_token_segment_id=2,
pad_on_left=True, pad_token_segment_id=4)
if args.local_rank == -1 or torch.distributed.get_rank() == 0: if args.local_rank == -1 or torch.distributed.get_rank() == 0:
logger.info(" Saving train features into cached file %s", cached_train_features_file) logger.info(" Saving train features into cached file %s", cached_train_features_file)
with open(cached_train_features_file, "wb") as writer: with open(cached_train_features_file, "wb") as writer:
...@@ -344,14 +347,17 @@ def main(): ...@@ -344,14 +347,17 @@ def main():
list(filter(None, args.xlnet_model.split('/'))).pop(), list(filter(None, args.xlnet_model.split('/'))).pop(),
str(args.max_seq_length), str(args.max_seq_length),
str(task_name))) str(task_name)))
try: if os.path.exists(cached_eval_features_file):
logger.info("Loading eval features for cache file %s", cached_eval_features_file)
with open(cached_eval_features_file, "rb") as reader: with open(cached_eval_features_file, "rb") as reader:
eval_features = pickle.load(reader) eval_features = pickle.load(reader)
except: else:
logger.info("No cache file at %s, preparing eval features", cached_eval_features_file)
eval_features = convert_examples_to_features( eval_features = convert_examples_to_features(
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode, eval_examples, label_list, args.max_seq_length, tokenizer, output_mode,
cls_token_at_end=True, cls_token=tokenizer.CLS_TOKEN, cls_token_at_end=True, cls_token=tokenizer.CLS_TOKEN,
sep_token=tokenizer.SEP_TOKEN, cls_token_segment_id=2) sep_token=tokenizer.SEP_TOKEN, cls_token_segment_id=2,
pad_on_left=True, pad_token_segment_id=4)
if args.local_rank == -1 or torch.distributed.get_rank() == 0: if args.local_rank == -1 or torch.distributed.get_rank() == 0:
logger.info(" Saving eval features into cached file %s", cached_eval_features_file) logger.info(" Saving eval features into cached file %s", cached_eval_features_file)
with open(cached_eval_features_file, "wb") as writer: with open(cached_eval_features_file, "wb") as writer:
......
...@@ -389,8 +389,11 @@ class WnliProcessor(DataProcessor): ...@@ -389,8 +389,11 @@ class WnliProcessor(DataProcessor):
def convert_examples_to_features(examples, label_list, max_seq_length, def convert_examples_to_features(examples, label_list, max_seq_length,
tokenizer, output_mode, tokenizer, output_mode,
cls_token_at_end=False, cls_token='[CLS]', cls_token_at_end=False, pad_on_left=False,
sep_token='[SEP]', cls_token_segment_id=0): cls_token='[CLS]', sep_token='[SEP]', pad_token=0,
sequence_a_segment_id=0, sequence_b_segment_id=1,
cls_token_segment_id=1, pad_token_segment_id=0,
mask_padding_with_zero=True):
""" Loads a data file into a list of `InputBatch`s """ Loads a data file into a list of `InputBatch`s
`cls_token_at_end` define the location of the CLS token: `cls_token_at_end` define the location of the CLS token:
- False (BERT pattern): [CLS] + A + [SEP] + B + [SEP] - False (BERT pattern): [CLS] + A + [SEP] + B + [SEP]
...@@ -438,11 +441,11 @@ def convert_examples_to_features(examples, label_list, max_seq_length, ...@@ -438,11 +441,11 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
# used as as the "sentence vector". Note that this only makes sense because # used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned. # the entire model is fine-tuned.
tokens = tokens_a + [sep_token] tokens = tokens_a + [sep_token]
segment_ids = [0] * len(tokens) segment_ids = [sequence_a_segment_id] * len(tokens)
if tokens_b: if tokens_b:
tokens += tokens_b + [sep_token] tokens += tokens_b + [sep_token]
segment_ids += [1] * (len(tokens_b) + 1) segment_ids += [sequence_b_segment_id] * (len(tokens_b) + 1)
if cls_token_at_end: if cls_token_at_end:
tokens = tokens + [cls_token] tokens = tokens + [cls_token]
...@@ -455,13 +458,18 @@ def convert_examples_to_features(examples, label_list, max_seq_length, ...@@ -455,13 +458,18 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
# The mask has 1 for real tokens and 0 for padding tokens. Only real # The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to. # tokens are attended to.
input_mask = [1] * len(input_ids) input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
# Zero-pad up to the sequence length. # Zero-pad up to the sequence length.
padding = [0] * (max_seq_length - len(input_ids)) padding_length = max_seq_length - len(input_ids)
input_ids += padding if pad_on_left:
input_mask += padding input_ids = ([pad_token] * padding_length) + input_ids
segment_ids += padding input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
else:
input_ids = input_ids + ([pad_token] * padding_length)
input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
segment_ids = segment_ids + ([pad_token_segment_id] * padding_length)
assert len(input_ids) == max_seq_length assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length assert len(input_mask) == max_seq_length
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment