"llm/llama.cpp/ggml/include/ggml-vulkan.h" did not exist on "ecd2f176277db4f074e25a2c3646b04b51cec119"
utils.py 720 Bytes
Newer Older
Carl Case's avatar
Carl Case committed
1
2
3
4
import torch

HALF = 'torch.cuda.HalfTensor'
FLOAT = 'torch.cuda.FloatTensor'
5
BFLOAT16 = 'torch.cuda.BFloat16Tensor'
Carl Case's avatar
Carl Case committed
6
7
8

DTYPES = [torch.half, torch.float]

rohithkrn's avatar
rohithkrn committed
9
10
DTYPES2 = [torch.bfloat16, torch.float]

Carl Case's avatar
Carl Case committed
11
12
ALWAYS_HALF = {torch.float: HALF,
               torch.half: HALF}
13
ALWAYS_BFLOAT16 = {torch.bfloat16: BFLOAT16,
rohithkrn's avatar
rohithkrn committed
14
                   torch.float: BFLOAT16}
Carl Case's avatar
Carl Case committed
15
16
17
ALWAYS_FLOAT = {torch.float: FLOAT,
                torch.half: FLOAT}
MATCH_INPUT = {torch.float: FLOAT,
18
19
               torch.half: HALF,
               torch.bfloat16: BFLOAT16}
Carl Case's avatar
Carl Case committed
20
21
22
23
24
25

def common_init(test_case):
    test_case.h = 64
    test_case.b = 16
    test_case.c = 16
    test_case.k = 3
Carl Case's avatar
Carl Case committed
26
    test_case.t = 10
Carl Case's avatar
Carl Case committed
27
    torch.set_default_tensor_type(torch.cuda.FloatTensor)