Unverified Commit b267d285 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Versatile] fix attention mask (#1763)

parent c7b4acfb
......@@ -959,6 +959,7 @@ class DualTransformer2DModel(nn.Module):
encoded_states = []
tokens_start = 0
# attention_mask is not used yet
for i in range(2):
# for each of the two transformers, pass the corresponding condition tokens
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
......@@ -967,7 +968,6 @@ class DualTransformer2DModel(nn.Module):
input_states,
encoder_hidden_states=condition_state,
timestep=timestep,
attention_mask=attention_mask,
return_dict=False,
)[0]
encoded_states.append(encoded_state - input_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment