Commit be5821b8 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Switch to torch.nn.functional.gelu when available

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/735

Differential Revision: D16377046

Pulled By: myleott

fbshipit-source-id: 9725d4a3ce6b2fc8cee0b1d1cb8921f9d59c551a
parent b002d009
...@@ -21,4 +21,7 @@ def gelu_accurate(x): ...@@ -21,4 +21,7 @@ def gelu_accurate(x):
def gelu(x: torch.Tensor) -> torch.Tensor: def gelu(x: torch.Tensor) -> torch.Tensor:
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) if hasattr(torch.nn.functional, 'gelu'):
return torch.nn.functional.gelu(x.float()).type_as(x)
else:
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment