Unverified Commit 82601f4c authored by Tianlei Wu's avatar Tianlei Wu Committed by GitHub
Browse files

Allow gpt2 to be exported to valid ONNX (#4244)

* allow gpt2 to be exported to valid ONNX model

* cast size from int to float explictly
parent 39994051
......@@ -26,7 +26,7 @@ def gelu_new(x):
""" Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
Also see https://arxiv.org/abs/1606.08415
"""
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
if torch.__version__ < "1.4.0":
......@@ -36,7 +36,7 @@ else:
def gelu_fast(x):
return 0.5 * x * (1 + torch.tanh(x * 0.7978845608 * (1 + 0.044715 * x * x)))
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
ACT2FN = {
......
......@@ -142,10 +142,10 @@ class Attention(nn.Module):
def _attn(self, q, k, v, attention_mask=None, head_mask=None):
w = torch.matmul(q, k)
if self.scale:
w = w / (v.size(-1) ** 0.5)
w = w / (float(v.size(-1)) ** 0.5)
nd, ns = w.size(-2), w.size(-1)
mask = self.bias[:, :, ns - nd : ns, :ns]
w = torch.where(mask, w, self.masked_bias.to(w.dtype))
w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype))
if attention_mask is not None:
# Apply the attention mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment