Unverified Commit a3b402ff authored by Prathik Rao's avatar Prathik Rao Committed by GitHub
Browse files

llama fp16 torch.max bug fix (#24561)



* open llama fp16 bug fix

* bug fix

* bug fixed

* make style

* Update modeling_llama.py

* apply formatting

* Address amy's comment

---------

Co-authored-by: Prathik Rao <prathikrao@microsoft.com@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: default avatarroot <root@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
parent 4e945660
......@@ -224,9 +224,10 @@ class LlamaAttention(nn.Module):
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
dtype_min = torch.tensor(
torch.finfo(attn_weights.dtype).min, device=attn_weights.device, dtype=attn_weights.dtype
)
attn_weights = torch.max(attn_weights, dtype_min)
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment