Commit dfce4096 authored by Rémi Louf's avatar Rémi Louf
Browse files

resolve PR comments

parent 4c3ac4a7
...@@ -16,10 +16,9 @@ ...@@ -16,10 +16,9 @@
""" Finetuning seq2seq models for sequence generation.""" """ Finetuning seq2seq models for sequence generation."""
import argparse import argparse
from collections import deque import functools
import logging import logging
import os import os
import pickle
import random import random
import sys import sys
...@@ -29,7 +28,22 @@ import torch ...@@ -29,7 +28,22 @@ import torch
from torch.optim import Adam from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from transformers import AutoTokenizer, PreTrainedSeq2seq, Model2Model from transformers import (
AutoTokenizer,
BertForMaskedLM,
BertConfig,
PreTrainedSeq2seq,
Model2Model,
)
from utils_summarization import (
CNNDailyMailDataset,
encode_for_summarization,
fit_to_block_size,
build_lm_labels,
build_mask,
compute_token_type_ids,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logging.basicConfig(stream=sys.stdout, level=logging.INFO) logging.basicConfig(stream=sys.stdout, level=logging.INFO)
...@@ -46,194 +60,41 @@ def set_seed(args): ...@@ -46,194 +60,41 @@ def set_seed(args):
# ------------ # ------------
class TextDataset(Dataset):
""" Abstracts the dataset used to train seq2seq models.
CNN/Daily News:
The CNN/Daily News raw datasets are downloaded from [1]. The stories are
stored in different files; the summary appears at the end of the story as
sentences that are prefixed by the special `@highlight` line. To process
the data, untar both datasets in the same folder, and pass the path to this
folder as the "data_dir argument. The formatting code was inspired by [2].
[1] https://cs.nyu.edu/~kcho/
[2] https://github.com/abisee/cnn-dailymail/
"""
def __init__(self, tokenizer, prefix="train", data_dir="", block_size=512):
assert os.path.isdir(data_dir)
# Load the features that have already been computed, if any
cached_features_file = os.path.join(
data_dir, "cached_lm_{}_{}".format(block_size, prefix)
)
if os.path.exists(cached_features_file):
logger.info("Loading features from cached file %s", cached_features_file)
with open(cached_features_file, "rb") as source:
self.examples = pickle.load(source)
return
logger.info("Creating features from dataset at %s", data_dir)
datasets = ["cnn", "dailymail"]
self.examples = {"source": [], "target": []}
for dataset in datasets:
path_to_stories = os.path.join(data_dir, dataset, "stories")
story_filenames_list = os.listdir(path_to_stories)
for story_filename in story_filenames_list:
path_to_story = os.path.join(path_to_stories, story_filename)
if not os.path.isfile(path_to_story):
continue
with open(path_to_story, encoding="utf-8") as source:
raw_story = source.read()
story_lines, summary_lines = process_story(raw_story)
if len(summary_lines) == 0 or len(story_lines) == 0:
continue
story_token_ids, summary_token_ids = _encode_for_summarization(
story_lines, summary_lines, tokenizer
)
story_seq = _fit_to_block_size(story_token_ids, block_size)
self.examples["source"].append(story_seq)
summary_seq = _fit_to_block_size(summary_token_ids, block_size)
self.examples["summary"].append(summary_seq)
logger.info("Saving features into cache file %s", cached_features_file)
with open(cached_features_file, "wb") as sink:
pickle.dump(self.examples, sink, protocol=pickle.HIGHEST_PROTOCOL)
def __len__(self):
return len(self.examples)
def __getitem__(self, items):
return (
torch.tensor(self.examples["source"][items]),
torch.tensor(self.examples["target"][items]),
)
def process_story(raw_story):
""" Extract the story and summary from a story file.
Attributes:
raw_story (str): content of the story file as an utf-8 encoded string.
Raises:
IndexError: If the stoy is empty or contains no highlights.
"""
nonempty_lines = list(
filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")])
)
# for some unknown reason some lines miss a period, add it
nonempty_lines = [_add_missing_period(line) for line in nonempty_lines]
# gather article lines
story_lines = []
lines = deque(nonempty_lines)
while True:
try:
element = lines.popleft()
if element.startswith("@highlight"):
break
story_lines.append(element)
except IndexError:
# if "@highlight" is absent from the file we pop
# all elements until there is None.
return story_lines, []
# gather summary lines
summary_lines = list(filter(lambda t: not t.startswith("@highlight"), lines))
return story_lines, summary_lines
def _encode_for_summarization(story_lines, summary_lines, tokenizer):
""" Encode the story and summary lines, and join them
as specified in [1] by using `[SEP] [CLS]` tokens to separate
sentences.
"""
story_lines_token_ids = [
tokenizer.add_special_tokens_single_sequence(tokenizer.encode(line))
for line in story_lines
]
summary_lines_token_ids = [
tokenizer.add_special_tokens_single_sequence(tokenizer.encode(line))
for line in summary_lines
]
story_token_ids = [
token for sentence in story_lines_token_ids for token in sentence
]
summary_token_ids = [
token for sentence in summary_lines_token_ids for token in sentence
]
return story_token_ids, summary_token_ids
def _add_missing_period(line):
END_TOKENS = [".", "!", "?", "...", "'", "`", '"', u"\u2019", u"\u2019", ")"]
if line.startswith("@highlight"):
return line
if line[-1] in END_TOKENS:
return line
return line + "."
def _fit_to_block_size(sequence, block_size):
""" Adapt the source and target sequences' lengths to the block size.
If the sequence is shorter than the block size we pad it with -1 ids
which correspond to padding tokens.
"""
if len(sequence) > block_size:
return sequence[:block_size]
else:
sequence.extend([0] * (block_size - len(sequence)))
return sequence
def mask_padding_tokens(sequence):
""" Padding token, encoded as 0, are represented by the value -1 in the
masks """
padded = sequence.clone()
padded[padded == 0] = -1
return padded
def load_and_cache_examples(args, tokenizer): def load_and_cache_examples(args, tokenizer):
dataset = TextDataset(tokenizer, data_dir=args.data_dir) dataset = CNNDailyMailDataset(tokenizer, data_dir=args.data_dir)
return dataset return dataset
def compute_token_type_ids(batch, separator_token_id): def collate(data, tokenizer, block_size):
""" Segment embeddings as described in [1] """ List of tuple as an input. """
# remove the files with empty an story/summary, encode and fit to block
The values {0,1} were found in the repository [2]. data = filter(lambda x: not (len(x[0]) == 0 or len(x[1]) == 0), data)
data = [
Attributes: encode_for_summarization(story, summary, tokenizer) for story, summary in data
batch: torch.Tensor, size [batch_size, block_size] ]
Batch of input. data = [
separator_token_id: int (
The value of the token that separates the segments. fit_to_block_size(story, block_size, tokenizer.pad_token_id),
fit_to_block_size(summary, block_size, tokenizer.pad_token_id),
)
for story, summary in data
]
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders." stories = torch.tensor([story for story, summary in data])
arXiv preprint arXiv:1908.08345 (2019). summaries = torch.tensor([summary for story, summary in data])
[2] https://github.com/nlpyang/PreSumm (/src/prepro/data_builder.py, commit fac1217) encoder_token_type_ids = compute_token_type_ids(stories, tokenizer.cls_token_id)
""" encoder_mask = build_mask(stories, tokenizer.pad_token_id)
batch_embeddings = [] decoder_mask = build_mask(summaries, tokenizer.pad_token_id)
sentence_num = 0 lm_labels = build_lm_labels(summaries, tokenizer.pad_token_id)
for sequence in batch:
embeddings = [] return (
for s in sequence: stories,
if s == separator_token_id: summaries,
sentence_num += 1 encoder_token_type_ids,
embeddings.append(sentence_num % 2) encoder_mask,
batch_embeddings.append(embeddings) decoder_mask,
return torch.tensor(batch_embeddings) lm_labels,
)
# ---------- # ----------
...@@ -252,7 +113,7 @@ class BertSumOptimizer(object): ...@@ -252,7 +113,7 @@ class BertSumOptimizer(object):
arXiv preprint arXiv:1908.08345 (2019). arXiv preprint arXiv:1908.08345 (2019).
""" """
def __init__(self, model, lr, warmup_steps, beta_1=0.99, beta_2=0.999, eps=1e-9): def __init__(self, model, lr, warmup_steps, beta_1=0.99, beta_2=0.999, eps=1e-8):
self.encoder = model.encoder self.encoder = model.encoder
self.decoder = model.decoder self.decoder = model.decoder
self.lr = lr self.lr = lr
...@@ -306,8 +167,12 @@ def train(args, model, tokenizer): ...@@ -306,8 +167,12 @@ def train(args, model, tokenizer):
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
train_dataset = load_and_cache_examples(args, tokenizer) train_dataset = load_and_cache_examples(args, tokenizer)
train_sampler = RandomSampler(train_dataset) train_sampler = RandomSampler(train_dataset)
model_collate_fn = functools.partial(collate, tokenizer=tokenizer, block_size=512)
train_dataloader = DataLoader( train_dataloader = DataLoader(
train_dataset, sampler=train_sampler, batch_size=args.train_batch_size train_dataset,
sampler=train_sampler,
batch_size=args.train_batch_size,
collate_fn=model_collate_fn,
) )
# Training schedule # Training schedule
...@@ -351,26 +216,23 @@ def train(args, model, tokenizer): ...@@ -351,26 +216,23 @@ def train(args, model, tokenizer):
for _ in train_iterator: for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=True) epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=True)
for step, batch in enumerate(epoch_iterator): for step, batch in enumerate(epoch_iterator):
source, target = batch source, target, encoder_token_type_ids, encoder_mask, decoder_mask, lm_labels = batch
token_type_ids = compute_token_type_ids(source, tokenizer.cls_token_id)
labels_src = mask_padding_tokens(source)
labels_tgt = mask_padding_tokens(target)
source = source.to(args.device) source = source.to(args.device)
target = target.to(args.device) target = target.to(args.device)
token_type_ids = token_type_ids.to(args.device) encoder_token_type_ids = encoder_token_type_ids.to(args.device)
labels_src = labels_src.to(args.device) encoder_mask = encoder_mask.to(args.device)
labels_tgt = labels_tgt.to(args.device) decoder_mask = decoder_mask.to(args.device)
lm_labels = lm_labels.to(args.device)
model.train() model.train()
outputs = model( outputs = model(
source, source,
target, target,
token_type_ids=token_type_ids, encoder_token_type_ids=encoder_token_type_ids,
decoder_encoder_attention_mask=labels_src, encoder_attention_mask=encoder_mask,
decoder_attention_mask=labels_tgt, decoder_attention_mask=decoder_mask,
decoder_lm_labels=labels_tgt, decoder_lm_labels=lm_labels,
decoder_initialize_randomly=True,
) )
loss = outputs[0] loss = outputs[0]
...@@ -421,21 +283,23 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -421,21 +283,23 @@ def evaluate(args, model, tokenizer, prefix=""):
model.eval() model.eval()
for batch in tqdm(eval_dataloader, desc="Evaluating"): for batch in tqdm(eval_dataloader, desc="Evaluating"):
source, target = batch source, target, encoder_token_type_ids, encoder_mask, decoder_mask, lm_labels = batch
labels_src = mask_padding_tokens(source)
labels_tgt = mask_padding_tokens(target) source = source.to(args.device)
source.to(args.device) target = target.to(args.device)
target.to(args.device) encoder_token_type_ids = encoder_token_type_ids.to(args.device)
labels_src.to(args.device) encoder_mask = encoder_mask.to(args.device)
labels_tgt.to(args.device) decoder_mask = decoder_mask.to(args.device)
lm_labels = lm_labels.to(args.device)
with torch.no_grad(): with torch.no_grad():
outputs = model( outputs = model(
source, source,
target, target,
decoder_encoder_attention_mask=labels_src, encoder_token_type_ids=encoder_token_type_ids,
decoder_attention_mask=labels_tgt, encoder_attention_mask=encoder_mask,
decoder_lm_labels=labels_tgt, decoder_attention_mask=decoder_mask,
decoder_lm_labels=lm_labels,
) )
lm_loss = outputs[0] lm_loss = outputs[0]
eval_loss += lm_loss.mean().item() eval_loss += lm_loss.mean().item()
...@@ -525,7 +389,7 @@ def main(): ...@@ -525,7 +389,7 @@ def main():
) )
parser.add_argument( parser.add_argument(
"--num_train_epochs", "--num_train_epochs",
default=1, default=10,
type=int, type=int,
help="Total number of training epochs to perform.", help="Total number of training epochs to perform.",
) )
...@@ -558,9 +422,13 @@ def main(): ...@@ -558,9 +422,13 @@ def main():
args.device = torch.device("cuda") args.device = torch.device("cuda")
args.n_gpu = torch.cuda.device_count() args.n_gpu = torch.cuda.device_count()
# Load pretrained model and tokenizer # Load pretrained model and tokenizer. The decoder's weights are randomly initialized.
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
model = Model2Model.from_pretrained(args.model_name_or_path) config = BertConfig.from_pretrained(args.model_name_or_path)
decoder_model = BertForMaskedLM(config)
model = Model2Model.from_pretrained(
args.model_name_or_path, decoder_model=decoder_model
)
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
......
from collections import deque
import os
import torch
from torch.utils.data import Dataset
# ------------
# Data loading
# ------------
class CNNDailyMailDataset(Dataset):
""" Abstracts the dataset used to train seq2seq models.
CNN/Daily News:
The CNN/Daily News raw datasets are downloaded from [1]. The stories are
stored in different files; the summary appears at the end of the story as
sentences that are prefixed by the special `@highlight` line. To process
the data, untar both datasets in the same folder, and pass the path to this
folder as the "data_dir argument. The formatting code was inspired by [2].
[1] https://cs.nyu.edu/~kcho/
[2] https://github.com/abisee/cnn-dailymail/
"""
def __init__(self, tokenizer, prefix="train", data_dir=""):
assert os.path.isdir(data_dir)
self.tokenizer = tokenizer
# We initialize the class by listing all the files that contain
# stories and summaries. Files are not read in memory given
# the size of the corpus.
self.stories_path = []
datasets = ("cnn", "dailymail")
for dataset in datasets:
path_to_stories = os.path.join(data_dir, dataset, "stories")
story_filenames_list = os.listdir(path_to_stories)
for story_filename in story_filenames_list:
path_to_story = os.path.join(path_to_stories, story_filename)
if not os.path.isfile(path_to_story):
continue
self.stories_path.append(path_to_story)
def __len__(self):
return len(self.stories_path)
def __getitem__(self, idx):
story_path = self.stories_path[idx]
with open(story_path, encoding="utf-8") as source:
raw_story = source.read()
story_lines, summary_lines = process_story(raw_story)
return story_lines, summary_lines
def process_story(raw_story):
""" Extract the story and summary from a story file.
Attributes:
raw_story (str): content of the story file as an utf-8 encoded string.
Raises:
IndexError: If the stoy is empty or contains no highlights.
"""
nonempty_lines = list(
filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")])
)
# for some unknown reason some lines miss a period, add it
nonempty_lines = [_add_missing_period(line) for line in nonempty_lines]
# gather article lines
story_lines = []
lines = deque(nonempty_lines)
while True:
try:
element = lines.popleft()
if element.startswith("@highlight"):
break
story_lines.append(element)
except IndexError:
# if "@highlight" is absent from the file we pop
# all elements until there is None.
return story_lines, []
# gather summary lines
summary_lines = list(filter(lambda t: not t.startswith("@highlight"), lines))
return story_lines, summary_lines
def _add_missing_period(line):
END_TOKENS = [".", "!", "?", "...", "'", "`", '"', u"\u2019", u"\u2019", ")"]
if line.startswith("@highlight"):
return line
if line[-1] in END_TOKENS:
return line
return line + "."
# --------------------------
# Encoding and preprocessing
# --------------------------
def fit_to_block_size(sequence, block_size, pad_token):
""" Adapt the source and target sequences' lengths to the block size.
If the sequence is shorter than the block size we pad it with -1 ids
which correspond to padding tokens.
"""
if len(sequence) > block_size:
return sequence[:block_size]
else:
sequence.extend([pad_token] * (block_size - len(sequence)))
return sequence
def build_lm_labels(sequence, pad_token):
""" Padding token, encoded as 0, are represented by the value -1 so they
are not taken into account in the loss computation. """
padded = sequence.clone()
padded[padded == pad_token] = -1
return padded
def build_mask(sequence, pad_token):
""" Builds the mask. The attention mechanism will only attend to positions
with value 1. """
mask = sequence.clone()
mask[mask != pad_token] = 1
mask[mask == pad_token] = 0
return mask
def encode_for_summarization(story_lines, summary_lines, tokenizer):
""" Encode the story and summary lines, and join them
as specified in [1] by using `[SEP] [CLS]` tokens to separate
sentences.
"""
story_lines_token_ids = [
tokenizer.add_special_tokens_single_sequence(tokenizer.encode(line))
for line in story_lines
]
summary_lines_token_ids = [
tokenizer.add_special_tokens_single_sequence(tokenizer.encode(line))
for line in summary_lines
]
story_token_ids = [
token for sentence in story_lines_token_ids for token in sentence
]
summary_token_ids = [
token for sentence in summary_lines_token_ids for token in sentence
]
return story_token_ids, summary_token_ids
def compute_token_type_ids(batch, separator_token_id):
""" Segment embeddings as described in [1]
The values {0,1} were found in the repository [2].
Attributes:
batch: torch.Tensor, size [batch_size, block_size]
Batch of input.
separator_token_id: int
The value of the token that separates the segments.
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
arXiv preprint arXiv:1908.08345 (2019).
[2] https://github.com/nlpyang/PreSumm (/src/prepro/data_builder.py, commit fac1217)
"""
batch_embeddings = []
for sequence in batch:
sentence_num = 0
embeddings = []
for s in sequence:
if s == separator_token_id:
sentence_num += 1
embeddings.append(sentence_num % 2)
batch_embeddings.append(embeddings)
return torch.tensor(batch_embeddings)
...@@ -14,47 +14,64 @@ ...@@ -14,47 +14,64 @@
# limitations under the License. # limitations under the License.
import unittest import unittest
from run_summarization_finetuning import _fit_to_block_size, process_story import numpy as np
import torch
from utils_summarization import (
compute_token_type_ids,
fit_to_block_size,
build_mask,
build_lm_labels,
process_story,
)
class DataLoaderTest(unittest.TestCase):
class SummarizationDataProcessingTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.block_size = 10 self.block_size = 10
def test_truncate_sequence_too_small(self): def test_fit_to_block_sequence_too_small(self):
""" Pad the sequence with 0 if the sequence is smaller than the block size.""" """ Pad the sequence with 0 if the sequence is smaller than the block size."""
sequence = [1, 2, 3, 4] sequence = [1, 2, 3, 4]
expected_output = [1, 2, 3, 4, 0, 0, 0, 0, 0, 0] expected_output = [1, 2, 3, 4, 0, 0, 0, 0, 0, 0]
self.assertEqual(_fit_to_block_size(sequence, self.block_size), expected_output) self.assertEqual(
fit_to_block_size(sequence, self.block_size, 0), expected_output
)
def test_truncate_sequence_fit_exactly(self): def test_fit_to_block_sequence_fit_exactly(self):
""" Do nothing if the sequence is the right size. """
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
self.assertEqual(_fit_to_block_size(sequence, self.block_size), expected_output) self.assertEqual(
fit_to_block_size(sequence, self.block_size, 0), expected_output
)
def test_truncate_sequence_too_big(self): def test_fit_to_block_sequence_too_big(self):
""" Truncate the sequence if it is too long. """
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
self.assertEqual(_fit_to_block_size(sequence, self.block_size), expected_output) self.assertEqual(
fit_to_block_size(sequence, self.block_size, 0), expected_output
)
def test_process_story_no_highlights(self): def test_process_story_no_highlights(self):
""" Processing a story with no highlights should raise an exception. """ Processing a story with no highlights returns an empty list for the summary.
""" """
raw_story = """It was the year of Our Lord one thousand seven hundred and raw_story = """It was the year of Our Lord one thousand seven hundred and
seventy-five.\n\nSpiritual revelations were conceded to England at that seventy-five.\n\nSpiritual revelations were conceded to England at that
favoured period, as at this.""" favoured period, as at this."""
_, summary = process_story(raw_story) _, summary_lines = process_story(raw_story)
self.assertEqual(summary, []) self.assertEqual(summary_lines, [])
def test_process_empty_story(self): def test_process_empty_story(self):
""" An empty story should also raise and exception. """ An empty story returns an empty collection of lines.
""" """
raw_story = "" raw_story = ""
story, summary = process_story(raw_story) story_lines, summary_lines = process_story(raw_story)
self.assertEqual(story, []) self.assertEqual(story_lines, [])
self.assertEqual(summary, []) self.assertEqual(summary_lines, [])
def test_story_with_missing_period(self): def test_process_story_with_missing_period(self):
raw_story = ( raw_story = (
"It was the year of Our Lord one thousand seven hundred and " "It was the year of Our Lord one thousand seven hundred and "
"seventy-five\n\nSpiritual revelations were conceded to England " "seventy-five\n\nSpiritual revelations were conceded to England "
...@@ -71,6 +88,46 @@ class DataLoaderTest(unittest.TestCase): ...@@ -71,6 +88,46 @@ class DataLoaderTest(unittest.TestCase):
expected_summary_lines = ["It was the best of times."] expected_summary_lines = ["It was the best of times."]
self.assertEqual(expected_summary_lines, summary_lines) self.assertEqual(expected_summary_lines, summary_lines)
def test_build_lm_labels_no_padding(self):
sequence = torch.tensor([1, 2, 3, 4])
expected = sequence
np.testing.assert_array_equal(
build_lm_labels(sequence, 0).numpy(), expected.numpy()
)
def test_build_lm_labels(self):
sequence = torch.tensor([1, 2, 3, 4, 0, 0, 0])
expected = torch.tensor([1, 2, 3, 4, -1, -1, -1])
np.testing.assert_array_equal(
build_lm_labels(sequence, 0).numpy(), expected.numpy()
)
def test_build_mask_no_padding(self):
sequence = torch.tensor([1, 2, 3, 4])
expected = torch.tensor([1, 1, 1, 1])
np.testing.assert_array_equal(
build_mask(sequence, 0).numpy(), expected.numpy()
)
def test_build_mask(self):
sequence = torch.tensor([1, 2, 3, 4, 23, 23, 23])
expected = torch.tensor([1, 1, 1, 1, 0, 0, 0])
np.testing.assert_array_equal(
build_mask(sequence, 23).numpy(), expected.numpy()
)
def test_compute_token_type_ids(self):
separator = 101
batch = torch.tensor(
[[1, 2, 3, 4, 5, 6], [1, 2, 3, 101, 5, 6], [1, 101, 3, 4, 101, 6]]
)
expected = torch.tensor(
[[0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 1, 1], [0, 1, 1, 1, 0, 0]]
)
result = compute_token_type_ids(batch, separator)
np.testing.assert_array_equal(result, expected)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -26,189 +26,220 @@ import torch ...@@ -26,189 +26,220 @@ import torch
from torch import nn from torch import nn
class ModelWithBeamSearch(nn.Module): class TransformerBeamSearch(nn.Module):
def __init__( def __init__(
self, self,
model, model,
tokenizer,
batch_size,
beam_size, beam_size,
start_token_id,
end_token_id,
pad_token_id,
min_length, min_length,
max_length, max_length,
alpha, alpha=0,
block_trigram=True, block_repeating_trigram=True,
): ):
""" """
Attributes: Attributes:
mask_word_id: token id that corresponds to the mask mask_word_id: token id that corresponds to the mask
""" """
super(ModelWithBeamSearch, self).__init__() super(TransformerBeamSearch, self).__init__()
self.model = model self.model = model
self.tokenizer = tokenizer
self.start_token_id = tokenizer.start_token_id
self.end_token_id = tokenizer.end_token_id
self.pad_token_id = tokenizer.pad_token_id
self.beam_size = beam_size self.beam_size = beam_size
self.start_token_id = start_token_id
self.end_token_id = end_token_id
self.pad_token_id = pad_token_id
self.min_length = min_length self.min_length = min_length
self.max_length = max_length self.max_length = max_length
self.block_repeating_trigram = block_repeating_trigram
self.apply_length_penalty = False if alpha == 0 else True
self.alpha = alpha self.alpha = alpha
self.block_trigram = block_trigram
def forward(self, input_ids, **kwargs): # State of the beam
# Separate the encoder- and decoder- specific kwargs. A kwarg is self.hypotheses = [[] for _ in range(batch_size)]
# decoder-specific it the key starts with `decoder_` self.batch_offset = torch.arange(batch_size, dtype=torch.long)
self.beam_offset = torch.arange(
0, batch_size * self.beam_size, step=self.beam_size, dtype=torch.long
)
self.growing_beam = torch.full(
(batch_size * self.beam_size, 1), self.start_token_id, dtype=torch.long
)
self.topk_log_probabilities = torch.tensor(
[0.0] + [float("-inf")] * (self.beam_size - 1), dtype=torch.float
).repeat(batch_size)
self.results = {
"prediction": [[] for _ in batch_size],
"scores": [[] for _ in batch_size],
}
self._step = 0
self.is_done = False
def step(self, log_probabilities):
""" Grows the beam by one step. """
self._step += 1
# The batch size changes as some beams finish so we define _B
vocab_size = log_probabilities.size(-1)
_B = log_probabilities.size(0) // self.beam_size
# Multiply each beam probability with the probability of the
# next token (conditioned on the words in the beam).
log_probabilities += self.topk_log_probabilities.view(-1, 1)
self.enforce_min_length(log_probabilities)
if self.block_repeating_trigram:
self.remove_repeating_trigrams(log_probabilities, _B)
# Find the `beam_size` (previous_beam + token) combinations with
# the highest score
topk_log_probabilities, topk_ids = log_probabilities.topk(
log_probabilities.view(_B, self.beam_size * vocab_size),
self.beam_size,
dim=1,
)
# Apply the length penalty. The +1 accounts for the [EOS] token
# that will be added if the beam ends.
topk_scores = topk_log_probabilities / self.length_penalty()
# Retrieve the corresponding respective beam and token id
# topk_token_ids[i] will be added to topk_beam_ids[i]
topk_beam_ids = topk_ids.div(vocab_size)
topk_token_ids = topk_ids.fmod(vocab_size)
# Retrieve the row index of the surviving beams in the original
# view of the log_probabilities tensor
surviving_beams_rows = (topk_beam_ids + self.beam_offset[:_B].view(-1, 1)).view(
-1
)
# Append the last predictions
self.growing_beam = torch.cat(
[
self.growing_beam.index_select(0, surviving_beams_rows),
topk_token_ids.view(-1, 1),
],
1,
)
# Check if any of the beam searches has ended during this
# growth step. Also if top beam (most probable) has ended
# for one element of the batch.
is_finished = topk_token_ids.eq(self.end_token_id)
self.enforce_max_length()
is_top_beam_finished = is_finished[:, 0].eq(1)
# Save the finished searches
if is_finished.any():
predictions = self.growing_beam.view(
-1, self.beam_size, self.growing_beam.size(1)
)
for i in range(is_finished.size(0)):
if is_top_beam_finished[i]:
is_finished[i].fill_(1)
finished_hyp = is_finished[i].nonzero().view(-1)
# Store finished hypotheses for this batch.
b = self.batch_offset[i]
for j in finished_hyp:
self.hypotheses[b].append((topk_scores[i, j], predictions[i, j, :]))
# If the batch reached the end, save the best hypotheses
# in terms of length-penalized score.
if is_top_beam_finished[i]:
best_hyp = sorted(
self.hypotheses[b], key=lambda x: x[0], reverse=True
)
best_score, best_prediction = best_hyp[0]
self.results["scores"][b].append(best_score)
self.results["predictions"][b].append(best_prediction)
non_finished = is_top_beam_finished.eq(0).nonzero().view(-1)
if len(non_finished) == 0:
self.is_done = True
# Remove finished batches for the next step.
topk_log_probabilities = topk_log_probabilities.index_select(
0, non_finished
)
self.batch_offset = self.batch_offset.index_select(0, non_finished)
self.growing_beam = predictions.index_select(0, non_finished).view(
-1, self.growing_beam.size(-1)
)
surviving_beams_rows = surviving_beams_rows.index_select(0, non_finished)
return surviving_beams_rows
def forward(self, encoder_input_ids, **kwargs):
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as whole.
# We let the specific kwargs override the common ones in case of conflict.
kwargs_encoder = { kwargs_encoder = {
argument: value argument[len("encoder_"):]: value
for argument, value in kwargs.items() for argument, value in kwargs.items()
if not argument.startswith("decoder_") if argument.startswith("encoder_")
} }
kwargs_decoder = { kwargs_decoder = {
argument[len("decoder_"):]: value argument[len("decoder_"):]: value
for argument, value in kwargs.items() for argument, value in kwargs.items()
if argument.startswith("decoder_") if argument.startswith("decoder_")
} }
kwargs_common = {
argument: value
for argument, value in kwargs.items()
if not (argument.startswith("encoder_") or argument.startswith("decoder_"))
}
kwargs_decoder = dict(kwargs_common, **kwargs_decoder)
kwargs_encoder = dict(kwargs_common, **kwargs_encoder)
batch_size, _ = input_ids.size(0) # forward pass on the encoder
encoder_outputs = self.model.encoder.forward(encoder_input_ids, kwargs_encoder)
# Variables that keep track of the status of the search
hypotheses = [[] for _ in range(batch_size)]
batch_offset = torch.arange(batch_size, dtype=torch.long)
beam_offset = torch.arange(
0,
batch_size * self.beam_size,
step=self.beam_size,
dtype=torch.long,
)
growing_beam = torch.full(
(batch_size * self.beam_size, 1),
self.start_token_id,
dtype=torch.long,
)
topk_log_probabilities = torch.tensor(
[0.0] + [float("-inf")] * (self.beam_size - 1),
dtype=torch.float,
).repeat(batch_size)
# Forward pass on the encoder
encoder_outputs = self.encoder(input_ids, kwargs_encoder)
kwargs_decoder["encoder_hidden_states"] = tile( kwargs_decoder["encoder_hidden_states"] = tile(
encoder_outputs, self.beam_size, dim=0 encoder_outputs, self.beam_size, dim=0
) )
results = {} # grow the beam by generating sequences in an autoregressive way
results["predictions"] = [[] for _ in batch_size] self.growing_beam = torch.full(
results["scores"] = [[] for _ in batch_size] (self.batch_size * self.beam_size, 1), self.start_token_id, dtype=torch.long
)
for step in range(self.max_length): for step in range(self.max_length):
decoder_input = growing_beam[:, -1] decoder_input = self.growing_beam[:, -1]
outputs = self.decoder(decoder_input, kwargs_decoder) outputs = self.model.decoder(decoder_input, kwargs_decoder)
log_probabilities = torch.nn.functional.log_softmax(outputs[1]) log_probabilities = torch.nn.functional.log_softmax(outputs[1])
vocab_size = log_probabilities.size(-1) surviving_beams_rows = self.step(log_probabilities)
if self.is_done:
# The batch size changes as some beams finish so we define: break
_B = log_probabilities.size(0) // self.beam_size
# Multiply each beam probability with the probability of the
# next token (conditioned on the words in the beam).
log_probabilities += topk_log_probabilities.view(-1, 1)
# if the beam has not attained the minimum required length we
# make the end token arbitrarily unlikely.
if step < self.min_length:
log_probabilities[self.end_token_id] = -1e20
# Remove repeating tri-grams
if(self.args.block_trigram):
if(step + 1 > 3):
for i in range(_B * self.beam_size):
tokens = [t for t in growing_beam[i]]
trigrams = [(tokens[i-1], tokens[i], tokens[i+1]) for i in range(1, len(words) - 1)]
last_trigram = tuple(trigrams[-1])
if last_trigram in trigrams[:-1]:
log_probabilities[i] = -1e20
# Find the `beam_size` (previous_beam + token) combinations with
# the highest score
topk_log_probabilities, topk_ids = log_probabilities.topk(
log_probabilities.view(_B, self.beam_size * vocab_size),
self.beam_size,
dim=1
)
# Apply the length penalty. The +1 accounts for the [EOS] token
# that will be added if the beam ends.
length_penalty = ((5.0 + (step + 1)) / 6.0) ** self.alpha
topk_scores = topk_log_probabilities / length_penalty
# Retrieve the corresponding respective beam and token id
# topk_token_ids[i] will be added to topk_beam_ids[i]
topk_beam_ids = topk_ids.div(vocab_size)
topk_token_ids = topk_ids.fmod(vocab_size)
# Retrieve the row index of the surviving beams in the original
# view of the log_probabilities tensor
surviving_beams_rows = (
topk_beam_ids + beam_offset[:_B].view(-1, 1)
).view(-1)
# Append the last predictions
growing_beam = torch.cat(
[
growing_beam.index_select(0, surviving_beams_rows),
topk_token_ids.view(-1, 1),
],
1,
)
# Check if any of the beam searches has ended during this
# growth step. Also if top beam (most probable) has ended
# for one element of the batch.
is_finished = topk_token_ids.eq(self.end_token_id)
if step + 1 == self.max_length:
is_finished.fill_(1)
is_top_beam_finished = is_finished[:, 0].eq(1)
# Save the finished searches
if is_finished.any():
predictions = growing_beam.view(-1, self.beam_size, growing_beam.size(1))
for i in range(is_finished.size(0)):
if is_top_beam_finished[i]:
is_finished[i].fill_(1)
finished_hyp = is_finished[i].nonzero().view(-1)
# Store finished hypotheses for this batch.
b = batch_offset[i]
for j in finished_hyp:
hypotheses[b].append((topk_scores[i, j], predictions[i, j, :]))
# If the batch reached the end, save the best hypotheses
# in terms of length-penalized score.
if is_top_beam_finished[i]:
best_hyp = sorted(
hypotheses[b], key=lambda x: x[0], reverse=True
)
best_score, best_prediction = best_hyp[0]
results["scores"][b].append(best_score)
results["predictions"][b].append(best_prediction)
non_finished = is_top_beam_finished.eq(0).nonzero().view(-1)
if len(non_finished) == 0:
break
# Remove finished batches for the next step.
topk_log_probabilities = topk_log_probabilities.index_select(0, non_finished)
batch_offset = batch_offset.index_select(0, non_finished)
growing_beam = predictions.index_select(0, non_finished).view(
-1, growing_beam.size(-1)
)
# Re-order the state for the next pass
surviving_beams_rows = surviving_beams_rows.index_select(0, non_finished)
kwargs_decoder["encoder_hidden_states"] = kwargs_decoder[ kwargs_decoder["encoder_hidden_states"] = kwargs_decoder[
"encoder_hidden_states" "encoder_hidden_states"
].index_select(0, surviving_beams_rows) ].index_select(0, surviving_beams_rows)
return results return self.results
def remove_repeating_trigrams(self, log_probabilities, _B):
if(self._step + 1 > 3):
for i in range(_B * self.beam_size):
tokens = [t for t in self.growing_beam[i]]
trigrams = [(tokens[i-1], tokens[i], tokens[i+1]) for i in range(1, len(words) - 1)]
last_trigram = tuple(trigrams[-1])
if last_trigram in trigrams[:-1]:
log_probabilities[i] = -1e20
def enforce_min_length(self):
if self._step < self.min_length:
self.log_probabilities[self.end_token_id] = -1e20
def enforce_max_length(self):
if self._step + 1 == self.max_length:
self.is_finished.fill_(1)
def length_penalty(self):
return ((5.0 + (self._step + 1)) / 6.0) ** self.alpha
def tile(x, count, dim=0): def tile(x, count, dim=0):
......
...@@ -632,6 +632,8 @@ class BertModel(BertPreTrainedModel): ...@@ -632,6 +632,8 @@ class BertModel(BertPreTrainedModel):
""" """
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones_like(input_ids)
if token_type_ids is None: if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids) token_type_ids = torch.zeros_like(input_ids)
...@@ -660,12 +662,15 @@ class BertModel(BertPreTrainedModel): ...@@ -660,12 +662,15 @@ class BertModel(BertPreTrainedModel):
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
# If a 2D encoder attention mask is provided for the cross-attention # If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
if encoder_attention_mask is not None: if encoder_attention_mask.dim() == 3:
encoder_attention_mask = encoder_attention_mask[:, None, None, :] encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
encoder_attention_mask = encoder_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility if encoder_attention_mask.dim() == 2:
encoder_attention_mask = (1.0 - encoder_attention_mask) * -10000.0 encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
...@@ -687,7 +692,7 @@ class BertModel(BertPreTrainedModel): ...@@ -687,7 +692,7 @@ class BertModel(BertPreTrainedModel):
attention_mask=extended_attention_mask, attention_mask=extended_attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask) encoder_attention_mask=encoder_extended_attention_mask)
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) pooled_output = self.pooler(sequence_output)
...@@ -788,8 +793,10 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -788,8 +793,10 @@ class BertForMaskedLM(BertPreTrainedModel):
in ``[0, ..., config.vocab_size]`` in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: **masked_lm_loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Masked language modeling loss. Masked language modeling loss.
**next_token_loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Next token prediction loss.
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
...@@ -854,13 +861,13 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -854,13 +861,13 @@ class BertForMaskedLM(BertPreTrainedModel):
if lm_labels is not None: if lm_labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one # we are doing next-token prediction; shift prediction scores and input ids by one
prediction_scores = prediction_scores[:, :-1, :] prediction_scores = prediction_scores[:, :-1, :].contiguous()
lm_labels = lm_labels[:, 1:] lm_labels = lm_labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
seq2seq_loss = loss_fct(prediction_scores.reshape(-1, self.config.vocab_size), lm_labels.reshape(-1)) next_token_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_labels.view(-1))
outputs = (seq2seq_loss,) + outputs outputs = (next_token_loss,) + outputs
return outputs # (mlm_or_seq2seq_loss), prediction_scores, (hidden_states), (attentions) return outputs # (masked_lm_loss), (next_token_loss), prediction_scores, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """, @add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """,
......
...@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) ...@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class PreTrainedSeq2seq(nn.Module): class PreTrainedSeq2seq(nn.Module):
r""" r"""
:class:`~transformers.Seq2seq` is a generic model class that will be :class:`~transformers.PreTrainedSeq2seq` is a generic model class that will be
instantiated as a Seq2seq model with one of the base model classes of instantiated as a Seq2seq model with one of the base model classes of
the library as encoder and (optionally) as decoder when created with the library as encoder and (optionally) as decoder when created with
the `AutoModel.from_pretrained(pretrained_model_name_or_path)` class the `AutoModel.from_pretrained(pretrained_model_name_or_path)` class
...@@ -49,8 +49,7 @@ class PreTrainedSeq2seq(nn.Module): ...@@ -49,8 +49,7 @@ class PreTrainedSeq2seq(nn.Module):
*model_args, *model_args,
**kwargs **kwargs
): ):
r""" Instantiates an encoder and a decoder from one or two base classes r""" Instantiates an encoder and a decoder from one or two base classes of the library from pre-trained model checkpoints.
of the library from pre-trained model checkpoints.
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated) The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
...@@ -111,35 +110,44 @@ class PreTrainedSeq2seq(nn.Module): ...@@ -111,35 +110,44 @@ class PreTrainedSeq2seq(nn.Module):
model = PreTrainedSeq2seq.from_pretained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert model = PreTrainedSeq2seq.from_pretained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
""" """
# Separate the encoder- and decoder- specific kwargs. A kwarg is # keyword arguments come in 3 flavors: encoder-specific (prefixed by
# decoder-specific it the key starts with `decoder_` # `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as a whole.
# We let the specific kwargs override the common ones in case of conflict.
kwargs_encoder = { kwargs_encoder = {
argument: value argument[len("encoder_"):]: value
for argument, value in kwargs.items() for argument, value in kwargs.items()
if not argument.startswith("decoder_") if argument.startswith("encoder_")
} }
kwargs_decoder = { kwargs_decoder = {
argument[len("decoder_") :]: value argument[len("decoder_"):]: value
for argument, value in kwargs.items() for argument, value in kwargs.items()
if argument.startswith("decoder_") if argument.startswith("decoder_")
} }
kwargs_common = {
argument: value
for argument, value in kwargs.items()
if not (argument.startswith("encoder_") or argument.startswith("decoder_"))
}
kwargs_decoder = dict(kwargs_common, **kwargs_decoder)
kwargs_encoder = dict(kwargs_common, **kwargs_encoder)
# Load and initialize the encoder and decoder # Load and initialize the encoder and decoder
# The distinction between encoder and decoder at the model level is made # The distinction between encoder and decoder at the model level is made
# by the value of the flag `is_decoder` that we need to set correctly. # by the value of the flag `is_decoder` that we need to set correctly.
encoder = kwargs_encoder.pop("encoder_model", None) encoder = kwargs_encoder.pop("model", None)
if encoder is None: if encoder is None:
kwargs_encoder["is_decoder"] = False
encoder = AutoModel.from_pretrained( encoder = AutoModel.from_pretrained(
encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder
) )
encoder.config.is_decoder = False
decoder = kwargs_decoder.pop("model", None) decoder = kwargs_decoder.pop("model", None)
if decoder is None: if decoder is None:
kwargs_decoder["is_decoder"] = True
decoder = AutoModelWithLMHead.from_pretrained( decoder = AutoModelWithLMHead.from_pretrained(
decoder_pretrained_model_name_or_path, **kwargs_decoder decoder_pretrained_model_name_or_path, **kwargs_decoder
) )
decoder.config.is_decoder = True
model = cls(encoder, decoder) model = cls(encoder, decoder)
...@@ -169,37 +177,60 @@ class PreTrainedSeq2seq(nn.Module): ...@@ -169,37 +177,60 @@ class PreTrainedSeq2seq(nn.Module):
decoder_input_ids: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)`` decoder_input_ids: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``
Indices of decoder input sequence tokens in the vocabulary. Indices of decoder input sequence tokens in the vocabulary.
""" """
# Separate the encoder- and decoder- specific kwargs. A kwarg is # keyword arguments come in 3 flavors: encoder-specific (prefixed by
# decoder-specific it the key starts with `decoder_` # `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as whole.
# We let the specific kwargs override the common ones in case of conflict.
kwargs_encoder = { kwargs_encoder = {
argument: value argument[len("encoder_"):]: value
for argument, value in kwargs.items() for argument, value in kwargs.items()
if not argument.startswith("decoder_") if argument.startswith("encoder_")
} }
kwargs_decoder = { kwargs_decoder = {
argument[len("decoder_") :]: value argument[len("decoder_"):]: value
for argument, value in kwargs.items() for argument, value in kwargs.items()
if argument.startswith("decoder_") if argument.startswith("decoder_")
} }
kwargs_common = {
argument: value
for argument, value in kwargs.items()
if not (argument.startswith("encoder_") or argument.startswith("decoder_"))
}
kwargs_decoder = dict(kwargs_common, **kwargs_decoder)
kwargs_encoder = dict(kwargs_common, **kwargs_encoder)
# Encode if needed (training, first prediction pass) # Encode if needed (training, first prediction pass)
encoder_hidden_states = kwargs_encoder.pop("encoder_hidden_states", None) encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder) encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
encoder_hidden_states = encoder_outputs[0][ encoder_hidden_states = encoder_outputs[0] # output the last layer hidden state
-1
] # output of the encoder *stack*
else: else:
encoder_outputs = () encoder_outputs = ()
# Decode # Decode
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states[None, :, :] kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get("attention_mask", None)
decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder) decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder)
return decoder_outputs + encoder_outputs return decoder_outputs + encoder_outputs
class Model2Model(PreTrainedSeq2seq): class Model2Model(PreTrainedSeq2seq):
r"""
:class:`~transformers.Model2Model` instantiates a Seq2Seq2 model
where both of the encoder and decoder are of the same family. If the
name of or that path to a pretrained model is specified the encoder and
the decoder will be initialized with the pretrained weight (the
cross-attention will be intialized randomly if its weights are not
present).
It is possible to override this behavior and initialize, say, the decoder randomly
by creating it beforehand as follows
config = BertConfig.from_pretrained()
decoder = BertForMaskedLM(config)
model = Model2Model.from_pretrained('bert-base-uncased', decoder_model=decoder)
"""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(Model2Model, self).__init__(*args, **kwargs) super(Model2Model, self).__init__(*args, **kwargs)
self.tie_weights() self.tie_weights()
...@@ -235,14 +266,10 @@ class Model2Model(PreTrainedSeq2seq): ...@@ -235,14 +266,10 @@ class Model2Model(PreTrainedSeq2seq):
model = super(Model2Model, cls).from_pretrained( model = super(Model2Model, cls).from_pretrained(
encoder_pretrained_model_name_or_path=pretrained_model_name_or_path, encoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
decoder_pretrained_model_name_or_path=pretrained_model_name_or_path, decoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
*args,
**kwargs **kwargs
) )
# Some architectures require for the decoder to be initialized randomly
# before fine-tuning.
if kwargs.get("decoder_initialize_randomly", False):
model.decoder.init_weights()
return model return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment