Unverified Commit 0e1c5585 authored by zhangyue's avatar zhangyue Committed by GitHub
Browse files

Merge pull request #358 from InfiniTensor/issue/342

issue/342: 昆仑芯P800上random_sample算子
parents 19d60bf8 1cadb2a1
......@@ -43,22 +43,6 @@ __device__ inline void loadsm(__shared_ptr__ const T *p, T *v, int len) {
__builtin_memcpy(v, p, len * sizeof(T));
}
/**
* @brief Convert data type. All data is in local memory
* @param v: input value
* @return output value
*/
template <typename Tout, typename Tin>
__device__ inline Tout to(Tin v) {
if constexpr (std::is_same<Tin, half>::value) {
return __half2float(v);
} else if constexpr (std::is_same<Tin, bfloat16_t>::value) {
return __bfloat162float(v);
} else {
return static_cast<Tout>(v);
}
}
/**
* @brief atomicAdd for kunlun xpu
* @param ptr: pointer to shared memory
......
This diff is collapsed.
#ifndef __RANDOM_SAMPLE_KUNLUN_H__
#define __RANDOM_SAMPLE_KUNLUN_H__
#include "../random_sample.h"
DESCRIPTOR(kunlun)
#endif // __RANDOM_SAMPLE_KUNLUN_H__
#include "random_sample_kunlun.h"
#include "../../../devices/kunlun/kunlun_common.h"
#include "../../../devices/kunlun/kunlun_handle.h"
#include "../info.h"
#include "kernel.h"
#include "xpu/kernel/xtdk_io.h"
template <typename Tval, typename Tidx>
void launchKernel(void *workspace,
void *result,
const void *probs,
float random_val,
float topp,
int topk,
float temperature,
int64_t n,
XPUStream stream) {
constexpr unsigned int cluster_num = 8;
constexpr unsigned int core_num = 64;
char *workspace_value = reinterpret_cast<char *>(workspace);
int topk_ = topk <= (int)n ? topk : (int)n;
bool dosample = topk_ > 1 && temperature != 0.0f && topp != 0.0f && random_val != 0.0f;
Tval *values = (Tval *)workspace_value;
xpu_memcpy(values, (Tval *)probs, n * sizeof(Tval), XPU_DEVICE_TO_DEVICE);
Tval *values_global = values + n;
char *workspace_sum = workspace_value + (n + cluster_num * core_num * topk_) * sizeof(Tval);
float *sum_global = (float *)workspace_sum;
char *workspace_index = workspace_sum + cluster_num * sizeof(float);
Tidx *indices = (Tidx *)workspace_index;
Tidx *indices_global = indices + n;
if (dosample){
randomSampleKernel<cluster_num, core_num, Tval, float, Tidx><<<cluster_num, core_num, stream>>>((Tidx *)result,
(Tval *)probs,
random_val,
topp,
n,
topk_,
temperature,
indices,
values,
indices_global,
values_global,
sum_global);
}
else{
argmaxKernel<Tval, Tidx><<<cluster_num, core_num, stream>>>((Tidx *)result, (Tval *)probs, n,
indices,
values,
indices_global,
values_global);
}
}
#define LAUNCH_KERNEL(Tval, Tidx) \
launchKernel<Tval, Tidx>(workspace, result, probs, random_val, topp, topk, temperature, n, reinterpret_cast<kunlunStream_t>(stream));
namespace op::random_sample::kunlun {
struct Descriptor::Opaque {
std::shared_ptr<device::kunlun::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t result_desc,
infiniopTensorDescriptor_t probs_desc) {
auto handle = reinterpret_cast<device::kunlun::Handle *>(handle_);
auto result = RandomSampleInfo::create(result_desc, probs_desc);
CHECK_RESULT(result);
auto info = result.take();
int cluster_num = 8;
int core_num = 64;
int n = probs_desc->numel();
size_t workspace_size = (n + cluster_num * core_num * n) * (infiniSizeOf(probs_desc->dtype()) + infiniSizeOf(result_desc->dtype())) + cluster_num * sizeof(float);
*desc_ptr = new Descriptor(
info,
workspace_size,
new Opaque{handle->internal()},
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
size_t Descriptor::minWorkspaceSize() const {
return _min_workspace_size;
}
infiniStatus_t
Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *result,
const void *probs,
float random_val,
float topp,
int topk,
float temperature,
void *stream) const {
if (workspace_size < _min_workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
int n = (int)_info.n;
if (_info.dt_i == INFINI_DTYPE_I32){
switch (_info.dt_p) {
case INFINI_DTYPE_F16:
LAUNCH_KERNEL(half, int32_t);
return INFINI_STATUS_SUCCESS;
case INFINI_DTYPE_BF16:
LAUNCH_KERNEL(bfloat16_t, int32_t);
return INFINI_STATUS_SUCCESS;
case INFINI_DTYPE_F32:
LAUNCH_KERNEL(float, int32_t);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
else if (_info.dt_i == INFINI_DTYPE_I64){
switch (_info.dt_p) {
case INFINI_DTYPE_F16:
LAUNCH_KERNEL(half, int64_t);
return INFINI_STATUS_SUCCESS;
case INFINI_DTYPE_BF16:
LAUNCH_KERNEL(bfloat16_t, int64_t);
return INFINI_STATUS_SUCCESS;
case INFINI_DTYPE_F32:
LAUNCH_KERNEL(float, int64_t);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::random_sample::kunlun
......@@ -20,6 +20,9 @@
#ifdef ENABLE_MOORE_API
#include "moore/random_sample_moore.h"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/random_sample_kunlun.h"
#endif
__C infiniStatus_t
infiniopCreateRandomSampleDescriptor(
......@@ -59,6 +62,9 @@ infiniopCreateRandomSampleDescriptor(
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -101,6 +107,9 @@ __C infiniStatus_t infiniopGetRandomSampleWorkspaceSize(
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -153,6 +162,9 @@ __C infiniStatus_t infiniopRandomSample(
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -192,6 +204,9 @@ __C infiniStatus_t infiniopDestroyRandomSampleDescriptor(
#ifdef ENABLE_MOORE_API
DELETE(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_KUNLUN_API
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment