"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "73f2c342f53f2ff02124da23ba029d80c386e7ce"
Commit ab0e8932 authored by thomwolf's avatar thomwolf
Browse files

convertion script WIP

parent 5581edb4
...@@ -10,7 +10,7 @@ import argparse ...@@ -10,7 +10,7 @@ import argparse
import tensorflow as tf import tensorflow as tf
import torch import torch
from .modeling_pytorch import BertConfig, BertModel from modeling_pytorch import BertConfig, BertModel
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -35,6 +35,10 @@ parser.add_argument("--pytorch_dump_path", ...@@ -35,6 +35,10 @@ parser.add_argument("--pytorch_dump_path",
args = parser.parse_args() args = parser.parse_args()
def convert(): def convert():
# Initialise PyTorch model
config = BertConfig.from_json_file(args.bert_config_file)
model = BertModel(config)
# Load weights from TF model # Load weights from TF model
path = args.tf_checkpoint_path path = args.tf_checkpoint_path
print("Converting TensorFlow checkpoint from {}".format(path)) print("Converting TensorFlow checkpoint from {}".format(path))
...@@ -49,24 +53,26 @@ def convert(): ...@@ -49,24 +53,26 @@ def convert():
names.append(name) names.append(name)
arrays.append(array) arrays.append(array)
# Initialise PyTorch model and fill weights-in
config = BertConfig.from_json_file(args.bert_config_file)
model = BertModel(config)
for name, array in zip(names, arrays): for name, array in zip(names, arrays):
name = name[5:] # skip "bert/" name = name[5:] # skip "bert/"
assert name[-2:] == ":0"
name = name[:-2]
name = name.split('/') name = name.split('/')
pointer = model pointer = model
for m_name in name: for m_name in name:
if re.fullmatch(r'[A-Za-z]+\d+', m_name): if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
l = re.split(r'(\d+)', m_name) l = re.split(r'_(\d+)', m_name)
else: else:
l = [m_name] l = [m_name]
if l[0] == 'kernel':
pointer = getattr(pointer, 'weight')
else:
pointer = getattr(pointer, l[0]) pointer = getattr(pointer, l[0])
if len(l) >= 2: if len(l) >= 2:
num = int(l[1]) num = int(l[1])
pointer = pointer[num] pointer = pointer[num]
if m_name[-11:] == '_embeddings':
pointer = getattr(pointer, 'weight')
# elif m_name == 'kernel':
# pointer = getattr(pointer, 'weight')
try: try:
assert pointer.shape == array.shape assert pointer.shape == array.shape
except AssertionError as e: except AssertionError as e:
...@@ -79,4 +85,3 @@ def convert(): ...@@ -79,4 +85,3 @@ def convert():
if __name__ == "__main__": if __name__ == "__main__":
convert() convert()
return None
...@@ -129,8 +129,8 @@ class BERTLayerNorm(nn.Module): ...@@ -129,8 +129,8 @@ class BERTLayerNorm(nn.Module):
class BERTEmbeddings(nn.Module): class BERTEmbeddings(nn.Module):
def __init__(self, config): def __init__(self, config):
super(BERTEmbeddings, self).__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size) self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
# Position embeddings are (normally) a contiguous range so we could use a slice # Position embeddings are (normally) a contiguous range so we could use a slice
# Since the position embedding table is a learned variable, we create it # Since the position embedding table is a learned variable, we create it
...@@ -142,12 +142,12 @@ class BERTEmbeddings(nn.Module): ...@@ -142,12 +142,12 @@ class BERTEmbeddings(nn.Module):
# for position [0, 1, 2, ..., max_position_embeddings-1], and the current # for position [0, 1, 2, ..., max_position_embeddings-1], and the current
# sequence has positions [0, 1, 2, ... seq_length-1], so we can just # sequence has positions [0, 1, 2, ... seq_length-1], so we can just
# perform a slice. # perform a slice.
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
# token_type_embeddings vocabulary is very small. TF used one-hot embeddings to speedup. # token_type_embeddings vocabulary is very small. TF used one-hot embeddings to speedup.
self.token_type_embeddings = nn.Embedding(config.token_type_vocab_size, config.embedding_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
self.LayerNorm = BERTLayerNorm() # Not snake-cased to stick with TF model variable name self.LayerNorm = BERTLayerNorm(config) # Not snake-cased to stick with TF model variable name
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids, token_type_ids=None): def forward(self, input_ids, token_type_ids=None):
...@@ -185,7 +185,7 @@ class BERTSelfAttention(nn.Module): ...@@ -185,7 +185,7 @@ class BERTSelfAttention(nn.Module):
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, input_tensor, num_attention_heads, is_key_tensor=False): def transpose_for_scores(self, x, is_key_tensor=False):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape) x = x.view(*new_x_shape)
if is_key_tensor: if is_key_tensor:
...@@ -270,7 +270,7 @@ class BERTAttention(nn.Module): ...@@ -270,7 +270,7 @@ class BERTAttention(nn.Module):
class BERTIntermediate(nn.Module): class BERTIntermediate(nn.Module):
def __init__(self, config): def __init__(self, config):
super(BERTOutput, self).__init__() super(BERTIntermediate, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size) self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.intermediate_act_fn = gelu self.intermediate_act_fn = gelu
...@@ -305,13 +305,13 @@ class BERTLayer(nn.Module): ...@@ -305,13 +305,13 @@ class BERTLayer(nn.Module):
attention_output = self.attention(hidden_states, attention_mask) attention_output = self.attention(hidden_states, attention_mask)
intermediate_output = self.intermediate(attention_output) intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output) layer_output = self.output(intermediate_output, attention_output)
return hidden_states return layer_output
class BERTEncoder(nn.Module): class BERTEncoder(nn.Module):
def __init__(self, config): def __init__(self, config):
super(BERTEncoder, self).__init__() super(BERTEncoder, self).__init__()
layer = BERTLayer(n_ctx, cfg, scale=True) layer = BERTLayer(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask): def forward(self, hidden_states, attention_mask):
...@@ -383,7 +383,7 @@ class BertModel(nn.Module): ...@@ -383,7 +383,7 @@ class BertModel(nn.Module):
ValueError: The config is invalid or one of the input tensor shapes ValueError: The config is invalid or one of the input tensor shapes
is invalid. is invalid.
""" """
super(BertModel).__init__() super(BertModel, self).__init__()
self.embeddings = BERTEmbeddings(config) self.embeddings = BERTEmbeddings(config)
self.encoder = BERTEncoder(config) self.encoder = BERTEncoder(config)
self.pooler = BERTPooler(config) self.pooler = BERTPooler(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment