Commit f7137096 authored by YdrMaster's avatar YdrMaster
Browse files

issue/63/refactor: 重构 Matmul 所有实现,添加命名空间


Signed-off-by: default avatarYdrMaster <ydrml@hotmail.com>
parent 8e34901e
#include "infiniop/ops/matmul.h"
#ifdef ENABLE_CPU_API
#include "cpu/matmul_cpu_api.h"
#include "cpu/matmul_cpu.h"
#endif
#ifdef ENABLE_CUDA_API
#include "cuda/matmul_cuda_api.h"
#include "cuda/matmul_cuda.cuh"
#endif
#ifdef ENABLE_CAMBRICON_API
#include "bang/matmul_cnnl_api.h"
#include "bang/matmul_bang.h"
#endif
#ifdef ENABLE_ASCEND_API
#include "ascend/matmul_aclnn_api.h"
#include "ascend/matmul_ascend.h"
#endif
__C infiniopStatus_t infiniopCreateMatmulDescriptor(
infiniopHandle_t handle, infiniopMatmulDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t a_desc,
infiniopHandle_t handle,
infiniopMatmulDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {
#define CREATE(CASE, HANDLE, NAMESPACE) \
case CASE: \
return matmul::NAMESPACE::Descriptor::create( \
reinterpret_cast<HANDLE>(handle), \
reinterpret_cast<matmul::NAMESPACE::Descriptor **>(desc_ptr), \
c_desc, \
a_desc, \
b_desc)
switch (handle->device) {
#ifdef ENABLE_CPU_API
case INFINI_DEVICE_CPU:
return cpuCreateMatmulDescriptor(
(infiniopCpuHandle_t)handle,
(infiniopMatmulCpuDescriptor_t *)desc_ptr, c_desc, a_desc, b_desc);
CREATE(INFINI_DEVICE_CPU, infiniopCpuHandle_t, cpu);
#endif
#ifdef ENABLE_CUDA_API
case INFINI_DEVICE_NVIDIA: {
return cudaCreateMatmulDescriptor(
(infiniopCudaHandle_t)handle,
(infiniopMatmulCudaDescriptor_t *)desc_ptr, c_desc, a_desc, b_desc);
}
CREATE(INFINI_DEVICE_NVIDIA, infiniopCudaHandle_t, cuda);
#endif
#ifdef ENABLE_CAMBRICON_API
case INFINI_DEVICE_CAMBRICON: {
return bangCreateMatmulDescriptor(
(infiniopBangHandle_t)handle,
(infiniopMatmulBangDescriptor_t *)desc_ptr, c_desc, a_desc, b_desc);
}
CREATE(INFINI_DEVICE_CAMBRICON, infiniopBangHandle_t, bang);
#endif
#ifdef ENABLE_ASCEND_API
case INFINI_DEVICE_ASCEND: {
return aclnnCreateMatmulDescriptor((infiniopAscendHandle_t)handle,
(MatmulAclnnDescriptor_t *)desc_ptr,
c_desc, a_desc, b_desc, 1);
}
CREATE(INFINI_DEVICE_ASCEND, infiniopAscendHandle_t, ascend);
#endif
default:
return INFINIOP_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return INFINIOP_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
#undef CREATE
}
__C infiniopStatus_t
infiniopGetMatmulWorkspaceSize(infiniopMatmulDescriptor_t desc, size_t *size) {
infiniopGetMatmulWorkspaceSize(
infiniopMatmulDescriptor_t desc,
size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<matmul::NAMESPACE::Descriptor const *>(desc)->workspace_size; \
return INFINIOP_STATUS_SUCCESS
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
case INFINI_DEVICE_CPU:
return cpuGetMatmulWorkspaceSize((infiniopMatmulCpuDescriptor_t)desc,
size);
GET(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_CUDA_API
case INFINI_DEVICE_NVIDIA: {
return cudaGetMatmulWorkspaceSize((infiniopMatmulCudaDescriptor_t)desc,
size);
}
GET(INFINI_DEVICE_NVIDIA, cuda);
#endif
#ifdef ENABLE_CAMBRICON_API
case INFINI_DEVICE_CAMBRICON: {
return bangGetMatmulWorkspaceSize((infiniopMatmulBangDescriptor_t)desc,
size);
}
GET(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_ASCEND_API
case INFINI_DEVICE_ASCEND: {
return aclnnGetMatmulWorkspaceSize((MatmulAclnnDescriptor_t)desc, size);
}
GET(INFINI_DEVICE_ASCEND, ascend);
#endif
default:
return INFINIOP_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return INFINIOP_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
#undef GET
}
__C infiniopStatus_t infiniopMatmul(infiniopMatmulDescriptor_t desc,
void *workspace, size_t workspace_size,
void *c, void const *a, void const *b,
float alpha, float beta, void *stream) {
__C infiniopStatus_t infiniopMatmul(
infiniopMatmulDescriptor_t desc,
void *workspace, size_t workspace_size,
void *c,
void const *a,
void const *b,
float alpha,
float beta,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<matmul::NAMESPACE::Descriptor const *>(desc) \
->calculate(workspace, workspace_size, \
c, beta, \
a, b, alpha, \
stream)
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
case INFINI_DEVICE_CPU:
return cpuMatmul((infiniopMatmulCpuDescriptor_t)desc, workspace,
workspace_size, c, a, b, alpha, beta);
CALCULATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_CUDA_API
case INFINI_DEVICE_NVIDIA:
return cudaMatmul((infiniopMatmulCudaDescriptor_t)desc, workspace,
workspace_size, c, a, b, alpha, beta, stream);
CALCULATE(INFINI_DEVICE_NVIDIA, cuda);
#endif
#ifdef ENABLE_CAMBRICON_API
case INFINI_DEVICE_CAMBRICON: {
return bangMatmul((infiniopMatmulBangDescriptor_t)desc, workspace,
workspace_size, c, a, b, alpha, beta, stream);
}
CALCULATE(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_ASCEND_API
case INFINI_DEVICE_ASCEND:
return aclnnMatmul((MatmulAclnnDescriptor_t)desc, workspace,
workspace_size, c, a, b, alpha, beta, stream);
CALCULATE(INFINI_DEVICE_ASCEND, ascend);
#endif
default:
return INFINIOP_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return INFINIOP_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
#undef CALCULATE
}
__C infiniopStatus_t
infiniopDestroyMatmulDescriptor(infiniopMatmulDescriptor_t desc) {
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<matmul::NAMESPACE::Descriptor const *>(desc); \
return INFINIOP_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
case INFINI_DEVICE_CPU:
return cpuDestroyMatmulDescriptor((infiniopMatmulCpuDescriptor_t)desc);
DELETE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_CUDA_API
case INFINI_DEVICE_NVIDIA: {
return cudaDestroyMatmulDescriptor(
(infiniopMatmulCudaDescriptor_t)desc);
}
DELETE(INFINI_DEVICE_NVIDIA, cuda);
#endif
#ifdef ENABLE_CAMBRICON_API
case INFINI_DEVICE_CAMBRICON: {
return bangDestroyMatmulDescriptor((infiniopMatmulBangDescriptor_t)desc);
}
DELETE(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_ASCEND_API
case INFINI_DEVICE_ASCEND: {
return aclnnDestroyMatmulDescriptor((MatmulAclnnDescriptor_t)desc);
}
DELETE(INFINI_DEVICE_ASCEND, ascend);
#endif
default:
return INFINIOP_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return INFINIOP_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
#undef DELETE
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment