Unverified Commit c7d06b79 authored by jazzcook15's avatar jazzcook15 Committed by GitHub
Browse files

Fix #3954 - GPT2 is not traceable (#3955)



* Update sqrt computation so it can survive a torch.jit.trace

* Update modeling_gpt2.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 9a0a8c1c
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
import logging import logging
import math
import os import os
import torch import torch
...@@ -143,7 +142,7 @@ class Attention(nn.Module): ...@@ -143,7 +142,7 @@ class Attention(nn.Module):
def _attn(self, q, k, v, attention_mask=None, head_mask=None): def _attn(self, q, k, v, attention_mask=None, head_mask=None):
w = torch.matmul(q, k) w = torch.matmul(q, k)
if self.scale: if self.scale:
w = w / math.sqrt(v.size(-1)) w = w / (v.size(-1) ** 0.5)
nd, ns = w.size(-2), w.size(-1) nd, ns = w.size(-2), w.size(-1)
mask = self.bias[:, :, ns - nd : ns, :ns] mask = self.bias[:, :, ns - nd : ns, :ns]
w = torch.where(mask, w, self.masked_bias) w = torch.where(mask, w, self.masked_bias)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment