testing.py 336 Bytes
Newer Older
rusty1s's avatar
update  
rusty1s committed
1
2
from typing import Any

rusty1s's avatar
rusty1s committed
3
4
import torch

5
dtypes = [torch.float, torch.double, torch.bfloat16]
rusty1s's avatar
rusty1s committed
6
7

devices = [torch.device('cpu')]
rusty1s's avatar
0.4.1  
rusty1s committed
8
if torch.cuda.is_available():
rusty1s's avatar
update  
rusty1s committed
9
    devices += [torch.device('cuda:0')]
rusty1s's avatar
rusty1s committed
10
11


rusty1s's avatar
update  
rusty1s committed
12
def tensor(x: Any, dtype: torch.dtype, device: torch.device):
rusty1s's avatar
rusty1s committed
13
    return None if x is None else torch.tensor(x, dtype=dtype, device=device)