Commit b915ba9d authored by Rémi Louf's avatar Rémi Louf
Browse files

pad sequence with 0, mask with -1

parent dc580dd4
...@@ -58,7 +58,7 @@ class TextDataset(Dataset): ...@@ -58,7 +58,7 @@ class TextDataset(Dataset):
[2] https://github.com/abisee/cnn-dailymail/ [2] https://github.com/abisee/cnn-dailymail/
""" """
def __init__(self, tokenizer, prefix='train', data_dir="", block_size=512): def __init__(self, tokenizer, prefix="train", data_dir="", block_size=512):
assert os.path.isdir(data_dir) assert os.path.isdir(data_dir)
# Load features that have already been computed if present # Load features that have already been computed if present
...@@ -165,7 +165,12 @@ def _fit_to_block_size(sequence, block_size): ...@@ -165,7 +165,12 @@ def _fit_to_block_size(sequence, block_size):
if len(sequence) > block_size: if len(sequence) > block_size:
return sequence[:block_size] return sequence[:block_size]
else: else:
return sequence.extend([-1] * (block_size - len(sequence))) return sequence.extend([0] * (block_size - len(sequence)))
def mask_padding_tokens(sequence):
""" Replace the padding token with -1 values """
return [s if s != 0 else -1 for s in sequence]
def load_and_cache_examples(args, tokenizer): def load_and_cache_examples(args, tokenizer):
...@@ -219,11 +224,8 @@ def train(args, train_dataset, model, tokenizer): ...@@ -219,11 +224,8 @@ def train(args, train_dataset, model, tokenizer):
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info( logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
)
logger.info(
" Total train batch size (w. parallel, distributed & accumulation) = %d",
args.train_batch_size args.train_batch_size
* args.gradient_accumulation_steps * args.gradient_accumulation_steps
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1), * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
...@@ -242,7 +244,7 @@ def train(args, train_dataset, model, tokenizer): ...@@ -242,7 +244,7 @@ def train(args, train_dataset, model, tokenizer):
source = ([s for s, _ in batch]).to(args.device) source = ([s for s, _ in batch]).to(args.device)
target = ([t for _, t in batch]).to(args.device) target = ([t for _, t in batch]).to(args.device)
model.train() model.train()
outputs = model(source, target) outputs = model(source, target, decoder_lm_labels=mask_padding_tokens(target))
loss = outputs[0] loss = outputs[0]
loss.backward() loss.backward()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment