"tests/utils/test_tokenization_utils.py" did not exist on "a611ac9b3f9493a80e2d0adf491f4868c71f71c5"
Unverified Commit 74cae670 authored by Ke Wen's avatar Ke Wen Committed by GitHub
Browse files

Make GPT2 traceable in meta state (#28054)

* Put device in tensor constructor instead of to()

* Fix copy
parent e2b6df79
...@@ -185,7 +185,7 @@ class DecisionTransformerGPT2Attention(nn.Module): ...@@ -185,7 +185,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
mask_value = torch.finfo(attn_weights.dtype).min mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device) mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
if attention_mask is not None: if attention_mask is not None:
......
...@@ -198,7 +198,7 @@ class GPT2Attention(nn.Module): ...@@ -198,7 +198,7 @@ class GPT2Attention(nn.Module):
mask_value = torch.finfo(attn_weights.dtype).min mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device) mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
if attention_mask is not None: if attention_mask is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment