Commit a1994a71 authored by Rémi Louf's avatar Rémi Louf Committed by Julien Chaumond
Browse files

simplified model and configuration

parent 3a9a9f78
...@@ -33,17 +33,6 @@ class BertAbsConfig(PretrainedConfig): ...@@ -33,17 +33,6 @@ class BertAbsConfig(PretrainedConfig):
r""" Class to store the configuration of the BertAbs model. r""" Class to store the configuration of the BertAbs model.
Arguments: Arguments:
temp_dir: string
Unused in the current situation. Kept for compatibility but will be removed.
finetune_bert: bool
Whether to fine-tune the model or not. Will be kept for reference
in case we want to add the possibility to fine-tune the model.
large: bool
Whether to use bert-large as a base.
share_emb: book
Whether the embeddings are shared between the encoder and decoder.
encoder: string
Not clear what this does. Leave to "bert" for pre-trained weights.
max_pos: int max_pos: int
The maximum sequence length that this model will be used with. The maximum sequence length that this model will be used with.
enc_layer: int enc_layer: int
...@@ -77,11 +66,6 @@ class BertAbsConfig(PretrainedConfig): ...@@ -77,11 +66,6 @@ class BertAbsConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
vocab_size_or_config_json_file=30522, vocab_size_or_config_json_file=30522,
temp_dir=".",
finetune_bert=False,
large=False,
share_emb=True,
encoder="bert",
max_pos=512, max_pos=512,
enc_layers=6, enc_layers=6,
enc_hidden_size=512, enc_hidden_size=512,
...@@ -104,21 +88,15 @@ class BertAbsConfig(PretrainedConfig): ...@@ -104,21 +88,15 @@ class BertAbsConfig(PretrainedConfig):
for key, value in json_config.items(): for key, value in json_config.items():
self.__dict__[key] = value self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int): elif isinstance(vocab_size_or_config_json_file, int):
self.temp_dir = temp_dir
self.finetune_bert = finetune_bert
self.large = large
self.vocab_size = vocab_size_or_config_json_file self.vocab_size = vocab_size_or_config_json_file
self.max_pos = max_pos self.max_pos = max_pos
self.encoder = encoder
self.enc_layers = enc_layers self.enc_layers = enc_layers
self.enc_hidden_size = enc_hidden_size self.enc_hidden_size = enc_hidden_size
self.enc_heads = enc_heads self.enc_heads = enc_heads
self.enc_ff_size = enc_ff_size self.enc_ff_size = enc_ff_size
self.enc_dropout = enc_dropout self.enc_dropout = enc_dropout
self.share_emb = share_emb
self.dec_layers = dec_layers self.dec_layers = dec_layers
self.dec_hidden_size = dec_hidden_size self.dec_hidden_size = dec_hidden_size
self.dec_heads = dec_heads self.dec_heads = dec_heads
......
...@@ -53,7 +53,7 @@ class BertAbs(BertAbsPreTrainedModel): ...@@ -53,7 +53,7 @@ class BertAbs(BertAbsPreTrainedModel):
def __init__(self, args, checkpoint=None, bert_extractive_checkpoint=None): def __init__(self, args, checkpoint=None, bert_extractive_checkpoint=None):
super(BertAbs, self).__init__(args) super(BertAbs, self).__init__(args)
self.args = args self.args = args
self.bert = Bert(args.large, args.temp_dir, args.finetune_bert) self.bert = Bert()
# If pre-trained weights are passed for Bert, load these. # If pre-trained weights are passed for Bert, load these.
load_bert_pretrained_extractive = True if bert_extractive_checkpoint else False load_bert_pretrained_extractive = True if bert_extractive_checkpoint else False
...@@ -69,18 +69,6 @@ class BertAbs(BertAbsPreTrainedModel): ...@@ -69,18 +69,6 @@ class BertAbs(BertAbsPreTrainedModel):
strict=True, strict=True,
) )
if args.encoder == "baseline":
bert_config = BertConfig(
self.bert.model.config.vocab_size,
hidden_size=args.enc_hidden_size,
num_hidden_layers=args.enc_layers,
num_attention_heads=8,
intermediate_size=args.enc_ff_size,
hidden_dropout_prob=args.enc_dropout,
attention_probs_dropout_prob=args.enc_dropout,
)
self.bert.model = BertModel(bert_config)
self.vocab_size = self.bert.model.config.vocab_size self.vocab_size = self.bert.model.config.vocab_size
if args.max_pos > 512: if args.max_pos > 512:
...@@ -101,10 +89,10 @@ class BertAbs(BertAbsPreTrainedModel): ...@@ -101,10 +89,10 @@ class BertAbs(BertAbsPreTrainedModel):
tgt_embeddings = nn.Embedding( tgt_embeddings = nn.Embedding(
self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0 self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0
) )
if self.args.share_emb:
tgt_embeddings.weight = copy.deepcopy( tgt_embeddings.weight = copy.deepcopy(
self.bert.model.embeddings.word_embeddings.weight self.bert.model.embeddings.word_embeddings.weight
) )
self.decoder = TransformerDecoder( self.decoder = TransformerDecoder(
self.args.dec_layers, self.args.dec_layers,
...@@ -141,16 +129,6 @@ class BertAbs(BertAbsPreTrainedModel): ...@@ -141,16 +129,6 @@ class BertAbs(BertAbsPreTrainedModel):
else: else:
p.data.zero_() p.data.zero_()
def maybe_tie_embeddings(self, args):
if args.use_bert_emb:
tgt_embeddings = nn.Embedding(
self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0
)
tgt_embeddings.weight = copy.deepcopy(
self.bert.model.embeddings.word_embeddings.weight
)
self.decoder.embeddings = tgt_embeddings
def forward( def forward(
self, self,
encoder_input_ids, encoder_input_ids,
...@@ -178,14 +156,9 @@ class Bert(nn.Module): ...@@ -178,14 +156,9 @@ class Bert(nn.Module):
""" This class is not really necessary and should probably disappear. """ This class is not really necessary and should probably disappear.
""" """
def __init__(self, large, temp_dir, finetune=False): def __init__(self):
super(Bert, self).__init__() super(Bert, self).__init__()
if large: self.model = BertModel.from_pretrained("bert-base-uncased")
self.model = BertModel.from_pretrained("bert-large-uncased", cache_dir=temp_dir)
else:
self.model = BertModel.from_pretrained("bert-base-uncased", cache_dir=temp_dir)
self.finetune = finetune
def forward(self, input_ids, attention_mask=None, token_type_ids=None, **kwargs): def forward(self, input_ids, attention_mask=None, token_type_ids=None, **kwargs):
self.eval() self.eval()
......
...@@ -31,9 +31,9 @@ Batch = namedtuple( ...@@ -31,9 +31,9 @@ Batch = namedtuple(
def evaluate(args): def evaluate(args):
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True) tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)
model = bertabs = BertAbs.from_pretrained("bertabs-finetuned-cnndm") model = BertAbs.from_pretrained("bertabs-finetuned-cnndm")
bertabs.to(args.device) model.to(args.device)
bertabs.eval() model.eval()
symbols = { symbols = {
"BOS": tokenizer.vocab["[unused0]"], "BOS": tokenizer.vocab["[unused0]"],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment