Unverified Commit 0a151115 authored by Akshay Babbar's avatar Akshay Babbar Committed by GitHub
Browse files

Fix #12116: preserve boolean dtype for attention masks in ChromaPipeline (#12263)



* fix: preserve boolean dtype for attention masks in ChromaPipeline

- Convert attention masks to bool and prevent dtype corruption
- Fix both positive and negative mask handling in _get_t5_prompt_embeds
- Remove float conversion in _prepare_attention_mask method

Fixes #12116

* test: add ChromaPipeline attention mask dtype tests

* test: add slow ChromaPipeline attention mask tests

* chore: removed comments

* refactor: removing redundant type conversion

* Remove dedicated dtype tests as per  feedback

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 19085ac8
......@@ -238,7 +238,7 @@ class ChromaPipeline(
# Chroma requires the attention mask to include one padding token
seq_lengths = attention_mask.sum(dim=1)
mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long()
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).bool()
prompt_embeds = self.text_encoder(
text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
......@@ -246,7 +246,7 @@ class ChromaPipeline(
dtype = self.text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
attention_mask = attention_mask.to(dtype=dtype, device=device)
attention_mask = attention_mask.to(device=device)
_, seq_len, _ = prompt_embeds.shape
......@@ -605,10 +605,9 @@ class ChromaPipeline(
# Extend the prompt attention mask to account for image tokens in the final sequence
attention_mask = torch.cat(
[attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device)],
[attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device, dtype=torch.bool)],
dim=1,
)
attention_mask = attention_mask.to(dtype)
return attention_mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment