Commit 5456d823 authored by thomwolf's avatar thomwolf
Browse files

more versatile model loading

parent 9b2540b5
......@@ -606,7 +606,9 @@ class BertPreTrainedModel(nn.Module):
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
start_prefix = 'bert.' if not hasattr(model, 'bert') and any(s.startwith('bert.') for s in state_dict.keys()) else ''
start_prefix = ''
if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()):
start_prefix = 'bert.'
load(model, prefix=start_prefix)
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(
......
......@@ -120,6 +120,7 @@ class OpenAIGPTConfig(object):
self,
vocab_size_or_config_json_file=40478,
n_special=0,
n_positions=512,
n_ctx=512,
n_embd=768,
n_layer=12,
......@@ -135,7 +136,8 @@ class OpenAIGPTConfig(object):
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `OpenAIGPTModel` or a configuration json file.
n_special: The number of special tokens to learn during fine-tuning ('[SEP]', '[CLF]', ...)
n_ctx: Number of positional embeddings.
n_positions: Number of positional embeddings.
n_ctx: Size of the causal mask (usually same as n_positions).
n_embd: Dimensionality of the embeddings and hidden states.
n_layer: Number of hidden layers in the Transformer encoder.
n_head: Number of attention heads for each attention layer in
......@@ -159,6 +161,7 @@ class OpenAIGPTConfig(object):
self.vocab_size = vocab_size_or_config_json_file
self.n_special = n_special
self.n_ctx = n_ctx
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
......@@ -175,7 +178,7 @@ class OpenAIGPTConfig(object):
@property
def total_num_embeddings(self):
return self.vocab_size + self.n_special + self.n_ctx
return self.vocab_size + self.n_special + self.n_positions
@classmethod
def from_dict(cls, json_object):
......@@ -234,7 +237,7 @@ class Attention(nn.Module):
n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
assert n_state % config.n_head == 0
self.register_buffer("b", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
self.n_head = config.n_head
self.split_size = n_state
self.scale = scale
......@@ -247,9 +250,9 @@ class Attention(nn.Module):
w = torch.matmul(q, k)
if self.scale:
w = w / math.sqrt(v.size(-1))
# w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights
# w = w * self.bias + -1e9 * (1 - self.bias) # TF implem method: mask_attn_weights
# XD: self.b may be larger than w, so we need to crop it
b = self.b[:, :, : w.size(-2), : w.size(-1)]
b = self.bias[:, :, : w.size(-2), : w.size(-1)]
w = w * b + -1e9 * (1 - b)
w = nn.Softmax(dim=-1)(w)
......@@ -474,10 +477,12 @@ class OpenAIGPTPreTrainedModel(nn.Module):
new_keys = []
for key in state_dict.keys():
new_key = None
if "gamma" in key:
new_key = key.replace("gamma", "weight")
if "beta" in key:
new_key = key.replace("beta", "bias")
if key.endswith(".g"):
new_key = key[:-2] + ".weight"
elif key.endswith(".b"):
new_key = key[:-2] + ".bias"
elif key.endswith(".w"):
new_key = key[:-2] + ".weight"
if new_key:
old_keys.append(key)
new_keys.append(new_key)
......@@ -502,7 +507,8 @@ class OpenAIGPTPreTrainedModel(nn.Module):
if child is not None:
load(child, prefix + name + ".")
if hasattr(model, "transformer") and all(not s.startwith('transformer.') for s in state_dict.keys()):
start_model = model
if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()):
start_model = model.transformer
load(start_model, prefix="")
......@@ -541,7 +547,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
total_num_embeddings - 1] ______________________
where total_num_embeddings can be obtained as config.total_num_embeddings and is:
total_num_embeddings = config.vocab_size + config.n_special + config.n_ctx
total_num_embeddings = config.vocab_size + config.n_special + config.n_positions
You should use the associate indices to index the embeddings.
The special embeddings ([SEP], [CLS]...) are not pre-trained and need to be trained during the fine-tuning if you use them.
......@@ -554,7 +560,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
with the position indices (selected in the range [config.vocab_size + config.n_special, config.vocab_size + config.n_special + config.n_ctx - 1[.
with the position indices (selected in the range [config.vocab_size + config.n_special, config.vocab_size + config.n_special + config.n_positions - 1[.
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
You can use it to add a third embedding (the previous two being the word and position embeddings)
to each token in the sentence.
......@@ -578,7 +584,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
def __init__(self, config):
super(OpenAIGPTModel, self).__init__(config)
total_embeddings_size = config.vocab_size + config.n_special + config.n_ctx
total_embeddings_size = config.vocab_size + config.n_special + config.n_positions
self.embed = nn.Embedding(total_embeddings_size, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop)
block = Block(config.n_ctx, config, scale=True)
......@@ -598,7 +604,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
self.init_weights(self.embed)
# Copy word and positional embeddings from the previous weights
self.embed.weight.data[: self.config.vocab_size, :] = old_embed.weight.data[: self.config.vocab_size, :]
self.embed.weight.data[-self.config.n_ctx :, :] = old_embed.weight.data[-self.config.n_ctx :, :]
self.embed.weight.data[-self.config.n_positions :, :] = old_embed.weight.data[-self.config.n_positions :, :]
def forward(self, input_ids, position_ids=None, token_type_ids=None):
if position_ids is None:
......@@ -645,7 +651,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
total_num_embeddings - 1] ______________________
where total_num_embeddings can be obtained as config.total_num_embeddings and is:
total_num_embeddings = config.vocab_size + config.n_special + config.n_ctx
total_num_embeddings = config.vocab_size + config.n_special + config.n_positions
You should use these indices to index the word, special and position embeddings.
The special embeddings ([SEP], [CLS]...) are not pre-trained and need to be trained during the fine-tuning if you use them.
......@@ -658,7 +664,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
with the position indices (selected in the range [config.vocab_size + config.n_special, config.vocab_size + config.n_special + config.n_ctx - 1[.
with the position indices (selected in the range [config.vocab_size + config.n_special, config.vocab_size + config.n_special + config.n_positions - 1[.
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
You can use it to add a third embedding (the previous two being the word and position embeddings)
to each token in the sentence.
......@@ -725,7 +731,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
total_num_embeddings - 1] ______________________
where total_num_embeddings can be obtained as config.total_num_embeddings and is:
total_num_embeddings = config.vocab_size + config.n_special + config.n_ctx
total_num_embeddings = config.vocab_size + config.n_special + config.n_positions
You should use these indices to index the word, special and position embeddings.
The special embeddings ([SEP], [CLS]...) are not pre-trained and need to be trained during the fine-tuning if you use them.
......@@ -741,7 +747,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
with a value of 1 were the last hidden state is (usually the [CLS] token) and 0 otherwise.
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
with the position indices (selected in the range [config.vocab_size + config.n_special,
config.vocab_size + config.n_special + config.n_ctx - 1[.
config.vocab_size + config.n_special + config.n_positions - 1[.
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
You can use it to add a third embedding (the previous two being the word and position embeddings)
to each token in the sentence.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment