Unverified Commit e77735ef authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #65 from YdrMaster/main

issue/63 重构算子定义文件结构及风格修改
parents b7893d65 3144cc9c
#include "matmul_bang.h"
#include "../../../devices/bang/bang_handle.h"
#include "../../../devices/bang/common_bang.h"
#include "../../utils.h"
#include <cnnl_extra.h>
namespace matmul::bang {
struct Descriptor::Opaque {
cnnlMatMulDescriptor_t op;
cnnlMatMulAlgo_t algo;
cnnlMatMulHeuristicResult_t algoResult;
cnnlTensorDescriptor_t a, b, c;
std::shared_ptr<Pool<cnnlHandle_t>> cnnl_handle_pool;
~Opaque() {
cnnlDestroyTensorDescriptor(a);
cnnlDestroyTensorDescriptor(b);
cnnlDestroyTensorDescriptor(c);
cnnlMatMulDescDestroy(op);
cnnlMatMulAlgoDestroy(algo);
cnnlDestroyMatMulHeuristicResult(algoResult);
}
};
static void setMatrixTensorEx(
cnnlTensorDescriptor_t desc,
const BlasMatrix &matrix, infiniDtype_t dtype,
bool trans = false) {
int ndim = matrix.ndim;
int batch = matrix.batch;
int stride = static_cast<int>(matrix.stride);
int rows = matrix.rows;
int cols = matrix.cols;
int row_stride = matrix.row_stride;
int col_stride = matrix.col_stride;
switch (ndim) {
case 3: {
std::vector<int> dim_size = {batch, rows, cols};
std::vector<int> dim_stride = {stride, row_stride, col_stride};
cnnlSetTensorDescriptorEx(
desc, CNNL_LAYOUT_ARRAY,
cnnlDataTypeConvert(dtype), dim_size.size(),
dim_size.data(), dim_stride.data());
} break;
case 2: {
std::vector<int> dim_size = {rows, cols};
std::vector<int> dim_stride = {row_stride, col_stride};
cnnlSetTensorDescriptorEx(
desc, CNNL_LAYOUT_ARRAY,
cnnlDataTypeConvert(dtype), dim_size.size(),
dim_size.data(), dim_stride.data());
} break;
}
}
Descriptor::~Descriptor() {
delete _opaque;
}
infiniopStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {
auto handle = reinterpret_cast<infiniopBangHandle_t>(handle_);
auto dtype = c_desc->dtype;
if (dtype != INFINI_DTYPE_F16 && dtype != INFINI_DTYPE_F32) {
return INFINIOP_STATUS_BAD_TENSOR_DTYPE;
}
infiniopStatus_t status;
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::ROW_MAJOR);
if (status != INFINIOP_STATUS_SUCCESS) {
return status;
}
cnnlTensorDescriptor_t a, b, c;
cnnlCreateTensorDescriptor(&a);
cnnlCreateTensorDescriptor(&b);
cnnlCreateTensorDescriptor(&c);
setMatrixTensorEx(a, info.a_matrix, a_desc->dtype);
setMatrixTensorEx(b, info.b_matrix, b_desc->dtype);
setMatrixTensorEx(c, info.c_matrix, c_desc->dtype);
cnnlMatMulDescriptor_t op;
cnnlMatMulAlgo_t algo;
cnnlMatMulHeuristicResult_t algoResult;
cnnlMatMulDescCreate(&op);
cnnlMatMulAlgoCreate(&algo);
cnnlCreateMatMulHeuristicResult(&algoResult);
int32_t use_stride = true;
cnnlSetMatMulDescAttr(
op,
CNNL_MATMUL_USE_STRIDE,
&use_stride,
sizeof(int32_t));
int count = 0;
use_cnnl(handle->cnnl_handle_pool,
[&](cnnlHandle_t _handle) {
cnnlGetBatchMatMulAlgoHeuristic(
_handle,
op, a, b, c,
NULL, 1, &algoResult, &count);
});
size_t workspace_size;
cnnlGetBatchMatMulHeuristicResult(algoResult, algo, &workspace_size);
*desc_ptr = new Descriptor(
dtype, info, workspace_size,
new Opaque{
op,
algo,
algoResult,
a,
b,
c,
handle->cnnl_handle_pool},
handle->device, handle->device_id);
return INFINIOP_STATUS_SUCCESS;
}
infiniopStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *c,
float beta,
const void *a,
const void *b,
float alpha,
void *stream) const {
if (_info.is_transed) {
std::swap(a, b);
}
use_cnnl(_opaque->cnnl_handle_pool,
(cnrtQueue_t)stream,
[&](cnnlHandle_t handle) {
cnnlBatchMatMulBCast_v2(
handle,
_opaque->op,
_opaque->algo,
&alpha,
_opaque->a, a,
_opaque->b, b,
&beta,
_opaque->c, c,
workspace,
workspace_size);
});
cnrtQueueSync((cnrtQueue_t)stream);
return INFINIOP_STATUS_SUCCESS;
}
} // namespace matmul::bang
#ifndef __MATMUL_BANG_H__
#define __MATMUL_BANG_H__
#include "../matmul.h"
DESCRIPTOR(bang)
#endif // __MATMUL_BANG_H__
#include "matmul_cnnl.h"
#include "../../../devices/bang/common_bang.h"
#include "../../utils.h"
#include "matmul_cnnl_api.h"
infiniopStatus_t bangCreateMatmulDescriptor(
infiniopBangHandle_t handle, infiniopMatmulBangDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {
infiniopStatus_t status;
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, false);
if (status != INFINIOP_STATUS_SUCCESS) {
return status;
}
cnnlTensorDescriptor_t aDesc, bDesc, cDesc;
cnnlCreateTensorDescriptor(&aDesc);
cnnlCreateTensorDescriptor(&bDesc);
cnnlCreateTensorDescriptor(&cDesc);
setMatrixTensorEx(aDesc, info.a_matrix, a_desc->dtype);
setMatrixTensorEx(bDesc, info.b_matrix, b_desc->dtype);
setMatrixTensorEx(cDesc, info.c_matrix, c_desc->dtype);
cnnlMatMulDescriptor_t opDesc;
cnnlMatMulAlgo_t algo;
cnnlMatMulHeuristicResult_t algoResult;
cnnlMatMulDescCreate(&opDesc);
cnnlMatMulAlgoCreate(&algo);
cnnlCreateMatMulHeuristicResult(&algoResult);
int32_t use_stride = true;
cnnlSetMatMulDescAttr(opDesc, CNNL_MATMUL_USE_STRIDE, &use_stride,
sizeof(int32_t));
int count = 0;
use_cnnl(handle->cnnl_handle_pool, [&](cnnlHandle_t _handle) {
cnnlGetBatchMatMulAlgoHeuristic(_handle, opDesc, aDesc, bDesc, cDesc,
NULL, 1, &algoResult, &count);
});
size_t workspace_size;
cnnlGetBatchMatMulHeuristicResult(algoResult, algo, &workspace_size);
*desc_ptr = new InfiniopMatmulBangDescriptor{handle->device,
handle->device_id,
info,
c_desc->dtype,
handle->cnnl_handle_pool,
aDesc,
bDesc,
cDesc,
opDesc,
algo,
algoResult,
workspace_size};
return INFINIOP_STATUS_SUCCESS;
}
infiniopStatus_t bangGetMatmulWorkspaceSize(infiniopMatmulBangDescriptor_t desc,
size_t *size) {
*size = desc->workspace_size;
return INFINIOP_STATUS_SUCCESS;
}
infiniopStatus_t
bangDestroyMatmulDescriptor(infiniopMatmulBangDescriptor_t desc) {
desc->cnnl_handle_pool = nullptr;
cnnlDestroyTensorDescriptor(desc->aDesc);
cnnlDestroyTensorDescriptor(desc->bDesc);
cnnlDestroyTensorDescriptor(desc->cDesc);
cnnlMatMulDescDestroy(desc->opDesc);
cnnlMatMulAlgoDestroy(desc->algo);
cnnlDestroyMatMulHeuristicResult(desc->algoResult);
delete desc;
return INFINIOP_STATUS_SUCCESS;
}
void bangMatmulCnnl(infiniopMatmulBangDescriptor_t desc, void *workspace, void *c,
float beta, void const *a, void const *b, float alpha,
void *stream) {
auto info = desc->info;
if (info.is_transed) {
std::swap(a, b);
}
use_cnnl(desc->cnnl_handle_pool, (cnrtQueue_t)stream, [&](cnnlHandle_t handle) {
cnnlBatchMatMulBCast_v2(handle, desc->opDesc, desc->algo, &alpha,
desc->aDesc, a, desc->bDesc, b, &beta,
desc->cDesc, c, workspace,
desc->workspace_size);
});
}
infiniopStatus_t bangMatmul(infiniopMatmulBangDescriptor_t desc,
void *workspace, size_t workspace_size, void *c,
void const *a, void const *b, float alpha,
float beta, void *stream) {
if (desc->dtype == INFINI_DTYPE_F16 || desc->dtype == INFINI_DTYPE_F32) {
bangMatmulCnnl(desc, workspace, c, beta, a, b, alpha, stream);
cnrtQueueSync((cnrtQueue_t)stream);
return INFINIOP_STATUS_SUCCESS;
}
return INFINIOP_STATUS_BAD_TENSOR_DTYPE;
}
#ifndef __CNNL_MATMUL_H__
#define __CNNL_MATMUL_H__
#include "../../../devices/bang/common_bang.h"
#include "../blas.h"
#include "cnnl_extra.h"
struct InfiniopMatmulBangDescriptor {
infiniDevice_t device;
int device_id;
MatmulInfo info;
infiniDtype_t dtype;
std::shared_ptr<Pool<cnnlHandle_t>> cnnl_handle_pool;
cnnlTensorDescriptor_t aDesc;
cnnlTensorDescriptor_t bDesc;
cnnlTensorDescriptor_t cDesc;
cnnlMatMulDescriptor_t opDesc;
cnnlMatMulAlgo_t algo;
cnnlMatMulHeuristicResult_t algoResult;
size_t workspace_size;
};
inline void setMatrixTensorEx(cnnlTensorDescriptor_t desc,
const BlasMatrix &matrix, infiniDtype_t dtype,
bool trans = false) {
int ndim = matrix.ndim;
int batch = matrix.batch;
int stride = static_cast<int>(matrix.stride);
int rows = matrix.rows;
int cols = matrix.cols;
int row_stride = matrix.row_stride;
int col_stride = matrix.col_stride;
if (ndim == 3) {
std::vector<int> dim_size = {batch, rows, cols};
std::vector<int> dim_stride = {stride, row_stride, col_stride};
cnnlSetTensorDescriptorEx(desc, CNNL_LAYOUT_ARRAY,
cnnlDataTypeConvert(dtype), dim_size.size(),
dim_size.data(), dim_stride.data());
} else if (ndim == 2) {
std::vector<int> dim_size = {rows, cols};
std::vector<int> dim_stride = {row_stride, col_stride};
cnnlSetTensorDescriptorEx(desc, CNNL_LAYOUT_ARRAY,
cnnlDataTypeConvert(dtype), dim_size.size(),
dim_size.data(), dim_stride.data());
}
}
#endif // __CNNL_MATMUL_H__
#ifndef __CNNL_MATMUL_API_H__
#define __CNNL_MATMUL_API_H__
#include "../../../devices/bang/bang_handle.h"
#include "infiniop/operator.h"
struct InfiniopMatmulBangDescriptor;
typedef struct InfiniopMatmulBangDescriptor *infiniopMatmulBangDescriptor_t;
infiniopStatus_t bangCreateMatmulDescriptor(
infiniopBangHandle_t handle, infiniopMatmulBangDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc);
infiniopStatus_t bangGetMatmulWorkspaceSize(infiniopMatmulBangDescriptor_t desc,
size_t *size);
infiniopStatus_t bangMatmul(infiniopMatmulBangDescriptor_t desc,
void *workspace, size_t workspace_size, void *c,
void const *a, void const *b, float alpha,
float beta, void *stream);
infiniopStatus_t
bangDestroyMatmulDescriptor(infiniopMatmulBangDescriptor_t desc);
#endif
#ifndef __BLAS_H__
#define __BLAS_H__
#include "../utils.h"
#include "infiniop/operator.h"
#include <algorithm>
#include <stdint.h>
typedef struct BlasMatrix {
namespace matmul {
struct BlasMatrix {
size_t ndim;
size_t batch;
ptrdiff_t stride;
......@@ -15,31 +14,31 @@ typedef struct BlasMatrix {
ptrdiff_t row_stride;
ptrdiff_t col_stride;
BlasMatrix() {}
BlasMatrix() = default;
BlasMatrix(infiniopTensorDescriptor_t layout, infiniopStatus_t *status) {
if (layout->ndim == 2) {
this->ndim = 2;
this->batch = 1;
this->stride = 0;
this->rows = layout->shape[0];
this->cols = layout->shape[1];
this->row_stride = layout->strides[0];
this->col_stride = layout->strides[1];
ndim = 2;
batch = 1;
stride = 0;
rows = layout->shape[0];
cols = layout->shape[1];
row_stride = layout->strides[0];
col_stride = layout->strides[1];
} else if (layout->ndim == 3) {
this->ndim = 3;
this->batch = layout->shape[0];
this->stride = this->batch == 1 ? 0 : layout->strides[0];
this->rows = layout->shape[1];
this->cols = layout->shape[2];
this->row_stride = layout->strides[1];
this->col_stride = layout->strides[2];
ndim = 3;
batch = layout->shape[0];
stride = batch == 1 ? 0 : layout->strides[0];
rows = layout->shape[1];
cols = layout->shape[2];
row_stride = layout->strides[1];
col_stride = layout->strides[2];
} else {
*status = INFINIOP_STATUS_BAD_TENSOR_SHAPE;
return;
}
if (this->row_stride != 1 && this->col_stride != 1) {
if (row_stride != 1 && col_stride != 1) {
*status = INFINIOP_STATUS_BAD_TENSOR_STRIDES;
return;
}
......@@ -48,7 +47,7 @@ typedef struct BlasMatrix {
}
bool match_batch(size_t _batch) const {
return this->batch == _batch || this->batch == 1;
return batch == _batch || batch == 1;
}
void transpose() {
......@@ -57,13 +56,14 @@ typedef struct BlasMatrix {
}
ptrdiff_t ld() const {
if (this->row_stride == 1) {
return this->col_stride;
} else {
return this->row_stride;
}
return row_stride == 1 ? col_stride : row_stride;
}
} BlasMatrix;
};
enum class MatrixLayout : char {
COL_MAJOR,
ROW_MAJOR,
};
struct MatmulInfo {
BlasMatrix a_matrix;
......@@ -74,7 +74,11 @@ struct MatmulInfo {
bool is_transed = false;
MatmulInfo(infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc, infiniopStatus_t *status, bool col_major = true) {
MatmulInfo(infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopStatus_t *status,
MatrixLayout layout) {
a_matrix = BlasMatrix(a_desc, status);
if (*status != INFINIOP_STATUS_SUCCESS) {
return;
......@@ -99,7 +103,8 @@ struct MatmulInfo {
return;
}
if ((col_major && c_matrix.col_stride == 1) || (!col_major && c_matrix.row_stride == 1)) {
if ((layout == MatrixLayout::COL_MAJOR && c_matrix.col_stride == 1)
|| (layout == MatrixLayout::ROW_MAJOR && c_matrix.row_stride == 1)) {
c_matrix.transpose();
b_matrix.transpose();
a_matrix.transpose();
......@@ -112,5 +117,6 @@ struct MatmulInfo {
k = a_matrix.cols;
}
};
} // namespace matmul
#endif // __BLAS_H__
#include "./matmul_cpu.h"
#include "matmul_cpu.h"
#include "../../../devices/cpu/common_cpu.h"
#include "../../utils.h"
#include <cmath>
#include "../../../devices/cpu/cpu_handle.h"
infiniopStatus_t cpuCreateMatmulDescriptor(
infiniopCpuHandle_t handle, infiniopMatmulCpuDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t a_desc,
namespace matmul::cpu {
Descriptor::~Descriptor() = default;
infiniopStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {
infiniDtype_t dtype = c_desc->dtype;
auto handle = reinterpret_cast<infiniopCpuHandle_t>(handle_);
auto dtype = c_desc->dtype;
if (dtype != INFINI_DTYPE_F16 && dtype != INFINI_DTYPE_F32) {
return INFINIOP_STATUS_BAD_TENSOR_DTYPE;
}
infiniopStatus_t status;
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status);
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::COL_MAJOR);
if (status != INFINIOP_STATUS_SUCCESS) {
return status;
}
*desc_ptr = new MatmulCpuDescriptor{INFINI_DEVICE_CPU, dtype, info};
return INFINIOP_STATUS_SUCCESS;
}
infiniopStatus_t cpuGetMatmulWorkspaceSize(infiniopMatmulCpuDescriptor_t desc,
size_t *size) {
*size = 0;
return INFINIOP_STATUS_SUCCESS;
}
infiniopStatus_t
cpuDestroyMatmulDescriptor(infiniopMatmulCpuDescriptor_t desc) {
delete desc;
*desc_ptr = new Descriptor(
dtype, info, 0,
nullptr,
handle->device, handle->device_id);
return INFINIOP_STATUS_SUCCESS;
}
template <typename Tdata>
infiniopStatus_t cpuCalculateMatmul(infiniopMatmulCpuDescriptor_t desc, void *c,
float beta, void const *a, void const *b,
float alpha) {
auto info = desc->info;
void calculate(
const MatmulInfo &info,
void *c,
float beta,
const void *a,
const void *b,
float alpha) {
if (info.is_transed) {
std::swap(a, b);
}
......@@ -52,8 +50,8 @@ infiniopStatus_t cpuCalculateMatmul(infiniopMatmulCpuDescriptor_t desc, void *c,
auto c_ = reinterpret_cast<Tdata *>(c) + i * info.c_matrix.stride + m_ * info.c_matrix.row_stride + n_ * info.c_matrix.col_stride;
float sum = 0;
for (size_t k_ = 0; k_ < info.k; ++k_) {
auto a_ = reinterpret_cast<Tdata const *>(a) + i * info.a_matrix.stride + m_ * info.a_matrix.row_stride + k_ * info.a_matrix.col_stride;
auto b_ = reinterpret_cast<Tdata const *>(b) + i * info.b_matrix.stride + n_ * info.b_matrix.col_stride + k_ * info.b_matrix.row_stride;
auto a_ = reinterpret_cast<const Tdata *>(a) + i * info.a_matrix.stride + m_ * info.a_matrix.row_stride + k_ * info.a_matrix.col_stride;
auto b_ = reinterpret_cast<const Tdata *>(b) + i * info.b_matrix.stride + n_ * info.b_matrix.col_stride + k_ * info.b_matrix.row_stride;
if constexpr (std::is_same<Tdata, uint16_t>::value) {
sum += f16_to_f32(*a_) * f16_to_f32(*b_);
} else {
......@@ -72,17 +70,30 @@ infiniopStatus_t cpuCalculateMatmul(infiniopMatmulCpuDescriptor_t desc, void *c,
}
}
}
return INFINIOP_STATUS_SUCCESS;
}
infiniopStatus_t cpuMatmul(infiniopMatmulCpuDescriptor_t desc, void *workspace,
size_t workspace_size, void *c, void const *a,
void const *b, float alpha, float beta) {
if (desc->dtype == INFINI_DTYPE_F16) {
return cpuCalculateMatmul<uint16_t>(desc, c, beta, a, b, alpha);
}
if (desc->dtype == INFINI_DTYPE_F32) {
return cpuCalculateMatmul<float>(desc, c, beta, a, b, alpha);
infiniopStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *c,
float beta,
const void *a,
const void *b,
float alpha,
void *stream) const {
switch (_dtype) {
case INFINI_DTYPE_F16:
cpu::calculate<uint16_t>(_info, c, beta, a, b, alpha);
return INFINIOP_STATUS_SUCCESS;
case INFINI_DTYPE_F32:
cpu::calculate<float>(_info, c, beta, a, b, alpha);
return INFINIOP_STATUS_SUCCESS;
default:
return INFINIOP_STATUS_BAD_TENSOR_DTYPE;
}
return INFINIOP_STATUS_BAD_TENSOR_DTYPE;
}
} // namespace matmul::cpu
#ifndef __INFINIOP_MATMUL_CPU_H__
#define __INFINIOP_MATMUL_CPU_H__
#ifndef __MATMUL_CPU_H__
#define __MATMUL_CPU_H__
#include "../blas.h"
#include "./matmul_cpu_api.h"
#include "../matmul.h"
typedef struct MatmulCpuDescriptor {
infiniDevice_t device;
infiniDtype_t dtype;
MatmulInfo info;
} MatmulCpuDescriptor;
DESCRIPTOR(cpu)
#endif // __INFINIOP_MATMUL_CPU_H__
#endif // __MATMUL_CPU_H__
#ifndef __INFINIOP_MATMUL_CPU_API_H__
#define __INFINIOP_MATMUL_CPU_API_H__
#include "../../../devices/cpu/cpu_handle.h"
#include "infiniop/operator.h"
struct MatmulCpuDescriptor;
typedef struct MatmulCpuDescriptor *infiniopMatmulCpuDescriptor_t;
infiniopStatus_t cpuCreateMatmulDescriptor(
infiniopCpuHandle_t handle, infiniopMatmulCpuDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc);
infiniopStatus_t cpuGetMatmulWorkspaceSize(infiniopMatmulCpuDescriptor_t desc,
size_t *size);
infiniopStatus_t cpuMatmul(infiniopMatmulCpuDescriptor_t desc, void *workspace,
size_t workspace_size, void *c, void const *a,
void const *b, float alpha, float beta);
infiniopStatus_t cpuDestroyMatmulDescriptor(infiniopMatmulCpuDescriptor_t desc);
#endif // __INFINIOP_MATMUL_CPU_API_H__
#include "../../../devices/cuda/common_cuda.cuh"
#include "../../utils.h"
#include "./matmul_cuda.cuh"
#include "matmul_cuda.cuh"
infiniopStatus_t cudaCreateMatmulDescriptor(infiniopCudaHandle_t handle,
infiniopMatmulCudaDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {
infiniDtype_t dtype = c_desc->dtype;
namespace matmul::cuda {
struct Descriptor::Opaque {
std::shared_ptr<Pool<cublasHandle_t>> cublas_handle_pool;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniopStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {
auto handle = reinterpret_cast<infiniopCudaHandle_t>(handle_);
auto dtype = c_desc->dtype;
if (dtype != INFINI_DTYPE_F16 && dtype != INFINI_DTYPE_F32) {
return INFINIOP_STATUS_BAD_TENSOR_DTYPE;
}
infiniopStatus_t status;
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status);
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::COL_MAJOR);
if (status != INFINIOP_STATUS_SUCCESS) {
return status;
}
*desc_ptr = new InfiniopMatmulCudaDescriptor{
handle->device,
dtype,
handle->device_id,
info,
handle->cublas_handle_pool};
*desc_ptr = new Descriptor(
dtype, info, 0,
new Opaque{handle->cublas_handle_pool},
handle->device, handle->device_id);
return INFINIOP_STATUS_SUCCESS;
}
infiniopStatus_t cudaGetMatmulWorkspaceSize(infiniopMatmulCudaDescriptor_t desc, size_t *size) {
*size = 0;
return INFINIOP_STATUS_SUCCESS;
template <typename Tdata>
void calculate(
const MatmulInfo &info,
std::shared_ptr<Pool<cublasHandle_t>> &cublas_handle_pool,
void *c,
float beta,
const void *a,
const void *b,
float alpha,
cudaStream_t stream) {
if (info.is_transed) {
std::swap(a, b);
}
cudaDataType a_type, b_type, c_type;
cublasComputeType_t compute_type;
if constexpr (std::is_same<Tdata, half>::value) {
a_type = b_type = c_type = CUDA_R_16F;
compute_type = CUBLAS_COMPUTE_32F;
} else {
a_type = b_type = c_type = CUDA_R_32F;
#ifdef ENABLE_SUGON_CUDA_API
compute_type = CUBLAS_COMPUTE_32F;
#else
compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
#endif
}
auto op_a = info.a_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T;
auto op_b = info.b_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T;
use_cublas(cublas_handle_pool,
stream,
[&](cublasHandle_t handle) {
cublasGemmStridedBatchedEx(
handle,
op_a,
op_b,
static_cast<int>(info.m),
static_cast<int>(info.n),
static_cast<int>(info.k),
&alpha,
a,
a_type,
static_cast<int>(info.a_matrix.ld()),
info.a_matrix.stride,
b,
b_type,
static_cast<int>(info.b_matrix.ld()),
info.b_matrix.stride,
&beta,
c,
c_type,
static_cast<int>(info.c_matrix.ld()),
info.c_matrix.stride,
static_cast<int>(info.batch),
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
});
}
infiniopStatus_t cudaDestroyMatmulDescriptor(infiniopMatmulCudaDescriptor_t desc) {
desc->cublas_handle_pool = nullptr;
delete desc;
return INFINIOP_STATUS_SUCCESS;
infiniopStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *c,
float beta,
const void *a,
const void *b,
float alpha,
void *stream) const {
switch (_dtype) {
case INFINI_DTYPE_F16:
cuda::calculate<uint16_t>(_info, _opaque->cublas_handle_pool, c, beta, a, b, alpha, (cudaStream_t)stream);
return INFINIOP_STATUS_SUCCESS;
case INFINI_DTYPE_F32:
cuda::calculate<float>(_info, _opaque->cublas_handle_pool, c, beta, a, b, alpha, (cudaStream_t)stream);
return INFINIOP_STATUS_SUCCESS;
default:
return INFINIOP_STATUS_BAD_TENSOR_DTYPE;
}
}
} // namespace matmul::cuda
#ifndef __INFINIOP_MATMUL_CUDA_H__
#define __INFINIOP_MATMUL_CUDA_H__
#ifndef __MATMUL_CUDA_CUH__
#define __MATMUL_CUDA_CUH__
#include "../../../devices/cuda/common_cuda.cuh"
#include "../blas.h"
#include "matmul_cuda_api.h"
#include <memory>
#include "../matmul.h"
typedef struct InfiniopMatmulCudaDescriptor {
infiniDevice_t device;
infiniDtype_t dtype;
int device_id;
MatmulInfo info;
std::shared_ptr<Pool<cublasHandle_t>> cublas_handle_pool;
} InfiniopMatmulCudaDescriptor;
DESCRIPTOR(cuda)
#endif // __INFINIOP_MATMUL_CUDA_H__
#endif // __MATMUL_CUDA_CUH__
#ifndef __INFINIOP_MATMUL_CUDA_API_H__
#define __INFINIOP_MATMUL_CUDA_API_H__
#include "../../../devices/cuda/cuda_handle.h"
#include "infiniop/operator.h"
struct InfiniopMatmulCudaDescriptor;
typedef struct InfiniopMatmulCudaDescriptor *infiniopMatmulCudaDescriptor_t;
infiniopStatus_t cudaCreateMatmulDescriptor(infiniopCudaHandle_t handle,
infiniopMatmulCudaDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc);
infiniopStatus_t cudaGetMatmulWorkspaceSize(infiniopMatmulCudaDescriptor_t desc, size_t *size);
infiniopStatus_t cudaMatmul(infiniopMatmulCudaDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *c,
void const *a,
void const *b,
float alpha,
float beta,
void *stream);
infiniopStatus_t cudaDestroyMatmulDescriptor(infiniopMatmulCudaDescriptor_t desc);
#endif // __INFINIOP_MATMUL_CUDA_API_H__
#include "../../utils.h"
#include "./matmul_cuda.cuh"
template <typename Tdata>
infiniopStatus_t cudaMatmulCublas(infiniopMatmulCudaDescriptor_t desc, void *c, float beta, void const *a, void const *b, float alpha, void *stream) {
auto info = desc->info;
if (info.is_transed) {
std::swap(a, b);
}
cudaDataType a_type, b_type, c_type;
cublasComputeType_t compute_type;
if constexpr (std::is_same<Tdata, half>::value) {
a_type = b_type = c_type = CUDA_R_16F;
compute_type = CUBLAS_COMPUTE_32F;
} else {
a_type = b_type = c_type = CUDA_R_32F;
#ifdef ENABLE_SUGON_CUDA_API
compute_type = CUBLAS_COMPUTE_32F;
#else
compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
#endif
}
auto op_a = info.a_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T;
auto op_b = info.b_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T;
use_cublas(desc->cublas_handle_pool, desc->device_id, (cudaStream_t)stream,
[&](cublasHandle_t handle) { cublasGemmStridedBatchedEx(
handle,
op_a,
op_b,
static_cast<int>(info.m),
static_cast<int>(info.n),
static_cast<int>(info.k),
&alpha,
a,
a_type,
static_cast<int>(info.a_matrix.ld()),
info.a_matrix.stride,
b,
b_type,
static_cast<int>(info.b_matrix.ld()),
info.b_matrix.stride,
&beta,
c,
c_type,
static_cast<int>(info.c_matrix.ld()),
info.c_matrix.stride,
static_cast<int>(info.batch),
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP); });
return INFINIOP_STATUS_SUCCESS;
}
infiniopStatus_t cudaMatmul(infiniopMatmulCudaDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *c,
void const *a,
void const *b,
float alpha,
float beta,
void *stream) {
if (desc->dtype == INFINI_DTYPE_F16) {
return cudaMatmulCublas<half>(desc, c, beta, a, b, alpha, stream);
}
if (desc->dtype == INFINI_DTYPE_F32) {
return cudaMatmulCublas<float>(desc, c, beta, a, b, alpha, stream);
}
return INFINIOP_STATUS_BAD_TENSOR_DTYPE;
}
#ifndef __MATMUL_H__
#define __MATMUL_H__
#include "blas.h"
#include "infiniop/operator.h"
/**
* # 关于 `DESCRIPTOR(NAMESPACE)` 和 `struct Opaque;` 的说明
*
* > - Data: 2025/02/20
* > - Author: <YdrMaster ydrml@hotmail.com>
*
* `DESCRIPTOR(NAMESPACE)` 宏定义了一个 `Descriptor` 类,
* 该类继承自 `InfiniopDescriptor`,并在内部定义了:
*
* - 描述任何硬件上的矩阵乘算子必要的信息
* - 参数数据类型;
* - 矩阵乘运算信息(BMNK);
* - 工作空间大小;
* - 特定硬件独有的描述信息保存在 `struct Opaque;` 中;
* - 私有的构造函数;
* - 公共接口
* - 析构函数;
* - 静态的工厂函数;
* - 矩阵乘计算函数;
*
* 这个宏必须写成一个宏,因为静态成员函数是不可继承的,但每个不同硬件上有不同的实现。
* 使用宏声明的不同硬件的矩阵乘描述符是不相关类型(并且它们也不必相关)。
*
* 这些描述符的声明对 operator.cc 可见,而 operator.cc 文件编译到 infiniop 库的硬件无关部分中,
* 因此,描述符声明不可以出现硬件相关的类型。
*
* 为了隐藏必须伴随算子描述符保存的硬件相关类型(例如,Ascend 上的 `executor`),
* 这里使用了 `Opaque` 声明。
*
* Opaque 表示“不透明的”,即其内部结构面型头文件的使用者隐藏。
* 这种设计来自 C++ 的一种设计模式,其经典形式称为 [PImpl](https://zh.cppreference.com/w/cpp/language/pimpl) 模式。
* 由于此处仅需要结构体的指针,因此结构体可以仅声明不定义,从而隐藏了结构体的成员。
* 通过将结构体定义为 private,保证了外部不可能操作这个指针,
* 因此这个设计可以由类型自由控制其成员的生命周期,
* 同时不必将成员的形式暴露给头文件的使用者,
* 这是一种安全的封装。
*
* 这个宏仅适用于矩阵乘,但这种模式很容易复制到其他算子,以简化和规范算子的声明。
*/
#define DESCRIPTOR(NAMESPACE) \
\
namespace matmul::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
infiniDtype_t _dtype; \
MatmulInfo _info; \
\
Descriptor( \
infiniDtype_t dtype, \
MatmulInfo info, \
size_t workspace_size_, \
Opaque *opaque, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_dtype(dtype), \
_info(info), \
workspace_size(workspace_size_) {} \
\
public: \
size_t workspace_size; \
\
~Descriptor(); \
\
static infiniopStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t c_desc, \
infiniopTensorDescriptor_t a_desc, \
infiniopTensorDescriptor_t b_desc); \
\
infiniopStatus_t calculate( \
void *workspace, \
size_t workspace_size, \
void *c, \
float beta, \
const void *a, \
const void *b, \
float alpha, \
void *stream) const; \
}; \
}
#endif // __MATMUL_H__
#include "infiniop/ops/matmul.h"
#ifdef ENABLE_CPU_API
#include "cpu/matmul_cpu_api.h"
#include "cpu/matmul_cpu.h"
#endif
#ifdef ENABLE_CUDA_API
#include "cuda/matmul_cuda_api.h"
#include "cuda/matmul_cuda.cuh"
#endif
#ifdef ENABLE_CAMBRICON_API
#include "bang/matmul_cnnl_api.h"
#include "bang/matmul_bang.h"
#endif
#ifdef ENABLE_ASCEND_API
#include "ascend/matmul_aclnn_api.h"
#include "ascend/matmul_ascend.h"
#endif
__C infiniopStatus_t infiniopCreateMatmulDescriptor(
infiniopHandle_t handle, infiniopMatmulDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t a_desc,
infiniopHandle_t handle,
infiniopMatmulDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return matmul::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<matmul::NAMESPACE::Descriptor **>(desc_ptr), \
c_desc, \
a_desc, \
b_desc)
switch (handle->device) {
#ifdef ENABLE_CPU_API
case INFINI_DEVICE_CPU:
return cpuCreateMatmulDescriptor(
(infiniopCpuHandle_t)handle,
(infiniopMatmulCpuDescriptor_t *)desc_ptr, c_desc, a_desc, b_desc);
CREATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_CUDA_API
case INFINI_DEVICE_NVIDIA: {
return cudaCreateMatmulDescriptor(
(infiniopCudaHandle_t)handle,
(infiniopMatmulCudaDescriptor_t *)desc_ptr, c_desc, a_desc, b_desc);
}
CREATE(INFINI_DEVICE_NVIDIA, cuda);
#endif
#ifdef ENABLE_CAMBRICON_API
case INFINI_DEVICE_CAMBRICON: {
return bangCreateMatmulDescriptor(
(infiniopBangHandle_t)handle,
(infiniopMatmulBangDescriptor_t *)desc_ptr, c_desc, a_desc, b_desc);
}
CREATE(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_ASCEND_API
case INFINI_DEVICE_ASCEND: {
return aclnnCreateMatmulDescriptor((infiniopAscendHandle_t)handle,
(MatmulAclnnDescriptor_t *)desc_ptr,
c_desc, a_desc, b_desc, 1);
}
CREATE(INFINI_DEVICE_ASCEND, ascend);
#endif
default:
return INFINIOP_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return INFINIOP_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
#undef CREATE
}
__C infiniopStatus_t
infiniopGetMatmulWorkspaceSize(infiniopMatmulDescriptor_t desc, size_t *size) {
switch (desc->device) {
infiniopGetMatmulWorkspaceSize(
infiniopMatmulDescriptor_t desc,
size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<const matmul::NAMESPACE::Descriptor *>(desc)->workspace_size; \
return INFINIOP_STATUS_SUCCESS
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
case INFINI_DEVICE_CPU:
return cpuGetMatmulWorkspaceSize((infiniopMatmulCpuDescriptor_t)desc,
size);
GET(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_CUDA_API
case INFINI_DEVICE_NVIDIA: {
return cudaGetMatmulWorkspaceSize((infiniopMatmulCudaDescriptor_t)desc,
size);
}
GET(INFINI_DEVICE_NVIDIA, cuda);
#endif
#ifdef ENABLE_CAMBRICON_API
case INFINI_DEVICE_CAMBRICON: {
return bangGetMatmulWorkspaceSize((infiniopMatmulBangDescriptor_t)desc,
size);
}
GET(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_ASCEND_API
case INFINI_DEVICE_ASCEND: {
return aclnnGetMatmulWorkspaceSize((MatmulAclnnDescriptor_t)desc, size);
}
GET(INFINI_DEVICE_ASCEND, ascend);
#endif
default:
return INFINIOP_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return INFINIOP_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
#undef GET
}
__C infiniopStatus_t infiniopMatmul(infiniopMatmulDescriptor_t desc,
void *workspace, size_t workspace_size,
void *c, void const *a, void const *b,
float alpha, float beta, void *stream) {
switch (desc->device) {
__C infiniopStatus_t infiniopMatmul(
infiniopMatmulDescriptor_t desc,
void *workspace, size_t workspace_size,
void *c,
const void *a,
const void *b,
float alpha,
float beta,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const matmul::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, \
c, beta, \
a, b, alpha, \
stream)
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
case INFINI_DEVICE_CPU:
return cpuMatmul((infiniopMatmulCpuDescriptor_t)desc, workspace,
workspace_size, c, a, b, alpha, beta);
CALCULATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_CUDA_API
case INFINI_DEVICE_NVIDIA:
return cudaMatmul((infiniopMatmulCudaDescriptor_t)desc, workspace,
workspace_size, c, a, b, alpha, beta, stream);
CALCULATE(INFINI_DEVICE_NVIDIA, cuda);
#endif
#ifdef ENABLE_CAMBRICON_API
case INFINI_DEVICE_CAMBRICON: {
return bangMatmul((infiniopMatmulBangDescriptor_t)desc, workspace,
workspace_size, c, a, b, alpha, beta, stream);
}
CALCULATE(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_ASCEND_API
case INFINI_DEVICE_ASCEND:
return aclnnMatmul((MatmulAclnnDescriptor_t)desc, workspace,
workspace_size, c, a, b, alpha, beta, stream);
CALCULATE(INFINI_DEVICE_ASCEND, ascend);
#endif
default:
return INFINIOP_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return INFINIOP_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
#undef CALCULATE
}
__C infiniopStatus_t
infiniopDestroyMatmulDescriptor(infiniopMatmulDescriptor_t desc) {
switch (desc->device) {
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<const matmul::NAMESPACE::Descriptor *>(desc); \
return INFINIOP_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
case INFINI_DEVICE_CPU:
return cpuDestroyMatmulDescriptor((infiniopMatmulCpuDescriptor_t)desc);
DELETE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_CUDA_API
case INFINI_DEVICE_NVIDIA: {
return cudaDestroyMatmulDescriptor(
(infiniopMatmulCudaDescriptor_t)desc);
}
DELETE(INFINI_DEVICE_NVIDIA, cuda);
#endif
#ifdef ENABLE_CAMBRICON_API
case INFINI_DEVICE_CAMBRICON: {
return bangDestroyMatmulDescriptor((infiniopMatmulBangDescriptor_t)desc);
}
DELETE(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_ASCEND_API
case INFINI_DEVICE_ASCEND: {
return aclnnDestroyMatmulDescriptor((MatmulAclnnDescriptor_t)desc);
}
DELETE(INFINI_DEVICE_ASCEND, ascend);
#endif
default:
return INFINIOP_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return INFINIOP_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
#undef DELETE
}
......@@ -39,7 +39,7 @@ __C infiniopStatus_t infiniopCreateRandomSampleDescriptor(infiniopHandle_t handl
};
__C infiniopStatus_t infiniopGetRandomSampleWorkspaceSize(infiniopRandomSampleDescriptor_t desc, size_t *size) {
switch (desc->device) {
switch (desc->device_type) {
#ifdef ENABLE_CPU
case DevCpu:
return cpuGetRandomSampleWorkspaceSize((RandomSampleCpuDescriptor_t)desc, size);
......@@ -79,13 +79,13 @@ __C infiniopStatus_t infiniopRandomSample(infiniopRandomSampleDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *result,
void const *probs,
const void *probs,
float random_val,
float topp,
int topk,
float temperature,
void *stream) {
switch (desc->device) {
switch (desc->device_type) {
#ifdef ENABLE_CPU
case DevCpu:
return cpuRandomSample((RandomSampleCpuDescriptor_t)desc, workspace, workspace_size, result, probs, random_val, topp, topk, temperature, stream);
......@@ -118,7 +118,7 @@ __C infiniopStatus_t infiniopRandomSample(infiniopRandomSampleDescriptor_t desc,
}
__C infiniopStatus_t infiniopDestroyRandomSampleDescriptor(infiniopRandomSampleDescriptor_t desc) {
switch (desc->device) {
switch (desc->device_type) {
#ifdef ENABLE_CPU
case DevCpu:
return cpuDestroyRandomSampleDescriptor((RandomSampleCpuDescriptor_t)desc);
......
......@@ -43,8 +43,8 @@ __C infiniopStatus_t infiniopCreateRearrangeDescriptor(
return INFINIOP_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniopStatus_t infiniopRearrange(infiniopRearrangeDescriptor_t desc, void *dst, void const *src, void *stream) {
switch (desc->device) {
__C infiniopStatus_t infiniopRearrange(infiniopRearrangeDescriptor_t desc, void *dst, const void *src, void *stream) {
switch (desc->device_type) {
#ifdef ENABLE_CPU
case DevCpu:
return cpuRearrange((RearrangeCpuDescriptor_t)desc, dst, src, stream);
......@@ -83,7 +83,7 @@ __C infiniopStatus_t infiniopRearrange(infiniopRearrangeDescriptor_t desc, void
}
__C infiniopStatus_t infiniopDestroyRearrangeDescriptor(infiniopRearrangeDescriptor_t desc) {
switch (desc->device) {
switch (desc->device_type) {
#ifdef ENABLE_CPU
case DevCpu:
return cpuDestroyRearrangeDescriptor((RearrangeCpuDescriptor_t)desc);
......
......@@ -47,7 +47,7 @@ __C infiniopStatus_t infiniopCreateRMSNormDescriptor(
}
__C infiniopStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t desc, size_t *size) {
switch (desc->device) {
switch (desc->device_type) {
#ifdef ENABLE_CPU
case DevCpu:
return cpuGetRMSNormWorkspaceSize((RMSNormCpuDescriptor_t)desc, size);
......@@ -84,8 +84,8 @@ __C infiniopStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t
}
__C infiniopStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *workspace, size_t workspace_size,
void *y, void const *x, void const *w, void *stream) {
switch (desc->device) {
void *y, const void *x, const void *w, void *stream) {
switch (desc->device_type) {
#ifdef ENABLE_CPU
case DevCpu:
return cpuRMSNorm((RMSNormCpuDescriptor_t)desc, workspace, workspace_size, y, x, w, stream);
......@@ -127,7 +127,7 @@ __C infiniopStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *wor
}
__C infiniopStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t desc) {
switch (desc->device) {
switch (desc->device_type) {
#ifdef ENABLE_CPU
case DevCpu:
return cpuDestroyRMSNormDescriptor((RMSNormCpuDescriptor_t)desc);
......
......@@ -54,7 +54,7 @@ __C infiniopStatus_t infiniopCreateRoPEDescriptor(
__C infiniopStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc,
size_t *size) {
switch (desc->device) {
switch (desc->device_type) {
#ifdef ENABLE_CPU
case DevCpu:
return cpuGetRoPEWorkspaceSize((RoPECpuDescriptor_t)desc, size);
......@@ -91,10 +91,10 @@ __C infiniopStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc,
__C infiniopStatus_t infiniopRoPE(infiniopRoPEDescriptor_t desc,
void *workspace, size_t workspace_size,
void *t, void const *pos_ids,
void const *sin_table, void const *cos_table,
void *t, const void *pos_ids,
const void *sin_table, const void *cos_table,
void *stream) {
switch (desc->device) {
switch (desc->device_type) {
#ifdef ENABLE_CPU
case DevCpu:
return cpuRoPE((RoPECpuDescriptor_t)desc, workspace, workspace_size, t,
......@@ -138,7 +138,7 @@ __C infiniopStatus_t infiniopRoPE(infiniopRoPEDescriptor_t desc,
__C infiniopStatus_t
infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc) {
switch (desc->device) {
switch (desc->device_type) {
#ifdef ENABLE_CPU
case DevCpu:
return cpuDestroyRoPEDescriptor((RoPECpuDescriptor_t)desc);
......
......@@ -46,9 +46,9 @@ __C infiniopStatus_t infiniopCreateSwiGLUDescriptor(
};
__C infiniopStatus_t infiniopSwiGLU(infiniopSwiGLUDescriptor_t desc, void *c,
void const *a, void const *b,
const void *a, const void *b,
void *stream) {
switch (desc->device) {
switch (desc->device_type) {
#ifdef ENABLE_CPU
case DevCpu:
return cpuSwiGLU((SwiGLUCpuDescriptor_t)desc, c, a, b, stream);
......@@ -80,7 +80,7 @@ __C infiniopStatus_t infiniopSwiGLU(infiniopSwiGLUDescriptor_t desc, void *c,
__C infiniopStatus_t
infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) {
switch (desc->device) {
switch (desc->device_type) {
#ifdef ENABLE_CPU
case DevCpu:
return cpuDestroySwiGLUDescriptor((SwiGLUCpuDescriptor_t)desc);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment