Commit 960ef4df authored by thomwolf's avatar thomwolf
Browse files

probably ok weights convertion script

parent ab0e8932
......@@ -9,6 +9,7 @@ import re
import argparse
import tensorflow as tf
import torch
import numpy as np
from modeling_pytorch import BertConfig, BertModel
......@@ -55,7 +56,11 @@ def convert():
for name, array in zip(names, arrays):
name = name[5:] # skip "bert/"
print("Loading {}".format(name))
name = name.split('/')
if name[0] in ['redictions', 'eq_relationship']:
print("Skipping")
continue
pointer = model
for m_name in name:
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
......@@ -71,8 +76,8 @@ def convert():
pointer = pointer[num]
if m_name[-11:] == '_embeddings':
pointer = getattr(pointer, 'weight')
# elif m_name == 'kernel':
# pointer = getattr(pointer, 'weight')
elif m_name == 'kernel':
array = np.transpose(array)
try:
assert pointer.shape == array.shape
except AssertionError as e:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment