Unverified Commit 29294b0e authored by zspo's avatar zspo Committed by GitHub
Browse files

Fix tensor device while attention_mask is not None (#23538)

* Fix tensor device while attention_mask is not None

* Fix tensor device while attention_mask is not None
parent 12ec7f0c
...@@ -226,7 +226,9 @@ class LlamaAttention(nn.Module): ...@@ -226,7 +226,9 @@ class LlamaAttention(nn.Module):
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
) )
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
)
# upcast attention to fp32 # upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
......
...@@ -249,7 +249,9 @@ class OpenLlamaAttention(nn.Module): ...@@ -249,7 +249,9 @@ class OpenLlamaAttention(nn.Module):
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
) )
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
)
# upcast attention to fp32 # upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
......
...@@ -222,7 +222,9 @@ class OPTAttention(nn.Module): ...@@ -222,7 +222,9 @@ class OPTAttention(nn.Module):
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
) )
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
......
...@@ -315,7 +315,9 @@ class XGLMAttention(nn.Module): ...@@ -315,7 +315,9 @@ class XGLMAttention(nn.Module):
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
) )
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment