Unverified Commit 2844c5de authored by Jingya HUANG's avatar Jingya HUANG Committed by GitHub
Browse files

Fix ORTTrainer failure on gpt2 fp16 training (#18017)



* Ensure value and attn weights have the same dtype

* Remove prints

* Modify decision transformers copied from gpt2

* Nit device
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Fix style
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent 2b096508
......@@ -178,7 +178,9 @@ class DecisionTransformerGPT2Attention(nn.Module):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
if self.scale_attn_weights:
attn_weights = attn_weights / (value.size(-1) ** 0.5)
attn_weights = attn_weights / torch.tensor(
value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
)
# Layer-wise attention scaling
if self.scale_attn_by_inverse_layer_idx:
......
......@@ -189,7 +189,9 @@ class GPT2Attention(nn.Module):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
if self.scale_attn_weights:
attn_weights = attn_weights / (value.size(-1) ** 0.5)
attn_weights = attn_weights / torch.tensor(
value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
)
# Layer-wise attention scaling
if self.scale_attn_by_inverse_layer_idx:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment