Unverified Commit b6156aaf authored by AstraliteHeart's avatar AstraliteHeart Committed by GitHub
Browse files

Rewrite AuraFlowPatchEmbed.pe_selection_index_based_on_dim to be torch.compile compatible (#11297)



* Update pe_selection_index_based_on_dim

* Make pe_selection_index_based_on_dim work with torh.compile

* Fix AuraFlowTransformer2DModel's dpcstring default values

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 7ecfe291
...@@ -74,15 +74,23 @@ class AuraFlowPatchEmbed(nn.Module): ...@@ -74,15 +74,23 @@ class AuraFlowPatchEmbed(nn.Module):
# PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected # PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected
# because original input are in flattened format, we have to flatten this 2d grid as well. # because original input are in flattened format, we have to flatten this 2d grid as well.
h_p, w_p = h // self.patch_size, w // self.patch_size h_p, w_p = h // self.patch_size, w // self.patch_size
original_pe_indexes = torch.arange(self.pos_embed.shape[1])
h_max, w_max = int(self.pos_embed_max_size**0.5), int(self.pos_embed_max_size**0.5) h_max, w_max = int(self.pos_embed_max_size**0.5), int(self.pos_embed_max_size**0.5)
original_pe_indexes = original_pe_indexes.view(h_max, w_max)
# Calculate the top-left corner indices for the centered patch grid
starth = h_max // 2 - h_p // 2 starth = h_max // 2 - h_p // 2
endh = starth + h_p
startw = w_max // 2 - w_p // 2 startw = w_max // 2 - w_p // 2
endw = startw + w_p
original_pe_indexes = original_pe_indexes[starth:endh, startw:endw] # Generate the row and column indices for the desired patch grid
return original_pe_indexes.flatten() rows = torch.arange(starth, starth + h_p, device=self.pos_embed.device)
cols = torch.arange(startw, startw + w_p, device=self.pos_embed.device)
# Create a 2D grid of indices
row_indices, col_indices = torch.meshgrid(rows, cols, indexing="ij")
# Convert the 2D grid indices to flattened 1D indices
selected_indices = (row_indices * w_max + col_indices).flatten()
return selected_indices
def forward(self, latent): def forward(self, latent):
batch_size, num_channels, height, width = latent.size() batch_size, num_channels, height, width = latent.size()
...@@ -275,17 +283,17 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From ...@@ -275,17 +283,17 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
sample_size (`int`): The width of the latent images. This is fixed during training since sample_size (`int`): The width of the latent images. This is fixed during training since
it is used to learn a number of position embeddings. it is used to learn a number of position embeddings.
patch_size (`int`): Patch size to turn the input data into small patches. patch_size (`int`): Patch size to turn the input data into small patches.
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. in_channels (`int`, *optional*, defaults to 4): The number of channels in the input.
num_mmdit_layers (`int`, *optional*, defaults to 4): The number of layers of MMDiT Transformer blocks to use. num_mmdit_layers (`int`, *optional*, defaults to 4): The number of layers of MMDiT Transformer blocks to use.
num_single_dit_layers (`int`, *optional*, defaults to 4): num_single_dit_layers (`int`, *optional*, defaults to 32):
The number of layers of Transformer blocks to use. These blocks use concatenated image and text The number of layers of Transformer blocks to use. These blocks use concatenated image and text
representations. representations.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. attention_head_dim (`int`, *optional*, defaults to 256): The number of channels in each head.
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. num_attention_heads (`int`, *optional*, defaults to 12): The number of heads to use for multi-head attention.
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`. caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
out_channels (`int`, defaults to 16): Number of output channels. out_channels (`int`, defaults to 4): Number of output channels.
pos_embed_max_size (`int`, defaults to 4096): Maximum positions to embed from the image latents. pos_embed_max_size (`int`, defaults to 1024): Maximum positions to embed from the image latents.
""" """
_no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"] _no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment