Unverified Commit 8d09630a authored by gongchensu's avatar gongchensu Committed by GitHub
Browse files

Merge branch 'demo131' into Issue/862

parents ab52dead 012df56c
import ninetoothed
from . import swiglu
import infiniop.ninetoothed.build
def build():
MAX_NDIM = 5
ndim_values = range(1, MAX_NDIM + 1)
dtype_values = (
ninetoothed.float16,
ninetoothed.bfloat16,
ninetoothed.float32,
)
constexpr_param_grid = {
"ndim": ndim_values,
"dtype": dtype_values,
"block_size": (1024,),
}
infiniop.ninetoothed.build.build(
swiglu.premake,
constexpr_param_grid,
caller="cuda",
op_name="swiglu",
output_dir=infiniop.ninetoothed.build.BUILD_DIRECTORY_PATH,
)
#ifndef SWIGLU_H
#define SWIGLU_H
#include "../../../handle.h"
#include "../../../operator.h"
#include "../../../tensor.h"
#include "../../../../../build/ninetoothed/swiglu.h"
#include "../../../ninetoothed/utils.h"
namespace op::swiglu::ninetoothed {
class Descriptor final : public InfiniopDescriptor {
public:
Descriptor(
infiniopHandle_t handle,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) : InfiniopDescriptor{handle->device, handle->device_id},
out_shape_{out_desc->shape()},
out_strides_{out_desc->strides()},
up_shape_{input_desc_vec[0]->shape()},
up_strides_{input_desc_vec[0]->strides()},
gate_shape_{input_desc_vec[1]->shape()},
gate_strides_{input_desc_vec[1]->strides()},
dtype_{out_desc->dtype()} {}
~Descriptor() = default;
size_t workspaceSize() const {
return 0;
}
static infiniStatus_t create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
*desc_ptr = new Descriptor(handle, out_desc, input_desc_vec);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
auto out_nt{::ninetoothed::Tensor(output, out_shape_, out_strides_)};
auto up_nt{::ninetoothed::Tensor(inputs[0], up_shape_, up_strides_)};
auto gate_nt{::ninetoothed::Tensor(inputs[1], gate_shape_, gate_strides_)};
if (launch_swiglu(stream,
out_nt,
up_nt,
gate_nt,
out_shape_.size(),
dtype_,
1024)) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}
return INFINI_STATUS_SUCCESS;
}
private:
using Size = ::ninetoothed::Tensor<>::Size;
using Stride = ::ninetoothed::Tensor<>::Stride;
std::vector<Size> out_shape_;
std::vector<Stride> out_strides_;
std::vector<Size> up_shape_;
std::vector<Stride> up_strides_;
std::vector<Size> gate_shape_;
std::vector<Stride> gate_strides_;
infiniDtype_t dtype_;
};
} // namespace op::swiglu::ninetoothed
#endif // SWIGLU_H
import functools
import ninetoothed.language as ntl
from ninetoothed import Tensor
from ntops.kernels.element_wise import arrangement
def application(output, up, gate):
output = ntl.sigmoid(ntl.cast(gate, ntl.float32)) * gate * up # noqa: F841
def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)
tensors = (
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
)
return arrangement_, application, tensors
......@@ -5,15 +5,23 @@
#ifdef ENABLE_CPU_API
#include "cpu/swiglu_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) || defined(ENABLE_ALI_API)
#if defined(ENABLE_NINETOOTHED)
#include "ninetoothed/swiglu.h"
#else
#include "nvidia/swiglu_nvidia.cuh"
#endif
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/swiglu_kunlun.h"
#endif
#ifdef ENABLE_METAX_API
#if defined(ENABLE_NINETOOTHED)
#include "ninetoothed/swiglu.h"
#else
#include "metax/swiglu_metax.h"
#endif
#endif
#ifdef ENABLE_CAMBRICON_API
#include "bang/swiglu_bang.h"
#endif
......@@ -46,11 +54,22 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
CREATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NINETOOTHED
CREATE(INFINI_DEVICE_NVIDIA, ninetoothed);
#else
CREATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#endif
#ifdef ENABLE_ILUVATAR_API
#ifdef ENABLE_NINETOOTHED
CREATE(INFINI_DEVICE_ILUVATAR, ninetoothed);
#else
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
......@@ -61,8 +80,12 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_METAX_API
#ifdef ENABLE_NINETOOTHED
CREATE(INFINI_DEVICE_METAX, ninetoothed);
#else
CREATE(INFINI_DEVICE_METAX, metax);
#endif
#endif
#ifdef ENABLE_CAMBRICON_API
CREATE(INFINI_DEVICE_CAMBRICON, bang);
#endif
......@@ -92,11 +115,22 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des
GET(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NINETOOTHED
GET(INFINI_DEVICE_NVIDIA, ninetoothed);
#else
GET(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#endif
#ifdef ENABLE_ILUVATAR_API
#ifdef ENABLE_NINETOOTHED
GET(INFINI_DEVICE_ILUVATAR, ninetoothed);
#else
GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
......@@ -107,8 +141,12 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des
GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_METAX_API
#ifdef ENABLE_NINETOOTHED
GET(INFINI_DEVICE_METAX, ninetoothed);
#else
GET(INFINI_DEVICE_METAX, metax);
#endif
#endif
#ifdef ENABLE_CAMBRICON_API
GET(INFINI_DEVICE_CAMBRICON, bang);
#endif
......@@ -145,11 +183,22 @@ __C infiniStatus_t infiniopSwiGLU(
CALCULATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NINETOOTHED
CALCULATE(INFINI_DEVICE_NVIDIA, ninetoothed);
#else
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#endif
#ifdef ENABLE_ILUVATAR_API
#ifdef ENABLE_NINETOOTHED
CALCULATE(INFINI_DEVICE_ILUVATAR, ninetoothed);
#else
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
......@@ -160,8 +209,12 @@ __C infiniStatus_t infiniopSwiGLU(
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_METAX_API
#ifdef ENABLE_NINETOOTHED
CALCULATE(INFINI_DEVICE_METAX, ninetoothed);
#else
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
#endif
#ifdef ENABLE_CAMBRICON_API
CALCULATE(INFINI_DEVICE_CAMBRICON, bang);
#endif
......@@ -193,11 +246,22 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) {
DELETE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NINETOOTHED
DELETE(INFINI_DEVICE_NVIDIA, ninetoothed);
#else
DELETE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#endif
#ifdef ENABLE_ILUVATAR_API
#ifdef ENABLE_NINETOOTHED
DELETE(INFINI_DEVICE_ILUVATAR, ninetoothed);
#else
DELETE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#endif
#ifdef ENABLE_ALI_API
DELETE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia);
#endif
......@@ -208,8 +272,12 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) {
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_METAX_API
#ifdef ENABLE_NINETOOTHED
DELETE(INFINI_DEVICE_METAX, ninetoothed);
#else
DELETE(INFINI_DEVICE_METAX, metax);
#endif
#endif
#ifdef ENABLE_CAMBRICON_API
DELETE(INFINI_DEVICE_CAMBRICON, bang);
#endif
......
......@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/tanh_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_ALI_API)
#include "nvidia/tanh_nvidia.cuh"
#endif
// #ifdef ENABLE_METAX_API
......@@ -40,6 +40,10 @@ __C infiniStatus_t infiniopCreateTanhDescriptor(
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
// #ifdef ENABLE_METAX_API
// CREATE(INFINI_DEVICE_METAX, metax);
// #endif
......@@ -71,6 +75,10 @@ __C infiniStatus_t infiniopGetTanhWorkspaceSize(infiniopTanhDescriptor_t desc, s
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
// #ifdef ENABLE_METAX_API
// GET(INFINI_DEVICE_METAX, metax);
// #endif
......@@ -109,6 +117,10 @@ __C infiniStatus_t infiniopTanh(
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
// #ifdef ENABLE_METAX_API
// CALCULATE(INFINI_DEVICE_METAX, metax);
// #endif
......@@ -142,6 +154,10 @@ infiniopDestroyTanhDescriptor(infiniopTanhDescriptor_t desc) {
#ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_ALI_API
DELETE(INFINI_DEVICE_ALI, nvidia);
#endif
// #ifdef ENABLE_METAX_API
// DELETE(INFINI_DEVICE_METAX, metax);
// #endif
......
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) || defined(ENABLE_ALI_API)
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
......
......@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/topkrouter_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) || defined(ENABLE_ALI_API)
#include "nvidia/topkrouter_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
......@@ -38,6 +38,9 @@ __C infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle, i
#endif
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
}
......@@ -67,6 +70,9 @@ __C infiniStatus_t infiniopGetTopkrouterWorkspaceSize(infiniopTopkrouterDescript
#endif
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
}
......@@ -99,6 +105,9 @@ __C infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t desc, void
#endif
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
}
......@@ -128,6 +137,9 @@ __C infiniStatus_t infiniopDestroyTopkrouterDescriptor(infiniopTopkrouterDescrip
#endif
#ifdef ENABLE_KUNLUN_API
DESTROY(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_ALI_API
DESTROY(INFINI_DEVICE_ALI, nvidia);
#endif
}
......
......@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/topksoftmax_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) || defined(ENABLE_ALI_API)
#include "nvidia/topksoftmax_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
......@@ -33,6 +33,9 @@ __C infiniStatus_t infiniopCreateTopksoftmaxDescriptor(infiniopHandle_t handle,
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
}
......@@ -60,6 +63,9 @@ __C infiniStatus_t infiniopGetTopksoftmaxWorkspaceSize(infiniopTopksoftmaxDescri
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
}
......@@ -92,6 +98,9 @@ __C infiniStatus_t infiniopTopksoftmax(infiniopTopksoftmaxDescriptor_t desc, voi
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
}
......@@ -119,6 +128,9 @@ __C infiniStatus_t infiniopDestroyTopksoftmaxDescriptor(infiniopTopksoftmaxDescr
#endif
#ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_ALI_API
DESTROY(INFINI_DEVICE_ALI, nvidia);
#endif
}
......
......@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/zeros_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_ALI_API)
#include "nvidia/zeros_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
......@@ -37,6 +37,9 @@ __C infiniStatus_t infiniopCreateZerosDescriptor(
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
......@@ -73,6 +76,9 @@ __C infiniStatus_t infiniopGetZerosWorkspaceSize(infiniopZerosDescriptor_t desc,
#ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
......@@ -114,6 +120,9 @@ __C infiniStatus_t infiniopZeros(
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
......@@ -149,6 +158,9 @@ infiniopDestroyZerosDescriptor(infiniopZerosDescriptor_t desc) {
#ifdef ENABLE_ILUVATAR_API
DELETE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_ALI_API
DELETE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia);
#endif
......
......@@ -50,7 +50,7 @@ __mlu_func__ float sum(const T *source, T *src, float *dst, int num_elements, in
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);
if (curr_batch < max_batch) {
__bang_write_zero(src, max_batch + offset);
__bang_write_value(src, max_batch + offset, 0);
}
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
......@@ -81,7 +81,7 @@ __mlu_func__ float sumBatched(const T *source, T *src, float *dst, int num_eleme
size_t remainder = curr_batch % batch_size;
// Ensure NRAM buffer is zeroed
__bang_write_zero(src, max_batch + offset);
__bang_write_value(src, max_batch + offset, 0);
// Copy data to NRAM
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
......@@ -120,7 +120,7 @@ __mlu_func__ float sumSquared(const T *source, T *src, float *dst, int num_eleme
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);
if (curr_batch < max_batch) {
__bang_write_zero(src, max_batch + offset);
__bang_write_value(src, max_batch + offset, 0);
}
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
......@@ -165,7 +165,7 @@ __mlu_func__ float sumSquaredBatched(const T *source, T *src, float *dst, int nu
size_t remainder = curr_batch % batch_size;
// Ensure NRAM buffer is zeroed
__bang_write_zero(src, max_batch + offset);
__bang_write_value(src, max_batch + offset, 0);
// Copy data to NRAM
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
......@@ -235,7 +235,7 @@ __mlu_func__ float max(const T *source, T *src, float *dst, int num_elements, in
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);
if (curr_batch < max_batch) {
__bang_write_zero(src, max_batch + offset);
__bang_write_value(src, max_batch + offset, 0);
}
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
......@@ -264,7 +264,7 @@ __mlu_func__ float maxBatched(const T *source, T *src, float *dst, int num_eleme
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);
if (curr_batch < max_batch) {
__bang_write_zero(src, max_batch + offset);
__bang_write_value(src, max_batch + offset, 0);
}
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
......
......@@ -23,6 +23,7 @@ void printUsage() {
<< " qy" << std::endl
<< " kunlun" << std::endl
<< " hygon" << std::endl
<< " ali" << std::endl
<< std::endl;
exit(EXIT_FAILURE);
}
......@@ -55,6 +56,7 @@ ParsedArgs parseArgs(int argc, char *argv[]) {
else PARSE_DEVICE("--qy", INFINI_DEVICE_QY)
else PARSE_DEVICE("--kunlun", INFINI_DEVICE_KUNLUN)
else PARSE_DEVICE("--hygon", INFINI_DEVICE_HYGON)
else PARSE_DEVICE("--ali", INFINI_DEVICE_ALI)
else {
printUsage();
}
......
......@@ -150,5 +150,35 @@ infiniStatus_t mallocAsync(void **p_ptr, size_t size, infinirtStream_t stream) {
infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) {
return freeDevice(ptr);
}
infiniStatus_t streamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t streamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphDestroy(infinirtGraph_t graph) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphInstantiate(
infinirtGraphExec_t *graph_exec_ptr,
infinirtGraph_t graph,
infinirtGraphNode_t *node_ptr,
char *log_buffer,
size_t buffer_size) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphExecDestroy(infinirtGraphExec_t graph_exec) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
} // namespace infinirt::ascend
#undef CHECK_ACLRT
......@@ -142,4 +142,34 @@ infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) {
CHECK_BANGRT(cnrtFree(ptr));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t streamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t streamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphDestroy(infinirtGraph_t graph) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphInstantiate(
infinirtGraphExec_t *graph_exec_ptr,
infinirtGraph_t graph,
infinirtGraphNode_t *node_ptr,
char *log_buffer,
size_t buffer_size) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphExecDestroy(infinirtGraphExec_t graph_exec) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
} // namespace infinirt::bang
......@@ -116,4 +116,33 @@ infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) {
return freeDevice(ptr);
}
infiniStatus_t streamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t streamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphDestroy(infinirtGraph_t graph) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphInstantiate(
infinirtGraphExec_t *graph_exec_ptr,
infinirtGraph_t graph,
infinirtGraphNode_t *node_ptr,
char *log_buffer,
size_t buffer_size) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphExecDestroy(infinirtGraphExec_t graph_exec) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
} // namespace infinirt::cpu
......@@ -4,6 +4,14 @@
#define CHECK_CUDART(RT_API) CHECK_INTERNAL(RT_API, cudaSuccess)
#define RUN_CUDART(RT_API) \
do { \
auto api_result_ = (RT_API); \
if (api_result_ != (cudaSuccess)) { \
{ return INFINI_STATUS_INTERNAL_ERROR; } \
} \
} while (0)
// 根据宏定义选择命名空间并实现
#if defined(ENABLE_NVIDIA_API)
namespace infinirt::cuda {
......@@ -13,6 +21,8 @@ namespace infinirt::iluvatar {
namespace infinirt::qy {
#elif defined(ENABLE_HYGON_API)
namespace infinirt::hygon {
#elif defined(ENABLE_ALI_API)
namespace infinirt::ali {
#else
namespace infinirt::cuda { // 默认回退
#endif
......@@ -40,7 +50,7 @@ infiniStatus_t streamCreate(infinirtStream_t *stream_ptr) {
}
infiniStatus_t streamDestroy(infinirtStream_t stream) {
CHECK_CUDART(cudaStreamDestroy((cudaStream_t)stream));
RUN_CUDART(cudaStreamDestroy((cudaStream_t)stream));
return INFINI_STATUS_SUCCESS;
}
......@@ -105,7 +115,7 @@ infiniStatus_t eventSynchronize(infinirtEvent_t event) {
}
infiniStatus_t eventDestroy(infinirtEvent_t event) {
CHECK_CUDART(cudaEventDestroy((cudaEvent_t)event));
RUN_CUDART(cudaEventDestroy((cudaEvent_t)event));
return INFINI_STATUS_SUCCESS;
}
......@@ -125,12 +135,12 @@ infiniStatus_t mallocHost(void **p_ptr, size_t size) {
}
infiniStatus_t freeDevice(void *ptr) {
CHECK_CUDART(cudaFree(ptr));
RUN_CUDART(cudaFree(ptr));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t freeHost(void *ptr) {
CHECK_CUDART(cudaFreeHost(ptr));
RUN_CUDART(cudaFreeHost(ptr));
return INFINI_STATUS_SUCCESS;
}
......@@ -165,7 +175,56 @@ infiniStatus_t mallocAsync(void **p_ptr, size_t size, infinirtStream_t stream) {
}
infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) {
CHECK_CUDART(cudaFreeAsync(ptr, (cudaStream_t)stream));
RUN_CUDART(cudaFreeAsync(ptr, (cudaStream_t)stream));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t streamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) {
cudaStreamCaptureMode graph_mode;
if (mode == INFINIRT_STREAM_CAPTURE_MODE_GLOBAL) {
graph_mode = cudaStreamCaptureModeGlobal;
} else if (mode == INFINIRT_STREAM_CAPTURE_MODE_THREAD_LOCAL) {
graph_mode = cudaStreamCaptureModeThreadLocal;
} else if (mode == INFINIRT_STREAM_CAPTURE_MODE_RELAXED) {
graph_mode = cudaStreamCaptureModeRelaxed;
} else {
return INFINI_STATUS_BAD_PARAM;
}
CHECK_CUDART(cudaStreamBeginCapture((cudaStream_t)stream, graph_mode));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t streamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) {
cudaGraph_t graph;
CHECK_CUDART(cudaStreamEndCapture((cudaStream_t)stream, &graph));
*graph_ptr = graph;
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t graphDestroy(infinirtGraph_t graph) {
RUN_CUDART(cudaGraphDestroy((cudaGraph_t)graph));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t graphInstantiate(
infinirtGraphExec_t *graph_exec_ptr,
infinirtGraph_t graph,
infinirtGraphNode_t *node_ptr,
char *log_buffer,
size_t buffer_size) {
CHECK_CUDART(cudaGraphInstantiate((cudaGraphExec_t *)graph_exec_ptr, (cudaGraph_t)graph, (cudaGraphNode_t *)node_ptr, log_buffer, buffer_size));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t graphExecDestroy(infinirtGraphExec_t graph_exec) {
RUN_CUDART(cudaGraphExecDestroy((cudaGraphExec_t)graph_exec));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t graphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) {
CHECK_CUDART(cudaGraphLaunch((cudaGraphExec_t)graph_exec, (cudaStream_t)stream));
return INFINI_STATUS_SUCCESS;
}
}
......@@ -38,4 +38,13 @@ INFINIRT_DEVICE_API_NOOP
#endif
} // namespace infinirt::hygon
// ALI namespace
namespace infinirt::ali {
#ifdef ENABLE_ALI_API
INFINIRT_DEVICE_API_IMPL
#else
INFINIRT_DEVICE_API_NOOP
#endif
} // namespace infinirt::ali
#endif // __INFINIRT_CUDA_H__
......@@ -81,6 +81,9 @@ __C infiniStatus_t infinirtGetDevice(infiniDevice_t *device_ptr, int *device_id_
case INFINI_DEVICE_HYGON: \
_status = infinirt::hygon::API PARAMS; \
break; \
case INFINI_DEVICE_ALI: \
_status = infinirt::ali::API PARAMS; \
break; \
default: \
_status = INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; \
} \
......@@ -192,3 +195,32 @@ __C infiniStatus_t infinirtMallocAsync(void **p_ptr, size_t size, infinirtStream
__C infiniStatus_t infinirtFreeAsync(void *ptr, infinirtStream_t stream) {
INFINIRT_CALL_DEVICE_API(freeAsync, (ptr, stream));
}
__C infiniStatus_t infinirtStreamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) {
INFINIRT_CALL_DEVICE_API(streamBeginCapture, (stream, mode));
}
__C infiniStatus_t infinirtStreamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) {
INFINIRT_CALL_DEVICE_API(streamEndCapture, (stream, graph_ptr));
}
__C infiniStatus_t infinirtGraphDestroy(infinirtGraph_t graph) {
INFINIRT_CALL_DEVICE_API(graphDestroy, (graph));
}
__C infiniStatus_t infinirtGraphInstantiate(
infinirtGraphExec_t *graph_exec_ptr,
infinirtGraph_t graph,
infinirtGraphNode_t *node_ptr,
char *log_buffer,
size_t buffer_size) {
INFINIRT_CALL_DEVICE_API(graphInstantiate, (graph_exec_ptr, graph, node_ptr, log_buffer, buffer_size));
}
__C infiniStatus_t infinirtGraphExecDestroy(infinirtGraphExec_t graph_exec) {
INFINIRT_CALL_DEVICE_API(graphExecDestroy, (graph_exec));
}
__C infiniStatus_t infinirtGraphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) {
INFINIRT_CALL_DEVICE_API(graphLuanch, (graph_exec, stream));
}
......@@ -30,7 +30,19 @@
INLINE infiniStatus_t memcpyAsync(void *dst, const void *src, size_t size, infinirtMemcpyKind_t kind, infinirtStream_t stream) IMPL; \
\
INLINE infiniStatus_t mallocAsync(void **p_ptr, size_t size, infinirtStream_t stream) IMPL; \
INLINE infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) IMPL;
INLINE infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) IMPL; \
\
INLINE infiniStatus_t streamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) IMPL; \
INLINE infiniStatus_t streamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) IMPL; \
INLINE infiniStatus_t graphDestroy(infinirtGraph_t graph) IMPL; \
INLINE infiniStatus_t graphInstantiate( \
infinirtGraphExec_t *graph_exec_ptr, \
infinirtGraph_t graph, \
infinirtGraphNode_t *node_ptr, \
char *log_buffer, \
size_t buffer_size) IMPL; \
INLINE infiniStatus_t graphExecDestroy(infinirtGraphExec_t graph_exec) IMPL; \
INLINE infiniStatus_t graphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) IMPL;
#define INFINIRT_DEVICE_API_IMPL INFINIRT_DEVICE_API(, , )
#define INFINIRT_DEVICE_API_NOOP INFINIRT_DEVICE_API( \
......
......@@ -153,4 +153,33 @@ infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) {
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t streamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t streamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphDestroy(infinirtGraph_t graph) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphInstantiate(
infinirtGraphExec_t *graph_exec_ptr,
infinirtGraph_t graph,
infinirtGraphNode_t *node_ptr,
char *log_buffer,
size_t buffer_size) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphExecDestroy(infinirtGraphExec_t graph_exec) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
} // namespace infinirt::kunlun
......@@ -152,4 +152,59 @@ infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) {
CHECK_MACART(hcFreeAsync(ptr, (hcStream_t)stream));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t streamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) {
hcStreamCaptureMode graph_mode;
if (mode == INFINIRT_STREAM_CAPTURE_MODE_GLOBAL) {
graph_mode = hcStreamCaptureModeGlobal;
} else if (mode == INFINIRT_STREAM_CAPTURE_MODE_THREAD_LOCAL) {
graph_mode = hcStreamCaptureModeThreadLocal;
} else if (mode == INFINIRT_STREAM_CAPTURE_MODE_RELAXED) {
graph_mode = hcStreamCaptureModeRelaxed;
} else {
return INFINI_STATUS_BAD_PARAM;
}
CHECK_MACART(hcStreamBeginCapture((hcStream_t)stream, graph_mode));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t streamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) {
hcGraph_t graph;
CHECK_MACART(hcStreamEndCapture((hcStream_t)stream, &graph));
*graph_ptr = graph;
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t graphDestroy(infinirtGraph_t graph) {
CHECK_MACART(hcGraphDestroy((hcGraph_t)graph));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t graphInstantiate(
infinirtGraphExec_t *graph_exec_ptr,
infinirtGraph_t graph,
infinirtGraphNode_t *node_ptr,
char *log_buffer,
size_t buffer_size) {
CHECK_MACART(hcGraphInstantiate(
(hcGraphExec_t *)graph_exec_ptr,
(hcGraph_t)graph,
(hcGraphNode_t *)node_ptr,
log_buffer,
buffer_size));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t graphExecDestroy(infinirtGraphExec_t graph_exec) {
CHECK_MACART(hcGraphExecDestroy((hcGraphExec_t)graph_exec));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t graphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) {
CHECK_MACART(hcGraphLaunch((hcGraphExec_t)graph_exec, (hcStream_t)stream));
return INFINI_STATUS_SUCCESS;
}
} // namespace infinirt::metax
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment