Commit 5456d823 authored by thomwolf's avatar thomwolf
Browse files

more versatile model loading

parent 9b2540b5
...@@ -606,7 +606,9 @@ class BertPreTrainedModel(nn.Module): ...@@ -606,7 +606,9 @@ class BertPreTrainedModel(nn.Module):
for name, child in module._modules.items(): for name, child in module._modules.items():
if child is not None: if child is not None:
load(child, prefix + name + '.') load(child, prefix + name + '.')
start_prefix = 'bert.' if not hasattr(model, 'bert') and any(s.startwith('bert.') for s in state_dict.keys()) else '' start_prefix = ''
if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()):
start_prefix = 'bert.'
load(model, prefix=start_prefix) load(model, prefix=start_prefix)
if len(missing_keys) > 0: if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format( logger.info("Weights of {} not initialized from pretrained model: {}".format(
......
...@@ -120,6 +120,7 @@ class OpenAIGPTConfig(object): ...@@ -120,6 +120,7 @@ class OpenAIGPTConfig(object):
self, self,
vocab_size_or_config_json_file=40478, vocab_size_or_config_json_file=40478,
n_special=0, n_special=0,
n_positions=512,
n_ctx=512, n_ctx=512,
n_embd=768, n_embd=768,
n_layer=12, n_layer=12,
...@@ -135,7 +136,8 @@ class OpenAIGPTConfig(object): ...@@ -135,7 +136,8 @@ class OpenAIGPTConfig(object):
Args: Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `OpenAIGPTModel` or a configuration json file. vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `OpenAIGPTModel` or a configuration json file.
n_special: The number of special tokens to learn during fine-tuning ('[SEP]', '[CLF]', ...) n_special: The number of special tokens to learn during fine-tuning ('[SEP]', '[CLF]', ...)
n_ctx: Number of positional embeddings. n_positions: Number of positional embeddings.
n_ctx: Size of the causal mask (usually same as n_positions).
n_embd: Dimensionality of the embeddings and hidden states. n_embd: Dimensionality of the embeddings and hidden states.
n_layer: Number of hidden layers in the Transformer encoder. n_layer: Number of hidden layers in the Transformer encoder.
n_head: Number of attention heads for each attention layer in n_head: Number of attention heads for each attention layer in
...@@ -159,6 +161,7 @@ class OpenAIGPTConfig(object): ...@@ -159,6 +161,7 @@ class OpenAIGPTConfig(object):
self.vocab_size = vocab_size_or_config_json_file self.vocab_size = vocab_size_or_config_json_file
self.n_special = n_special self.n_special = n_special
self.n_ctx = n_ctx self.n_ctx = n_ctx
self.n_positions = n_positions
self.n_embd = n_embd self.n_embd = n_embd
self.n_layer = n_layer self.n_layer = n_layer
self.n_head = n_head self.n_head = n_head
...@@ -175,7 +178,7 @@ class OpenAIGPTConfig(object): ...@@ -175,7 +178,7 @@ class OpenAIGPTConfig(object):
@property @property
def total_num_embeddings(self): def total_num_embeddings(self):
return self.vocab_size + self.n_special + self.n_ctx return self.vocab_size + self.n_special + self.n_positions
@classmethod @classmethod
def from_dict(cls, json_object): def from_dict(cls, json_object):
...@@ -234,7 +237,7 @@ class Attention(nn.Module): ...@@ -234,7 +237,7 @@ class Attention(nn.Module):
n_state = nx # in Attention: n_state=768 (nx=n_embd) n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem] # [switch nx => n_state from Block to Attention to keep identical to TF implem]
assert n_state % config.n_head == 0 assert n_state % config.n_head == 0
self.register_buffer("b", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
self.n_head = config.n_head self.n_head = config.n_head
self.split_size = n_state self.split_size = n_state
self.scale = scale self.scale = scale
...@@ -247,9 +250,9 @@ class Attention(nn.Module): ...@@ -247,9 +250,9 @@ class Attention(nn.Module):
w = torch.matmul(q, k) w = torch.matmul(q, k)
if self.scale: if self.scale:
w = w / math.sqrt(v.size(-1)) w = w / math.sqrt(v.size(-1))
# w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights # w = w * self.bias + -1e9 * (1 - self.bias) # TF implem method: mask_attn_weights
# XD: self.b may be larger than w, so we need to crop it # XD: self.b may be larger than w, so we need to crop it
b = self.b[:, :, : w.size(-2), : w.size(-1)] b = self.bias[:, :, : w.size(-2), : w.size(-1)]
w = w * b + -1e9 * (1 - b) w = w * b + -1e9 * (1 - b)
w = nn.Softmax(dim=-1)(w) w = nn.Softmax(dim=-1)(w)
...@@ -474,10 +477,12 @@ class OpenAIGPTPreTrainedModel(nn.Module): ...@@ -474,10 +477,12 @@ class OpenAIGPTPreTrainedModel(nn.Module):
new_keys = [] new_keys = []
for key in state_dict.keys(): for key in state_dict.keys():
new_key = None new_key = None
if "gamma" in key: if key.endswith(".g"):
new_key = key.replace("gamma", "weight") new_key = key[:-2] + ".weight"
if "beta" in key: elif key.endswith(".b"):
new_key = key.replace("beta", "bias") new_key = key[:-2] + ".bias"
elif key.endswith(".w"):
new_key = key[:-2] + ".weight"
if new_key: if new_key:
old_keys.append(key) old_keys.append(key)
new_keys.append(new_key) new_keys.append(new_key)
...@@ -502,7 +507,8 @@ class OpenAIGPTPreTrainedModel(nn.Module): ...@@ -502,7 +507,8 @@ class OpenAIGPTPreTrainedModel(nn.Module):
if child is not None: if child is not None:
load(child, prefix + name + ".") load(child, prefix + name + ".")
if hasattr(model, "transformer") and all(not s.startwith('transformer.') for s in state_dict.keys()): start_model = model
if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()):
start_model = model.transformer start_model = model.transformer
load(start_model, prefix="") load(start_model, prefix="")
...@@ -541,7 +547,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -541,7 +547,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
total_num_embeddings - 1] ______________________ total_num_embeddings - 1] ______________________
where total_num_embeddings can be obtained as config.total_num_embeddings and is: where total_num_embeddings can be obtained as config.total_num_embeddings and is:
total_num_embeddings = config.vocab_size + config.n_special + config.n_ctx total_num_embeddings = config.vocab_size + config.n_special + config.n_positions
You should use the associate indices to index the embeddings. You should use the associate indices to index the embeddings.
The special embeddings ([SEP], [CLS]...) are not pre-trained and need to be trained during the fine-tuning if you use them. The special embeddings ([SEP], [CLS]...) are not pre-trained and need to be trained during the fine-tuning if you use them.
...@@ -554,7 +560,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -554,7 +560,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length] `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[ were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
`position_ids`: an optional torch.LongTensor with the same shape as input_ids `position_ids`: an optional torch.LongTensor with the same shape as input_ids
with the position indices (selected in the range [config.vocab_size + config.n_special, config.vocab_size + config.n_special + config.n_ctx - 1[. with the position indices (selected in the range [config.vocab_size + config.n_special, config.vocab_size + config.n_special + config.n_positions - 1[.
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
You can use it to add a third embedding (the previous two being the word and position embeddings) You can use it to add a third embedding (the previous two being the word and position embeddings)
to each token in the sentence. to each token in the sentence.
...@@ -578,7 +584,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -578,7 +584,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super(OpenAIGPTModel, self).__init__(config) super(OpenAIGPTModel, self).__init__(config)
total_embeddings_size = config.vocab_size + config.n_special + config.n_ctx total_embeddings_size = config.vocab_size + config.n_special + config.n_positions
self.embed = nn.Embedding(total_embeddings_size, config.n_embd) self.embed = nn.Embedding(total_embeddings_size, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop) self.drop = nn.Dropout(config.embd_pdrop)
block = Block(config.n_ctx, config, scale=True) block = Block(config.n_ctx, config, scale=True)
...@@ -598,7 +604,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -598,7 +604,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
self.init_weights(self.embed) self.init_weights(self.embed)
# Copy word and positional embeddings from the previous weights # Copy word and positional embeddings from the previous weights
self.embed.weight.data[: self.config.vocab_size, :] = old_embed.weight.data[: self.config.vocab_size, :] self.embed.weight.data[: self.config.vocab_size, :] = old_embed.weight.data[: self.config.vocab_size, :]
self.embed.weight.data[-self.config.n_ctx :, :] = old_embed.weight.data[-self.config.n_ctx :, :] self.embed.weight.data[-self.config.n_positions :, :] = old_embed.weight.data[-self.config.n_positions :, :]
def forward(self, input_ids, position_ids=None, token_type_ids=None): def forward(self, input_ids, position_ids=None, token_type_ids=None):
if position_ids is None: if position_ids is None:
...@@ -645,7 +651,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -645,7 +651,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
total_num_embeddings - 1] ______________________ total_num_embeddings - 1] ______________________
where total_num_embeddings can be obtained as config.total_num_embeddings and is: where total_num_embeddings can be obtained as config.total_num_embeddings and is:
total_num_embeddings = config.vocab_size + config.n_special + config.n_ctx total_num_embeddings = config.vocab_size + config.n_special + config.n_positions
You should use these indices to index the word, special and position embeddings. You should use these indices to index the word, special and position embeddings.
The special embeddings ([SEP], [CLS]...) are not pre-trained and need to be trained during the fine-tuning if you use them. The special embeddings ([SEP], [CLS]...) are not pre-trained and need to be trained during the fine-tuning if you use them.
...@@ -658,7 +664,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -658,7 +664,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length] `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[ were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
`position_ids`: an optional torch.LongTensor with the same shape as input_ids `position_ids`: an optional torch.LongTensor with the same shape as input_ids
with the position indices (selected in the range [config.vocab_size + config.n_special, config.vocab_size + config.n_special + config.n_ctx - 1[. with the position indices (selected in the range [config.vocab_size + config.n_special, config.vocab_size + config.n_special + config.n_positions - 1[.
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
You can use it to add a third embedding (the previous two being the word and position embeddings) You can use it to add a third embedding (the previous two being the word and position embeddings)
to each token in the sentence. to each token in the sentence.
...@@ -725,7 +731,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -725,7 +731,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
total_num_embeddings - 1] ______________________ total_num_embeddings - 1] ______________________
where total_num_embeddings can be obtained as config.total_num_embeddings and is: where total_num_embeddings can be obtained as config.total_num_embeddings and is:
total_num_embeddings = config.vocab_size + config.n_special + config.n_ctx total_num_embeddings = config.vocab_size + config.n_special + config.n_positions
You should use these indices to index the word, special and position embeddings. You should use these indices to index the word, special and position embeddings.
The special embeddings ([SEP], [CLS]...) are not pre-trained and need to be trained during the fine-tuning if you use them. The special embeddings ([SEP], [CLS]...) are not pre-trained and need to be trained during the fine-tuning if you use them.
...@@ -741,7 +747,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -741,7 +747,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
with a value of 1 were the last hidden state is (usually the [CLS] token) and 0 otherwise. with a value of 1 were the last hidden state is (usually the [CLS] token) and 0 otherwise.
`position_ids`: an optional torch.LongTensor with the same shape as input_ids `position_ids`: an optional torch.LongTensor with the same shape as input_ids
with the position indices (selected in the range [config.vocab_size + config.n_special, with the position indices (selected in the range [config.vocab_size + config.n_special,
config.vocab_size + config.n_special + config.n_ctx - 1[. config.vocab_size + config.n_special + config.n_positions - 1[.
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
You can use it to add a third embedding (the previous two being the word and position embeddings) You can use it to add a third embedding (the previous two being the word and position embeddings)
to each token in the sentence. to each token in the sentence.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment