Commit 7334bf6c authored by thomwolf's avatar thomwolf
Browse files

pad on left for xlnet

parent c888663f
......@@ -198,14 +198,17 @@ def main():
list(filter(None, args.xlnet_model.split('/'))).pop(),
str(args.max_seq_length),
str(task_name)))
try:
if os.path.exists(cached_train_features_file):
logger.info("Loading train features for cache file %s", cached_train_features_file)
with open(cached_train_features_file, "rb") as reader:
train_features = pickle.load(reader)
except:
else:
logger.info("No cache file at %s, preparing train features", cached_train_features_file)
train_features = convert_examples_to_features(
train_examples, label_list, args.max_seq_length, tokenizer, output_mode,
cls_token_at_end=True, cls_token=tokenizer.CLS_TOKEN,
sep_token=tokenizer.SEP_TOKEN, cls_token_segment_id=2)
sep_token=tokenizer.SEP_TOKEN, cls_token_segment_id=2,
pad_on_left=True, pad_token_segment_id=4)
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
logger.info(" Saving train features into cached file %s", cached_train_features_file)
with open(cached_train_features_file, "wb") as writer:
......@@ -344,14 +347,17 @@ def main():
list(filter(None, args.xlnet_model.split('/'))).pop(),
str(args.max_seq_length),
str(task_name)))
try:
if os.path.exists(cached_eval_features_file):
logger.info("Loading eval features for cache file %s", cached_eval_features_file)
with open(cached_eval_features_file, "rb") as reader:
eval_features = pickle.load(reader)
except:
else:
logger.info("No cache file at %s, preparing eval features", cached_eval_features_file)
eval_features = convert_examples_to_features(
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode,
cls_token_at_end=True, cls_token=tokenizer.CLS_TOKEN,
sep_token=tokenizer.SEP_TOKEN, cls_token_segment_id=2)
sep_token=tokenizer.SEP_TOKEN, cls_token_segment_id=2,
pad_on_left=True, pad_token_segment_id=4)
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
logger.info(" Saving eval features into cached file %s", cached_eval_features_file)
with open(cached_eval_features_file, "wb") as writer:
......
......@@ -389,8 +389,11 @@ class WnliProcessor(DataProcessor):
def convert_examples_to_features(examples, label_list, max_seq_length,
tokenizer, output_mode,
cls_token_at_end=False, cls_token='[CLS]',
sep_token='[SEP]', cls_token_segment_id=0):
cls_token_at_end=False, pad_on_left=False,
cls_token='[CLS]', sep_token='[SEP]', pad_token=0,
sequence_a_segment_id=0, sequence_b_segment_id=1,
cls_token_segment_id=1, pad_token_segment_id=0,
mask_padding_with_zero=True):
""" Loads a data file into a list of `InputBatch`s
`cls_token_at_end` define the location of the CLS token:
- False (BERT pattern): [CLS] + A + [SEP] + B + [SEP]
......@@ -438,11 +441,11 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
# used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens = tokens_a + [sep_token]
segment_ids = [0] * len(tokens)
segment_ids = [sequence_a_segment_id] * len(tokens)
if tokens_b:
tokens += tokens_b + [sep_token]
segment_ids += [1] * (len(tokens_b) + 1)
segment_ids += [sequence_b_segment_id] * (len(tokens_b) + 1)
if cls_token_at_end:
tokens = tokens + [cls_token]
......@@ -455,13 +458,18 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask = [1] * len(input_ids)
input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
# Zero-pad up to the sequence length.
padding = [0] * (max_seq_length - len(input_ids))
input_ids += padding
input_mask += padding
segment_ids += padding
padding_length = max_seq_length - len(input_ids)
if pad_on_left:
input_ids = ([pad_token] * padding_length) + input_ids
input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
else:
input_ids = input_ids + ([pad_token] * padding_length)
input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
segment_ids = segment_ids + ([pad_token_segment_id] * padding_length)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment