utils.py 667 Bytes
Newer Older
Carl Case's avatar
Carl Case committed
1
2
3
4
import torch

HALF = 'torch.cuda.HalfTensor'
FLOAT = 'torch.cuda.FloatTensor'
5
BFLOAT16 = 'torch.cuda.BFloat16Tensor'
Carl Case's avatar
Carl Case committed
6
7
8
9
10

DTYPES = [torch.half, torch.float]

ALWAYS_HALF = {torch.float: HALF,
               torch.half: HALF}
11
12
ALWAYS_BFLOAT16 = {torch.bfloat16: BFLOAT16,
				   torch.float: BFLOAT16}
Carl Case's avatar
Carl Case committed
13
14
15
ALWAYS_FLOAT = {torch.float: FLOAT,
                torch.half: FLOAT}
MATCH_INPUT = {torch.float: FLOAT,
16
17
               torch.half: HALF,
               torch.bfloat16: BFLOAT16}
Carl Case's avatar
Carl Case committed
18
19
20
21
22
23

def common_init(test_case):
    test_case.h = 64
    test_case.b = 16
    test_case.c = 16
    test_case.k = 3
Carl Case's avatar
Carl Case committed
24
    test_case.t = 10
Carl Case's avatar
Carl Case committed
25
    torch.set_default_tensor_type(torch.cuda.FloatTensor)