Unverified Commit 50e66f2f authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Chore] remove all is from auraflow. (#8980)

remove all is from auraflow.
parent 9b8c8605
...@@ -138,14 +138,14 @@ class AuraFlowSingleTransformerBlock(nn.Module): ...@@ -138,14 +138,14 @@ class AuraFlowSingleTransformerBlock(nn.Module):
self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False) self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
self.ff = AuraFlowFeedForward(dim, dim * 4) self.ff = AuraFlowFeedForward(dim, dim * 4)
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor, i=9999): def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor):
residual = hidden_states residual = hidden_states
# Norm + Projection. # Norm + Projection.
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
# Attention. # Attention.
attn_output = self.attn(hidden_states=norm_hidden_states, i=i) attn_output = self.attn(hidden_states=norm_hidden_states)
# Process attention outputs for the `hidden_states`. # Process attention outputs for the `hidden_states`.
hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output) hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output)
...@@ -201,7 +201,7 @@ class AuraFlowJointTransformerBlock(nn.Module): ...@@ -201,7 +201,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
self.ff_context = AuraFlowFeedForward(dim, dim * 4) self.ff_context = AuraFlowFeedForward(dim, dim * 4)
def forward( def forward(
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, i=0 self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
): ):
residual = hidden_states residual = hidden_states
residual_context = encoder_hidden_states residual_context = encoder_hidden_states
...@@ -214,7 +214,7 @@ class AuraFlowJointTransformerBlock(nn.Module): ...@@ -214,7 +214,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
# Attention. # Attention.
attn_output, context_attn_output = self.attn( attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, i=i hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
) )
# Process attention outputs for the `hidden_states`. # Process attention outputs for the `hidden_states`.
...@@ -366,7 +366,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin): ...@@ -366,7 +366,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
else: else:
encoder_hidden_states, hidden_states = block( encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, i=index_block hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
) )
# Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text) # Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment