"tests/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "978dec9014667f394ab11f79dfc54a9c9a7290c7"
Commit f7eba090 authored by Rémi Louf's avatar Rémi Louf Committed by Julien Chaumond
Browse files

clean for release

parent 2a64107e
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Convert BertExtAbs's checkpoints """
import argparse
from collections import namedtuple
import logging
import pdb
import torch
from models.model_builder import AbsSummarizer # The authors' implementation
from model_bertabs import BertAbsSummarizer
from transformers import BertTokenizer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
SAMPLE_TEXT = 'Hello world! cécé herlolip'
BertAbsConfig = namedtuple(
"BertAbsConfig",
["temp_dir", "large", "use_bert_emb", "finetune_bert", "encoder", "share_emb", "max_pos", "enc_layers", "enc_hidden_size", "enc_heads", "enc_ff_size", "enc_dropout", "dec_layers", "dec_hidden_size", "dec_heads", "dec_ff_size", "dec_dropout"],
)
def convert_bertabs_checkpoints(path_to_checkpoints, dump_path):
""" Copy/paste and tweak the pre-trained weights provided by the creators
of BertAbs for the internal architecture.
"""
# Instantiate the authors' model with the pre-trained weights
config = BertAbsConfig(
temp_dir=".",
finetune_bert=False,
large=False,
share_emb=True,
use_bert_emb=False,
encoder="bert",
max_pos=512,
enc_layers=6,
enc_hidden_size=512,
enc_heads=8,
enc_ff_size=512,
enc_dropout=0.2,
dec_layers=6,
dec_hidden_size=768,
dec_heads=8,
dec_ff_size=2048,
dec_dropout=0.2,
)
checkpoints = torch.load(path_to_checkpoints, lambda storage, loc: storage)
original = AbsSummarizer(config, torch.device("cpu"), checkpoints)
original.eval()
new_model = BertAbsSummarizer(config, torch.device("cpu"))
new_model.eval()
# -------------------
# Convert the weights
# -------------------
logging.info("convert the model")
new_model.encoder.load_state_dict(original.bert.state_dict())
new_model.decoder.generator.load_state_dict(original.generator.state_dict())
new_model.decoder.embeddings.load_state_dict(original.decoder.embeddings.state_dict())
new_model.decoder.pos_emb.load_state_dict(original.decoder.pos_emb.state_dict())
new_model.decoder.transformer_layers.load_state_dict(original.decoder.transformer_layers.state_dict())
new_model.decoder.layer_norm.load_state_dict(original.decoder.layer_norm.state_dict())
# ----------------------------------
# Make sure the outpus are identical
# ----------------------------------
logging.info("Make sure that the models' outputs are identical")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# prepare the model inputs
encoder_input_ids = tokenizer.encode("This is sample éàalj'-.")
encoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(encoder_input_ids)))
encoder_input_ids = torch.tensor(encoder_input_ids).unsqueeze(0)
decoder_input_ids = tokenizer.encode("This is sample 3 éàalj'-.")
decoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(decoder_input_ids)))
decoder_input_ids = torch.tensor(decoder_input_ids).unsqueeze(0)
# failsafe to make sure the weights reset does not affect the
# loaded weights.
assert torch.max(torch.abs(original.generator[0].weight - new_model.decoder.generator[0].weight)) == 0
# forward pass
src = encoder_input_ids
tgt = decoder_input_ids
segs = token_type_ids = None
clss = None
mask_src = encoder_attention_mask = None
mask_tgt = decoder_attention_mask = None
mask_cls = None
# The original model does not apply the geneator layer immediatly but rather in
# the beam search (where it combines softmax + linear layer). Since we already
# apply the softmax in our generation process we only apply the linear layer here.
# We make sure that the outputs of the full stack are identical
output_original_model = original(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)[0]
output_original_model = original.generator(output_original_model)
output_converted_model = new_model(encoder_input_ids, decoder_input_ids, token_type_ids, encoder_attention_mask, decoder_attention_mask)[0]
output_converted_model = torch.nn.functional.log_softmax(output_converted_model, dim=-1)
maximum_absolute_difference = torch.max(torch.abs(output_converted_model - output_original_model)).item()
print("Maximum absolute difference beween weights: {:.2f}".format(maximum_absolute_difference))
are_identical = torch.allclose(output_converted_model, output_original_model, atol=1e-3)
if are_identical:
logging.info("all weights are equal up to 1e-3")
else:
raise ValueError("the weights are different. The new model is likely different from the original one.")
# The model has been saved with torch.save(model) and this is bound to the exact
# directory structure. We save the state_dict instead.
logging.info("saving the model's state dictionary")
torch.save(new_model.state_dict(), "bert-ext-abs.pt")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--bertabs_checkpoint_path",
default=None,
type=str,
required=True,
help="Path the official PyTorch dump.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=str,
required=True,
help="Path to the output PyTorch model.",
)
args = parser.parse_args()
convert_bertabs_checkpoints(
args.bertabs_checkpoint_path,
args.pytorch_dump_folder_path,
)
# MIT License # MIT License
# Copyright (c) 2019 Yang Liu # Copyright (c) 2019 Yang Liu and the HuggingFace team
# Permission is hereby granted, free of charge, to any person obtaining a copy # Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal # of this software and associated documentation files (the "Software"), to deal
......
# progress bars in model download and training scripts
tqdm
# Accessing files from S3 directly.
boto3
# Used for downloading models over HTTP
requests
# For ROUGE
nltk
py-rouge
#! /usr/bin/python3
import argparse import argparse
from collections import namedtuple from collections import namedtuple
import logging import logging
...@@ -97,6 +98,32 @@ def evaluate(args): ...@@ -97,6 +98,32 @@ def evaluate(args):
print(str_scores) print(str_scores)
def save_summaries(summaries, path, original_document_name):
""" Write the summaries in fies that are prefixed by the original
files' name with the `_summary` appended.
Attributes:
original_document_names: List[string]
Name of the document that was summarized.
path: string
Path were the summaries will be written
summaries: List[string]
The summaries that we produced.
"""
for summary, document_name in zip(summaries, original_document_name):
# Prepare the summary file's name
if "." in document_name:
bare_document_name = ".".join(document_name.split(".")[:-1])
extension = document_name.split(".")[-1]
name = bare_document_name + "_summary." + extension
else:
name = document_name + "_summary"
file_path = os.path.join(path, name)
with open(file_path, "w") as output:
output.write(summary)
def format_summary(translation): def format_summary(translation):
""" Transforms the output of the `from_batch` function """ Transforms the output of the `from_batch` function
into nicely formatted summaries. into nicely formatted summaries.
...@@ -151,32 +178,6 @@ def save_rouge_scores(str_scores): ...@@ -151,32 +178,6 @@ def save_rouge_scores(str_scores):
output.write(str_scores) output.write(str_scores)
def save_summaries(summaries, path, original_document_name):
""" Write the summaries in fies that are prefixed by the original
files' name with the `_summary` appended.
Attributes:
original_document_names: List[string]
Name of the document that was summarized.
path: string
Path were the summaries will be written
summaries: List[string]
The summaries that we produced.
"""
for summary, document_name in zip(summaries, original_document_name):
# Prepare the summary file's name
if "." in document_name:
bare_document_name = ".".join(document_name.split(".")[:-1])
extension = document_name.split(".")[-1]
name = bare_document_name + "_summary." + extension
else:
name = document_name + "_summary"
file_path = os.path.join(path, name)
with open(file_path, "w") as output:
output.write(summary)
# #
# LOAD the dataset # LOAD the dataset
# #
...@@ -323,7 +324,7 @@ def main(): ...@@ -323,7 +324,7 @@ def main():
raise FileNotFoundError( raise FileNotFoundError(
"We could not find the directory you specified for the documents to summarize, or it was empty. Please specify a valid path." "We could not find the directory you specified for the documents to summarize, or it was empty. Please specify a valid path."
) )
maybe_create_output_dir(args.summaries_output_dir) os.makedirs(args.summaries_output_dir, exist_ok=True)
evaluate(args) evaluate(args)
...@@ -339,10 +340,5 @@ def documents_dir_is_valid(path): ...@@ -339,10 +340,5 @@ def documents_dir_is_valid(path):
return True return True
def maybe_create_output_dir(path):
if not os.path.exists(path):
os.makedirs(path)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -10,6 +10,3 @@ regex ...@@ -10,6 +10,3 @@ regex
sentencepiece sentencepiece
# For XLM # For XLM
sacremoses sacremoses
# For ROUGE
nltk
py-rouge
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Convert BertExtAbs's checkpoints """
import argparse
from collections import namedtuple
import logging
import torch
from models.model_builder import AbsSummarizer # The authors' implementation
from transformers import BertConfig, Model2Model, BertModel, BertForMaskedLM
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
BertExtAbsConfig = namedtuple(
"BertExtAbsConfig",
["temp_dir", "large", "finetune_bert", "encoder", "share_emb", "max_pos", "enc_layers", "enc_hidden_size", "enc_heads", "enc_ff_size", "enc_dropout", "dec_layers", "dec_hidden_size", "dec_heads", "dec_ff_size", "dec_dropout"],
)
def convert_bertextabs_checkpoints(path_to_checkpoints, dump_path):
""" Copy/paste and tweak the pre-trained weights provided by the creators
of BertExtAbs for the internal architecture.
"""
# Load checkpoints in memory
checkpoints = torch.load(path_to_checkpoints, lambda storage, loc: storage)
# Instantiate the authors' model with the pre-trained weights
config = BertExtAbsConfig(
temp_dir=".",
finetune_bert=False,
large=False,
share_emb=True,
encoder="bert",
max_pos=512,
enc_layers=6,
enc_hidden_size=512,
enc_heads=8,
enc_ff_size=512,
enc_dropout=0.2,
dec_layers=6,
dec_hidden_size=768,
dec_heads=8,
dec_ff_size=2048,
dec_dropout=0.2,
)
bertextabs = AbsSummarizer(config, torch.device("cpu"), checkpoints)
bertextabs.eval()
# Instantiate our version of the model
decoder_config = BertConfig(
hidden_size=config.dec_hidden_size,
num_hidden_layers=config.dec_layers,
num_attention_heads=config.dec_heads,
intermediate_size=config.dec_ff_size,
hidden_dropout_prob=config.dec_dropout,
attention_probs_dropout_prob=config.dec_dropout,
is_decoder=True,
)
decoder_model = BertForMaskedLM(decoder_config)
model = Model2Model.from_pretrained('bert-base-uncased', decoder_model=decoder_model)
model.eval()
# Let us now start the weight copying process
model.encoder.load_state_dict(bertextabs.bert.model.state_dict())
# Decoder
# Embeddings. The positional embeddings are equal to the word embedding plus a modulation
# that is computed at each forward pass. This may be a source of discrepancy.
model.decoder.bert.embeddings.word_embeddings.weight = bertextabs.decoder.embeddings.weight
model.decoder.bert.embeddings.position_embeddings.weight = bertextabs.decoder.embeddings.weight
model.decoder.bert.embeddings.token_type_embeddings.weight.data = torch.zeros_like(bertextabs.decoder.embeddings.weight) # not defined for BertExtAbs decoder
# In the original code the LayerNorms are applied twice in the layers, at the beginning and between the
# attention layers.
model.decoder.bert.embeddings.LayerNorm.weight = bertextabs.decoder.transformer_layers[0].layer_norm_1.weight
for i in range(config.dec_layers):
# self attention
model.decoder.bert.encoder.layer[i].attention.self.query.weight = bertextabs.decoder.transformer_layers[i].self_attn.linear_query.weight
model.decoder.bert.encoder.layer[i].attention.self.key.weight = bertextabs.decoder.transformer_layers[i].self_attn.linear_keys.weight
model.decoder.bert.encoder.layer[i].attention.self.value.weight = bertextabs.decoder.transformer_layers[i].self_attn.linear_values.weight
model.decoder.bert.encoder.layer[i].attention.output.dense.weight = bertextabs.decoder.transformer_layers[i].self_attn.final_linear.weight
model.decoder.bert.encoder.layer[i].attention.output.LayerNorm.weight = bertextabs.decoder.transformer_layers[i].layer_norm_2.weight
# attention
model.decoder.bert.encoder.layer[i].crossattention.self.query.weight = bertextabs.decoder.transformer_layers[i].context_attn.linear_query.weight
model.decoder.bert.encoder.layer[i].crossattention.self.key.weight = bertextabs.decoder.transformer_layers[i].context_attn.linear_keys.weight
model.decoder.bert.encoder.layer[i].crossattention.self.value.weight = bertextabs.decoder.transformer_layers[i].context_attn.linear_values.weight
model.decoder.bert.encoder.layer[i].crossattention.output.dense.weight = bertextabs.decoder.transformer_layers[i].context_attn.final_linear.weight
model.decoder.bert.encoder.layer[i].crossattention.output.LayerNorm.weight = bertextabs.decoder.transformer_layers[i].feed_forward.layer_norm.weight
# intermediate
model.decoder.bert.encoder.layer[i].intermediate.dense.weight = bertextabs.decoder.transformer_layers[i].feed_forward.w_1.weight
# output
model.decoder.bert.encoder.layer[i].output.dense.weight = bertextabs.decoder.transformer_layers[i].feed_forward.w_2.weight
try:
model.decoder.bert.encoder.layer[i].output.LayerNorm.weight = bertextabs.decoder.transformer_layers[i + 1].layer_norm_1.weight
except IndexError:
model.decoder.bert.encoder.layer[i].output.LayerNorm.weight = bertextabs.decoder.layer_norm.weight
# LM Head
"""
model.decoder.cls.predictions.transform.dense.weight
model.decoder.cls.predictions.transform.dense.biais
model.decoder.cls.predictions.transform.LayerNorm.weight
model.decoder.cls.predictions.transform.LayerNorm.biais
model.decoder.cls.predictions.decoder.weight
model.decoder.cls.predictions.decoder.biais
model.decoder.cls.predictions.biais.data
"""
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--bertextabs_checkpoint_path",
default=None,
type=str,
required=True,
help="Path the official PyTorch dump.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=str,
required=True,
help="Path to the output PyTorch model.",
)
args = parser.parse_args()
convert_bertextabs_checkpoints(
args.bertextabs_checkpoint_path,
args.pytorch_dump_folder_path,
)
from .beam_search import BeamSearch
...@@ -117,7 +117,8 @@ class PreTrainedEncoderDecoder(nn.Module): ...@@ -117,7 +117,8 @@ class PreTrainedEncoderDecoder(nn.Module):
kwargs_common = { kwargs_common = {
argument: value argument: value
for argument, value in kwargs.items() for argument, value in kwargs.items()
if not argument.startswith("encoder_") and not argument.startswith("decoder_") if not argument.startswith("encoder_")
and not argument.startswith("decoder_")
} }
kwargs_decoder = kwargs_common.copy() kwargs_decoder = kwargs_common.copy()
kwargs_encoder = kwargs_common.copy() kwargs_encoder = kwargs_common.copy()
...@@ -157,27 +158,14 @@ class PreTrainedEncoderDecoder(nn.Module): ...@@ -157,27 +158,14 @@ class PreTrainedEncoderDecoder(nn.Module):
return model return model
def save_pretrained(self, save_directory, model_type="bert"): def save_pretrained(self, save_directory):
""" Save an EncoderDecoder model and its configuration file in a format such """ Save a Seq2Seq model and its configuration file in a format such
that it can be loaded using `:func:`~transformers.PreTrainedEncoderDecoder.from_pretrained` that it can be loaded using `:func:`~transformers.PreTrainedEncoderDecoder.from_pretrained`
We save the encoder' and decoder's parameters in two separate directories. We save the encoder' and decoder's parameters in two separate directories.
If we want the weight loader to function we need to preprend the model
type to the directories' names. As far as I know there is no simple way
to infer the type of the model (except maybe by parsing the class'
names, which is not very future-proof). For now, we ask the user to
specify the model type explicitly when saving the weights.
""" """
encoder_path = os.path.join(save_directory, "{}_encoder".format(model_type)) self.encoder.save_pretrained(os.path.join(save_directory, "encoder"))
if not os.path.exists(encoder_path): self.decoder.save_pretrained(os.path.join(save_directory, "decoder"))
os.makedirs(encoder_path)
self.encoder.save_pretrained(encoder_path)
decoder_path = os.path.join(save_directory, "{}_decoder".format(model_type))
if not os.path.exists(decoder_path):
os.makedirs(decoder_path)
self.decoder.save_pretrained(decoder_path)
def forward(self, encoder_input_ids, decoder_input_ids, **kwargs): def forward(self, encoder_input_ids, decoder_input_ids, **kwargs):
""" The forward pass on a seq2eq depends what we are performing: """ The forward pass on a seq2eq depends what we are performing:
...@@ -205,7 +193,8 @@ class PreTrainedEncoderDecoder(nn.Module): ...@@ -205,7 +193,8 @@ class PreTrainedEncoderDecoder(nn.Module):
kwargs_common = { kwargs_common = {
argument: value argument: value
for argument, value in kwargs.items() for argument, value in kwargs.items()
if not argument.startswith("encoder_") and not argument.startswith("decoder_") if not argument.startswith("encoder_")
and not argument.startswith("decoder_")
} }
kwargs_decoder = kwargs_common.copy() kwargs_decoder = kwargs_common.copy()
kwargs_encoder = kwargs_common.copy() kwargs_encoder = kwargs_common.copy()
...@@ -228,7 +217,9 @@ class PreTrainedEncoderDecoder(nn.Module): ...@@ -228,7 +217,9 @@ class PreTrainedEncoderDecoder(nn.Module):
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None) encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder) encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
encoder_hidden_states = encoder_outputs[0] # output the last layer hidden state encoder_hidden_states = encoder_outputs[
0
] # output the last layer hidden state
else: else:
encoder_outputs = () encoder_outputs = ()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment