"torch_harmonics/_neighborhood_attention.py" did not exist on "b3816ebc7d5336069dfc389d2740c5153a1add9f"
q_gemm.cuh 512 Bytes
Newer Older
ilyas@huggingface.co's avatar
ilyas@huggingface.co committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#ifndef _q_gemm_cuh
#define _q_gemm_cuh

#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#include <ATen/cuda/CUDAContext.h>

#include "q_matrix.cuh"

void gemm_half_q_half_cuda
(
    cublasHandle_t cublas_handle,
    const half* a,
    QMatrix* b,
    half* c,
    int size_m,
    int size_n,
    int size_k,
    bool clear = false,
    half* reconstruct = NULL,
    bool force_cuda = false
);

void clear_tensor_cuda
(
    half* c,
    int size_m,
    int size_n
);

#endif