Unverified Commit b62d9a1f authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Text-to-video] Add `torch.compile()` compatibility (#3949)

* use sample directly instead of the dataclass.

* more usage of directly samples instead of dataclasses

* more usage of directly samples instead of dataclasses

* use direct sample in the pipeline.

* direct usage of sample in the img2img case.
parent 46af9826
...@@ -250,10 +250,11 @@ class UNetMidBlock3DCrossAttn(nn.Module): ...@@ -250,10 +250,11 @@ class UNetMidBlock3DCrossAttn(nn.Module):
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
).sample return_dict=False,
)[0]
hidden_states = temp_attn( hidden_states = temp_attn(
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
).sample )[0]
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
hidden_states = temp_conv(hidden_states, num_frames=num_frames) hidden_states = temp_conv(hidden_states, num_frames=num_frames)
...@@ -377,10 +378,11 @@ class CrossAttnDownBlock3D(nn.Module): ...@@ -377,10 +378,11 @@ class CrossAttnDownBlock3D(nn.Module):
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
).sample return_dict=False,
)[0]
hidden_states = temp_attn( hidden_states = temp_attn(
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
).sample )[0]
output_states += (hidden_states,) output_states += (hidden_states,)
...@@ -590,10 +592,11 @@ class CrossAttnUpBlock3D(nn.Module): ...@@ -590,10 +592,11 @@ class CrossAttnUpBlock3D(nn.Module):
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
).sample return_dict=False,
)[0]
hidden_states = temp_attn( hidden_states = temp_attn(
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
).sample )[0]
if self.upsamplers is not None: if self.upsamplers is not None:
for upsampler in self.upsamplers: for upsampler in self.upsamplers:
......
...@@ -526,8 +526,11 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -526,8 +526,11 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
sample = self.conv_in(sample) sample = self.conv_in(sample)
sample = self.transformer_in( sample = self.transformer_in(
sample, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs sample,
).sample num_frames=num_frames,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
# 3. down # 3. down
down_block_res_samples = (sample,) down_block_res_samples = (sample,)
......
...@@ -648,7 +648,8 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora ...@@ -648,7 +648,8 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
).sample return_dict=False,
)[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
......
...@@ -723,7 +723,8 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor ...@@ -723,7 +723,8 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
).sample return_dict=False,
)[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment