Commit 960ef4df authored by thomwolf's avatar thomwolf
Browse files

probably ok weights convertion script

parent ab0e8932
...@@ -9,6 +9,7 @@ import re ...@@ -9,6 +9,7 @@ import re
import argparse import argparse
import tensorflow as tf import tensorflow as tf
import torch import torch
import numpy as np
from modeling_pytorch import BertConfig, BertModel from modeling_pytorch import BertConfig, BertModel
...@@ -55,7 +56,11 @@ def convert(): ...@@ -55,7 +56,11 @@ def convert():
for name, array in zip(names, arrays): for name, array in zip(names, arrays):
name = name[5:] # skip "bert/" name = name[5:] # skip "bert/"
print("Loading {}".format(name))
name = name.split('/') name = name.split('/')
if name[0] in ['redictions', 'eq_relationship']:
print("Skipping")
continue
pointer = model pointer = model
for m_name in name: for m_name in name:
if re.fullmatch(r'[A-Za-z]+_\d+', m_name): if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
...@@ -71,8 +76,8 @@ def convert(): ...@@ -71,8 +76,8 @@ def convert():
pointer = pointer[num] pointer = pointer[num]
if m_name[-11:] == '_embeddings': if m_name[-11:] == '_embeddings':
pointer = getattr(pointer, 'weight') pointer = getattr(pointer, 'weight')
# elif m_name == 'kernel': elif m_name == 'kernel':
# pointer = getattr(pointer, 'weight') array = np.transpose(array)
try: try:
assert pointer.shape == array.shape assert pointer.shape == array.shape
except AssertionError as e: except AssertionError as e:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment