Commit 34858ae1 authored by thomwolf's avatar thomwolf
Browse files

adding bert whole words, bertgerman and gpt-2 medium models, head masking

parent 80684f6f
...@@ -492,9 +492,12 @@ where ...@@ -492,9 +492,12 @@ where
- `bert-base-multilingual-cased`: **(New, recommended)** 104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters - `bert-base-multilingual-cased`: **(New, recommended)** 104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
- `bert-base-chinese`: Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M parameters - `bert-base-chinese`: Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M parameters
- `bert-base-german-cased`: Trained on German data only, 12-layer, 768-hidden, 12-heads, 110M parameters [Performance Evaluation](https://deepset.ai/german-bert) - `bert-base-german-cased`: Trained on German data only, 12-layer, 768-hidden, 12-heads, 110M parameters [Performance Evaluation](https://deepset.ai/german-bert)
- `openai-gpt`: OpenAI English model, 12-layer, 768-hidden, 12-heads, 110M parameters - `bert-large-uncased-whole-word-masking`: 24-layer, 1024-hidden, 16-heads, 340M parameters - Trained with Whole Word Masking (mask all of the the tokens corresponding to a word at once)
- `transfo-xl-wt103`: Transformer-XL English model trained on wikitext-103, 18-layer, 1024-hidden, 16-heads, 257M parameters - `bert-large-cased-whole-word-masking`: 24-layer, 1024-hidden, 16-heads, 340M parameters - Trained with Whole Word Masking (mask all of the the tokens corresponding to a word at once)
- `openai-gpt`: OpenAI GPT English model, 12-layer, 768-hidden, 12-heads, 110M parameters
- `gpt2`: OpenAI GPT-2 English model, 12-layer, 768-hidden, 12-heads, 117M parameters - `gpt2`: OpenAI GPT-2 English model, 12-layer, 768-hidden, 12-heads, 117M parameters
- `gpt2-medium`: OpenAI GPT-2 English model, 24-layer, 1024-hidden, 16-heads, 345M parameters
- `transfo-xl-wt103`: Transformer-XL English model trained on wikitext-103, 18-layer, 1024-hidden, 16-heads, 257M parameters
- a path or url to a pretrained model archive containing: - a path or url to a pretrained model archive containing:
......
...@@ -45,6 +45,8 @@ PRETRAINED_MODEL_ARCHIVE_MAP = { ...@@ -45,6 +45,8 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased.tar.gz", 'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased.tar.gz",
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking.tar.gz",
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking.tar.gz",
} }
BERT_CONFIG_NAME = 'bert_config.json' BERT_CONFIG_NAME = 'bert_config.json'
TF_WEIGHTS_NAME = 'model.ckpt' TF_WEIGHTS_NAME = 'model.ckpt'
...@@ -279,13 +281,16 @@ class BertEmbeddings(nn.Module): ...@@ -279,13 +281,16 @@ class BertEmbeddings(nn.Module):
class BertSelfAttention(nn.Module): class BertSelfAttention(nn.Module):
def __init__(self, config, output_attentions=False): def __init__(self, config, output_attentions=False, keep_multihead_output=False):
super(BertSelfAttention, self).__init__() super(BertSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0: if config.hidden_size % config.num_attention_heads != 0:
raise ValueError( raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention " "The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)) "heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.output_attentions = output_attentions self.output_attentions = output_attentions
self.keep_multihead_output = keep_multihead_output
self.multihead_output = None
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size
...@@ -301,7 +306,7 @@ class BertSelfAttention(nn.Module): ...@@ -301,7 +306,7 @@ class BertSelfAttention(nn.Module):
x = x.view(*new_x_shape) x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask): def forward(self, hidden_states, attention_mask, head_mask=None):
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states) mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states) mixed_value_layer = self.value(hidden_states)
...@@ -323,7 +328,20 @@ class BertSelfAttention(nn.Module): ...@@ -323,7 +328,20 @@ class BertSelfAttention(nn.Module):
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs) attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
# attention_probs has shape bsz x n_heads x N x N
if head_mask is not None:
if head_mask.dim() == 1:
head_mask.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
elif head_mask.dim() == 2:
head_mask.unsqueeze(-1).unsqueeze(-1) # We can define heads to mask for each instance in the batch
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer) context_layer = torch.matmul(attention_probs, value_layer)
if self.keep_multihead_output:
self.multihead_output = context_layer
self.multihead_output.retain_grad()
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(*new_context_layer_shape)
...@@ -353,8 +371,8 @@ class BertAttention(nn.Module): ...@@ -353,8 +371,8 @@ class BertAttention(nn.Module):
self.self = BertSelfAttention(config, output_attentions=output_attentions) self.self = BertSelfAttention(config, output_attentions=output_attentions)
self.output = BertSelfOutput(config) self.output = BertSelfOutput(config)
def forward(self, input_tensor, attention_mask): def forward(self, input_tensor, attention_mask, head_mask=None):
self_output = self.self(input_tensor, attention_mask) self_output = self.self(input_tensor, attention_mask, head_mask)
if self.output_attentions: if self.output_attentions:
attentions, self_output = self_output attentions, self_output = self_output
attention_output = self.output(self_output, input_tensor) attention_output = self.output(self_output, input_tensor)
...@@ -400,8 +418,8 @@ class BertLayer(nn.Module): ...@@ -400,8 +418,8 @@ class BertLayer(nn.Module):
self.intermediate = BertIntermediate(config) self.intermediate = BertIntermediate(config)
self.output = BertOutput(config) self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask): def forward(self, hidden_states, attention_mask, head_mask=None):
attention_output = self.attention(hidden_states, attention_mask) attention_output = self.attention(hidden_states, attention_mask, head_mask)
if self.output_attentions: if self.output_attentions:
attentions, attention_output = attention_output attentions, attention_output = attention_output
intermediate_output = self.intermediate(attention_output) intermediate_output = self.intermediate(attention_output)
...@@ -418,11 +436,11 @@ class BertEncoder(nn.Module): ...@@ -418,11 +436,11 @@ class BertEncoder(nn.Module):
layer = BertLayer(config, output_attentions=output_attentions) layer = BertLayer(config, output_attentions=output_attentions)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, head_mask=None):
all_encoder_layers = [] all_encoder_layers = []
all_attentions = [] all_attentions = []
for layer_module in self.layer: for layer_module in self.layer:
hidden_states = layer_module(hidden_states, attention_mask) hidden_states = layer_module(hidden_states, attention_mask, head_mask)
if self.output_attentions: if self.output_attentions:
attentions, hidden_states = hidden_states attentions, hidden_states = hidden_states
all_attentions.append(attentions) all_attentions.append(attentions)
...@@ -731,7 +749,7 @@ class BertModel(BertPreTrainedModel): ...@@ -731,7 +749,7 @@ class BertModel(BertPreTrainedModel):
self.pooler = BertPooler(config) self.pooler = BertPooler(config)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, head_mask=None):
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
if token_type_ids is None: if token_type_ids is None:
...@@ -755,7 +773,8 @@ class BertModel(BertPreTrainedModel): ...@@ -755,7 +773,8 @@ class BertModel(BertPreTrainedModel):
embedding_output = self.embeddings(input_ids, token_type_ids) embedding_output = self.embeddings(input_ids, token_type_ids)
encoded_layers = self.encoder(embedding_output, encoded_layers = self.encoder(embedding_output,
extended_attention_mask, extended_attention_mask,
output_all_encoded_layers=output_all_encoded_layers) output_all_encoded_layers=output_all_encoded_layers,
head_mask=head_mask)
if self.output_attentions: if self.output_attentions:
all_attentions, encoded_layers = encoded_layers all_attentions, encoded_layers = encoded_layers
sequence_output = encoded_layers[-1] sequence_output = encoded_layers[-1]
...@@ -824,9 +843,9 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -824,9 +843,9 @@ class BertForPreTraining(BertPreTrainedModel):
self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight) self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None, head_mask=None):
outputs = self.bert(input_ids, token_type_ids, attention_mask, outputs = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False) output_all_encoded_layers=False, head_mask=head_mask)
if self.output_attentions: if self.output_attentions:
all_attentions, sequence_output, pooled_output = outputs all_attentions, sequence_output, pooled_output = outputs
else: else:
...@@ -893,9 +912,10 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -893,9 +912,10 @@ class BertForMaskedLM(BertPreTrainedModel):
self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None):
outputs = self.bert(input_ids, token_type_ids, attention_mask, outputs = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False) output_all_encoded_layers=False,
head_mask=head_mask)
if self.output_attentions: if self.output_attentions:
all_attentions, sequence_output, _ = outputs all_attentions, sequence_output, _ = outputs
else: else:
...@@ -961,9 +981,10 @@ class BertForNextSentencePrediction(BertPreTrainedModel): ...@@ -961,9 +981,10 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
self.cls = BertOnlyNSPHead(config) self.cls = BertOnlyNSPHead(config)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None, head_mask=None):
outputs = self.bert(input_ids, token_type_ids, attention_mask, outputs = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False) output_all_encoded_layers=False,
head_mask=head_mask)
if self.output_attentions: if self.output_attentions:
all_attentions, _, pooled_output = outputs all_attentions, _, pooled_output = outputs
else: else:
...@@ -1033,8 +1054,8 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -1033,8 +1054,8 @@ class BertForSequenceClassification(BertPreTrainedModel):
self.classifier = nn.Linear(config.hidden_size, num_labels) self.classifier = nn.Linear(config.hidden_size, num_labels)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, head_mask=head_mask)
if self.output_attentions: if self.output_attentions:
all_attentions, _, pooled_output = outputs all_attentions, _, pooled_output = outputs
else: else:
...@@ -1104,11 +1125,11 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1104,11 +1125,11 @@ class BertForMultipleChoice(BertPreTrainedModel):
self.classifier = nn.Linear(config.hidden_size, 1) self.classifier = nn.Linear(config.hidden_size, 1)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) flat_input_ids = input_ids.view(-1, input_ids.size(-1))
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
outputs = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False) outputs = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False, head_mask=head_mask)
if self.output_attentions: if self.output_attentions:
all_attentions, _, pooled_output = outputs all_attentions, _, pooled_output = outputs
else: else:
...@@ -1180,8 +1201,8 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1180,8 +1201,8 @@ class BertForTokenClassification(BertPreTrainedModel):
self.classifier = nn.Linear(config.hidden_size, num_labels) self.classifier = nn.Linear(config.hidden_size, num_labels)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, head_mask=head_mask)
if self.output_attentions: if self.output_attentions:
all_attentions, sequence_output, _ = outputs all_attentions, sequence_output, _ = outputs
else: else:
...@@ -1259,8 +1280,10 @@ class BertForQuestionAnswering(BertPreTrainedModel): ...@@ -1259,8 +1280,10 @@ class BertForQuestionAnswering(BertPreTrainedModel):
self.qa_outputs = nn.Linear(config.hidden_size, 2) self.qa_outputs = nn.Linear(config.hidden_size, 2)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None, head_mask=None):
outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) outputs = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False,
head_mask=head_mask)
if self.output_attentions: if self.output_attentions:
all_attentions, sequence_output, _ = outputs all_attentions, sequence_output, _ = outputs
else: else:
......
...@@ -35,6 +35,8 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = { ...@@ -35,6 +35,8 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = {
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt", 'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt",
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt",
} }
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'bert-base-uncased': 512, 'bert-base-uncased': 512,
...@@ -45,6 +47,8 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { ...@@ -45,6 +47,8 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'bert-base-multilingual-cased': 512, 'bert-base-multilingual-cased': 512,
'bert-base-chinese': 512, 'bert-base-chinese': 512,
'bert-base-german-cased': 512, 'bert-base-german-cased': 512,
'bert-large-uncased-whole-word-masking': 512,
'bert-large-cased-whole-word-masking': 512,
} }
VOCAB_NAME = 'vocab.txt' VOCAB_NAME = 'vocab.txt'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment