utils.py 512 Bytes
Newer Older
Carl Case's avatar
Carl Case committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch

HALF = 'torch.cuda.HalfTensor'
FLOAT = 'torch.cuda.FloatTensor'

DTYPES = [torch.half, torch.float]

ALWAYS_HALF = {torch.float: HALF,
               torch.half: HALF}
ALWAYS_FLOAT = {torch.float: FLOAT,
                torch.half: FLOAT}
MATCH_INPUT = {torch.float: FLOAT,
               torch.half: HALF}

def common_init(test_case):
    test_case.h = 64
    test_case.b = 16
    test_case.c = 16
    test_case.k = 3
Carl Case's avatar
Carl Case committed
20
    test_case.t = 10
Carl Case's avatar
Carl Case committed
21
    torch.set_default_tensor_type(torch.cuda.FloatTensor)