"pytorch_transformers/optimization.py" did not exist on "886cb49792f0d39b24d285726e9434897dd9dc6e"
Commit 7ba83730 authored by lukovnikov's avatar lukovnikov
Browse files

clean up pr

parent fa0c5a2e
...@@ -68,11 +68,17 @@ def convert(): ...@@ -68,11 +68,17 @@ def convert():
arrays.append(array) arrays.append(array)
for name, array in zip(names, arrays): for name, array in zip(names, arrays):
name = name[5:] # skip "bert/" if not name.startswith("bert"):
print("Skipping {}".format(name))
continue
else:
name = name.replace("bert/", "") # skip "bert/"
print("Loading {}".format(name)) print("Loading {}".format(name))
name = name.split('/') name = name.split('/')
if name[0] in ['redictions', 'eq_relationship']: # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
print("Skipping") # which are not required for using pretrained model
if name[0] in ['redictions', 'eq_relationship'] or name[-1] == "adam_v" or name[-1] == "adam_m":
print("Skipping {}".format("/".join(name)))
continue continue
pointer = model pointer = model
for m_name in name: for m_name in name:
......
...@@ -26,6 +26,10 @@ import torch ...@@ -26,6 +26,10 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
ACT2FN = {"gelu": gelu, "relu": torch.nn.ReLU, "swish": swish}
def gelu(x): def gelu(x):
"""Implementation of the gelu activation function. """Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
...@@ -241,8 +245,8 @@ class BERTIntermediate(nn.Module): ...@@ -241,8 +245,8 @@ class BERTIntermediate(nn.Module):
def __init__(self, config): def __init__(self, config):
super(BERTIntermediate, self).__init__() super(BERTIntermediate, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size) self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
act2fn = {"gelu": gelu, "relu": torch.nn.ReLU, "swish": swish} self.intermediate_act_fn = ACT2FN[config.hidden_act] \
self.intermediate_act_fn = act2fn[config.hidden_act] if isinstance(config.hidden_act, str) else config.hidden_act if isinstance(config.hidden_act, str) else config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment