Unverified Commit 521da651 authored by Jingya HUANG's avatar Jingya HUANG Committed by GitHub
Browse files

Fix gpt2 fp16 training when tracing is enabled (#20656)

* ONNX tracing fix

* Remove conditional
parent 93b54368
...@@ -186,7 +186,7 @@ class DecisionTransformerGPT2Attention(nn.Module): ...@@ -186,7 +186,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device) mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value) attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
if attention_mask is not None: if attention_mask is not None:
# Apply the attention mask # Apply the attention mask
......
...@@ -198,7 +198,7 @@ class GPT2Attention(nn.Module): ...@@ -198,7 +198,7 @@ class GPT2Attention(nn.Module):
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device) mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value) attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
if attention_mask is not None: if attention_mask is not None:
# Apply the attention mask # Apply the attention mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment