Unverified Commit 50e66f2f authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Chore] remove all is from auraflow. (#8980)

remove all is from auraflow.
parent 9b8c8605
......@@ -138,14 +138,14 @@ class AuraFlowSingleTransformerBlock(nn.Module):
self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
self.ff = AuraFlowFeedForward(dim, dim * 4)
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor, i=9999):
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor):
residual = hidden_states
# Norm + Projection.
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
# Attention.
attn_output = self.attn(hidden_states=norm_hidden_states, i=i)
attn_output = self.attn(hidden_states=norm_hidden_states)
# Process attention outputs for the `hidden_states`.
hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output)
......@@ -201,7 +201,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
self.ff_context = AuraFlowFeedForward(dim, dim * 4)
def forward(
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, i=0
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
):
residual = hidden_states
residual_context = encoder_hidden_states
......@@ -214,7 +214,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
# Attention.
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, i=i
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
)
# Process attention outputs for the `hidden_states`.
......@@ -366,7 +366,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, i=index_block
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
)
# Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment