Unverified Commit f947e703 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Fix bug when deducing dtype in linear functional API (#2017)



Fix bug when deducing dtype in linear functional API
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent bfca2e33
......@@ -426,7 +426,7 @@ class BasicLinear(BasicOperation):
if dtype is None:
if out is not None and isinstance(out, torch.Tensor):
dtype = out.dtype
elif weight is not None and isinstance(out, torch.Tensor):
elif weight is not None and isinstance(weight, torch.Tensor):
dtype = weight.dtype
else:
raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment