grouped_gemm_ck.h 478 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
// SPDX-License-Identifier: MIT

#pragma once

#include <torch/extension.h>

#include <vector>

std::vector<torch::Tensor>
ck_grouped_gemm(std::vector<torch::Tensor>& a_tensors, std::vector<torch::Tensor>& b_tensors);

std::vector<torch::Tensor> ck_grouped_gemm_out(std::vector<torch::Tensor>& a_tensors,
                                               std::vector<torch::Tensor>& b_tensors,
                                               std::vector<torch::Tensor>& c_tensors);