"git@developer.sourcefind.cn:OpenDAS/torch-cluster.git" did not exist on "89b74f0aaf0a9a60b36d4241a5578c92a7cced8a"
Unverified Commit b62d9a1f authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Text-to-video] Add `torch.compile()` compatibility (#3949)

* use sample directly instead of the dataclass.

* more usage of directly samples instead of dataclasses

* more usage of directly samples instead of dataclasses

* use direct sample in the pipeline.

* direct usage of sample in the img2img case.
parent 46af9826
......@@ -250,10 +250,11 @@ class UNetMidBlock3DCrossAttn(nn.Module):
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]
hidden_states = temp_attn(
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
).sample
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
)[0]
hidden_states = resnet(hidden_states, temb)
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
......@@ -377,10 +378,11 @@ class CrossAttnDownBlock3D(nn.Module):
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]
hidden_states = temp_attn(
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
).sample
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
)[0]
output_states += (hidden_states,)
......@@ -590,10 +592,11 @@ class CrossAttnUpBlock3D(nn.Module):
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]
hidden_states = temp_attn(
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
).sample
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
)[0]
if self.upsamplers is not None:
for upsampler in self.upsamplers:
......
......@@ -526,8 +526,11 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
sample = self.conv_in(sample)
sample = self.transformer_in(
sample, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
).sample
sample,
num_frames=num_frames,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
# 3. down
down_block_res_samples = (sample,)
......
......@@ -648,7 +648,8 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
......
......@@ -723,7 +723,8 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment