Commit 2403a665 authored by Rémi Louf's avatar Rémi Louf Committed by Julien Chaumond
Browse files

give transformers API to BertAbs

parent 4d181999
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Convert BertExtAbs's checkpoints """
import argparse
from collections import namedtuple
import logging
import pdb
import torch
from models.model_builder import AbsSummarizer # The authors' implementation
from model_bertabs import BertAbsSummarizer
from transformers import BertTokenizer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
SAMPLE_TEXT = 'Hello world! cécé herlolip'
BertAbsConfig = namedtuple(
"BertAbsConfig",
["temp_dir", "large", "use_bert_emb", "finetune_bert", "encoder", "share_emb", "max_pos", "enc_layers", "enc_hidden_size", "enc_heads", "enc_ff_size", "enc_dropout", "dec_layers", "dec_hidden_size", "dec_heads", "dec_ff_size", "dec_dropout"],
)
def convert_bertabs_checkpoints(path_to_checkpoints, dump_path):
""" Copy/paste and tweak the pre-trained weights provided by the creators
of BertAbs for the internal architecture.
"""
# Instantiate the authors' model with the pre-trained weights
config = BertAbsConfig(
temp_dir=".",
finetune_bert=False,
large=False,
share_emb=True,
use_bert_emb=False,
encoder="bert",
max_pos=512,
enc_layers=6,
enc_hidden_size=512,
enc_heads=8,
enc_ff_size=512,
enc_dropout=0.2,
dec_layers=6,
dec_hidden_size=768,
dec_heads=8,
dec_ff_size=2048,
dec_dropout=0.2,
)
checkpoints = torch.load(path_to_checkpoints, lambda storage, loc: storage)
original = AbsSummarizer(config, torch.device("cpu"), checkpoints)
original.eval()
new_model = BertAbsSummarizer(config, torch.device("cpu"))
new_model.eval()
# -------------------
# Convert the weights
# -------------------
logging.info("convert the model")
new_model.encoder.load_state_dict(original.bert.state_dict())
new_model.decoder.generator.load_state_dict(original.generator.state_dict())
new_model.decoder.embeddings.load_state_dict(original.decoder.embeddings.state_dict())
new_model.decoder.pos_emb.load_state_dict(original.decoder.pos_emb.state_dict())
new_model.decoder.transformer_layers.load_state_dict(original.decoder.transformer_layers.state_dict())
new_model.decoder.layer_norm.load_state_dict(original.decoder.layer_norm.state_dict())
# ----------------------------------
# Make sure the outpus are identical
# ----------------------------------
logging.info("Make sure that the models' outputs are identical")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# prepare the model inputs
encoder_input_ids = tokenizer.encode("This is sample éàalj'-.")
encoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(encoder_input_ids)))
encoder_input_ids = torch.tensor(encoder_input_ids).unsqueeze(0)
decoder_input_ids = tokenizer.encode("This is sample 3 éàalj'-.")
decoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(decoder_input_ids)))
decoder_input_ids = torch.tensor(decoder_input_ids).unsqueeze(0)
# failsafe to make sure the weights reset does not affect the
# loaded weights.
assert torch.max(torch.abs(original.generator[0].weight - new_model.decoder.generator[0].weight)) == 0
# forward pass
src = encoder_input_ids
tgt = decoder_input_ids
segs = token_type_ids = None
clss = None
mask_src = encoder_attention_mask = None
mask_tgt = decoder_attention_mask = None
mask_cls = None
# The original model does not apply the geneator layer immediatly but rather in
# the beam search (where it combines softmax + linear layer). Since we already
# apply the softmax in our generation process we only apply the linear layer here.
# We make sure that the outputs of the full stack are identical
output_original_model = original(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)[0]
output_original_model = original.generator(output_original_model)
output_converted_model = new_model(encoder_input_ids, decoder_input_ids, token_type_ids, encoder_attention_mask, decoder_attention_mask)[0]
output_converted_model = torch.nn.functional.log_softmax(output_converted_model, dim=-1)
maximum_absolute_difference = torch.max(torch.abs(output_converted_model - output_original_model)).item()
print("Maximum absolute difference beween weights: {:.2f}".format(maximum_absolute_difference))
are_identical = torch.allclose(output_converted_model, output_original_model, atol=1e-3)
if are_identical:
logging.info("all weights are equal up to 1e-3")
else:
raise ValueError("the weights are different. The new model is likely different from the original one.")
# The model has been saved with torch.save(model) and this is bound to the exact
# directory structure. We save the state_dict instead.
logging.info("saving the model's state dictionary")
torch.save(new_model.state_dict(), "bert-ext-abs.pt")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--bertabs_checkpoint_path",
default=None,
type=str,
required=True,
help="Path the official PyTorch dump.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=str,
required=True,
help="Path to the output PyTorch model.",
)
args = parser.parse_args()
convert_bertabs_checkpoints(
args.bertabs_checkpoint_path,
args.pytorch_dump_folder_path,
)
# coding=utf-8
# Copyright 2019 The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" BertAbs configuration """
import json
import logging
import sys
from transformers import PretrainedConfig
logger = logging.getLogger(__name__)
BERTABS_FINETUNED_CONFIG_MAP = {
"bertabs-finetuned-cnndm": "https://s3.amazonaws.com/models.huggingface.co/bert/remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization-config.json",
}
class BertAbsConfig(PretrainedConfig):
r""" Class to store the configuration of the BertAbs model.
Arguments:
temp_dir: string
Unused in the current situation. Kept for compatibility but will be removed.
finetune_bert: bool
Whether to fine-tune the model or not. Will be kept for reference
in case we want to add the possibility to fine-tune the model.
large: bool
Whether to use bert-large as a base.
share_emb: book
Whether the embeddings are shared between the encoder and decoder.
encoder: string
Not clear what this does. Leave to "bert" for pre-trained weights.
max_pos: int
The maximum sequence length that this model will be used with.
enc_layer: int
The numner of hidden layers in the Transformer encoder.
enc_hidden_size: int
The size of the encoder's layers.
enc_heads: int
The number of attention heads for each attention layer in the encoder.
enc_ff_size: int
The size of the encoder's feed-forward layers.
enc_dropout: int
The dropout probabilitiy for all fully connected layers in the
embeddings, layers, pooler and also the attention probabilities in
the encoder.
dec_layer: int
The numner of hidden layers in the decoder.
dec_hidden_size: int
The size of the decoder's layers.
dec_heads: int
The number of attention heads for each attention layer in the decoder.
dec_ff_size: int
The size of the decoder's feed-forward layers.
dec_dropout: int
The dropout probabilitiy for all fully connected layers in the
embeddings, layers, pooler and also the attention probabilities in
the decoder.
"""
pretrained_config_archive_map = BERTABS_FINETUNED_CONFIG_MAP
def __init__(
self,
vocab_size_or_config_json_file=30522,
temp_dir=".",
finetune_bert=False,
large=False,
share_emb=True,
encoder="bert",
max_pos=512,
enc_layers=6,
enc_hidden_size=512,
enc_heads=8,
enc_ff_size=512,
enc_dropout=0.2,
dec_layers=6,
dec_hidden_size=768,
dec_heads=8,
dec_ff_size=2048,
dec_dropout=0.2,
**kwargs,
):
super(BertAbsConfig, self).__init__(**kwargs)
if self._input_is_path_to_json(vocab_size_or_config_json_file):
path_to_json = vocab_size_or_config_json_file
with open(path_to_json, "r", encoding="utf-8") as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.temp_dir = temp_dir
self.finetune_bert = finetune_bert
self.large = large
self.vocab_size = vocab_size_or_config_json_file
self.max_pos = max_pos
self.encoder = encoder
self.enc_layers = enc_layers
self.enc_hidden_size = enc_hidden_size
self.enc_heads = enc_heads
self.enc_ff_size = enc_ff_size
self.enc_dropout = enc_dropout
self.share_emb = share_emb
self.dec_layers = dec_layers
self.dec_hidden_size = dec_hidden_size
self.dec_heads = dec_heads
self.dec_ff_size = dec_ff_size
self.dec_dropout = dec_dropout
else:
raise ValueError(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
def _input_is_path_to_json(self, first_argument):
""" Checks whether the first argument passed to config
is the path to a JSON file that contains the config.
"""
is_python_2 = sys.version_info[0] == 2
if is_python_2:
return isinstance(first_argument, unicode)
else:
return isinstance(first_argument, str)
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Convert BertExtAbs's checkpoints
The file currently does not do much as we ended up copying the exact model
structure, but I leave it here in case we ever want to refactor the model.
"""
import argparse
from collections import namedtuple
import logging
import torch
from models.model_builder import AbsSummarizer # The authors' implementation
from model_bertabs import BertAbsSummarizer
from transformers import BertTokenizer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
SAMPLE_TEXT = 'Hello world! cécé herlolip'
BertAbsConfig = namedtuple(
"BertAbsConfig",
["temp_dir", "large", "use_bert_emb", "finetune_bert", "encoder", "share_emb", "max_pos", "enc_layers", "enc_hidden_size", "enc_heads", "enc_ff_size", "enc_dropout", "dec_layers", "dec_hidden_size", "dec_heads", "dec_ff_size", "dec_dropout"],
)
def convert_bertabs_checkpoints(path_to_checkpoints, dump_path):
""" Copy/paste and tweak the pre-trained weights provided by the creators
of BertAbs for the internal architecture.
"""
# Instantiate the authors' model with the pre-trained weights
config = BertAbsConfig(
temp_dir=".",
finetune_bert=False,
large=False,
share_emb=True,
use_bert_emb=False,
encoder="bert",
max_pos=512,
enc_layers=6,
enc_hidden_size=512,
enc_heads=8,
enc_ff_size=512,
enc_dropout=0.2,
dec_layers=6,
dec_hidden_size=768,
dec_heads=8,
dec_ff_size=2048,
dec_dropout=0.2,
)
checkpoints = torch.load(path_to_checkpoints, lambda storage, loc: storage)
original = AbsSummarizer(config, torch.device("cpu"), checkpoints)
original.eval()
new_model = BertAbsSummarizer(config, torch.device("cpu"))
new_model.eval()
# -------------------
# Convert the weights
# -------------------
logging.info("convert the model")
new_model.bert.load_state_dict(original.bert.state_dict())
new_model.decoder.load_state_dict(original.decoder.state_dict())
new_model.generator.load_state_dict(original.generator.state_dict())
# ----------------------------------
# Make sure the outpus are identical
# ----------------------------------
logging.info("Make sure that the models' outputs are identical")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# prepare the model inputs
encoder_input_ids = tokenizer.encode("This is sample éàalj'-.")
encoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(encoder_input_ids)))
encoder_input_ids = torch.tensor(encoder_input_ids).unsqueeze(0)
decoder_input_ids = tokenizer.encode("This is sample 3 éàalj'-.")
decoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(decoder_input_ids)))
decoder_input_ids = torch.tensor(decoder_input_ids).unsqueeze(0)
# failsafe to make sure the weights reset does not affect the
# loaded weights.
assert torch.max(torch.abs(original.generator[0].weight - new_model.generator[0].weight)) == 0
# forward pass
src = encoder_input_ids
tgt = decoder_input_ids
segs = token_type_ids = None
clss = None
mask_src = encoder_attention_mask = None
mask_tgt = decoder_attention_mask = None
mask_cls = None
# The original model does not apply the geneator layer immediatly but rather in
# the beam search (where it combines softmax + linear layer). Since we already
# apply the softmax in our generation process we only apply the linear layer here.
# We make sure that the outputs of the full stack are identical
output_original_model = original(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)[0]
output_original_generator = original.generator(output_original_model)
output_converted_model = new_model(encoder_input_ids, decoder_input_ids, token_type_ids, encoder_attention_mask, decoder_attention_mask)[0]
output_converted_generator = new_model.generator(output_converted_model)
maximum_absolute_difference = torch.max(torch.abs(output_converted_model - output_original_model)).item()
print("Maximum absolute difference beween weights: {:.2f}".format(maximum_absolute_difference))
maximum_absolute_difference = torch.max(torch.abs(output_converted_generator - output_original_generator)).item()
print("Maximum absolute difference beween weights: {:.2f}".format(maximum_absolute_difference))
are_identical = torch.allclose(output_converted_model, output_original_model, atol=1e-3)
if are_identical:
logging.info("all weights are equal up to 1e-3")
else:
raise ValueError("the weights are different. The new model is likely different from the original one.")
# The model has been saved with torch.save(model) and this is bound to the exact
# directory structure. We save the state_dict instead.
logging.info("saving the model's state dictionary")
torch.save(new_model.state_dict(), "bertabs-finetuned-cnndm-extractive-abstractive-summarization-pytorch_model.bin")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--bertabs_checkpoint_path",
default=None,
type=str,
required=True,
help="Path the official PyTorch dump.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=str,
required=True,
help="Path to the output PyTorch model.",
)
args = parser.parse_args()
convert_bertabs_checkpoints(
args.bertabs_checkpoint_path,
args.pytorch_dump_folder_path,
)
# MIT License
# Copyright (c) 2019 Yang Liu
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import copy
import math
import shutil
import time
import os
import numpy as np
import torch
from torch import nn
from torch.nn.init import xavier_uniform_
from transformers import BertModel, BertConfig, PreTrainedModel
from configuration_bertabs import BertAbsConfig
MAX_SIZE = 5000
BERTABS_FINETUNED_MODEL_MAP = {
"bertabs-finetuned-cnndm": "https://s3.amazonaws.com/models.huggingface.co/bert/remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization-pytorch_model.bin",
}
class BertAbsPreTrainedModel(PreTrainedModel):
config_class = BertAbsConfig
pretrained_model_archive_map = BERTABS_FINETUNED_MODEL_MAP
load_tf_weights = False
base_model_prefix = "bert"
class BertAbs(BertAbsPreTrainedModel):
def __init__(self, args, checkpoint=None, bert_extractive_checkpoint=None):
super(BertAbs, self).__init__(args)
self.args = args
self.bert = Bert(args.large, args.temp_dir, args.finetune_bert)
# If pre-trained weights are passed for Bert, load these.
load_bert_pretrained_extractive = True if bert_extractive_checkpoint else False
if load_bert_pretrained_extractive:
self.bert.model.load_state_dict(
dict(
[
(n[11:], p)
for n, p in bert_extractive_checkpoint.items()
if n.startswith("bert.model")
]
),
strict=True,
)
if args.encoder == "baseline":
bert_config = BertConfig(
self.bert.model.config.vocab_size,
hidden_size=args.enc_hidden_size,
num_hidden_layers=args.enc_layers,
num_attention_heads=8,
intermediate_size=args.enc_ff_size,
hidden_dropout_prob=args.enc_dropout,
attention_probs_dropout_prob=args.enc_dropout,
)
self.bert.model = BertModel(bert_config)
self.vocab_size = self.bert.model.config.vocab_size
if args.max_pos > 512:
my_pos_embeddings = nn.Embedding(
args.max_pos, self.bert.model.config.hidden_size
)
my_pos_embeddings.weight.data[
:512
] = self.bert.model.embeddings.position_embeddings.weight.data
my_pos_embeddings.weight.data[
512:
] = self.bert.model.embeddings.position_embeddings.weight.data[-1][
None, :
].repeat(
args.max_pos - 512, 1
)
self.bert.model.embeddings.position_embeddings = my_pos_embeddings
tgt_embeddings = nn.Embedding(
self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0
)
if self.args.share_emb:
tgt_embeddings.weight = copy.deepcopy(
self.bert.model.embeddings.word_embeddings.weight
)
self.decoder = TransformerDecoder(
self.args.dec_layers,
self.args.dec_hidden_size,
heads=self.args.dec_heads,
d_ff=self.args.dec_ff_size,
dropout=self.args.dec_dropout,
embeddings=tgt_embeddings,
vocab_size=self.vocab_size,
)
gen_func = nn.LogSoftmax(dim=-1)
self.generator = nn.Sequential(
nn.Linear(args.dec_hidden_size, args.vocab_size), gen_func
)
self.generator[0].weight = self.decoder.embeddings.weight
load_from_checkpoints = False if checkpoint is None else True
if load_from_checkpoints:
self.load_state_dict(checkpoint)
def init_weights(self):
for module in self.decoder.modules():
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
for p in self.generator.parameters():
if p.dim() > 1:
xavier_uniform_(p)
else:
p.data.zero_()
def maybe_tie_embeddings(self, args):
if args.use_bert_emb:
tgt_embeddings = nn.Embedding(
self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0
)
tgt_embeddings.weight = copy.deepcopy(
self.bert.model.embeddings.word_embeddings.weight
)
self.decoder.embeddings = tgt_embeddings
def forward(
self,
encoder_input_ids,
decoder_input_ids,
token_type_ids,
encoder_attention_mask,
decoder_attention_mask,
):
encoder_output = self.bert(
input_ids=encoder_input_ids,
token_type_ids=token_type_ids,
attention_mask=encoder_attention_mask,
)
encoder_hidden_states = encoder_output[0]
dec_state = self.decoder.init_decoder_state(
encoder_input_ids, encoder_hidden_states
)
decoder_outputs, _ = self.decoder(
decoder_input_ids[:, :-1], encoder_hidden_states, dec_state
)
return decoder_outputs
class Bert(nn.Module):
""" This class is not really necessary and should probably disappear.
"""
def __init__(self, large, temp_dir, finetune=False):
super(Bert, self).__init__()
if large:
self.model = BertModel.from_pretrained("bert-large-uncased", cache_dir=temp_dir)
else:
self.model = BertModel.from_pretrained("bert-base-uncased", cache_dir=temp_dir)
self.finetune = finetune
def forward(self, input_ids, attention_mask=None, token_type_ids=None, **kwargs):
self.eval()
with torch.no_grad():
encoder_outputs, _ = self.model(
input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
**kwargs
)
return encoder_outputs
class TransformerDecoder(nn.Module):
"""
The Transformer decoder from "Attention is All You Need".
Args:
num_layers (int): number of encoder layers.
d_model (int): size of the model
heads (int): number of heads
d_ff (int): size of the inner FF layer
dropout (float): dropout parameters
embeddings (:obj:`onmt.modules.Embeddings`):
embeddings to use, should have positional encodings
attn_type (str): if using a seperate copy attention
"""
def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings, vocab_size):
super(TransformerDecoder, self).__init__()
# Basic attributes.
self.decoder_type = "transformer"
self.num_layers = num_layers
self.embeddings = embeddings
self.pos_emb = PositionalEncoding(dropout, self.embeddings.embedding_dim)
# Build TransformerDecoder.
self.transformer_layers = nn.ModuleList(
[
TransformerDecoderLayer(d_model, heads, d_ff, dropout)
for _ in range(num_layers)
]
)
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
# forward(input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask)
# def forward(self, input_ids, state, attention_mask=None, memory_lengths=None,
# step=None, cache=None, encoder_attention_mask=None, encoder_hidden_states=None, memory_masks=None):
def forward(
self,
input_ids,
encoder_hidden_states=None,
state=None,
attention_mask=None,
memory_lengths=None,
step=None,
cache=None,
encoder_attention_mask=None,
):
"""
See :obj:`onmt.modules.RNNDecoderBase.forward()`
memory_bank = encoder_hidden_states
"""
# Name conversion
tgt = input_ids
memory_bank = encoder_hidden_states
memory_mask = encoder_attention_mask
# src_words = state.src
src_words = state.src
src_batch, src_len = src_words.size()
padding_idx = self.embeddings.padding_idx
# Decoder padding mask
tgt_words = tgt
tgt_batch, tgt_len = tgt_words.size()
tgt_pad_mask = (
tgt_words.data.eq(padding_idx).unsqueeze(1).expand(tgt_batch, tgt_len, tgt_len)
)
# Encoder padding mask
if memory_mask is not None:
src_len = memory_mask.size(-1)
src_pad_mask = memory_mask.expand(src_batch, tgt_len, src_len)
else:
src_pad_mask = (
src_words.data.eq(padding_idx)
.unsqueeze(1)
.expand(src_batch, tgt_len, src_len)
)
# Pass through the embeddings
emb = self.embeddings(input_ids)
output = self.pos_emb(emb, step)
assert emb.dim() == 3 # len x batch x embedding_dim
if state.cache is None:
saved_inputs = []
for i in range(self.num_layers):
prev_layer_input = None
if state.cache is None:
if state.previous_input is not None:
prev_layer_input = state.previous_layer_inputs[i]
output, all_input = self.transformer_layers[i](
output,
memory_bank,
src_pad_mask,
tgt_pad_mask,
previous_input=prev_layer_input,
layer_cache=state.cache["layer_{}".format(i)]
if state.cache is not None
else None,
step=step,
)
if state.cache is None:
saved_inputs.append(all_input)
if state.cache is None:
saved_inputs = torch.stack(saved_inputs)
output = self.layer_norm(output)
if state.cache is None:
state = state.update_state(tgt, saved_inputs)
# Decoders in transformers return a tuple. Beam search will fail
# if we don't follow this convention.
return output, state # , state
def init_decoder_state(self, src, memory_bank, with_cache=False):
""" Init decoder state """
state = TransformerDecoderState(src)
if with_cache:
state._init_cache(memory_bank, self.num_layers)
return state
class PositionalEncoding(nn.Module):
def __init__(self, dropout, dim, max_len=5000):
pe = torch.zeros(max_len, dim)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(
(torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim))
)
pe[:, 0::2] = torch.sin(position.float() * div_term)
pe[:, 1::2] = torch.cos(position.float() * div_term)
pe = pe.unsqueeze(0)
super(PositionalEncoding, self).__init__()
self.register_buffer("pe", pe)
self.dropout = nn.Dropout(p=dropout)
self.dim = dim
def forward(self, emb, step=None):
emb = emb * math.sqrt(self.dim)
if step:
emb = emb + self.pe[:, step][:, None, :]
else:
emb = emb + self.pe[:, : emb.size(1)]
emb = self.dropout(emb)
return emb
def get_emb(self, emb):
return self.pe[:, : emb.size(1)]
class TransformerDecoderLayer(nn.Module):
"""
Args:
d_model (int): the dimension of keys/values/queries in
MultiHeadedAttention, also the input size of
the first-layer of the PositionwiseFeedForward.
heads (int): the number of heads for MultiHeadedAttention.
d_ff (int): the second-layer of the PositionwiseFeedForward.
dropout (float): dropout probability(0-1.0).
self_attn_type (string): type of self-attention scaled-dot, average
"""
def __init__(self, d_model, heads, d_ff, dropout):
super(TransformerDecoderLayer, self).__init__()
self.self_attn = MultiHeadedAttention(heads, d_model, dropout=dropout)
self.context_attn = MultiHeadedAttention(heads, d_model, dropout=dropout)
self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6)
self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6)
self.drop = nn.Dropout(dropout)
mask = self._get_attn_subsequent_mask(MAX_SIZE)
# Register self.mask as a buffer in TransformerDecoderLayer, so
# it gets TransformerDecoderLayer's cuda behavior automatically.
self.register_buffer("mask", mask)
def forward(
self,
inputs,
memory_bank,
src_pad_mask,
tgt_pad_mask,
previous_input=None,
layer_cache=None,
step=None,
):
"""
Args:
inputs (`FloatTensor`): `[batch_size x 1 x model_dim]`
memory_bank (`FloatTensor`): `[batch_size x src_len x model_dim]`
src_pad_mask (`LongTensor`): `[batch_size x 1 x src_len]`
tgt_pad_mask (`LongTensor`): `[batch_size x 1 x 1]`
Returns:
(`FloatTensor`, `FloatTensor`, `FloatTensor`):
* output `[batch_size x 1 x model_dim]`
* attn `[batch_size x 1 x src_len]`
* all_input `[batch_size x current_step x model_dim]`
"""
dec_mask = torch.gt(
tgt_pad_mask + self.mask[:, : tgt_pad_mask.size(1), : tgt_pad_mask.size(1)], 0
)
input_norm = self.layer_norm_1(inputs)
all_input = input_norm
if previous_input is not None:
all_input = torch.cat((previous_input, input_norm), dim=1)
dec_mask = None
query = self.self_attn(
all_input,
all_input,
input_norm,
mask=dec_mask,
layer_cache=layer_cache,
type="self",
)
query = self.drop(query) + inputs
query_norm = self.layer_norm_2(query)
mid = self.context_attn(
memory_bank,
memory_bank,
query_norm,
mask=src_pad_mask,
layer_cache=layer_cache,
type="context",
)
output = self.feed_forward(self.drop(mid) + query)
return output, all_input
# return output
def _get_attn_subsequent_mask(self, size):
"""
Get an attention mask to avoid using the subsequent info.
Args:
size: int
Returns:
(`LongTensor`):
* subsequent_mask `[1 x size x size]`
"""
attn_shape = (1, size, size)
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype("uint8")
subsequent_mask = torch.from_numpy(subsequent_mask)
return subsequent_mask
class MultiHeadedAttention(nn.Module):
"""
Multi-Head Attention module from
"Attention is All You Need"
:cite:`DBLP:journals/corr/VaswaniSPUJGKP17`.
Similar to standard `dot` attention but uses
multiple attention distributions simulataneously
to select relevant items.
.. mermaid::
graph BT
A[key]
B[value]
C[query]
O[output]
subgraph Attn
D[Attn 1]
E[Attn 2]
F[Attn N]
end
A --> D
C --> D
A --> E
C --> E
A --> F
C --> F
D --> O
E --> O
F --> O
B --> O
Also includes several additional tricks.
Args:
head_count (int): number of parallel heads
model_dim (int): the dimension of keys/values/queries,
must be divisible by head_count
dropout (float): dropout parameter
"""
def __init__(self, head_count, model_dim, dropout=0.1, use_final_linear=True):
assert model_dim % head_count == 0
self.dim_per_head = model_dim // head_count
self.model_dim = model_dim
super(MultiHeadedAttention, self).__init__()
self.head_count = head_count
self.linear_keys = nn.Linear(model_dim, head_count * self.dim_per_head)
self.linear_values = nn.Linear(model_dim, head_count * self.dim_per_head)
self.linear_query = nn.Linear(model_dim, head_count * self.dim_per_head)
self.softmax = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
self.use_final_linear = use_final_linear
if self.use_final_linear:
self.final_linear = nn.Linear(model_dim, model_dim)
def forward(
self,
key,
value,
query,
mask=None,
layer_cache=None,
type=None,
predefined_graph_1=None,
):
"""
Compute the context vector and the attention vectors.
Args:
key (`FloatTensor`): set of `key_len`
key vectors `[batch, key_len, dim]`
value (`FloatTensor`): set of `key_len`
value vectors `[batch, key_len, dim]`
query (`FloatTensor`): set of `query_len`
query vectors `[batch, query_len, dim]`
mask: binary mask indicating which keys have
non-zero attention `[batch, query_len, key_len]`
Returns:
(`FloatTensor`, `FloatTensor`) :
* output context vectors `[batch, query_len, dim]`
* one of the attention vectors `[batch, query_len, key_len]`
"""
batch_size = key.size(0)
dim_per_head = self.dim_per_head
head_count = self.head_count
key_len = key.size(1)
query_len = query.size(1)
def shape(x):
""" projection """
return x.view(batch_size, -1, head_count, dim_per_head).transpose(1, 2)
def unshape(x):
""" compute context """
return (
x.transpose(1, 2)
.contiguous()
.view(batch_size, -1, head_count * dim_per_head)
)
# 1) Project key, value, and query.
if layer_cache is not None:
if type == "self":
query, key, value = (
self.linear_query(query),
self.linear_keys(query),
self.linear_values(query),
)
key = shape(key)
value = shape(value)
if layer_cache is not None:
device = key.device
if layer_cache["self_keys"] is not None:
key = torch.cat((layer_cache["self_keys"].to(device), key), dim=2)
if layer_cache["self_values"] is not None:
value = torch.cat(
(layer_cache["self_values"].to(device), value), dim=2
)
layer_cache["self_keys"] = key
layer_cache["self_values"] = value
elif type == "context":
query = self.linear_query(query)
if layer_cache is not None:
if layer_cache["memory_keys"] is None:
key, value = self.linear_keys(key), self.linear_values(value)
key = shape(key)
value = shape(value)
else:
key, value = (
layer_cache["memory_keys"],
layer_cache["memory_values"],
)
layer_cache["memory_keys"] = key
layer_cache["memory_values"] = value
else:
key, value = self.linear_keys(key), self.linear_values(value)
key = shape(key)
value = shape(value)
else:
key = self.linear_keys(key)
value = self.linear_values(value)
query = self.linear_query(query)
key = shape(key)
value = shape(value)
query = shape(query)
key_len = key.size(2)
query_len = query.size(2)
# 2) Calculate and scale scores.
query = query / math.sqrt(dim_per_head)
scores = torch.matmul(query, key.transpose(2, 3))
if mask is not None:
mask = mask.unsqueeze(1).expand_as(scores)
scores = scores.masked_fill(mask, -1e18)
# 3) Apply attention dropout and compute context vectors.
attn = self.softmax(scores)
if not predefined_graph_1 is None:
attn_masked = attn[:, -1] * predefined_graph_1
attn_masked = attn_masked / (torch.sum(attn_masked, 2).unsqueeze(2) + 1e-9)
attn = torch.cat([attn[:, :-1], attn_masked.unsqueeze(1)], 1)
drop_attn = self.dropout(attn)
if self.use_final_linear:
context = unshape(torch.matmul(drop_attn, value))
output = self.final_linear(context)
return output
else:
context = torch.matmul(drop_attn, value)
return context
class DecoderState(object):
"""Interface for grouping together the current state of a recurrent
decoder. In the simplest case just represents the hidden state of
the model. But can also be used for implementing various forms of
input_feeding and non-recurrent models.
Modules need to implement this to utilize beam search decoding.
"""
def detach(self):
""" Need to document this """
self.hidden = tuple([_.detach() for _ in self.hidden])
self.input_feed = self.input_feed.detach()
def beam_update(self, idx, positions, beam_size):
""" Need to document this """
for e in self._all:
sizes = e.size()
br = sizes[1]
if len(sizes) == 3:
sent_states = e.view(sizes[0], beam_size, br // beam_size, sizes[2])[
:, :, idx
]
else:
sent_states = e.view(
sizes[0], beam_size, br // beam_size, sizes[2], sizes[3]
)[:, :, idx]
sent_states.data.copy_(sent_states.data.index_select(1, positions))
def map_batch_fn(self, fn):
raise NotImplementedError()
class TransformerDecoderState(DecoderState):
""" Transformer Decoder state base class """
def __init__(self, src):
"""
Args:
src (FloatTensor): a sequence of source words tensors
with optional feature tensors, of size (len x batch).
"""
self.src = src
self.previous_input = None
self.previous_layer_inputs = None
self.cache = None
@property
def _all(self):
"""
Contains attributes that need to be updated in self.beam_update().
"""
if self.previous_input is not None and self.previous_layer_inputs is not None:
return (self.previous_input, self.previous_layer_inputs, self.src)
else:
return (self.src,)
def detach(self):
if self.previous_input is not None:
self.previous_input = self.previous_input.detach()
if self.previous_layer_inputs is not None:
self.previous_layer_inputs = self.previous_layer_inputs.detach()
self.src = self.src.detach()
def update_state(self, new_input, previous_layer_inputs):
state = TransformerDecoderState(self.src)
state.previous_input = new_input
state.previous_layer_inputs = previous_layer_inputs
return state
def _init_cache(self, memory_bank, num_layers):
self.cache = {}
for l in range(num_layers):
layer_cache = {"memory_keys": None, "memory_values": None}
layer_cache["self_keys"] = None
layer_cache["self_values"] = None
self.cache["layer_{}".format(l)] = layer_cache
def repeat_beam_size_times(self, beam_size):
""" Repeat beam_size times along batch dimension. """
self.src = self.src.data.repeat(1, beam_size, 1)
def map_batch_fn(self, fn):
def _recursive_map(struct, batch_dim=0):
for k, v in struct.items():
if v is not None:
if isinstance(v, dict):
_recursive_map(v)
else:
struct[k] = fn(v, batch_dim)
self.src = fn(self.src, 0)
if self.cache is not None:
_recursive_map(self.cache)
def gelu(x):
return (
0.5
* x
* (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
)
class PositionwiseFeedForward(nn.Module):
""" A two-layer Feed-Forward-Network with residual layer norm.
Args:
d_model (int): the size of input for the first-layer of the FFN.
d_ff (int): the hidden layer size of the second-layer
of the FNN.
dropout (float): dropout probability in :math:`[0, 1)`.
"""
def __init__(self, d_model, d_ff, dropout=0.1):
super(PositionwiseFeedForward, self).__init__()
self.w_1 = nn.Linear(d_model, d_ff)
self.w_2 = nn.Linear(d_ff, d_model)
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
self.actv = gelu
self.dropout_1 = nn.Dropout(dropout)
self.dropout_2 = nn.Dropout(dropout)
def forward(self, x):
inter = self.dropout_1(self.actv(self.w_1(self.layer_norm(x))))
output = self.dropout_2(self.w_2(inter))
return output + x
#
# TRANSLATOR
# The following code is used to generate summaries using the
# pre-trained weights and beam search.
#
def build_predictor(args, tokenizer, symbols, model, logger=None):
# we should be able to refactor the global scorer a lot
scorer = GNMTGlobalScorer(args.alpha, length_penalty="wu")
translator = Translator(
args, model, tokenizer, symbols, global_scorer=scorer, logger=logger
)
return translator
class GNMTGlobalScorer(object):
"""
NMT re-ranking score from
"Google's Neural Machine Translation System" :cite:`wu2016google`
Args:
alpha (float): length parameter
beta (float): coverage parameter
"""
def __init__(self, alpha, length_penalty):
self.alpha = alpha
penalty_builder = PenaltyBuilder(length_penalty)
self.length_penalty = penalty_builder.length_penalty()
def score(self, beam, logprobs):
"""
Rescores a prediction based on penalty functions
"""
normalized_probs = self.length_penalty(beam, logprobs, self.alpha)
return normalized_probs
class PenaltyBuilder(object):
"""
Returns the Length and Coverage Penalty function for Beam Search.
Args:
length_pen (str): option name of length pen
cov_pen (str): option name of cov pen
"""
def __init__(self, length_pen):
self.length_pen = length_pen
def length_penalty(self):
if self.length_pen == "wu":
return self.length_wu
elif self.length_pen == "avg":
return self.length_average
else:
return self.length_none
"""
Below are all the different penalty terms implemented so far
"""
def length_wu(self, beam, logprobs, alpha=0.0):
"""
NMT length re-ranking score from
"Google's Neural Machine Translation System" :cite:`wu2016google`.
"""
modifier = ((5 + len(beam.next_ys)) ** alpha) / ((5 + 1) ** alpha)
return logprobs / modifier
def length_average(self, beam, logprobs, alpha=0.0):
"""
Returns the average probability of tokens in a sequence.
"""
return logprobs / len(beam.next_ys)
def length_none(self, beam, logprobs, alpha=0.0, beta=0.0):
"""
Returns unmodified scores.
"""
return logprobs
class Translator(object):
"""
Uses a model to translate a batch of sentences.
Args:
model (:obj:`onmt.modules.NMTModel`):
NMT model to use for translation
fields (dict of Fields): data fields
beam_size (int): size of beam to use
n_best (int): number of translations produced
max_length (int): maximum length output to produce
global_scores (:obj:`GlobalScorer`):
object to rescore final translations
copy_attn (bool): use copy attention during translation
cuda (bool): use cuda
beam_trace (bool): trace beam search for debugging
logger(logging.Logger): logger.
"""
def __init__(self, args, model, vocab, symbols, global_scorer=None, logger=None):
self.logger = logger
self.cuda = args.visible_gpus != "-1"
self.args = args
self.model = model
self.generator = self.model.generator
self.vocab = vocab
self.symbols = symbols
self.start_token = symbols["BOS"]
self.end_token = symbols["EOS"]
self.global_scorer = global_scorer
self.beam_size = args.beam_size
self.min_length = args.min_length
self.max_length = args.max_length
def translate(self, batch, step, attn_debug=False):
""" Generates summaries from one batch of data.
"""
self.model.eval()
with torch.no_grad():
batch_data = self.translate_batch(batch)
translations = self.from_batch(batch_data)
return translations
def translate_batch(self, batch, fast=False):
"""
Translate a batch of sentences.
Mostly a wrapper around :obj:`Beam`.
Args:
batch (:obj:`Batch`): a batch from a dataset object
data (:obj:`Dataset`): the dataset object
fast (bool): enables fast beam search (may not support all features)
Todo:
Shouldn't need the original dataset.
"""
with torch.no_grad():
return self._fast_translate_batch(
batch, self.max_length, min_length=self.min_length
)
# Where the beam search lives
# I have no idea why it is being called from the method above
def _fast_translate_batch(self, batch, max_length, min_length=0):
""" Beam Search using the encoder inputs contained in `batch`.
"""
# The batch object is funny
# Instead of just looking at the size of the arguments we encapsulate
# a size argument.
# Where is it defined?
beam_size = self.beam_size
batch_size = batch.batch_size
src = batch.src
segs = batch.segs
mask_src = batch.mask_src
src_features = self.model.bert(src, segs, mask_src)
dec_states = self.model.decoder.init_decoder_state(
src, src_features, with_cache=True
)
device = src_features.device
# Tile states and memory beam_size times.
dec_states.map_batch_fn(lambda state, dim: tile(state, beam_size, dim=dim))
src_features = tile(src_features, beam_size, dim=0)
batch_offset = torch.arange(batch_size, dtype=torch.long, device=device)
beam_offset = torch.arange(
0, batch_size * beam_size, step=beam_size, dtype=torch.long, device=device
)
alive_seq = torch.full(
[batch_size * beam_size, 1], self.start_token, dtype=torch.long, device=device
)
# Give full probability to the first beam on the first step.
topk_log_probs = torch.tensor(
[0.0] + [float("-inf")] * (beam_size - 1), device=device
).repeat(batch_size)
# Structure that holds finished hypotheses.
hypotheses = [[] for _ in range(batch_size)] # noqa: F812
results = {}
results["predictions"] = [[] for _ in range(batch_size)] # noqa: F812
results["scores"] = [[] for _ in range(batch_size)] # noqa: F812
results["gold_score"] = [0] * batch_size
results["batch"] = batch
for step in range(max_length):
decoder_input = alive_seq[:, -1].view(1, -1)
# Decoder forward.
decoder_input = decoder_input.transpose(0, 1)
dec_out, dec_states = self.model.decoder(
decoder_input, src_features, dec_states, step=step
)
# Generator forward.
log_probs = self.generator.forward(dec_out.transpose(0, 1).squeeze(0))
vocab_size = log_probs.size(-1)
if step < min_length:
log_probs[:, self.end_token] = -1e20
# Multiply probs by the beam probability.
log_probs += topk_log_probs.view(-1).unsqueeze(1)
alpha = self.global_scorer.alpha
length_penalty = ((5.0 + (step + 1)) / 6.0) ** alpha
# Flatten probs into a list of possibilities.
curr_scores = log_probs / length_penalty
if self.args.block_trigram:
cur_len = alive_seq.size(1)
if cur_len > 3:
for i in range(alive_seq.size(0)):
fail = False
words = [int(w) for w in alive_seq[i]]
words = [self.vocab.ids_to_tokens[w] for w in words]
words = " ".join(words).replace(" ##", "").split()
if len(words) <= 3:
continue
trigrams = [
(words[i - 1], words[i], words[i + 1])
for i in range(1, len(words) - 1)
]
trigram = tuple(trigrams[-1])
if trigram in trigrams[:-1]:
fail = True
if fail:
curr_scores[i] = -10e20
curr_scores = curr_scores.reshape(-1, beam_size * vocab_size)
topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1)
# Recover log probs.
topk_log_probs = topk_scores * length_penalty
# Resolve beam origin and true word ids.
topk_beam_index = topk_ids.div(vocab_size)
topk_ids = topk_ids.fmod(vocab_size)
# Map beam_index to batch_index in the flat representation.
batch_index = topk_beam_index + beam_offset[
: topk_beam_index.size(0)
].unsqueeze(1)
select_indices = batch_index.view(-1)
# Append last prediction.
alive_seq = torch.cat(
[alive_seq.index_select(0, select_indices), topk_ids.view(-1, 1)], -1
)
is_finished = topk_ids.eq(self.end_token)
if step + 1 == max_length:
is_finished.fill_(1)
# End condition is top beam is finished.
end_condition = is_finished[:, 0].eq(1)
# Save finished hypotheses.
if is_finished.any():
predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1))
for i in range(is_finished.size(0)):
b = batch_offset[i]
if end_condition[i]:
is_finished[i].fill_(1)
finished_hyp = is_finished[i].nonzero().view(-1)
# Store finished hypotheses for this batch.
for j in finished_hyp:
hypotheses[b].append((topk_scores[i, j], predictions[i, j, 1:]))
# If the batch reached the end, save the n_best hypotheses.
if end_condition[i]:
best_hyp = sorted(hypotheses[b], key=lambda x: x[0], reverse=True)
score, pred = best_hyp[0]
results["scores"][b].append(score)
results["predictions"][b].append(pred)
non_finished = end_condition.eq(0).nonzero().view(-1)
# If all sentences are translated, no need to go further.
if len(non_finished) == 0:
break
# Remove finished batches for the next step.
topk_log_probs = topk_log_probs.index_select(0, non_finished)
batch_index = batch_index.index_select(0, non_finished)
batch_offset = batch_offset.index_select(0, non_finished)
alive_seq = predictions.index_select(0, non_finished).view(
-1, alive_seq.size(-1)
)
# Reorder states.
select_indices = batch_index.view(-1)
src_features = src_features.index_select(0, select_indices)
dec_states.map_batch_fn(
lambda state, dim: state.index_select(dim, select_indices)
)
return results
def from_batch(self, translation_batch):
batch = translation_batch["batch"]
assert len(translation_batch["gold_score"]) == len(translation_batch["predictions"])
batch_size = batch.batch_size
preds, _, _, tgt_str, src = (
translation_batch["predictions"],
translation_batch["scores"],
translation_batch["gold_score"],
batch.tgt_str,
batch.src,
)
translations = []
for b in range(batch_size):
pred_sents = self.vocab.convert_ids_to_tokens([int(n) for n in preds[b][0]])
pred_sents = " ".join(pred_sents).replace(" ##", "")
gold_sent = " ".join(tgt_str[b].split())
raw_src = [self.vocab.ids_to_tokens[int(t)] for t in src[b]][:500]
raw_src = " ".join(raw_src)
translation = (pred_sents, gold_sent, raw_src)
translations.append(translation)
return translations
def _report_rouge(self, gold_path, can_path):
self.logger.info("Calculating Rouge")
results_dict = test_rouge(self.args.temp_dir, can_path, gold_path)
return results_dict
def tile(x, count, dim=0):
"""
Tiles x on dimension dim count times.
"""
perm = list(range(len(x.size())))
if dim != 0:
perm[0], perm[dim] = perm[dim], perm[0]
x = x.permute(perm).contiguous()
out_size = list(x.size())
out_size[0] *= count
batch = x.size(0)
x = (
x.view(batch, -1)
.transpose(0, 1)
.repeat(count, 1)
.transpose(0, 1)
.contiguous()
.view(*out_size)
)
if dim != 0:
x = x.permute(perm).contiguous()
return x
#
# All things ROUGE. Uses `pyrouge` which is a hot mess.
#
def test_rouge(temp_dir, cand, ref):
candidates = [line.strip() for line in open(cand, encoding="utf-8")]
references = [line.strip() for line in open(ref, encoding="utf-8")]
print(len(candidates))
print(len(references))
assert len(candidates) == len(references)
cnt = len(candidates)
current_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
tmp_dir = os.path.join(temp_dir, "rouge-tmp-{}".format(current_time))
if not os.path.isdir(tmp_dir):
os.mkdir(tmp_dir)
os.mkdir(tmp_dir + "/candidate")
os.mkdir(tmp_dir + "/reference")
try:
for i in range(cnt):
if len(references[i]) < 1:
continue
with open(
tmp_dir + "/candidate/cand.{}.txt".format(i), "w", encoding="utf-8"
) as f:
f.write(candidates[i])
with open(
tmp_dir + "/reference/ref.{}.txt".format(i), "w", encoding="utf-8"
) as f:
f.write(references[i])
r = pyrouge.Rouge155(temp_dir=temp_dir)
r.model_dir = tmp_dir + "/reference/"
r.system_dir = tmp_dir + "/candidate/"
r.model_filename_pattern = "ref.#ID#.txt"
r.system_filename_pattern = r"cand.(\d+).txt"
rouge_results = r.convert_and_evaluate()
print(rouge_results)
results_dict = r.output_to_dict(rouge_results)
finally:
pass
if os.path.isdir(tmp_dir):
shutil.rmtree(tmp_dir)
return results_dict
def rouge_results_to_str(results_dict):
return ">> ROUGE-F(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\nROUGE-R(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\n".format(
results_dict["rouge_1_f_score"] * 100,
results_dict["rouge_2_f_score"] * 100,
results_dict["rouge_l_f_score"] * 100,
results_dict["rouge_1_recall"] * 100,
results_dict["rouge_2_recall"] * 100,
results_dict["rouge_l_recall"] * 100,
)
class BertSumOptimizer(object):
""" Specific optimizer for BertSum.
As described in [1], the authors fine-tune BertSum for abstractive
summarization using two Adam Optimizers with different warm-up steps and
learning rate. They also use a custom learning rate scheduler.
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
arXiv preprint arXiv:1908.08345 (2019).
"""
def __init__(self, model, lr, warmup_steps, beta_1=0.99, beta_2=0.999, eps=1e-8):
self.encoder = model.encoder
self.decoder = model.decoder
self.lr = lr
self.warmup_steps = warmup_steps
self.optimizers = {
"encoder": torch.optim.Adam(
model.encoder.parameters(),
lr=lr["encoder"],
betas=(beta_1, beta_2),
eps=eps,
),
"decoder": torch.optim.Adam(
model.decoder.parameters(),
lr=lr["decoder"],
betas=(beta_1, beta_2),
eps=eps,
),
}
self._step = 0
self.current_learning_rates = {}
def _update_rate(self, stack):
return self.lr[stack] * min(
self._step ** (-0.5), self._step * self.warmup_steps[stack] ** (-1.5)
)
def zero_grad(self):
self.optimizer_decoder.zero_grad()
self.optimizer_encoder.zero_grad()
def step(self):
self._step += 1
for stack, optimizer in self.optimizers.items():
new_rate = self._update_rate(stack)
for param_group in optimizer.param_groups:
param_group["lr"] = new_rate
optimizer.step()
self.current_learning_rates[stack] = new_rate
import argparse
from collections import namedtuple
import logging
import os
import sys
import torch
from torch.utils.data import DataLoader, SequentialSampler
from tqdm import tqdm
from transformers import BertTokenizer
from modeling_bertabs import BertAbs, build_predictor
from utils_summarization import (
SummarizationDataset,
encode_for_summarization,
build_mask,
fit_to_block_size,
compute_token_type_ids,
)
logger = logging.getLogger(__name__)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
Batch = namedtuple(
"Batch", ["document_names", "batch_size", "src", "segs", "mask_src", "tgt_str"]
)
def evaluate(args):
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)
model = bertabs = BertAbs.from_pretrained(
"bertabs-finetuned-{}".format(args.finetuned_model)
)
bertabs.to(args.device)
bertabs.eval()
symbols = {
"BOS": tokenizer.vocab["[unused0]"],
"EOS": tokenizer.vocab["[unused1]"],
"PAD": tokenizer.vocab["[PAD]"],
}
# these (unused) arguments are defined to keep the compatibility
# with the legacy code and will be deleted in a next iteration.
args.result_path = ""
args.temp_dir = ""
data_iterator = build_data_iterator(args, tokenizer)
predictor = build_predictor(args, tokenizer, symbols, model)
logger.info("***** Running evaluation *****")
logger.info(" Number examples = %d", len(data_iterator.dataset))
logger.info(" Batch size = %d", args.batch_size)
logger.info("")
logger.info("***** Beam Search parameters *****")
logger.info(" Beam size = %d", args.beam_size)
logger.info(" Minimum length = %d", args.min_length)
logger.info(" Maximum length = %d", args.max_length)
logger.info(" Alpha (length penalty) = %.2f", args.alpha)
logger.info(" Trigrams %s be blocked", ("will" if args.block_trigram else "will NOT"))
for batch in tqdm(data_iterator):
batch_data = predictor.translate_batch(batch)
translations = predictor.from_batch(batch_data)
summaries = [format_summary(t) for t in translations]
save_summaries(summaries, args.summaries_output_dir, batch.document_names)
def format_summary(translation):
""" Transforms the output of the `from_batch` function
into nicely formatted summaries.
"""
raw_summary, _, _ = translation
summary = (
raw_summary.replace("[unused0]", "")
.replace("[unused3]", "")
.replace("[PAD]", "")
.replace("[unused1]", "")
.replace(r" +", " ")
.replace(" [unused2] ", ". ")
.replace("[unused2]", "")
.strip()
)
return summary
def save_summaries(summaries, path, original_document_name):
""" Write the summaries in fies that are prefixed by the original
files' name with the `_summary` appended.
Attributes:
original_document_names: List[string]
Name of the document that was summarized.
path: string
Path were the summaries will be written
summaries: List[string]
The summaries that we produced.
"""
for summary, document_name in zip(summaries, original_document_name):
# Prepare the summary file's name
if "." in document_name:
bare_document_name = ".".join(document_name.split(".")[:-1])
extension = document_name.split(".")[-1]
name = bare_document_name + "_summary." + extension
else:
name = document_name + "_summary"
file_path = os.path.join(path, name)
with open(file_path, "w") as output:
output.write(summary)
#
# LOAD the dataset
#
def build_data_iterator(args, tokenizer):
dataset = load_and_cache_examples(args, tokenizer)
sampler = SequentialSampler(dataset)
collate_fn = lambda data: collate(data, tokenizer, block_size=512)
iterator = DataLoader(
dataset, sampler=sampler, batch_size=args.batch_size, collate_fn=collate_fn,
)
return iterator
def load_and_cache_examples(args, tokenizer):
dataset = SummarizationDataset(args.documents_dir)
return dataset
def collate(data, tokenizer, block_size):
""" Collate formats the data passed to the data loader.
In particular we tokenize the data batch after batch to avoid keeping them
all in memory. We output the data as a namedtuple to fit the original BertAbs's
API.
"""
data = [x for x in data if not len(x[1]) == 0] # remove empty_files
names = [name for name, _, _ in data]
encoded_text = [
encode_for_summarization(story, summary, tokenizer) for _, story, summary in data
]
stories = torch.tensor(
[
fit_to_block_size(story, block_size, tokenizer.pad_token_id)
for story, _ in encoded_text
]
)
encoder_token_type_ids = compute_token_type_ids(stories, tokenizer.cls_token_id)
encoder_mask = build_mask(stories, tokenizer.pad_token_id)
batch = Batch(
document_names=names,
batch_size=len(stories),
src=stories,
segs=encoder_token_type_ids,
mask_src=encoder_mask,
tgt_str=[""] * len(stories),
)
return batch
def decode_summary(summary_tokens, tokenizer):
""" Decode the summary and return it in a format
suitable for evaluation.
"""
summary_tokens = summary_tokens.to("cpu").numpy()
summary = tokenizer.decode(summary_tokens)
sentences = summary.split(".")
sentences = [s + "." for s in sentences]
return sentences
def main():
""" The main function defines the interface with the users.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--documents_dir",
default=None,
type=str,
required=True,
help="The folder where the documents to summarize are located.",
)
parser.add_argument(
"--summaries_output_dir",
default=None,
type=str,
required=True,
help="The folder in wich the summaries should be written.",
)
# EVALUATION options
parser.add_argument(
"--visible_gpus",
default=-1,
type=int,
help="Number of GPUs with which to do the training.",
)
parser.add_argument(
"--batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.",
)
# BEAM SEARCH arguments
parser.add_argument(
"--min_length",
default=50,
type=int,
help="Minimum number of tokens for the summaries.",
)
parser.add_argument(
"--max_length",
default=200,
type=int,
help="Maixmum number of tokens for the summaries.",
)
parser.add_argument(
"--beam_size",
default=5,
type=int,
help="The number of beams to start with for each example.",
)
parser.add_argument(
"--alpha",
default=0.95,
type=float,
help="The value of alpha for the length penalty in the beam search.",
)
parser.add_argument(
"--block_trigram",
default=True,
type=bool,
help="Whether to block the existence of repeating trigrams in the text generated by beam search.",
)
args = parser.parse_args()
args.device = torch.device("cpu") if args.visible_gpus == -1 else torch.device("cuda")
if not documents_dir_is_valid(args.documents_dir):
raise FileNotFoundError(
"We could not find the directory you specified for the documents to summarize, or it was empty. Please specify a valid path."
)
maybe_create_output_dir(args.summaries_output_dir)
evaluate(args)
def documents_dir_is_valid(path):
if not os.path.exists(path):
return False
file_list = os.listdir(path)
if len(file_list) == 0:
return False
return True
def maybe_create_output_dir(path):
if not os.path.exists(path):
os.makedirs(path)
if __name__ == "__main__":
main()
...@@ -10,9 +10,14 @@ from torch.utils.data import Dataset ...@@ -10,9 +10,14 @@ from torch.utils.data import Dataset
# ------------ # ------------
class CNNDailyMailDataset(Dataset): class SummarizationDataset(Dataset):
""" Abstracts the dataset used to train seq2seq models. """ Abstracts the dataset used to train seq2seq models.
The class will process the documents that are located in the specified
folder. The preprocessing will work on any document that is reasonably
formatted. On the CNN/DailyMail dataset it will extract both the story
and the summary.
CNN/Daily News: CNN/Daily News:
The CNN/Daily News raw datasets are downloaded from [1]. The stories are The CNN/Daily News raw datasets are downloaded from [1]. The stories are
...@@ -25,32 +30,31 @@ class CNNDailyMailDataset(Dataset): ...@@ -25,32 +30,31 @@ class CNNDailyMailDataset(Dataset):
[2] https://github.com/abisee/cnn-dailymail/ [2] https://github.com/abisee/cnn-dailymail/
""" """
def __init__(self, data_dir="", prefix="train"): def __init__(self, path="", prefix="train"):
assert os.path.isdir(data_dir) """ We initialize the class by listing all the documents to summarize.
Files are not read in memory due to the size of some datasets (like CNN/DailyMail).
# We initialize the class by listing all the files that contain """
# stories and summaries. Files are not read in memory given assert os.path.isdir(path)
# the size of the corpus.
self.stories_path = [] self.documents = []
datasets = ("cnn", "dailymail") story_filenames_list = os.listdir(path)
for dataset in datasets: for story_filename in story_filenames_list:
path_to_stories = os.path.join(data_dir, dataset, "stories") path_to_story = os.path.join(path, story_filename)
story_filenames_list = os.listdir(path_to_stories) if not os.path.isfile(path_to_story):
for story_filename in story_filenames_list: continue
path_to_story = os.path.join(path_to_stories, story_filename) self.documents.append(path_to_story)
if not os.path.isfile(path_to_story):
continue
self.stories_path.append(path_to_story)
def __len__(self): def __len__(self):
return len(self.stories_path) """ Returns the number of documents. """
return len(self.documents)
def __getitem__(self, idx): def __getitem__(self, idx):
story_path = self.stories_path[idx] document_path = self.documents[idx]
with open(story_path, encoding="utf-8") as source: document_name = document_path.split("/")[-1]
with open(document_path, encoding="utf-8") as source:
raw_story = source.read() raw_story = source.read()
story_lines, summary_lines = process_story(raw_story) story_lines, summary_lines = process_story(raw_story)
return story_lines, summary_lines return document_name, story_lines, summary_lines
def process_story(raw_story): def process_story(raw_story):
...@@ -80,7 +84,7 @@ def process_story(raw_story): ...@@ -80,7 +84,7 @@ def process_story(raw_story):
story_lines.append(element) story_lines.append(element)
except IndexError: except IndexError:
# if "@highlight" is absent from the file we pop # if "@highlight" is absent from the file we pop
# all elements until there is None. # all elements until there is None, raising an exception.
return story_lines, [] return story_lines, []
# gather summary lines # gather summary lines
...@@ -114,14 +118,6 @@ def fit_to_block_size(sequence, block_size, pad_token_id): ...@@ -114,14 +118,6 @@ def fit_to_block_size(sequence, block_size, pad_token_id):
return sequence return sequence
def build_lm_labels(sequence, pad_token_id):
""" Padding token are replaced by the value -1 so they
are not taken into account in the loss computation. """
padded = sequence.clone()
padded[padded == pad_token_id] = -1
return padded
def build_mask(sequence, pad_token_id): def build_mask(sequence, pad_token_id):
""" Builds the mask. The attention mechanism will only attend to positions """ Builds the mask. The attention mechanism will only attend to positions
with value 1. """ with value 1. """
...@@ -165,7 +161,7 @@ def compute_token_type_ids(batch, separator_token_id): ...@@ -165,7 +161,7 @@ def compute_token_type_ids(batch, separator_token_id):
""" """
batch_embeddings = [] batch_embeddings = []
for sequence in batch: for sequence in batch:
sentence_num = 0 sentence_num = -1
embeddings = [] embeddings = []
for s in sequence: for s in sequence:
if s == separator_token_id: if s == separator_token_id:
......
...@@ -21,7 +21,6 @@ from utils_summarization import ( ...@@ -21,7 +21,6 @@ from utils_summarization import (
compute_token_type_ids, compute_token_type_ids,
fit_to_block_size, fit_to_block_size,
build_mask, build_mask,
build_lm_labels,
process_story, process_story,
) )
...@@ -88,20 +87,6 @@ class SummarizationDataProcessingTest(unittest.TestCase): ...@@ -88,20 +87,6 @@ class SummarizationDataProcessingTest(unittest.TestCase):
expected_summary_lines = ["It was the best of times."] expected_summary_lines = ["It was the best of times."]
self.assertEqual(expected_summary_lines, summary_lines) self.assertEqual(expected_summary_lines, summary_lines)
def test_build_lm_labels_no_padding(self):
sequence = torch.tensor([1, 2, 3, 4])
expected = sequence
np.testing.assert_array_equal(
build_lm_labels(sequence, 0).numpy(), expected.numpy()
)
def test_build_lm_labels(self):
sequence = torch.tensor([1, 2, 3, 4, 0, 0, 0])
expected = torch.tensor([1, 2, 3, 4, -1, -1, -1])
np.testing.assert_array_equal(
build_lm_labels(sequence, 0).numpy(), expected.numpy()
)
def test_build_mask_no_padding(self): def test_build_mask_no_padding(self):
sequence = torch.tensor([1, 2, 3, 4]) sequence = torch.tensor([1, 2, 3, 4])
expected = torch.tensor([1, 1, 1, 1]) expected = torch.tensor([1, 1, 1, 1])
...@@ -125,7 +110,7 @@ class SummarizationDataProcessingTest(unittest.TestCase): ...@@ -125,7 +110,7 @@ class SummarizationDataProcessingTest(unittest.TestCase):
[[1, 2, 3, 4, 5, 6], [1, 2, 3, 101, 5, 6], [1, 101, 3, 4, 101, 6]] [[1, 2, 3, 4, 5, 6], [1, 2, 3, 101, 5, 6], [1, 101, 3, 4, 101, 6]]
) )
expected = torch.tensor( expected = torch.tensor(
[[0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 1, 1], [0, 1, 1, 1, 0, 0]] [[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0], [1, 0, 0, 0, 1, 1]]
) )
result = compute_token_type_ids(batch, separator) result = compute_token_type_ids(batch, separator)
......
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Convert BertExtAbs's checkpoints """
import argparse
from collections import namedtuple
import logging
import torch
from models.model_builder import AbsSummarizer # The authors' implementation
from transformers import BertConfig, Model2Model, BertModel, BertForMaskedLM
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
BertExtAbsConfig = namedtuple(
"BertExtAbsConfig",
["temp_dir", "large", "finetune_bert", "encoder", "share_emb", "max_pos", "enc_layers", "enc_hidden_size", "enc_heads", "enc_ff_size", "enc_dropout", "dec_layers", "dec_hidden_size", "dec_heads", "dec_ff_size", "dec_dropout"],
)
def convert_bertextabs_checkpoints(path_to_checkpoints, dump_path):
""" Copy/paste and tweak the pre-trained weights provided by the creators
of BertExtAbs for the internal architecture.
"""
# Load checkpoints in memory
checkpoints = torch.load(path_to_checkpoints, lambda storage, loc: storage)
# Instantiate the authors' model with the pre-trained weights
config = BertExtAbsConfig(
temp_dir=".",
finetune_bert=False,
large=False,
share_emb=True,
encoder="bert",
max_pos=512,
enc_layers=6,
enc_hidden_size=512,
enc_heads=8,
enc_ff_size=512,
enc_dropout=0.2,
dec_layers=6,
dec_hidden_size=768,
dec_heads=8,
dec_ff_size=2048,
dec_dropout=0.2,
)
bertextabs = AbsSummarizer(config, torch.device("cpu"), checkpoints)
bertextabs.eval()
# Instantiate our version of the model
decoder_config = BertConfig(
hidden_size=config.dec_hidden_size,
num_hidden_layers=config.dec_layers,
num_attention_heads=config.dec_heads,
intermediate_size=config.dec_ff_size,
hidden_dropout_prob=config.dec_dropout,
attention_probs_dropout_prob=config.dec_dropout,
is_decoder=True,
)
decoder_model = BertForMaskedLM(decoder_config)
model = Model2Model.from_pretrained('bert-base-uncased', decoder_model=decoder_model)
model.eval()
# Let us now start the weight copying process
model.encoder.load_state_dict(bertextabs.bert.model.state_dict())
# Decoder
# Embeddings. The positional embeddings are equal to the word embedding plus a modulation
# that is computed at each forward pass. This may be a source of discrepancy.
model.decoder.bert.embeddings.word_embeddings.weight = bertextabs.decoder.embeddings.weight
model.decoder.bert.embeddings.position_embeddings.weight = bertextabs.decoder.embeddings.weight
model.decoder.bert.embeddings.token_type_embeddings.weight.data = torch.zeros_like(bertextabs.decoder.embeddings.weight) # not defined for BertExtAbs decoder
# In the original code the LayerNorms are applied twice in the layers, at the beginning and between the
# attention layers.
model.decoder.bert.embeddings.LayerNorm.weight = bertextabs.decoder.transformer_layers[0].layer_norm_1.weight
for i in range(config.dec_layers):
# self attention
model.decoder.bert.encoder.layer[i].attention.self.query.weight = bertextabs.decoder.transformer_layers[i].self_attn.linear_query.weight
model.decoder.bert.encoder.layer[i].attention.self.key.weight = bertextabs.decoder.transformer_layers[i].self_attn.linear_keys.weight
model.decoder.bert.encoder.layer[i].attention.self.value.weight = bertextabs.decoder.transformer_layers[i].self_attn.linear_values.weight
model.decoder.bert.encoder.layer[i].attention.output.dense.weight = bertextabs.decoder.transformer_layers[i].self_attn.final_linear.weight
model.decoder.bert.encoder.layer[i].attention.output.LayerNorm.weight = bertextabs.decoder.transformer_layers[i].layer_norm_2.weight
# attention
model.decoder.bert.encoder.layer[i].crossattention.self.query.weight = bertextabs.decoder.transformer_layers[i].context_attn.linear_query.weight
model.decoder.bert.encoder.layer[i].crossattention.self.key.weight = bertextabs.decoder.transformer_layers[i].context_attn.linear_keys.weight
model.decoder.bert.encoder.layer[i].crossattention.self.value.weight = bertextabs.decoder.transformer_layers[i].context_attn.linear_values.weight
model.decoder.bert.encoder.layer[i].crossattention.output.dense.weight = bertextabs.decoder.transformer_layers[i].context_attn.final_linear.weight
model.decoder.bert.encoder.layer[i].crossattention.output.LayerNorm.weight = bertextabs.decoder.transformer_layers[i].feed_forward.layer_norm.weight
# intermediate
model.decoder.bert.encoder.layer[i].intermediate.dense.weight = bertextabs.decoder.transformer_layers[i].feed_forward.w_1.weight
# output
model.decoder.bert.encoder.layer[i].output.dense.weight = bertextabs.decoder.transformer_layers[i].feed_forward.w_2.weight
try:
model.decoder.bert.encoder.layer[i].output.LayerNorm.weight = bertextabs.decoder.transformer_layers[i + 1].layer_norm_1.weight
except IndexError:
model.decoder.bert.encoder.layer[i].output.LayerNorm.weight = bertextabs.decoder.layer_norm.weight
# LM Head
"""
model.decoder.cls.predictions.transform.dense.weight
model.decoder.cls.predictions.transform.dense.biais
model.decoder.cls.predictions.transform.LayerNorm.weight
model.decoder.cls.predictions.transform.LayerNorm.biais
model.decoder.cls.predictions.decoder.weight
model.decoder.cls.predictions.decoder.biais
model.decoder.cls.predictions.biais.data
"""
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--bertextabs_checkpoint_path",
default=None,
type=str,
required=True,
help="Path the official PyTorch dump.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=str,
required=True,
help="Path to the output PyTorch model.",
)
args = parser.parse_args()
convert_bertextabs_checkpoints(
args.bertextabs_checkpoint_path,
args.pytorch_dump_folder_path,
)
...@@ -25,7 +25,6 @@ Use Beam Search to generate sequences using encoder-decoder models. ...@@ -25,7 +25,6 @@ Use Beam Search to generate sequences using encoder-decoder models.
""" """
import torch import torch
from torch import nn from torch import nn
import logging import logging
...@@ -45,6 +44,7 @@ class BeamSearch(object): ...@@ -45,6 +44,7 @@ class BeamSearch(object):
max_length, max_length,
alpha=0, alpha=0,
block_repeating_trigrams=True, block_repeating_trigrams=True,
device=torch.device("cpu"),
): ):
r""" r"""
Inputs: Inputs:
...@@ -156,18 +156,24 @@ class BeamSearch(object): ...@@ -156,18 +156,24 @@ class BeamSearch(object):
kwargs_decoder["encoder_hidden_states"] = tile( kwargs_decoder["encoder_hidden_states"] = tile(
encoder_hidden_states, self.beam_size, dim=0 encoder_hidden_states, self.beam_size, dim=0
) )
kwargs_decoder["encoder_attention_mask"] = tile( try:
kwargs_encoder["attention_mask"], self.beam_size, dim=0 kwargs_decoder["encoder_attention_mask"] = tile(
kwargs_encoder["attention_mask"], self.beam_size, dim=0
)
except:
pass
kwargs_decoder["state"].src = tile(
kwargs_decoder["state"].src, self.beam_size, dim=0
) )
# grow the beam iteratively # grow the beam iteratively
batch_size, block_size = encoder_input_ids.size() batch_size, block_size = encoder_input_ids.size()
self._init_beam_state(batch_size) self._init_beam_state(batch_size)
for step in range(self.max_length): for step in range(self.max_length):
decoder_input = fit_to_block_size(self.growing_beams, block_size, self.pad_token_id) decoder_input = fit_to_block_size(self.growing_beams, block_size, self.pad_token_id)
kwargs_decoder["attention_mask"] = build_mask(decoder_input, self.pad_token_id) kwargs_decoder["attention_mask"] = build_mask(decoder_input, self.pad_token_id)
outputs = self.model.decoder(decoder_input, **kwargs_decoder)
outputs, state = self.model.decoder(decoder_input, **kwargs_decoder)
next_token_scores = outputs[0][:, -1, :].squeeze(1) next_token_scores = outputs[0][:, -1, :].squeeze(1)
log_probabilities = torch.nn.functional.log_softmax(next_token_scores, dim=0) log_probabilities = torch.nn.functional.log_softmax(next_token_scores, dim=0)
...@@ -178,9 +184,13 @@ class BeamSearch(object): ...@@ -178,9 +184,13 @@ class BeamSearch(object):
kwargs_decoder["encoder_hidden_states"] = kwargs_decoder[ kwargs_decoder["encoder_hidden_states"] = kwargs_decoder[
"encoder_hidden_states" "encoder_hidden_states"
].index_select(0, surviving_beams_rows) ].index_select(0, surviving_beams_rows)
kwargs_decoder["encoder_attention_mask"] = kwargs_decoder[ try:
"encoder_attention_mask" kwargs_decoder["encoder_attention_mask"] = kwargs_decoder[
].index_select(0, surviving_beams_rows) "encoder_attention_mask"
].index_select(0, surviving_beams_rows)
except:
pass
kwargs_decoder["state"] = state
return self.results return self.results
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment