Unverified Commit b267d285 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Versatile] fix attention mask (#1763)

parent c7b4acfb
...@@ -959,6 +959,7 @@ class DualTransformer2DModel(nn.Module): ...@@ -959,6 +959,7 @@ class DualTransformer2DModel(nn.Module):
encoded_states = [] encoded_states = []
tokens_start = 0 tokens_start = 0
# attention_mask is not used yet
for i in range(2): for i in range(2):
# for each of the two transformers, pass the corresponding condition tokens # for each of the two transformers, pass the corresponding condition tokens
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]] condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
...@@ -967,7 +968,6 @@ class DualTransformer2DModel(nn.Module): ...@@ -967,7 +968,6 @@ class DualTransformer2DModel(nn.Module):
input_states, input_states,
encoder_hidden_states=condition_state, encoder_hidden_states=condition_state,
timestep=timestep, timestep=timestep,
attention_mask=attention_mask,
return_dict=False, return_dict=False,
)[0] )[0]
encoded_states.append(encoded_state - input_states) encoded_states.append(encoded_state - input_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment