Unverified Commit 95414bd6 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Experimental: allow fp16 in `mps` (#961)

* Docs: refer to pre-RC version of PyTorch 1.13.0.

* Remove temporary workaround for unavailable op.

* Update comment to make it less ambiguous.

* Remove use of contiguous in mps.

It appears to not longer be necessary.

* Special case: use einsum for much better performance in mps

* Update mps docs.

* MPS: make pipeline work in half precision.
parent a59f9990
......@@ -376,6 +376,12 @@ class GEGLU(nn.Module):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def gelu(self, gate):
if gate.device.type != "mps":
return F.gelu(gate)
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
def forward(self, hidden_states):
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return hidden_states * F.gelu(gate)
return hidden_states * self.gelu(gate)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment