Unverified Commit 2b7deffe authored by Vladimir Mandic's avatar Vladimir Mandic Committed by GitHub
Browse files

fix scale_shift_factor being on cpu for wan and ltx (#12347)



* wan fix scale_shift_factor being on cpu

* apply device cast to ltx transformer

* Apply style fixes

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 941ac9c3
...@@ -353,7 +353,9 @@ class LTXVideoTransformerBlock(nn.Module): ...@@ -353,7 +353,9 @@ class LTXVideoTransformerBlock(nn.Module):
norm_hidden_states = self.norm1(hidden_states) norm_hidden_states = self.norm1(hidden_states)
num_ada_params = self.scale_shift_table.shape[0] num_ada_params = self.scale_shift_table.shape[0]
ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1) ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape(
batch_size, temb.size(1), num_ada_params, -1
)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
......
...@@ -682,12 +682,12 @@ class WanTransformer3DModel( ...@@ -682,12 +682,12 @@ class WanTransformer3DModel(
# 5. Output norm, projection & unpatchify # 5. Output norm, projection & unpatchify
if temb.ndim == 3: if temb.ndim == 3:
# batch_size, seq_len, inner_dim (wan 2.2 ti2v) # batch_size, seq_len, inner_dim (wan 2.2 ti2v)
shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2) shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2)
shift = shift.squeeze(2) shift = shift.squeeze(2)
scale = scale.squeeze(2) scale = scale.squeeze(2)
else: else:
# batch_size, inner_dim # batch_size, inner_dim
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
# Move the shift and scale tensors to the same device as hidden_states. # Move the shift and scale tensors to the same device as hidden_states.
# When using multi-GPU inference via accelerate these will be on the # When using multi-GPU inference via accelerate these will be on the
......
...@@ -103,7 +103,7 @@ class WanVACETransformerBlock(nn.Module): ...@@ -103,7 +103,7 @@ class WanVACETransformerBlock(nn.Module):
control_hidden_states = control_hidden_states + hidden_states control_hidden_states = control_hidden_states + hidden_states
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table + temb.float() self.scale_shift_table.to(temb.device) + temb.float()
).chunk(6, dim=1) ).chunk(6, dim=1)
# 1. Self-attention # 1. Self-attention
...@@ -361,7 +361,7 @@ class WanVACETransformer3DModel( ...@@ -361,7 +361,7 @@ class WanVACETransformer3DModel(
hidden_states = hidden_states + control_hint * scale hidden_states = hidden_states + control_hint * scale
# 6. Output norm, projection & unpatchify # 6. Output norm, projection & unpatchify
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
# Move the shift and scale tensors to the same device as hidden_states. # Move the shift and scale tensors to the same device as hidden_states.
# When using multi-GPU inference via accelerate these will be on the # When using multi-GPU inference via accelerate these will be on the
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment