Unverified Commit 44f6b859 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Core] refactor `transformer_2d` forward logic into meaningful conditions. (#7489)



* refactor transformer_2d forward logic into meaningful conditions.

* Empty-Commit

* fix: _operate_on_patched_inputs

* fix: _operate_on_patched_inputs

* check

* fix: patch output computation block.

* fix: _operate_on_patched_inputs.

* remove print.

* move operations to blocks.

* more readability neats.

* empty commit

* Apply suggestions from code review
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* Revert "Apply suggestions from code review"

This reverts commit 12178b1aa0da3c29434e95a2a0126cf3ef5706a7.

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent ac7ff7d4
...@@ -402,41 +402,18 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -402,41 +402,18 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
# 1. Input # 1. Input
if self.is_input_continuous: if self.is_input_continuous:
batch, _, height, width = hidden_states.shape batch_size, _, height, width = hidden_states.shape
residual = hidden_states residual = hidden_states
hidden_states, inner_dim = self._operate_on_continuous_inputs(hidden_states)
hidden_states = self.norm(hidden_states)
if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
else:
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
hidden_states = self.proj_in(hidden_states)
elif self.is_input_vectorized: elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states) hidden_states = self.latent_image_embedding(hidden_states)
elif self.is_input_patches: elif self.is_input_patches:
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
hidden_states = self.pos_embed(hidden_states) hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs(
hidden_states, encoder_hidden_states, timestep, added_cond_kwargs
if self.adaln_single is not None:
if self.use_additional_conditions and added_cond_kwargs is None:
raise ValueError(
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
)
batch_size = hidden_states.shape[0]
timestep, embedded_timestep = self.adaln_single(
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
) )
# 2. Blocks # 2. Blocks
if self.is_input_patches and self.caption_projection is not None:
batch_size = hidden_states.shape[0]
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
for block in self.transformer_blocks: for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing: if self.training and self.gradient_checkpointing:
...@@ -474,24 +451,93 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -474,24 +451,93 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
# 3. Output # 3. Output
if self.is_input_continuous: if self.is_input_continuous:
output = self._get_output_for_continuous_inputs(
hidden_states=hidden_states,
residual=residual,
batch_size=batch_size,
height=height,
width=width,
inner_dim=inner_dim,
)
elif self.is_input_vectorized:
output = self._get_output_for_vectorized_inputs(hidden_states)
elif self.is_input_patches:
output = self._get_output_for_patched_inputs(
hidden_states=hidden_states,
timestep=timestep,
class_labels=class_labels,
embedded_timestep=embedded_timestep,
height=height,
width=width,
)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
def _operate_on_continuous_inputs(self, hidden_states):
batch, _, height, width = hidden_states.shape
hidden_states = self.norm(hidden_states)
if not self.use_linear_projection: if not self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
else:
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
hidden_states = self.proj_in(hidden_states)
return hidden_states, inner_dim
def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs):
batch_size = hidden_states.shape[0]
hidden_states = self.pos_embed(hidden_states)
embedded_timestep = None
if self.adaln_single is not None:
if self.use_additional_conditions and added_cond_kwargs is None:
raise ValueError(
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
)
timestep, embedded_timestep = self.adaln_single(
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
if self.caption_projection is not None:
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
return hidden_states, encoder_hidden_states, timestep, embedded_timestep
def _get_output_for_continuous_inputs(self, hidden_states, residual, batch_size, height, width, inner_dim):
if not self.use_linear_projection:
hidden_states = (
hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
)
hidden_states = self.proj_out(hidden_states) hidden_states = self.proj_out(hidden_states)
else: else:
hidden_states = self.proj_out(hidden_states) hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() hidden_states = (
hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
)
output = hidden_states + residual output = hidden_states + residual
elif self.is_input_vectorized: return output
def _get_output_for_vectorized_inputs(self, hidden_states):
hidden_states = self.norm_out(hidden_states) hidden_states = self.norm_out(hidden_states)
logits = self.out(hidden_states) logits = self.out(hidden_states)
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels) # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
logits = logits.permute(0, 2, 1) logits = logits.permute(0, 2, 1)
# log(p(x_0)) # log(p(x_0))
output = F.log_softmax(logits.double(), dim=1).float() output = F.log_softmax(logits.double(), dim=1).float()
return output
if self.is_input_patches: def _get_output_for_patched_inputs(
self, hidden_states, timestep, class_labels, embedded_timestep, height=None, width=None
):
if self.config.norm_type != "ada_norm_single": if self.config.norm_type != "ada_norm_single":
conditioning = self.transformer_blocks[0].norm1.emb( conditioning = self.transformer_blocks[0].norm1.emb(
timestep, class_labels, hidden_dtype=hidden_states.dtype timestep, class_labels, hidden_dtype=hidden_states.dtype
...@@ -517,8 +563,4 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -517,8 +563,4 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
output = hidden_states.reshape( output = hidden_states.reshape(
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
) )
return output
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment