Unverified Commit dc72fd7e authored by Ella Charlaix's avatar Ella Charlaix Committed by GitHub
Browse files

Requires for torch.tensor before casting (#31755)

parent 7f91f168
......@@ -762,7 +762,7 @@ def torch_int(x):
import torch
return x.to(torch.int64) if torch.jit.is_tracing() else int(x)
return x.to(torch.int64) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else int(x)
def torch_float(x):
......@@ -774,7 +774,7 @@ def torch_float(x):
import torch
return x.to(torch.float32) if torch.jit.is_tracing() else int(x)
return x.to(torch.float32) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else int(x)
def filter_out_non_signature_kwargs(extra: Optional[list] = None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment