Unverified Commit 85bc98ac authored by qinyiqun's avatar qinyiqun Committed by GitHub
Browse files

ISSUE/628 适配QY C610 GPU,增加编译选项,适配已有算子。添加bge类模型所需的算子, (#629)



* ISSUE/628 适配QY C610 GPU,增加编译选项,适配已有算子。添加bge类模型所需的算子,包括gelu,layer_norm,lp_norm(支持l1,l2 norm),relu,softmax,tanh。

---------
Co-authored-by: default avatarxgqdut2016 <kenan_gewei@163.com>
Co-authored-by: default avatarxgqdut2016 <140036308+xgqdut2016@users.noreply.github.com>
parent 7c397dd2
#include "layer_norm_cpu.h"
#include "../../../devices/cpu/common_cpu.h"
#include "../../../reduce/cpu/reduce.h"
#include "../info.h"
namespace op::layer_norm::cpu {
template <typename Tdata>
infiniStatus_t calculate_layer_norm(
const LayerNormInfo &info,
Tdata *output,
Tdata *input_standardization,
Tdata *input_std_deviation,
const Tdata *input,
const Tdata *weight,
const Tdata *bias) {
#pragma omp parallel for
for (int b = 0; b < (int)(info.input_shape[0] * info.input_shape[1]); b++) {
int b0 = b / (int)info.input_shape[1], b1 = b % (int)info.input_shape[1];
auto output_ptr = output + b0 * info.output_strides[0] + b1 * info.output_strides[1];
auto input_ptr = input + b0 * info.input_strides[0] + b1 * info.input_strides[1];
auto standard_ptr = input_standardization + b0 * info.input_standardization_strides[0] + b1 * info.input_standardization_strides[1];
auto std_ptr = input_std_deviation + b0 * info.input_std_deviation_strides[0] + b1 * info.input_std_deviation_strides[1];
float mean = op::common_cpu::reduce_op::sum(
input_ptr,
info.normalized_size,
info.input_strides[2])
/ info.input_shape[2];
float sum_sq = op::common_cpu::reduce_op::sumSquared(
input_ptr,
info.normalized_size,
info.input_strides[2]);
float var = sum_sq / (info.normalized_size) - mean * mean;
float std_deviation = std::sqrt(var + info.eps);
*std_ptr = utils::cast<Tdata>(std_deviation);
for (size_t d = 0; d < info.normalized_size; d++) {
float x_standard = (utils::cast<float>(*(input_ptr + d * info.input_strides[2])) - mean) / std_deviation;
*(standard_ptr + d * info.input_standardization_strides[2]) = utils::cast<Tdata>(x_standard);
*(output_ptr + d * info.output_strides[2]) = utils::cast<Tdata>(
x_standard * utils::cast<float>(*(weight + d * info.weight_strides[0])) + (info.bias_exist ? utils::cast<float>(*(bias + d * info.bias_strides[0])) : float(0)));
}
}
return INFINI_STATUS_SUCCESS;
}
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t output_desc,
infiniopTensorDescriptor_t input_standardization_desc,
infiniopTensorDescriptor_t input_std_deviation_desc,
infiniopTensorDescriptor_t input_desc,
infiniopTensorDescriptor_t weight_desc,
infiniopTensorDescriptor_t bias_desc,
float eps) {
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);
// --------------------- start: check data type and calculate workspace size ----------------------
auto dtype = input_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
size_t WorkSpaceSize = 0;
auto result = LayerNormInfo::createLayerNormInfo(
output_desc,
input_standardization_desc,
input_std_deviation_desc,
input_desc,
weight_desc,
bias_desc,
eps);
CHECK_RESULT(result);
const LayerNormInfo &info = result.take();
*desc_ptr = new Descriptor(
dtype, std::move(info), WorkSpaceSize,
nullptr,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
#define CALCULATE_LAYER_NORM(TDATA) \
CHECK_STATUS(calculate_layer_norm<TDATA>(_info, \
(TDATA *)output, (TDATA *)input_standardization, (TDATA *)input_std_deviation, (const TDATA *)input, (const TDATA *)weight, (const TDATA *)bias))
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
void *input_standardization,
void *input_std_deviation,
const void *input,
const void *weight,
const void *bias,
void *stream) const {
if (_info.dtype == INFINI_DTYPE_F16) {
CALCULATE_LAYER_NORM(fp16_t);
} else if (_info.dtype == INFINI_DTYPE_BF16) {
CALCULATE_LAYER_NORM(bf16_t);
} else if (_info.dtype == INFINI_DTYPE_F32) {
CALCULATE_LAYER_NORM(float);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::layer_norm::cpu
#ifndef __LAYER_NORM_CPU_H__
#define __LAYER_NORM_CPU_H__
#include "../layer_norm.h"
DESCRIPTOR(cpu)
#endif // __LAYER_NORM_CPU_H__
#ifndef __LAYER_NORM_KERNEL_CUH__
#define __LAYER_NORM_KERNEL_CUH__
#include <cub/block/block_reduce.cuh>
template <unsigned int BLOCK_SIZE, typename Tdata, typename Tcompute>
__device__ void layerNormKernel(
Tdata *output,
Tdata *input_standardization,
Tdata *input_std_deviation,
const Tdata *input,
const Tdata *weight,
const Tdata *bias,
float eps,
size_t normalized_size,
const ptrdiff_t *output_strides,
const ptrdiff_t *input_standardization_strides,
const ptrdiff_t *input_std_deviation_strides,
const ptrdiff_t *input_strides,
ptrdiff_t weight_stride,
ptrdiff_t bias_stride,
bool bias_exist) {
size_t b0 = blockIdx.x, b1 = blockIdx.y;
auto output_ptr = output + b0 * output_strides[0] + b1 * output_strides[1];
auto input_ptr = input + b0 * input_strides[0] + b1 * input_strides[1];
auto standard_ptr = input_standardization + b0 * input_standardization_strides[0] + b1 * input_standardization_strides[1];
auto std_ptr = input_std_deviation + b0 * input_std_deviation_strides[0] + b1 * input_std_deviation_strides[1];
Tcompute mean = op::common_cuda::reduce_op::sum<BLOCK_SIZE, Tdata, Tcompute>(
input_ptr,
normalized_size)
/ normalized_size;
Tcompute sum_squared = op::common_cuda::reduce_op::sumSquared<BLOCK_SIZE, Tdata, Tcompute>(
input_ptr,
normalized_size);
Tcompute var = sum_squared / normalized_size - mean * mean;
Tcompute std_deviation = sqrtf(var + Tcompute(eps));
*std_ptr = std_deviation;
for (size_t d = 0; d < normalized_size; d++) {
Tcompute x_standard = (Tcompute(input_ptr[d]) - mean) / std_deviation;
standard_ptr[d] = x_standard;
output_ptr[d] = x_standard * Tcompute(*(weight + d * weight_stride)) + (bias_exist ? Tcompute(*(bias + d * bias_stride)) : Tcompute(0));
}
}
template <typename T, int BLOCK_SIZE>
__device__ void blockLayernormKernel(T *output, T const *input, T const *weight, T const *bias, float eps, int dimsize,
const ptrdiff_t *output_strides,
const ptrdiff_t *input_strides,
const size_t *shape,
ptrdiff_t weight_stride,
ptrdiff_t bias_stride,
int ndim,
bool bias_exist) {
// 只能处理axis=-1
int ind_i = 0; // input id
int ind_o = 0; // output id
int tid = blockIdx.x;
for (int j = ndim - 2; j >= 0; j--) {
ind_i += (tid % (int)shape[j]) * (int)input_strides[j];
ind_o += (tid % (int)shape[j]) * (int)output_strides[j];
tid = tid / (int)shape[j];
}
float mu_partial = op::common_cuda::reduce_op::sum<BLOCK_SIZE, T, float>(
input + ind_i,
dimsize)
/ dimsize;
__shared__ float mu;
if (threadIdx.x == 0) {
mu = mu_partial;
} // threadIdx.x = 0对应的是全局sum
typedef cub::BlockReduce<float, BLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__syncthreads();
float sigma2_partial = 0.0f;
for (int id = threadIdx.x; id < dimsize; id += BLOCK_SIZE) {
sigma2_partial += (static_cast<float>(input[ind_i + id]) - mu) * (static_cast<float>(input[ind_i + id]) - mu);
}
__shared__ float sigma2;
float sigma2_block = BlockReduce(temp_storage).Reduce(sigma2_partial, cub::Sum());
if (threadIdx.x == 0) {
float sigma_tmp = sqrt(sigma2_block * __fdividef(1.0F, dimsize) + eps);
sigma2 = __fdividef(1.0F, sigma_tmp);
}
__syncthreads();
for (int id = threadIdx.x; id < dimsize; id += BLOCK_SIZE) {
output[ind_o + id] = static_cast<T>(static_cast<float>(weight[id * weight_stride]) * (static_cast<float>(input[ind_i + id]) - mu) * sigma2 + (bias_exist ? static_cast<float>(bias[id * bias_stride]) : 0.0f));
}
}
template <typename T>
struct SumOp {
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
return a + b;
}
};
template <template <typename> class ReductionOp, typename T,
int thread_group_width>
__inline__ __device__ T WarpAllReduce(T val) {
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
val = ReductionOp<T>()(val, __shfl_xor_sync(0xffffffff, val, mask));
}
return val;
}
template <typename T, int BLOCK_SIZE_x, int BLOCK_SIZE_y>
__device__ void warpLayernormKernel(T *output, T const *input, T const *weight, T const *bias, float eps, int othersize, int dimsize,
const ptrdiff_t *output_strides,
const ptrdiff_t *input_strides,
const size_t *shape,
ptrdiff_t weight_stride,
ptrdiff_t bias_stride,
int ndim,
bool bias_exist) {
// 默认dimsize < 1024
int ind_i = 0; // input id
int ind_o = 0; // output id
int tid = blockIdx.x * blockDim.y + threadIdx.y;
if (tid < othersize) {
for (int j = ndim - 2; j >= 0; j--) {
ind_i += (tid % (int)shape[j]) * (int)input_strides[j];
ind_o += (tid % (int)shape[j]) * (int)output_strides[j];
tid = tid / (int)shape[j];
}
float mu_partial = 0.0f;
for (int id = threadIdx.x; id < dimsize; id += BLOCK_SIZE_x) {
mu_partial += static_cast<float>(input[ind_i + id]);
}
mu_partial = WarpAllReduce<SumOp, float, BLOCK_SIZE_x>(mu_partial);
__shared__ float mu[BLOCK_SIZE_y];
if (threadIdx.x == 0) {
mu[threadIdx.y] = mu_partial * __fdividef(1.0F, dimsize);
} // threadIdx.x = 0对应的是全局sum
__syncthreads();
float sigma2_partial = 0.0f;
for (int id = threadIdx.x; id < dimsize; id += BLOCK_SIZE_x) {
sigma2_partial += (static_cast<float>(input[ind_i + id]) - mu[threadIdx.y]) * (static_cast<float>(input[ind_i + id]) - mu[threadIdx.y]);
}
sigma2_partial = WarpAllReduce<SumOp, float, BLOCK_SIZE_x>(sigma2_partial);
__shared__ float sigma2[BLOCK_SIZE_y];
if (threadIdx.x == 0) {
float sigma_tmp = sqrt(sigma2_partial * __fdividef(1.0F, dimsize) + eps);
sigma2[threadIdx.y] = __fdividef(1.0F, sigma_tmp);
}
__syncthreads();
for (int id = threadIdx.x; id < dimsize; id += BLOCK_SIZE_x) {
output[ind_o + id] = static_cast<T>(static_cast<float>(weight[id * weight_stride]) * (static_cast<float>(input[ind_i + id]) - mu[threadIdx.y]) * sigma2[threadIdx.y] + (bias_exist ? static_cast<float>(bias[id * bias_stride]) : 0.0f));
}
}
}
#endif // __LAYER_NORM_KERNEL_CUH__
#ifndef __LAYER_NORM_INFO_H__
#define __LAYER_NORM_INFO_H__
#include "../../../utils.h"
#include "../../operator.h"
#include "../../tensor.h"
namespace op::layer_norm {
class LayerNormInfo {
private:
LayerNormInfo() = default;
public:
infiniDtype_t dtype;
size_t ndim;
std::vector<size_t> input_shape;
size_t normalized_size;
size_t othersize;
std::vector<ptrdiff_t> output_strides;
std::vector<ptrdiff_t> input_standardization_strides;
std::vector<ptrdiff_t> input_std_deviation_strides;
std::vector<ptrdiff_t> input_strides;
std::vector<ptrdiff_t> weight_strides;
std::vector<ptrdiff_t> bias_strides;
float eps;
bool bias_exist;
static utils::Result<LayerNormInfo> createLayerNormInfo(
infiniopTensorDescriptor_t output_desc,
infiniopTensorDescriptor_t input_standardization_desc,
infiniopTensorDescriptor_t input_std_deviation_desc,
infiniopTensorDescriptor_t input_desc,
infiniopTensorDescriptor_t weight_desc,
infiniopTensorDescriptor_t bias_desc,
float eps) {
CHECK_SAME_SHAPE(
output_desc->shape(), input_desc->shape(), input_standardization_desc->shape());
size_t ndim = input_desc->ndim();
size_t normalized_size = input_desc->dim(ndim - 1);
size_t othersize = 1;
for (size_t i = 0; i < ndim - 1; i++) {
othersize *= input_desc->dim(i);
}
size_t feature_size = input_desc->dim(ndim - 1);
bool bias_exist = bias_desc != nullptr;
CHECK_OR_RETURN(
(!bias_exist) || (bias_desc->ndim() == 1 && bias_desc->dim(0) == feature_size),
INFINI_STATUS_BAD_TENSOR_SHAPE);
CHECK_OR_RETURN(
(weight_desc->ndim() == 1) && (weight_desc->dim(0) == feature_size),
INFINI_STATUS_BAD_TENSOR_SHAPE);
CHECK_OR_RETURN(
input_std_deviation_desc->ndim() == ndim - 1,
INFINI_STATUS_BAD_TENSOR_SHAPE);
for (size_t i = 0; i < ndim - 1; i++) {
CHECK_OR_RETURN(
input_std_deviation_desc->dim(i) == input_desc->dim(i),
INFINI_STATUS_BAD_TENSOR_SHAPE);
}
return utils::Result<LayerNormInfo>(LayerNormInfo{
output_desc->dtype(),
ndim,
input_desc->shape(),
normalized_size,
othersize,
output_desc->strides(),
input_standardization_desc->strides(),
input_std_deviation_desc->strides(),
input_desc->strides(),
weight_desc->strides(),
bias_exist ? bias_desc->strides() : std::vector<ptrdiff_t>(),
eps,
bias_exist});
}
};
} // namespace op::layer_norm
#endif // __LAYER_NORM_INFO_H__
#ifndef __LAYER_NORM_H__
#define __LAYER_NORM_H__
#include "../../../utils.h"
#include "../../operator.h"
#include "../../tensor.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
namespace op::layer_norm::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
LayerNormInfo _info; \
size_t _workspace_size; \
Descriptor( \
infiniDtype_t dtype, \
LayerNormInfo info, \
size_t workspace_size_, \
Opaque *opaque, \
infiniDevice_t device_type, \
int device_id) : InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size_) {} \
\
public: \
~Descriptor(); \
size_t workspaceSize() const { return _workspace_size; } \
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t output_desc, \
infiniopTensorDescriptor_t input_standardization_desc, \
infiniopTensorDescriptor_t input_std_deviation_desc, \
infiniopTensorDescriptor_t input_desc, \
infiniopTensorDescriptor_t weight_desc, \
infiniopTensorDescriptor_t bias_desc, \
float eps); \
infiniStatus_t calculate( \
void *workspace, \
size_t workspace_size, \
void *output, \
void *input_standardization, \
void *input_std_deviation, \
const void *input, \
const void *weight, \
const void *bias, \
void *stream) const; \
}; \
}
#endif
#ifndef __LAYER_NORM_METAX_H__
#define __LAYER_NORM_METAX_H__
#include "../layer_norm.h"
DESCRIPTOR(metax)
#endif // __LAYER_NORM_METAX_H__
#include "../../../devices/metax/metax_common.h"
#include "layer_norm_metax.h"
#include <hccub/block/block_reduce.cuh>
#include "../../../devices/metax/metax_kernel_common.h"
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
#include "../info.h"
namespace op::layer_norm::metax {
template <unsigned int BLOCK_SIZE, typename Tdata, typename Tcompute>
INFINIOP_METAX_KERNEL launchKernel(
Tdata * output,
Tdata * input_standardization,
Tdata * input_std_deviation,
const Tdata * input,
const Tdata * weight,
const Tdata * bias,
float eps,
size_t normalized_size,
const ptrdiff_t* output_strides,
const ptrdiff_t* input_standardization_strides,
const ptrdiff_t* input_std_deviation_strides,
const ptrdiff_t* input_strides,
ptrdiff_t weight_stride,
ptrdiff_t bias_stride,
bool bias_exist
) {
layerNormKernel<BLOCK_SIZE, Tdata, Tcompute>(
output,
input_standardization,
input_std_deviation,
input,
weight,
bias,
eps,
normalized_size,
output_strides,
input_standardization_strides,
input_std_deviation_strides,
input_strides,
weight_stride,
bias_stride,
bias_exist
);
}
// ----------------------------------- start: call launchKernel -----------------------------------
template <unsigned int BLOCK_SIZE, typename Tdata>
infiniStatus_t calculate_layer_norm(
const LayerNormInfo &info,
Tdata * output,
Tdata * input_standardization,
Tdata * input_std_deviation,
const Tdata * input,
const Tdata * weight,
const Tdata * bias,
hcStream_t stream,
void *workspace
) {
size_t ndim = info.ndim;
ptrdiff_t * input_strides_cuda = reinterpret_cast<ptrdiff_t*>(workspace);
ptrdiff_t * output_strides_cuda = input_strides_cuda + ndim;
ptrdiff_t * input_standardization_strides_cuda = output_strides_cuda + ndim;
ptrdiff_t * input_std_deviation_strides_cuda = input_standardization_strides_cuda + ndim;
CHECK_METAX(hcMemcpyAsync(input_strides_cuda, info.input_strides.data(), sizeof(ptrdiff_t) * ndim, hcMemcpyHostToDevice, stream));
CHECK_METAX(hcMemcpyAsync(output_strides_cuda, info.output_strides.data(), sizeof(ptrdiff_t) * ndim, hcMemcpyHostToDevice, stream));
CHECK_METAX(hcMemcpyAsync(input_standardization_strides_cuda, info.input_standardization_strides.data(), sizeof(ptrdiff_t) * (ndim - 1), hcMemcpyHostToDevice, stream));
CHECK_METAX(hcMemcpyAsync(input_std_deviation_strides_cuda, info.input_std_deviation_strides.data(), sizeof(ptrdiff_t) * (ndim - 1), hcMemcpyHostToDevice, stream));
launchKernel<1, Tdata, float><<<dim3(info.input_shape[0], info.input_shape[1]), 1, 0, stream>>>(
output,
input_standardization,
input_std_deviation,
input,
weight,
bias,
info.eps,
info.normalized_size,
output_strides_cuda,
input_standardization_strides_cuda,
input_std_deviation_strides_cuda,
input_strides_cuda,
info.weight_strides[0],
info.bias_exist ? info.bias_strides[0] : 0,
info.bias_exist
);
return INFINI_STATUS_SUCCESS;
}
// ------------------------------------ end: call launchKernel ------------------------------------
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t output_desc,
infiniopTensorDescriptor_t input_standardization_desc,
infiniopTensorDescriptor_t input_std_deviation_desc,
infiniopTensorDescriptor_t input_desc,
infiniopTensorDescriptor_t weight_desc,
infiniopTensorDescriptor_t bias_desc,
float eps
) {
auto handle = reinterpret_cast<device::metax::Handle *>(handle_);
// --------------------- start: check data type and calculate workspace size ----------------------
auto dtype = output_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
auto result = LayerNormInfo::createLayerNormInfo(
output_desc,
input_standardization_desc,
input_std_deviation_desc,
input_desc,
weight_desc,
bias_desc,
eps
);
CHECK_RESULT(result);
const LayerNormInfo &info = result.take();
size_t WorkSpaceSize = sizeof(ptrdiff_t) * input_desc->ndim() * 4;
// ---------------------- end: check data type and calculate workspace size -----------------------
*desc_ptr = new Descriptor(
dtype, std::move(info), WorkSpaceSize,
new Opaque{handle->internal()},
handle->device, handle->device_id
);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void * workspace,
size_t workspace_size,
void * output,
void * input_standardization,
void * input_std_deviation,
const void * input,
const void * weight,
const void * bias,
void *stream_
) const {
if (workspace_size < _workspace_size)
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
hcStream_t stream = (hcStream_t)stream_;
#define CALCULATE_LAYER_NORM(BLOCK_SIZE, TDATA) \
calculate_layer_norm<BLOCK_SIZE, TDATA>(_info, (TDATA *)output, (TDATA *)input_standardization, (TDATA *)input_std_deviation, (const TDATA *)input, (const TDATA *)weight, (const TDATA *)bias, stream, workspace)
#define CALCULATE_LAYER_NORM_WITH_MATEX_BLOCK(BLOCK_SIZE) \
{ \
if (_info.dtype == INFINI_DTYPE_F16) \
return CALCULATE_LAYER_NORM(BLOCK_SIZE, half); \
else if (_info.dtype == INFINI_DTYPE_F32) \
return CALCULATE_LAYER_NORM(BLOCK_SIZE, float); \
else if (_info.dtype == INFINI_DTYPE_BF16) \
return CALCULATE_LAYER_NORM(BLOCK_SIZE, cuda_bfloat16); \
else \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024)
CALCULATE_LAYER_NORM_WITH_MATEX_BLOCK(METAX_BLOCK_SIZE_1024)
else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_512)
CALCULATE_LAYER_NORM_WITH_MATEX_BLOCK(METAX_BLOCK_SIZE_512)
else
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
return INFINI_STATUS_SUCCESS;
#undef CALCULATE_LAYER_NORM_WITH_MATEX_BLOCK
#undef CALCULATE_LAYER_NORM
}
} // namespace op::layer_norm::metax
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_handle.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
#include "../info.h"
#include "layer_norm_nvidia.cuh"
#include <cub/block/block_reduce.cuh>
namespace op::layer_norm::nvidia {
template <unsigned int BLOCK_SIZE, typename Tdata, typename Tcompute>
INFINIOP_CUDA_KERNEL launchKernel(
Tdata *output,
Tdata *input_standardization,
Tdata *input_std_deviation,
const Tdata *input,
const Tdata *weight,
const Tdata *bias,
float eps,
size_t normalized_size,
const ptrdiff_t *output_strides,
const ptrdiff_t *input_standardization_strides,
const ptrdiff_t *input_std_deviation_strides,
const ptrdiff_t *input_strides,
ptrdiff_t weight_stride,
ptrdiff_t bias_stride,
bool bias_exist) {
layerNormKernel<BLOCK_SIZE, Tdata, Tcompute>(
output,
input_standardization,
input_std_deviation,
input,
weight,
bias,
eps,
normalized_size,
output_strides,
input_standardization_strides,
input_std_deviation_strides,
input_strides,
weight_stride,
bias_stride,
bias_exist);
}
template <typename Tdata, unsigned int BLOCK_SIZE>
INFINIOP_CUDA_KERNEL blockLayernorm(
Tdata *output,
const Tdata *input,
const Tdata *weight,
const Tdata *bias,
float eps,
int dimsize,
const ptrdiff_t *output_strides,
const ptrdiff_t *input_strides,
const size_t *shape,
ptrdiff_t weight_stride,
ptrdiff_t bias_stride,
int ndim,
bool bias_exist) {
blockLayernormKernel<Tdata, BLOCK_SIZE>(output,
input,
weight,
bias,
eps,
dimsize,
output_strides,
input_strides,
shape,
weight_stride,
bias_stride,
ndim,
bias_exist);
}
template <typename Tdata, unsigned int BLOCK_SIZE_x, unsigned int BLOCK_SIZE_y>
INFINIOP_CUDA_KERNEL warpLayernorm(
Tdata *output,
const Tdata *input,
const Tdata *weight,
const Tdata *bias,
float eps,
int othersize,
int dimsize,
const ptrdiff_t *output_strides,
const ptrdiff_t *input_strides,
const size_t *shape,
ptrdiff_t weight_stride,
ptrdiff_t bias_stride,
int ndim,
bool bias_exist) {
warpLayernormKernel<Tdata, BLOCK_SIZE_x, BLOCK_SIZE_y>(output,
input,
weight,
bias,
eps,
othersize,
dimsize,
output_strides,
input_strides,
shape,
weight_stride,
bias_stride,
ndim,
bias_exist);
}
template <unsigned int BLOCK_SIZE, typename Tdata>
infiniStatus_t calculate_layer_norm(
const LayerNormInfo &info,
Tdata *output,
Tdata *input_standardization,
Tdata *input_std_deviation,
const Tdata *input,
const Tdata *weight,
const Tdata *bias,
cudaStream_t stream,
void *workspace) {
size_t ndim = info.ndim;
char *workspace_ptr = reinterpret_cast<char *>(workspace);
ptrdiff_t *input_strides_cuda = reinterpret_cast<ptrdiff_t *>(workspace_ptr);
ptrdiff_t *output_strides_cuda = input_strides_cuda + ndim;
ptrdiff_t *input_standardization_strides_cuda = output_strides_cuda + ndim;
ptrdiff_t *input_std_deviation_strides_cuda = input_standardization_strides_cuda + ndim;
size_t ptrdiff_array_size = 4 * ndim * sizeof(ptrdiff_t);
size_t *shape_cuda = reinterpret_cast<size_t *>(workspace_ptr + ptrdiff_array_size);
CHECK_CUDA(cudaMemcpyAsync(input_strides_cuda, info.input_strides.data(), sizeof(ptrdiff_t) * ndim, cudaMemcpyHostToDevice, stream));
CHECK_CUDA(cudaMemcpyAsync(output_strides_cuda, info.output_strides.data(), sizeof(ptrdiff_t) * ndim, cudaMemcpyHostToDevice, stream));
CHECK_CUDA(cudaMemcpyAsync(input_standardization_strides_cuda, info.input_standardization_strides.data(), sizeof(ptrdiff_t) * (ndim - 1), cudaMemcpyHostToDevice, stream));
CHECK_CUDA(cudaMemcpyAsync(input_std_deviation_strides_cuda, info.input_std_deviation_strides.data(), sizeof(ptrdiff_t) * (ndim - 1), cudaMemcpyHostToDevice, stream));
CHECK_CUDA(cudaMemcpyAsync(shape_cuda, info.input_shape.data(), sizeof(size_t) * ndim, cudaMemcpyHostToDevice, stream));
int dimsize = (int)info.normalized_size;
int num_blocks = (int)info.othersize;
if (dimsize > 1024) {
blockLayernorm<Tdata, BLOCK_SIZE>
<<<num_blocks, BLOCK_SIZE, 0, stream>>>(output,
input,
weight,
bias,
info.eps,
dimsize,
output_strides_cuda,
input_strides_cuda,
shape_cuda,
info.weight_strides[0],
info.bias_exist ? info.bias_strides[0] : 0,
(int)info.ndim,
info.bias_exist);
} else {
constexpr unsigned int BLOCK_SIZE_x = 32;
constexpr unsigned int BLOCK_SIZE_y = 32;
int num_block_x = (num_blocks + BLOCK_SIZE_y - 1) / BLOCK_SIZE_y;
dim3 block_dim(BLOCK_SIZE_x, BLOCK_SIZE_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLayernorm<Tdata, BLOCK_SIZE_x, BLOCK_SIZE_y>
<<<grid_dim, block_dim, 0, stream>>>(output,
input,
weight,
bias,
info.eps,
num_blocks,
dimsize,
output_strides_cuda,
input_strides_cuda,
shape_cuda,
info.weight_strides[0],
info.bias_exist ? info.bias_strides[0] : 0,
(int)info.ndim,
info.bias_exist);
}
return INFINI_STATUS_SUCCESS;
}
struct Descriptor::Opaque {
std::shared_ptr<device::nvidia::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t output_desc,
infiniopTensorDescriptor_t input_standardization_desc,
infiniopTensorDescriptor_t input_std_deviation_desc,
infiniopTensorDescriptor_t input_desc,
infiniopTensorDescriptor_t weight_desc,
infiniopTensorDescriptor_t bias_desc,
float eps) {
auto handle = reinterpret_cast<device::nvidia::Handle *>(handle_);
auto dtype = output_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
size_t WorkSpaceSize = output_desc->ndim() * (sizeof(ptrdiff_t) * 4 + sizeof(size_t));
auto result = LayerNormInfo::createLayerNormInfo(
output_desc,
input_standardization_desc,
input_std_deviation_desc,
input_desc,
weight_desc,
bias_desc,
eps);
CHECK_RESULT(result);
const LayerNormInfo &info = result.take();
*desc_ptr = new Descriptor(
dtype, std::move(info), WorkSpaceSize,
new Opaque{handle->internal()},
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
void *input_standardization,
void *input_std_deviation,
const void *input,
const void *weight,
const void *bias,
void *stream_) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
cudaStream_t stream = (cudaStream_t)stream_;
#define CALCULATE_LAYER_NORM(BLOCK_SIZE, TDATA) \
calculate_layer_norm<BLOCK_SIZE, TDATA>(_info, (TDATA *)output, (TDATA *)input_standardization, (TDATA *)input_std_deviation, (const TDATA *)input, (const TDATA *)weight, (const TDATA *)bias, stream, workspace)
#define CALCULATE_LAYER_NORM_WITH_BLOCK_SIZE(BLOCK_SIZE) \
{ \
if (_info.dtype == INFINI_DTYPE_F16) \
return CALCULATE_LAYER_NORM(BLOCK_SIZE, half); \
else if (_info.dtype == INFINI_DTYPE_F32) \
return CALCULATE_LAYER_NORM(BLOCK_SIZE, float); \
else if (_info.dtype == INFINI_DTYPE_BF16) \
return CALCULATE_LAYER_NORM(BLOCK_SIZE, __nv_bfloat16); \
else \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
CALCULATE_LAYER_NORM_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_1024)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) {
CALCULATE_LAYER_NORM_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_512)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) {
CALCULATE_LAYER_NORM_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_4096)
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::layer_norm::nvidia
#ifndef __LAYER_NORM_NVIDIA_API_H__
#define __LAYER_NORM_NVIDIA_API_H__
#include "../layer_norm.h"
DESCRIPTOR(nvidia)
#endif // __LAYER_NORM_NVIDIA_API_H__
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/layer_norm.h"
#ifdef ENABLE_CPU_API
#include "cpu/layer_norm_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#include "nvidia/layer_norm_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
#include "metax/layer_norm_metax.h"
#endif
__C infiniStatus_t infiniopCreateLayerNormDescriptor(
infiniopHandle_t handle,
infiniopLayerNormDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t output_desc,
infiniopTensorDescriptor_t input_standardization_desc,
infiniopTensorDescriptor_t input_std_deviation_desc,
infiniopTensorDescriptor_t input_desc,
infiniopTensorDescriptor_t weight_desc,
infiniopTensorDescriptor_t bias_desc,
float eps) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::layer_norm::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::layer_norm::NAMESPACE::Descriptor **>(desc_ptr), \
output_desc, \
input_standardization_desc, \
input_std_deviation_desc, \
input_desc, \
weight_desc, \
bias_desc, \
eps)
switch (handle->device) {
#ifdef ENABLE_CPU_API
CREATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CREATE
}
__C infiniStatus_t infiniopGetLayerNormWorkspaceSize(infiniopLayerNormDescriptor_t desc, size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::layer_norm::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
GET(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef GET
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopLayerNorm(
infiniopLayerNormDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *output,
void *input_standardization,
void *input_std_deviation,
const void *input,
const void *weight,
const void *bias,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::layer_norm::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, \
workspace_size, \
output, \
input_standardization, \
input_std_deviation, \
input, \
weight, \
bias, \
stream)
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
CALCULATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CALCULATE
}
__C infiniStatus_t
infiniopDestroyLayerNormDescriptor(infiniopLayerNormDescriptor_t desc) {
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<const op::layer_norm::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
DELETE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
DELETE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, metax);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef DELETE
}
......@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/logsoftmax_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#include "nvidia/logsoftmax_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
......@@ -39,6 +39,9 @@ __C infiniStatus_t infiniopCreateLogSoftmaxDescriptor(
#ifdef ENABLE_ILUVATAR_API
// CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
// CREATE(INFINI_DEVICE_METAX, metax)
#endif
......@@ -66,6 +69,9 @@ __C infiniStatus_t infiniopGetLogSoftmaxWorkspaceSize(infiniopLogSoftmaxDescript
#ifdef ENABLE_ILUVATAR_API
// GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
// GET(INFINI_DEVICE_METAX, metax)
#endif
......@@ -98,6 +104,9 @@ __C infiniStatus_t infiniopLogSoftmax(
#ifdef ENABLE_ILUVATAR_API
// CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
// CALCULATE(INFINI_DEVICE_METAX, metax)
#endif
......@@ -125,6 +134,9 @@ __C infiniStatus_t infiniopDestroyLogSoftmaxDescriptor(infiniopLogSoftmaxDescrip
#ifdef ENABLE_ILUVATAR_API
// DESTROY(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
DESTROY(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
// DESTROY(INFINI_DEVICE_METAX, metax)
#endif
......
#ifndef __LP_NORM_KERNEL_CUH__
#define __LP_NORM_KERNEL_CUH__
#include <cub/block/block_reduce.cuh>
template <typename T, unsigned int BLOCK_SIZE>
__device__ void blockLPNormKernel(
T const *input, T *output, float p, size_t dimsize,
ptrdiff_t stride, float eps) {
int tid = blockIdx.x % stride + (blockIdx.x - blockIdx.x % stride) * dimsize; // now, tid = i(JKS) + k(S) + s;
typedef cub::BlockReduce<float, BLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
float local_max = 0.0f;
for (int ind = threadIdx.x; ind < dimsize; ind += BLOCK_SIZE) {
local_max = max(local_max, fabsf((float)input[tid + ind * stride]));
}
__shared__ float global_max;
float max_block = BlockReduce(temp_storage).Reduce(local_max, cub::Max());
if (threadIdx.x == 0) { // must set threadIdx.x = 0 write the output to memory
global_max = max_block;
}
__syncthreads();
float global_max_inv = __fdividef(1.0F, max(global_max, eps));
float p_partial = 0.0f;
for (int ind = threadIdx.x; ind < dimsize; ind += BLOCK_SIZE) {
p_partial += powf((float)input[tid + ind * stride] * global_max_inv, p);
}
__shared__ float p_total;
float p_block = BlockReduce(temp_storage).Reduce(p_partial, cub::Sum());
if (threadIdx.x == 0) { // must set threadIdx.x = 0 write the output to memory
p_total = powf(p_block, 1.0f / p);
}
__syncthreads();
float inv = __fdividef(1.0F, p_total + eps) * global_max_inv;
for (int ind = threadIdx.x; ind < dimsize; ind += BLOCK_SIZE) {
output[tid + ind * stride] = static_cast<T>(
static_cast<float>(
input[tid + ind * stride])
* inv);
}
}
template <typename T, unsigned int BLOCK_SIZE>
__device__ void blockLPNormStridesKernel(
T const *input, T *output, const ptrdiff_t *output_strides,
const ptrdiff_t *input_strides,
const size_t *shape, int ndim, float p, size_t dimsize,
float eps) {
// 只能处理axis=-1
int ind_i = 0; // input id
int ind_o = 0; // output id
int tid = blockIdx.x;
for (int j = ndim - 2; j >= 0; j--) {
ind_i += (tid % (int)shape[j]) * (int)input_strides[j];
ind_o += (tid % (int)shape[j]) * (int)output_strides[j];
tid = tid / (int)shape[j];
}
typedef cub::BlockReduce<float, BLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
float local_max = 0.0f;
for (int ind = threadIdx.x; ind < dimsize; ind += BLOCK_SIZE) {
local_max = max(local_max, fabsf((float)input[ind_i + ind]));
}
__shared__ float global_max;
float max_block = BlockReduce(temp_storage).Reduce(local_max, cub::Max());
if (threadIdx.x == 0) { // must set threadIdx.x = 0 write the output to memory
global_max = max_block;
}
__syncthreads();
float global_max_inv = __fdividef(1.0F, max(global_max, eps));
float p_partial = 0.0f;
for (int ind = threadIdx.x; ind < dimsize; ind += BLOCK_SIZE) {
p_partial += powf((float)input[ind_i + ind] * global_max_inv, p);
}
__shared__ float p_total;
float p_block = BlockReduce(temp_storage).Reduce(p_partial, cub::Sum());
if (threadIdx.x == 0) { // must set threadIdx.x = 0 write the output to memory
p_total = powf(p_block, 1.0f / p);
}
__syncthreads();
float inv = __fdividef(1.0F, p_total + eps) * global_max_inv;
for (int ind = threadIdx.x; ind < dimsize; ind += BLOCK_SIZE) {
output[ind_o + ind] = static_cast<T>(
static_cast<float>(
input[ind_i + ind])
* inv);
}
}
template <typename T>
struct SumOp {
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
return a + b;
}
};
template <typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
return max(a, b);
}
};
template <template <typename> class ReductionOp, typename T,
int thread_group_width>
__inline__ __device__ T WarpAllReduce(T val) {
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
val = ReductionOp<T>()(val, __shfl_xor_sync(0xffffffff, val, mask));
}
return val;
}
template <typename T, unsigned int BLOCK_SIZE_x, unsigned int BLOCK_SIZE_y>
__device__ void warpLPNormKernel(T const *input, T *output,
float p, size_t othersize, size_t dimsize,
ptrdiff_t stride, float eps) {
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;
if (otherIdx < othersize) {
__shared__ float p_total[BLOCK_SIZE_y];
__shared__ float p_max[BLOCK_SIZE_y];
float local_max = 0.0f;
for (int ind = threadIdx.x; ind < dimsize; ind += BLOCK_SIZE_x) {
local_max = max(local_max, fabsf((float)input[tid + ind * stride]));
}
local_max = WarpAllReduce<MaxOp, float, BLOCK_SIZE_x>(local_max);
if (threadIdx.x == 0) {
p_max[threadIdx.y] = local_max;
}
__syncthreads();
float global_max = max(p_max[threadIdx.y], eps);
float global_max_inv = __fdividef(1.0F, max(p_max[threadIdx.y], eps));
float p_data = 0.0f;
for (int ind = threadIdx.x; ind < dimsize; ind += BLOCK_SIZE_x) {
float v = fabsf((float)input[tid + ind * stride]) * global_max_inv;
p_data += powf(v, p);
}
p_data = WarpAllReduce<SumOp, float, BLOCK_SIZE_x>(p_data);
if (threadIdx.x == 0) {
p_total[threadIdx.y] = powf(p_data, 1.0f / p);
}
__syncthreads();
//--------------------------------------------
float inv = __fdividef(1.0F, p_total[threadIdx.y] + eps) * global_max_inv;
for (int ind = threadIdx.x; ind < dimsize; ind += BLOCK_SIZE_x) {
output[tid + ind * stride] = static_cast<T>(
(float)input[tid + ind * stride] * inv);
}
}
}
template <typename T, unsigned int BLOCK_SIZE_x, unsigned int BLOCK_SIZE_y>
__device__ void warpLPNormStridesKernel(T const *input, T *output, const ptrdiff_t *output_strides,
const ptrdiff_t *input_strides,
const size_t *shape, int ndim,
float p, size_t othersize, size_t dimsize,
float eps) {
int ind_i = 0; // input id
int ind_o = 0; // output id
int tid = blockIdx.x * blockDim.y + threadIdx.y;
if (tid < othersize) {
for (int j = ndim - 2; j >= 0; j--) {
ind_i += (tid % (int)shape[j]) * (int)input_strides[j];
ind_o += (tid % (int)shape[j]) * (int)output_strides[j];
tid = tid / (int)shape[j];
}
__shared__ float p_total[BLOCK_SIZE_y];
__shared__ float p_max[BLOCK_SIZE_y];
float local_max = 0.0f;
for (int ind = threadIdx.x; ind < dimsize; ind += BLOCK_SIZE_x) {
local_max = max(local_max, fabsf((float)input[ind_i + ind]));
}
local_max = WarpAllReduce<MaxOp, float, BLOCK_SIZE_x>(local_max);
if (threadIdx.x == 0) {
p_max[threadIdx.y] = local_max;
}
__syncthreads();
float global_max = max(p_max[threadIdx.y], eps);
float global_max_inv = __fdividef(1.0F, max(p_max[threadIdx.y], eps));
float p_data = 0.0f;
for (int ind = threadIdx.x; ind < dimsize; ind += BLOCK_SIZE_x) {
float v = fabsf((float)input[ind_i + ind]) * global_max_inv;
p_data += powf(v, p);
}
p_data = WarpAllReduce<SumOp, float, BLOCK_SIZE_x>(p_data);
if (threadIdx.x == 0) {
p_total[threadIdx.y] = powf(p_data, 1.0f / p);
}
__syncthreads();
//--------------------------------------------
float inv = __fdividef(1.0F, p_total[threadIdx.y] + eps) * global_max_inv;
for (int ind = threadIdx.x; ind < dimsize; ind += BLOCK_SIZE_x) {
output[ind_o + ind] = static_cast<T>(
(float)input[ind_i + ind] * inv);
}
}
}
#endif // __LP_NORM_KERNEL_CUH__
#ifndef __LP_NORM_INFO_H__
#define __LP_NORM_INFO_H__
#include "../../../utils.h"
#include "../../operator.h"
#include "../../tensor.h"
namespace op::lp_norm {
class LPNormInfo {
private:
LPNormInfo() = default;
public:
// ---------------------------- start: define member variables of Info ----------------------------
infiniDtype_t dtype;
size_t dimsize;
size_t othersize;
ptrdiff_t stride;
int axis;
int p;
float eps;
bool continuous;
size_t ndim;
std::vector<size_t> input_shape;
std::vector<ptrdiff_t> input_strides;
std::vector<ptrdiff_t> output_strides;
// ----------------------------- end: define member variables of Info -----------------------------
static utils::Result<LPNormInfo> createLPNormInfo(
infiniopTensorDescriptor_t output_desc,
infiniopTensorDescriptor_t input_desc,
int axis,
int p,
float eps) {
auto dtype = output_desc->dtype();
if (dtype != input_desc->dtype()) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
auto input_shape = input_desc->shape();
auto input_strides = input_desc->strides();
auto output_strides = output_desc->strides();
size_t ndim = input_desc->ndim();
if (input_strides[ndim - 1] != 1 || output_strides[ndim - 1] != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
if (axis < 0) {
axis += (int)(ndim);
}
size_t othersize = 1;
for (int i = 0; i < (int)ndim; i++) {
if (i != axis) {
othersize *= input_shape[i];
}
}
ptrdiff_t stride = 1;
for (int i = (int)ndim - 1; i > axis; i--) {
stride *= (ptrdiff_t)input_shape[i];
}
size_t dimsize = input_shape[axis];
bool continuous = true;
int coutinuous_stride = 1;
for (int i = (int)ndim - 1; i >= 0; i--) {
if (coutinuous_stride != (int)input_strides[i] || coutinuous_stride != (int)output_strides[i]) {
continuous = false;
break;
}
coutinuous_stride *= (ptrdiff_t)input_shape[i];
}
return utils::Result<LPNormInfo>(LPNormInfo{
dtype,
dimsize,
othersize,
stride,
axis,
p,
eps,
continuous,
ndim,
input_shape,
input_strides,
output_strides});
}
};
} // namespace op::lp_norm
#endif // __LP_NORM_INFO_H__
#ifndef LPNORM_H
#define LPNORM_H
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::lp_norm::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
LPNormInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
LPNormInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t x_desc, \
int axis, \
int p, \
float eps); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *y, \
const void *x, \
void *stream) const; \
}; \
}
#endif // LPNORM_H
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "lp_norm_nvidia.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include <cub/block/block_reduce.cuh>
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
template <typename Tdata, unsigned int BLOCK_SIZE>
INFINIOP_CUDA_KERNEL blockLPNorm(
Tdata *y, const Tdata *x,
float p,
size_t dimsize,
ptrdiff_t stride, float eps) {
blockLPNormKernel<Tdata, BLOCK_SIZE>(x, y, p, dimsize, stride, eps);
}
template <typename Tdata, unsigned int BLOCK_SIZE>
INFINIOP_CUDA_KERNEL blockLPNormStrides(
Tdata *y, const Tdata *x,
const ptrdiff_t *output_strides,
const ptrdiff_t *input_strides,
const size_t *shape, int ndim, float p, size_t dimsize,
float eps) {
blockLPNormStridesKernel<Tdata, BLOCK_SIZE>(x, y, output_strides, input_strides, shape, ndim, p, dimsize, eps);
}
template <typename Tdata, unsigned int BLOCK_SIZE_x, unsigned int BLOCK_SIZE_y>
INFINIOP_CUDA_KERNEL warpLPNorm(
Tdata *y, const Tdata *x,
float p,
size_t othersize,
size_t dimsize,
ptrdiff_t stride, float eps) {
warpLPNormKernel<Tdata, BLOCK_SIZE_x, BLOCK_SIZE_y>(x, y, p, othersize, dimsize, stride, eps);
}
template <typename Tdata, unsigned int BLOCK_SIZE_x, unsigned int BLOCK_SIZE_y>
INFINIOP_CUDA_KERNEL warpLPNormStrides(
Tdata *y, const Tdata *x,
const ptrdiff_t *output_strides,
const ptrdiff_t *input_strides,
const size_t *shape, int ndim,
float p, size_t othersize, size_t dimsize,
float eps) {
warpLPNormStridesKernel<Tdata, BLOCK_SIZE_x, BLOCK_SIZE_y>(x, y, output_strides, input_strides, shape, ndim, p, othersize, dimsize, eps);
}
namespace op::lp_norm::nvidia {
struct Descriptor::Opaque {
std::shared_ptr<device::nvidia::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
int axis,
int p,
float eps) {
auto info = LPNormInfo::createLPNormInfo(y_desc, x_desc, axis, p, eps);
CHECK_RESULT(info);
size_t workspace_size = y_desc->ndim() * (sizeof(ptrdiff_t) * 2 + sizeof(size_t));
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::nvidia::Handle *>(handle)->internal()},
info.take(), workspace_size, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <unsigned int BLOCK_SIZE, typename Tdata>
infiniStatus_t launchKernel(const LPNormInfo &info, Tdata *y, const Tdata *x,
cudaStream_t stream, void *workspace) {
size_t dimsize = info.dimsize;
size_t othersize = info.othersize;
float p_f = static_cast<float>(info.p);
float eps = info.eps;
int num_blocks = static_cast<float>(info.othersize);
ptrdiff_t stride = info.stride;
int ndim = (int)info.ndim;
char *workspace_ptr = reinterpret_cast<char *>(workspace);
ptrdiff_t *input_strides_cuda = reinterpret_cast<ptrdiff_t *>(workspace_ptr);
ptrdiff_t *output_strides_cuda = input_strides_cuda + ndim;
size_t ptrdiff_array_size = 2 * ndim * sizeof(ptrdiff_t);
size_t *shape_cuda = reinterpret_cast<size_t *>(workspace_ptr + ptrdiff_array_size);
CHECK_CUDA(cudaMemcpyAsync(input_strides_cuda, info.input_strides.data(), sizeof(ptrdiff_t) * ndim, cudaMemcpyHostToDevice, stream));
CHECK_CUDA(cudaMemcpyAsync(output_strides_cuda, info.output_strides.data(), sizeof(ptrdiff_t) * ndim, cudaMemcpyHostToDevice, stream));
CHECK_CUDA(cudaMemcpyAsync(shape_cuda, info.input_shape.data(), sizeof(size_t) * ndim, cudaMemcpyHostToDevice, stream));
if (info.continuous) {
if (dimsize > 1024) {
blockLPNorm<Tdata, BLOCK_SIZE>
<<<num_blocks, BLOCK_SIZE, 0, stream>>>(y, x,
p_f, dimsize, stride, eps);
} else {
constexpr unsigned int BLOCK_SIZE_x = 32;
constexpr unsigned int BLOCK_SIZE_y = 32;
int num_block_x = (num_blocks + BLOCK_SIZE_y - 1) / BLOCK_SIZE_y;
dim3 block_dim(BLOCK_SIZE_x, BLOCK_SIZE_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLPNorm<Tdata, BLOCK_SIZE_x, BLOCK_SIZE_y>
<<<grid_dim, block_dim, 0, stream>>>(y, x,
p_f, othersize, dimsize, stride, eps);
}
} else {
if (info.axis == ndim - 1) {
if (dimsize > 1024) {
blockLPNormStrides<Tdata, BLOCK_SIZE>
<<<num_blocks, BLOCK_SIZE, 0, stream>>>(y, x, output_strides_cuda, input_strides_cuda, shape_cuda, ndim,
p_f, dimsize, eps);
} else {
constexpr unsigned int BLOCK_SIZE_x = 32;
constexpr unsigned int BLOCK_SIZE_y = 32;
int num_block_x = (num_blocks + BLOCK_SIZE_y - 1) / BLOCK_SIZE_y;
dim3 block_dim(BLOCK_SIZE_x, BLOCK_SIZE_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLPNormStrides<Tdata, BLOCK_SIZE_x, BLOCK_SIZE_y>
<<<grid_dim, block_dim, 0, stream>>>(y, x, output_strides_cuda, input_strides_cuda, shape_cuda, ndim,
p_f, othersize, dimsize, eps);
}
} else {
return INFINI_STATUS_BAD_PARAM;
}
}
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
void *y,
const void *x,
void *stream_) const {
cudaStream_t stream = (cudaStream_t)stream_;
#define CALCULATE_LP_NORM(BLOCK_SIZE, TDATA) \
launchKernel<BLOCK_SIZE, TDATA>(_info, (TDATA *)y, (const TDATA *)x, stream, workspace)
#define CALCULATE_LP_NORM_WITH_BLOCK_SIZE(BLOCK_SIZE) \
{ \
if (_info.dtype == INFINI_DTYPE_F16) \
return CALCULATE_LP_NORM(BLOCK_SIZE, half); \
else if (_info.dtype == INFINI_DTYPE_F32) \
return CALCULATE_LP_NORM(BLOCK_SIZE, float); \
else if (_info.dtype == INFINI_DTYPE_BF16) \
return CALCULATE_LP_NORM(BLOCK_SIZE, __nv_bfloat16); \
else \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
CALCULATE_LP_NORM_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_1024)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) {
CALCULATE_LP_NORM_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_512)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) {
CALCULATE_LP_NORM_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_4096)
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::lp_norm::nvidia
#ifndef __LP_NORM_NVIDIA_API_H__
#define __LP_NORM_NVIDIA_API_H__
#include "../lp_norm.h"
DESCRIPTOR(nvidia)
#endif // __LP_NORM_NVIDIA_API_H__
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/lp_norm.h"
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#include "nvidia/lp_norm_nvidia.cuh"
#endif
__C infiniStatus_t infiniopCreateLPNormDescriptor(
infiniopHandle_t handle,
infiniopLPNormDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t output_desc,
infiniopTensorDescriptor_t input_desc,
int axis,
int p,
float eps) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::lp_norm::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::lp_norm::NAMESPACE::Descriptor **>(desc_ptr), \
output_desc, \
input_desc, \
axis, \
p, \
eps)
switch (handle->device) {
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CREATE
}
__C infiniStatus_t infiniopGetLPNormWorkspaceSize(infiniopLPNormDescriptor_t desc, size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::lp_norm::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef GET
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopLPNorm(
infiniopLPNormDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *output,
const void *input,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::lp_norm::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, \
workspace_size, \
output, \
input, \
stream)
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CALCULATE
}
__C infiniStatus_t
infiniopDestroyLPNormDescriptor(infiniopLPNormDescriptor_t desc) {
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<const op::lp_norm::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
DELETE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
DELETE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef DELETE
}
......@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/mul_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#include "nvidia/mul_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
......@@ -42,6 +42,9 @@ __C infiniStatus_t infiniopCreateMulDescriptor(
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif
......@@ -73,6 +76,9 @@ __C infiniStatus_t infiniopGetMulWorkspaceSize(infiniopMulDescriptor_t desc, siz
#ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax);
#endif
......@@ -113,6 +119,9 @@ __C infiniStatus_t infiniopMul(
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
......@@ -146,6 +155,9 @@ infiniopDestroyMulDescriptor(infiniopMulDescriptor_t desc) {
#ifdef ENABLE_ILUVATAR_API
DELETE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, metax);
#endif
......
......@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/ones_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#include "nvidia/ones_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
......@@ -40,6 +40,9 @@ __C infiniStatus_t infiniopCreateOnesDescriptor(
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif
......@@ -70,6 +73,9 @@ __C infiniStatus_t infiniopGetOnesWorkspaceSize(infiniopOnesDescriptor_t desc, s
#ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax);
#endif
......@@ -108,6 +114,9 @@ __C infiniStatus_t infiniopOnes(
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
......@@ -140,6 +149,9 @@ infiniopDestroyOnesDescriptor(infiniopOnesDescriptor_t desc) {
#ifdef ENABLE_ILUVATAR_API
DELETE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, metax);
#endif
......
......@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/random_sample_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_HYGON_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
#include "nvidia/random_sample_nvidia.cuh"
#endif
#ifdef ENABLE_CAMBRICON_API
......@@ -50,6 +50,9 @@ infiniopCreateRandomSampleDescriptor(
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
CREATE(INFINI_DEVICE_HYGON, nvidia);
#endif
......@@ -98,6 +101,9 @@ __C infiniStatus_t infiniopGetRandomSampleWorkspaceSize(
#ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
GET(INFINI_DEVICE_HYGON, nvidia);
#endif
......@@ -156,6 +162,9 @@ __C infiniStatus_t infiniopRandomSample(
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
CALCULATE(INFINI_DEVICE_HYGON, nvidia);
#endif
......@@ -201,6 +210,9 @@ __C infiniStatus_t infiniopDestroyRandomSampleDescriptor(
#ifdef ENABLE_ILUVATAR_API
DELETE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
DELETE(INFINI_DEVICE_HYGON, nvidia);
#endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment