Unverified Commit c6acd246 authored by Max Ryabinin's avatar Max Ryabinin Committed by GitHub
Browse files

Speed up GELU computation with torch.jit (#2988)

* Compile gelu_new with torchscript

* Compile _gelu_python with torchscript

* Wrap gelu_new with torch.jit for torch>=1.4
parent d5d7d886
......@@ -18,12 +18,6 @@ def _gelu_python(x):
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
if torch.__version__ < "1.4.0":
gelu = _gelu_python
else:
gelu = F.gelu
def gelu_new(x):
""" Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
Also see https://arxiv.org/abs/1606.08415
......@@ -31,6 +25,12 @@ def gelu_new(x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
if torch.__version__ < "1.4.0":
gelu = _gelu_python
else:
gelu = F.gelu
gelu_new = torch.jit.script(gelu_new)
ACT2FN = {
"relu": F.relu,
"swish": swish,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment