Unverified Commit e7ffeae0 authored by 39th president of the United States, probably's avatar 39th president of the United States, probably Committed by GitHub
Browse files

Fix for multi-GPU WAN inference (#10997)

Ensure that hidden_state and shift/scale are on the same device when running with multiple GPUs

Co-authored-by: Jimmy <39@🇺🇸.com>
parent d87ce2ce
......@@ -441,6 +441,14 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
# 5. Output norm, projection & unpatchify
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
# Move the shift and scale tensors to the same device as hidden_states.
# When using multi-GPU inference via accelerate these will be on the
# first device rather than the last device, which hidden_states ends up
# on.
shift = shift.to(hidden_states.device)
scale = scale.to(hidden_states.device)
hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
hidden_states = self.proj_out(hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment