"test/git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "0d40a52eec5351b113cf861d03b3041ababb1727"
Unverified Commit 82601f4c authored by Tianlei Wu's avatar Tianlei Wu Committed by GitHub
Browse files

Allow gpt2 to be exported to valid ONNX (#4244)

* allow gpt2 to be exported to valid ONNX model

* cast size from int to float explictly
parent 39994051
...@@ -26,7 +26,7 @@ def gelu_new(x): ...@@ -26,7 +26,7 @@ def gelu_new(x):
""" Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT). """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
Also see https://arxiv.org/abs/1606.08415 Also see https://arxiv.org/abs/1606.08415
""" """
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
if torch.__version__ < "1.4.0": if torch.__version__ < "1.4.0":
...@@ -36,7 +36,7 @@ else: ...@@ -36,7 +36,7 @@ else:
def gelu_fast(x): def gelu_fast(x):
return 0.5 * x * (1 + torch.tanh(x * 0.7978845608 * (1 + 0.044715 * x * x))) return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
ACT2FN = { ACT2FN = {
......
...@@ -142,10 +142,10 @@ class Attention(nn.Module): ...@@ -142,10 +142,10 @@ class Attention(nn.Module):
def _attn(self, q, k, v, attention_mask=None, head_mask=None): def _attn(self, q, k, v, attention_mask=None, head_mask=None):
w = torch.matmul(q, k) w = torch.matmul(q, k)
if self.scale: if self.scale:
w = w / (v.size(-1) ** 0.5) w = w / (float(v.size(-1)) ** 0.5)
nd, ns = w.size(-2), w.size(-1) nd, ns = w.size(-2), w.size(-1)
mask = self.bias[:, :, ns - nd : ns, :ns] mask = self.bias[:, :, ns - nd : ns, :ns]
w = torch.where(mask, w, self.masked_bias.to(w.dtype)) w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype))
if attention_mask is not None: if attention_mask is not None:
# Apply the attention mask # Apply the attention mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment