Commit 314bc6bb authored by Chris's avatar Chris
Browse files

added transposes to attention.self.[query,key,value]

parent 8de1faea
...@@ -39,6 +39,24 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str): ...@@ -39,6 +39,24 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
N BertForQuestionAnswering N BertForQuestionAnswering
""" """
tensors_to_transopse = (
"dense.weight",
"attention.self.query",
"attention.self.key",
"attention.self.value"
)
var_map = (
('layer.', 'layer_'),
('word_embeddings.weight', 'word_embeddings'),
('position_embeddings.weight', 'position_embeddings'),
('token_type_embeddings.weight', 'token_type_embeddings'),
('.', '/'),
('LayerNorm/weight', 'LayerNorm/gamma'),
('LayerNorm/bias', 'LayerNorm/beta'),
('weight', 'kernel')
)
if not os.path.isdir(ckpt_dir): if not os.path.isdir(ckpt_dir):
os.makedirs(ckpt_dir) os.makedirs(ckpt_dir)
...@@ -47,15 +65,8 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str): ...@@ -47,15 +65,8 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
tf_vars = [] tf_vars = []
def to_tf_var_name(name:str): def to_tf_var_name(name:str):
"""todo: compile as regex""" for patt, repl in iter(var_map):
name = name.replace('layer.', 'layer_') name = name.replace(patt, repl)
name = name.replace('word_embeddings.weight', 'word_embeddings')
name = name.replace('position_embeddings.weight', 'position_embeddings')
name = name.replace('token_type_embeddings.weight', 'token_type_embeddings')
name = name.replace('.', '/')
name = name.replace('LayerNorm/weight', 'LayerNorm/gamma')
name = name.replace('LayerNorm/bias', 'LayerNorm/beta')
name = name.replace('weight', 'kernel')
return 'bert/{}'.format(name) return 'bert/{}'.format(name)
def assign_tf_var(tensor:np.ndarray, name:str): def assign_tf_var(tensor:np.ndarray, name:str):
...@@ -69,7 +80,7 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str): ...@@ -69,7 +80,7 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
for var_name in state_dict: for var_name in state_dict:
tf_name = to_tf_var_name(var_name) tf_name = to_tf_var_name(var_name)
torch_tensor = state_dict[var_name].numpy() torch_tensor = state_dict[var_name].numpy()
if var_name.endswith('dense.weight'): if any([x in var_name for x in tensors_to_transopse]):
torch_tensor = torch_tensor.T torch_tensor = torch_tensor.T
tf_tensor = assign_tf_var(tensor=torch_tensor, name=tf_name) tf_tensor = assign_tf_var(tensor=torch_tensor, name=tf_name)
tf_vars.append(tf_tensor) tf_vars.append(tf_tensor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment