"vscode:/vscode.git/clone" did not exist on "9a502a5b14b4a6160103c1f2c64331772878d86a"
Commit 2f2a74b6 authored by Zimin Li's avatar Zimin Li
Browse files

Merge remote-tracking branch 'upstream/main'

parents 1d95ddf3 70806eed
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include "../../tensor.h" #include "../../tensor.h"
#include <algorithm> #include <algorithm>
namespace op::matmul { namespace op::gemm {
struct BlasMatrix { struct BlasMatrix {
size_t ndim; size_t ndim;
...@@ -120,6 +120,6 @@ struct MatmulInfo { ...@@ -120,6 +120,6 @@ struct MatmulInfo {
} }
}; };
} // namespace op::matmul } // namespace op::gemm
#endif // __BLAS_H__ #endif // __BLAS_H__
#include "matmul_cpu.h" #include "gemm_cpu.h"
#include "../../../devices/cpu/common_cpu.h" #include "../../../devices/cpu/common_cpu.h"
namespace op::matmul::cpu { namespace op::gemm::cpu {
Descriptor::~Descriptor() = default; Descriptor::~Descriptor() = default;
...@@ -95,4 +95,4 @@ infiniStatus_t Descriptor::calculate( ...@@ -95,4 +95,4 @@ infiniStatus_t Descriptor::calculate(
} }
} }
} // namespace op::matmul::cpu } // namespace op::gemm::cpu
#ifndef __GEMM_CPU_H__
#define __GEMM_CPU_H__
#include "../gemm.h"
DESCRIPTOR(cpu)
#endif // __GEMM_CPU_H__
#include "../../../devices/cuda/cuda_handle.cuh" #include "../../../devices/cuda/cuda_handle.cuh"
#include "matmul_cuda.cuh" #include "gemm_cuda.cuh"
namespace op::matmul::cuda { namespace op::gemm::cuda {
struct Descriptor::Opaque { struct Descriptor::Opaque {
std::shared_ptr<device::cuda::Handle::Internal> internal; std::shared_ptr<device::cuda::Handle::Internal> internal;
...@@ -109,4 +109,4 @@ infiniStatus_t Descriptor::calculate( ...@@ -109,4 +109,4 @@ infiniStatus_t Descriptor::calculate(
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
} // namespace op::matmul::cuda } // namespace op::gemm::cuda
#ifndef __GEMM_CUDA_CUH__
#define __GEMM_CUDA_CUH__
#include "../gemm.h"
DESCRIPTOR(cuda)
#endif // __GEMM_CUDA_CUH__
#ifndef __MATMUL_H__ #ifndef __GEMM_H__
#define __MATMUL_H__ #define __GEMM_H__
#include "../../operator.h" #include "../../operator.h"
#include "blas.h" #include "blas.h"
...@@ -46,7 +46,7 @@ ...@@ -46,7 +46,7 @@
#define DESCRIPTOR(NAMESPACE) \ #define DESCRIPTOR(NAMESPACE) \
\ \
namespace op::matmul::NAMESPACE { \ namespace op::gemm::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \ class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \ struct Opaque; \
Opaque *_opaque; \ Opaque *_opaque; \
...@@ -90,4 +90,4 @@ ...@@ -90,4 +90,4 @@
}; \ }; \
} }
#endif // __MATMUL_H__ #endif // __GEMM_H__
#include "matmul_kunlun.h" #include "gemm_kunlun.h"
#include "../../../../utils.h" #include "../../../../utils.h"
#include "../../../devices/kunlun/kunlun_handle.h" #include "../../../devices/kunlun/kunlun_handle.h"
namespace op::matmul::kunlun { namespace op::gemm::kunlun {
typedef device::kunlun::Handle::Internal HandleInternal; typedef device::kunlun::Handle::Internal HandleInternal;
...@@ -103,12 +103,12 @@ infiniStatus_t Descriptor::calculate( ...@@ -103,12 +103,12 @@ infiniStatus_t Descriptor::calculate(
void *stream) const { void *stream) const {
switch (_dtype) { switch (_dtype) {
case INFINI_DTYPE_F16: case INFINI_DTYPE_F16:
return op::matmul::kunlun::calculate<float16>(_info, _opaque->internal, _dtype, c, beta, a, b, alpha, (kunlunStream_t)stream); return op::gemm::kunlun::calculate<float16>(_info, _opaque->internal, _dtype, c, beta, a, b, alpha, (kunlunStream_t)stream);
case INFINI_DTYPE_F32: case INFINI_DTYPE_F32:
return op::matmul::kunlun::calculate<float>(_info, _opaque->internal, _dtype, c, beta, a, b, alpha, (kunlunStream_t)stream); return op::gemm::kunlun::calculate<float>(_info, _opaque->internal, _dtype, c, beta, a, b, alpha, (kunlunStream_t)stream);
default: default:
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
} }
} // namespace op::matmul::kunlun } // namespace op::gemm::kunlun
#ifndef __GEMM_KUNLUN_H__
#define __GEMM_KUNLUN_H__
#include "../gemm.h"
DESCRIPTOR(kunlun)
#endif // __GEMM_KUNLUN_H__
#include "matmul_maca.h" #include "gemm_maca.h"
#include "../../../devices/maca/common_maca.h" #include "../../../devices/maca/common_maca.h"
#include "../../../devices/maca/maca_handle.h" #include "../../../devices/maca/maca_handle.h"
namespace op::matmul::maca { namespace op::gemm::maca {
struct Descriptor::Opaque { struct Descriptor::Opaque {
std::shared_ptr<device::maca::Handle::Internal> internal; std::shared_ptr<device::maca::Handle::Internal> internal;
...@@ -106,4 +106,4 @@ infiniStatus_t Descriptor::calculate( ...@@ -106,4 +106,4 @@ infiniStatus_t Descriptor::calculate(
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
} // namespace op::matmul::maca } // namespace op::gemm::maca
#ifndef __GEMM_MACA_H__
#define __GEMM_MACA_H__
#include "../gemm.h"
DESCRIPTOR(maca)
#endif // __GEMM_MACA_H__
#include "../../operator.h" #include "../../operator.h"
#include "../../handle.h" #include "../../handle.h"
#include "infiniop/ops/matmul.h" #include "infiniop/ops/gemm.h"
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
#include "cpu/matmul_cpu.h" #include "cpu/gemm_cpu.h"
#endif #endif
#ifdef ENABLE_CUDA_API #ifdef ENABLE_CUDA_API
#include "cuda/matmul_cuda.cuh" #include "cuda/gemm_cuda.cuh"
#endif #endif
#ifdef ENABLE_CAMBRICON_API #ifdef ENABLE_CAMBRICON_API
#include "bang/matmul_bang.h" #include "bang/gemm_bang.h"
#endif #endif
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
#include "ascend/matmul_ascend.h" #include "ascend/gemm_ascend.h"
#endif #endif
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
#include "maca/matmul_maca.h" #include "maca/gemm_maca.h"
#endif #endif
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
#include "kunlun/matmul_kunlun.h" #include "kunlun/gemm_kunlun.h"
#endif #endif
__C infiniStatus_t infiniopCreateMatmulDescriptor( __C infiniStatus_t infiniopCreateGemmDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
infiniopMatmulDescriptor_t *desc_ptr, infiniopGemmDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) { infiniopTensorDescriptor_t b_desc) {
#define CREATE(CASE, NAMESPACE) \ #define CREATE(CASE, NAMESPACE) \
case CASE: \ case CASE: \
return op::matmul::NAMESPACE::Descriptor::create( \ return op::gemm::NAMESPACE::Descriptor::create( \
handle, \ handle, \
reinterpret_cast<op::matmul::NAMESPACE::Descriptor **>(desc_ptr), \ reinterpret_cast<op::gemm::NAMESPACE::Descriptor **>(desc_ptr), \
c_desc, \ c_desc, \
a_desc, \ a_desc, \
b_desc) b_desc)
switch (handle->device) { switch (handle->device) {
...@@ -66,13 +66,13 @@ __C infiniStatus_t infiniopCreateMatmulDescriptor( ...@@ -66,13 +66,13 @@ __C infiniStatus_t infiniopCreateMatmulDescriptor(
} }
__C infiniStatus_t __C infiniStatus_t
infiniopGetMatmulWorkspaceSize( infiniopGetGemmWorkspaceSize(
infiniopMatmulDescriptor_t desc, infiniopGemmDescriptor_t desc,
size_t *size) { size_t *size) {
#define GET(CASE, NAMESPACE) \ #define GET(CASE, NAMESPACE) \
case CASE: \ case CASE: \
*size = reinterpret_cast<const op::matmul::NAMESPACE::Descriptor *>(desc)->workspace_size; \ *size = reinterpret_cast<const op::gemm::NAMESPACE::Descriptor *>(desc)->workspace_size; \
return INFINI_STATUS_SUCCESS return INFINI_STATUS_SUCCESS
switch (desc->device_type) { switch (desc->device_type) {
...@@ -103,8 +103,8 @@ infiniopGetMatmulWorkspaceSize( ...@@ -103,8 +103,8 @@ infiniopGetMatmulWorkspaceSize(
#undef GET #undef GET
} }
__C infiniStatus_t infiniopMatmul( __C infiniStatus_t infiniopGemm(
infiniopMatmulDescriptor_t desc, infiniopGemmDescriptor_t desc,
void *workspace, size_t workspace_size, void *workspace, size_t workspace_size,
void *c, void *c,
const void *a, const void *a,
...@@ -113,12 +113,12 @@ __C infiniStatus_t infiniopMatmul( ...@@ -113,12 +113,12 @@ __C infiniStatus_t infiniopMatmul(
float beta, float beta,
void *stream) { void *stream) {
#define CALCULATE(CASE, NAMESPACE) \ #define CALCULATE(CASE, NAMESPACE) \
case CASE: \ case CASE: \
return reinterpret_cast<const op::matmul::NAMESPACE::Descriptor *>(desc) \ return reinterpret_cast<const op::gemm::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, \ ->calculate(workspace, workspace_size, \
c, beta, \ c, beta, \
a, b, alpha, \ a, b, alpha, \
stream) stream)
switch (desc->device_type) { switch (desc->device_type) {
...@@ -150,11 +150,11 @@ __C infiniStatus_t infiniopMatmul( ...@@ -150,11 +150,11 @@ __C infiniStatus_t infiniopMatmul(
} }
__C infiniStatus_t __C infiniStatus_t
infiniopDestroyMatmulDescriptor(infiniopMatmulDescriptor_t desc) { infiniopDestroyGemmDescriptor(infiniopGemmDescriptor_t desc) {
#define DELETE(CASE, NAMESPACE) \ #define DELETE(CASE, NAMESPACE) \
case CASE: \ case CASE: \
delete reinterpret_cast<const op::matmul::NAMESPACE::Descriptor *>(desc); \ delete reinterpret_cast<const op::gemm::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
switch (desc->device_type) { switch (desc->device_type) {
......
#ifndef __MATMUL_ASCEND_H__
#define __MATMUL_ASCEND_H__
#include "../matmul.h"
DESCRIPTOR(ascend)
#endif // __MATMUL_ASCEND_H__
#ifndef __MATMUL_BANG_H__
#define __MATMUL_BANG_H__
#include "../matmul.h"
DESCRIPTOR(bang)
#endif // __MATMUL_BANG_H__
#ifndef __MATMUL_CPU_H__
#define __MATMUL_CPU_H__
#include "../matmul.h"
DESCRIPTOR(cpu)
#endif // __MATMUL_CPU_H__
#ifndef __MATMUL_CUDA_CUH__
#define __MATMUL_CUDA_CUH__
#include "../matmul.h"
DESCRIPTOR(cuda)
#endif // __MATMUL_CUDA_CUH__
#ifndef __MATMUL_KUNLUN_H__
#define __MATMUL_KUNLUN_H__
#include "../matmul.h"
DESCRIPTOR(kunlun)
#endif // __MATMUL_KUNLUN_H__
#ifndef __MATMUL_MACA_H__
#define __MATMUL_MACA_H__
#include "../matmul.h"
DESCRIPTOR(maca)
#endif // __MATMUL_MACA_H__
#include "random_sample_cpu.h"
#include "../../../devices/cpu/common_cpu.h"
#include "../../../devices/cpu/cpu_handle.h"
#include "../../../tensor.h"
#include <algorithm>
namespace op::random_sample::cpu {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t result_desc,
infiniopTensorDescriptor_t probs_desc) {
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);
auto dt_i = result_desc->dtype();
auto dt_p = probs_desc->dtype();
CHECK_DTYPE(dt_i,
INFINI_DTYPE_U8, INFINI_DTYPE_U16, INFINI_DTYPE_U32, INFINI_DTYPE_U64,
INFINI_DTYPE_I8, INFINI_DTYPE_I16, INFINI_DTYPE_I32, INFINI_DTYPE_I64);
CHECK_DTYPE(dt_p, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_API_OR(result_desc->ndim(), 0,
return INFINI_STATUS_BAD_TENSOR_SHAPE);
CHECK_API_OR(probs_desc->ndim(), 1,
return INFINI_STATUS_BAD_TENSOR_SHAPE);
CHECK_API_OR(probs_desc->stride(0), 1,
return INFINI_STATUS_BAD_TENSOR_STRIDES);
*desc_ptr = new Descriptor(
dt_i,
dt_p,
probs_desc->dim(0),
0,
nullptr,
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
size_t Descriptor::minWorkspaceSize() const {
return _min_workspace_size;
}
template <typename DT>
struct ComputeType {
using type = DT;
};
template <>
struct ComputeType<fp16_t> {
using type = float;
};
template <class Tidx, class Tval>
struct Scheme {
using Tcompute = typename ComputeType<Tval>::type;
static Tcompute get(void const *ptr, size_t i) {
return utils::cast<Tcompute, Tval>(reinterpret_cast<Tval const *>(ptr)[i]);
}
static void argmax(
void *result, void const *probs, size_t n) {
auto idx = reinterpret_cast<Tidx *>(result);
*idx = 0;
auto max_val = get(probs, 0);
for (size_t i = 0; i < n; i++) {
if (auto val = get(probs, i); val > max_val) {
max_val = val;
*idx = static_cast<Tidx>(i);
}
}
}
static void random(
void *result, void const *probs, size_t n,
float random_val, float topp, int topk, float temperature) {
struct KVPair {
Tidx idx;
Tcompute val;
bool operator<(const KVPair &other) const {
return val > other.val;
}
};
auto idx = reinterpret_cast<Tidx *>(result);
// build & sort
std::vector<KVPair> pairs(n);
for (size_t i = 0; i < n; i++) {
pairs[i] = {static_cast<Tidx>(i), get(probs, i)};
}
std::sort(pairs.begin(), pairs.end());
// softmax & sum
auto const max_val = pairs[0].val;
pairs[0].val = 1;
for (size_t i = 1; i < n; i++) {
pairs[i].val = pairs[i - 1].val + std::exp((pairs[i].val - max_val) / temperature);
}
// topk & topp & limit
auto const pk = pairs[std::min(static_cast<size_t>(topk), n) - 1].val,
pp = pairs[n - 1].val * topp,
plimit = random_val * std::min(pk, pp);
// sample
for (size_t i = 0; i < n; i++) {
if (plimit <= pairs[i].val) {
*idx = pairs[i].idx;
break;
}
}
}
};
template <class Tidx, class Tval>
void switch_f(
size_t n,
void *result, const void *probs,
float random_val, float topp, int topk, float temperature) {
if (random_val == 0 || topp == 0 || topk == 1 || temperature == 0) {
Scheme<Tidx, Tval>::argmax(result, probs, n);
} else {
Scheme<Tidx, Tval>::random(result, probs, n, random_val, topp, topk, temperature);
}
}
template <class Tidx>
void switch_val(
infiniDtype_t dt_p, size_t n,
void *result, void const *probs,
float random_val, float topp, int topk, float temperature) {
switch (dt_p) {
case INFINI_DTYPE_F16:
switch_f<Tidx, fp16_t>(n, result, probs, random_val, topp, topk, temperature);
break;
case INFINI_DTYPE_F32:
switch_f<Tidx, float>(n, result, probs, random_val, topp, topk, temperature);
break;
case INFINI_DTYPE_F64:
switch_f<Tidx, double>(n, result, probs, random_val, topp, topk, temperature);
break;
default:
// unreachable
std::abort();
}
}
void switch_idx(
infiniDtype_t dt_i, infiniDtype_t dt_p, size_t n,
void *result, void const *probs,
float random_val, float topp, int topk, float temperature) {
#define CASE(DT_VAL, DT_TYP) \
case DT_VAL: \
switch_val<DT_TYP>(dt_p, n, result, probs, random_val, topp, topk, temperature); \
break
switch (dt_i) {
CASE(INFINI_DTYPE_I8, int8_t);
CASE(INFINI_DTYPE_I16, int16_t);
CASE(INFINI_DTYPE_I32, int32_t);
CASE(INFINI_DTYPE_I64, int64_t);
CASE(INFINI_DTYPE_U8, uint8_t);
CASE(INFINI_DTYPE_U16, uint16_t);
CASE(INFINI_DTYPE_U32, uint32_t);
CASE(INFINI_DTYPE_U64, uint64_t);
default:
// unreachable
std::abort();
}
#undef CASE
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *result,
const void *probs,
float random_val,
float topp,
int topk,
float temperature,
void *stream) const {
switch_idx(_dt_i, _dt_p, _n, result, probs, random_val, topp, topk, temperature);
return INFINI_STATUS_SUCCESS;
}
} // namespace op::random_sample::cpu
#ifndef __RANDOM_SAMPLE_CPU_H__
#define __RANDOM_SAMPLE_CPU_H__
#include "../random_sample.h"
DESCRIPTOR(cpu)
#endif // __RANDOM_SAMPLE_CPU_H__
...@@ -2,152 +2,111 @@ ...@@ -2,152 +2,111 @@
#include "../../handle.h" #include "../../handle.h"
#include "infiniop/ops/random_sample.h" #include "infiniop/ops/random_sample.h"
__C infiniStatus_t infiniopCreateRandomSampleDescriptor(infiniopHandle_t handle, infiniopRandomSampleDescriptor_t *desc_ptr, infiniopTensorDescriptor_t result, infiniopTensorDescriptor_t probs) { #ifdef ENABLE_CPU_API
switch (handle->device) { #include "cpu/random_sample_cpu.h"
#ifdef ENABLE_CPU
case DevCpu:
return cpuCreateRandomSampleDescriptor(handle, (RandomSampleCpuDescriptor_t *)desc_ptr, result, probs);
#endif
#ifdef ENABLE_NV_GPU
case DevNvGpu:
return cudaCreateRandomSampleDescriptor((CudaHandle_t)handle, (RandomSampleCudaDescriptor_t *)desc_ptr, result, probs);
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
return bangCreateRandomSampleDescriptor((BangHandle_t)handle,
(RandomSampleBangDescriptor_t *)desc_ptr, result,
probs);
}
#endif
#ifdef ENABLE_ASCEND_NPU
case DevAscendNpu: {
return ascendCreateRandomSampleDescriptor((AscendHandle_t)handle,
(RandomSampleAscendDescriptor_t *)desc_ptr, result, probs);
}
#endif
#ifdef ENABLE_METAX_GPU
case DevMetaxGpu: {
return macaCreateRandomSampleDescriptor((MacaHandle_t)handle,
(RandomSampleMacaDescriptor_t *)desc_ptr, result,
probs);
}
#endif #endif
#ifdef ENABLE_MTHREADS_GPU
case DevMthreadsGpu: __C infiniStatus_t infiniopCreateRandomSampleDescriptor(
return musaCreateRandomSampleDescriptor((MusaHandle_t)handle, (RandomSampleMusaDescriptor_t *)desc_ptr, result, probs); infiniopHandle_t handle,
infiniopRandomSampleDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t result,
infiniopTensorDescriptor_t probs) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::random_sample::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::random_sample::NAMESPACE::Descriptor **>(desc_ptr), \
result, \
probs)
switch (handle->device) {
#ifdef ENABLE_CPU_API
CREATE(INFINI_DEVICE_CPU, cpu);
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
#undef CREATE
}; };
__C infiniStatus_t infiniopGetRandomSampleWorkspaceSize(infiniopRandomSampleDescriptor_t desc, size_t *size) { __C infiniStatus_t infiniopGetRandomSampleWorkspaceSize(
infiniopRandomSampleDescriptor_t desc,
size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
using Ptr = const op::random_sample::NAMESPACE::Descriptor *; \
*size = reinterpret_cast<Ptr>(desc)->minWorkspaceSize(); \
return INFINI_STATUS_SUCCESS
switch (desc->device_type) { switch (desc->device_type) {
#ifdef ENABLE_CPU
case DevCpu:
return cpuGetRandomSampleWorkspaceSize((RandomSampleCpuDescriptor_t)desc, size);
#endif
#ifdef ENABLE_NV_GPU
case DevNvGpu: {
return cudaGetRandomSampleWorkspaceSize((RandomSampleCudaDescriptor_t)desc, size);
}
#ifdef ENABLE_CPU_API
GET(INFINI_DEVICE_CPU, cpu);
#endif #endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { default:
return bangGetRandomSampleWorkspaceSize((RandomSampleBangDescriptor_t)desc, size); return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
// return cnnlGetRandomSampleWorkspaceSize((RandomSampleCnnlDescriptor_t) desc, size);
}
#endif
#ifdef ENABLE_ASCEND_NPU
case DevAscendNpu: {
return ascendGetRandomSampleWorkspaceSize((RandomSampleAscendDescriptor_t)desc, size);
}
#endif
#ifdef ENABLE_METAX_GPU
case DevMetaxGpu: {
return macaGetRandomSampleWorkspaceSize((RandomSampleMacaDescriptor_t)desc, size);
}
#endif
#ifdef ENABLE_MTHREADS_GPU
case DevMthreadsGpu: {
return musaGetRandomSampleWorkspaceSize((RandomSampleMusaDescriptor_t)desc, size);
}
#endif
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
#undef GET
} }
__C infiniStatus_t infiniopRandomSample(infiniopRandomSampleDescriptor_t desc, __C infiniStatus_t infiniopRandomSample(
void *workspace, infiniopRandomSampleDescriptor_t desc,
size_t workspace_size, void *workspace,
void *result, size_t workspace_size,
const void *probs, void *result,
float random_val, const void *probs,
float topp, float random_val,
int topk, float topp,
float temperature, int topk,
void *stream) { float temperature,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::random_sample::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, \
result, probs, \
random_val, \
topp, topk, temperature, \
stream)
switch (desc->device_type) { switch (desc->device_type) {
#ifdef ENABLE_CPU
case DevCpu: #ifdef ENABLE_CPU_API
return cpuRandomSample((RandomSampleCpuDescriptor_t)desc, workspace, workspace_size, result, probs, random_val, topp, topk, temperature, stream); CALCULATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NV_GPU
case DevNvGpu:
return cudaRandomSample((RandomSampleCudaDescriptor_t)desc, workspace, workspace_size, result, probs, random_val, topp, topk, temperature, stream);
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
return bangRandomSample((RandomSampleBangDescriptor_t)desc, workspace, workspace_size, result, probs, random_val, topp, topk, temperature, stream);
}
#endif
#ifdef ENABLE_ASCEND_NPU
case DevAscendNpu: {
return ascendRandomSample((RandomSampleAscendDescriptor_t)desc, workspace, workspace_size, result, probs, random_val, topp, topk, temperature, stream);
}
#endif
#ifdef ENABLE_METAX_GPU
case DevMetaxGpu: {
return macaRandomSample((RandomSampleMacaDescriptor_t)desc, workspace, workspace_size, result, probs, random_val, topp, topk, temperature, stream);
}
#endif
#ifdef ENABLE_MTHREADS_GPU
case DevMthreadsGpu:
return musaRandomSample((RandomSampleMusaDescriptor_t)desc, workspace, workspace_size, result, probs, random_val, topp, topk, temperature, stream);
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
#undef CALCULATE
} }
__C infiniStatus_t infiniopDestroyRandomSampleDescriptor(infiniopRandomSampleDescriptor_t desc) { __C infiniStatus_t infiniopDestroyRandomSampleDescriptor(
infiniopRandomSampleDescriptor_t desc) {
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<const op::random_sample::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) { switch (desc->device_type) {
#ifdef ENABLE_CPU
case DevCpu: #ifdef ENABLE_CPU_API
return cpuDestroyRandomSampleDescriptor((RandomSampleCpuDescriptor_t)desc); DELETE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NV_GPU
case DevNvGpu:
return cudaDestroyRandomSampleDescriptor((RandomSampleCudaDescriptor_t)desc);
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
return bangDestroyRandomSampleDescriptor((RandomSampleBangDescriptor_t)desc);
}
#endif
#ifdef ENABLE_ASCEND_NPU
case DevAscendNpu: {
return ascendDestroyRandomSampleDescriptor((RandomSampleAscendDescriptor_t)desc);
}
#endif
#ifdef ENABLE_METAX_GPU
case DevMetaxGpu: {
return macaDestroyRandomSampleDescriptor((RandomSampleMacaDescriptor_t)desc);
}
#endif
#ifdef ENABLE_MTHREADS_GPU
case DevMthreadsGpu:
return musaDestroyRandomSampleDescriptor((RandomSampleMusaDescriptor_t)desc);
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
#undef DELETE
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment