Commit 19666dcb authored by thomwolf's avatar thomwolf
Browse files

Should fix #438

parent 1d8c2323
...@@ -91,6 +91,8 @@ def load_tf_weights_in_bert(model, tf_checkpoint_path): ...@@ -91,6 +91,8 @@ def load_tf_weights_in_bert(model, tf_checkpoint_path):
pointer = getattr(pointer, 'bias') pointer = getattr(pointer, 'bias')
elif l[0] == 'output_weights': elif l[0] == 'output_weights':
pointer = getattr(pointer, 'weight') pointer = getattr(pointer, 'weight')
elif l[0] == 'squad':
pointer = getattr(pointer, 'classifier')
else: else:
try: try:
pointer = getattr(pointer, l[0]) pointer = getattr(pointer, l[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment