custom.h 309 Bytes
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
#pragma once
// SPDX-License-Identifier: MIT
 
#include <torch/extension.h>

void wvSpltK(at::Tensor &in_a, at::Tensor &in_b, at::Tensor &out_c,
             const int64_t N_in, const int64_t CuCount);

void LLMM1(
    at::Tensor &in_a, at::Tensor &in_b, at::Tensor &out_c,
    const int64_t rows_per_block);