Unverified Commit ec9bfa9e authored by skotapati's avatar skotapati Committed by GitHub
Browse files

Remove mps workaround for fp16 GELU, which is now supported natively (#10133)



* Remove mps workaround for fp16 GELU, which is now supported natively

---------
Co-authored-by: default avatarhlky <hlky@hlky.ac>
parent bdbaea8f
......@@ -18,7 +18,7 @@ import torch.nn.functional as F
from torch import nn
from ..utils import deprecate
from ..utils.import_utils import is_torch_npu_available
from ..utils.import_utils import is_torch_npu_available, is_torch_version
if is_torch_npu_available():
......@@ -79,10 +79,10 @@ class GELU(nn.Module):
self.approximate = approximate
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps":
return F.gelu(gate, approximate=self.approximate)
# mps: gelu is not implemented for float16
if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
# fp16 gelu not supported on mps before torch 2.0
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
return F.gelu(gate, approximate=self.approximate)
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
......@@ -105,10 +105,10 @@ class GEGLU(nn.Module):
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps":
return F.gelu(gate)
# mps: gelu is not implemented for float16
if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
# fp16 gelu not supported on mps before torch 2.0
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
return F.gelu(gate)
def forward(self, hidden_states, *args, **kwargs):
if len(args) > 0 or kwargs.get("scale", None) is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment