"...text-generation-inference.git" did not exist on "8642250602bf9d0cf0fd7312f3e643458f3b3e09"
Unverified Commit babfb8a0 authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

[MPS] call contiguous after permute (#1411)

* call contiguous after permute

Fixes for MPS device

* Fix MPS UserWarning

* make style

* Revert "Fix MPS UserWarning"

This reverts commit b46c32810ee5fdc4c16a8e9224a826490b66cf49.
parent 35099b20
...@@ -221,11 +221,15 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -221,11 +221,15 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
# 3. Output # 3. Output
if self.is_input_continuous: if self.is_input_continuous:
if not self.use_linear_projection: if not self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2) hidden_states = (
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
)
hidden_states = self.proj_out(hidden_states) hidden_states = self.proj_out(hidden_states)
else: else:
hidden_states = self.proj_out(hidden_states) hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2) hidden_states = (
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
)
output = hidden_states + residual output = hidden_states + residual
elif self.is_input_vectorized: elif self.is_input_vectorized:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment