"git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "ee6a4de2a4f396b532909576d0d02f78ab33799c"
Commit ce649d61 authored by comfyanonymous's avatar comfyanonymous
Browse files

Allow zeroing out of embeds with unused attention mask.

parent b4c2d03d
...@@ -169,7 +169,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): ...@@ -169,7 +169,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens = torch.LongTensor(tokens).to(device) tokens = torch.LongTensor(tokens).to(device)
attention_mask = None attention_mask = None
if self.enable_attention_masks: if self.enable_attention_masks or self.zero_out_masked:
attention_mask = torch.zeros_like(tokens) attention_mask = torch.zeros_like(tokens)
end_token = self.special_tokens.get("end", -1) end_token = self.special_tokens.get("end", -1)
for x in range(attention_mask.shape[0]): for x in range(attention_mask.shape[0]):
...@@ -178,7 +178,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): ...@@ -178,7 +178,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if tokens[x, y] == end_token: if tokens[x, y] == end_token:
break break
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state) attention_mask_model = None
if self.enable_attention_masks:
attention_mask_model = attention_mask
outputs = self.transformer(tokens, attention_mask_model, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
self.transformer.set_input_embeddings(backup_embeds) self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "last": if self.layer == "last":
...@@ -186,7 +190,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): ...@@ -186,7 +190,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
else: else:
z = outputs[1].float() z = outputs[1].float()
if self.zero_out_masked and attention_mask is not None: if self.zero_out_masked:
z *= attention_mask.unsqueeze(-1).float() z *= attention_mask.unsqueeze(-1).float()
pooled_output = None pooled_output = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment