Unverified Commit 54af3ca7 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[chore] allow string device to be passed to randn_tensor. (#11559)

allow string device to be passed to randn_tensor.
parent ba8dc7dc
...@@ -38,7 +38,7 @@ except (ImportError, ModuleNotFoundError): ...@@ -38,7 +38,7 @@ except (ImportError, ModuleNotFoundError):
def randn_tensor( def randn_tensor(
shape: Union[Tuple, List], shape: Union[Tuple, List],
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
device: Optional["torch.device"] = None, device: Optional[Union[str, "torch.device"]] = None,
dtype: Optional["torch.dtype"] = None, dtype: Optional["torch.dtype"] = None,
layout: Optional["torch.layout"] = None, layout: Optional["torch.layout"] = None,
): ):
...@@ -47,6 +47,8 @@ def randn_tensor( ...@@ -47,6 +47,8 @@ def randn_tensor(
is always created on the CPU. is always created on the CPU.
""" """
# device on which tensor is created defaults to device # device on which tensor is created defaults to device
if isinstance(device, str):
device = torch.device(device)
rand_device = device rand_device = device
batch_size = shape[0] batch_size = shape[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment