Unverified Commit c6acd246 authored by Max Ryabinin's avatar Max Ryabinin Committed by GitHub
Browse files

Speed up GELU computation with torch.jit (#2988)

* Compile gelu_new with torchscript

* Compile _gelu_python with torchscript

* Wrap gelu_new with torch.jit for torch>=1.4
parent d5d7d886
...@@ -18,12 +18,6 @@ def _gelu_python(x): ...@@ -18,12 +18,6 @@ def _gelu_python(x):
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
if torch.__version__ < "1.4.0":
gelu = _gelu_python
else:
gelu = F.gelu
def gelu_new(x): def gelu_new(x):
""" Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT). """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
Also see https://arxiv.org/abs/1606.08415 Also see https://arxiv.org/abs/1606.08415
...@@ -31,6 +25,12 @@ def gelu_new(x): ...@@ -31,6 +25,12 @@ def gelu_new(x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
if torch.__version__ < "1.4.0":
gelu = _gelu_python
else:
gelu = F.gelu
gelu_new = torch.jit.script(gelu_new)
ACT2FN = { ACT2FN = {
"relu": F.relu, "relu": F.relu,
"swish": swish, "swish": swish,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment