Commit 5c0cb198 authored by YdrMaster's avatar YdrMaster
Browse files

issue/191/fix: fix review change request


Signed-off-by: default avatarYdrMaster <ydrml@hotmail.com>
parent dafe0ae5
#include "../../../../utils.h"
#include "../../../devices/cuda/cuda_kernel_common.cuh"
#include "infinicore.h"
#include <cstddef>
#include <cub/device/device_radix_sort.cuh>
#include <cub/device/device_reduce.cuh>
#include <cub/device/device_scan.cuh>
......@@ -12,7 +11,7 @@ namespace op::random_sample::cuda {
template <class T>
static cudaError argMax_(
cub::KeyValuePair<int, T> *kv_pair,
T const *logits,
const T *logits,
int n,
void *workspace_ptr,
size_t &workspace_len,
......@@ -26,8 +25,8 @@ static cudaError argMax_(
template <class Tval, class Tidx>
static cudaError radixSort(
void *workspace_ptr, size_t &workspace_len,
Tval const *key_in, Tval *key_out,
Tidx const *val_in, Tidx *val_out,
const Tval *key_in, Tval *key_out,
const Tidx *val_in, Tidx *val_out,
int n,
cudaStream_t stream) {
return cub::DeviceRadixSort::SortPairsDescending(
......@@ -53,8 +52,6 @@ static cudaError inclusiveSum(
// ↑↑↑ 重新封装 cub api,减少模板参数,方便调用
// ↓↓↓ 计算 workspace
#define CHECK_CUB(API) CHECK_INTERNAL(API, cudaSuccess)
// 地址对齐到 256
static constexpr size_t align256(size_t size) {
return (size + 255) & (~255);
......@@ -62,10 +59,10 @@ static constexpr size_t align256(size_t size) {
template <class Tidx, class Tval>
utils::Result<size_t> calculateWorkspace(size_t n_) {
auto const n = static_cast<int>(n_);
const auto n = static_cast<int>(n_);
size_t argmax;
CHECK_CUB(argMax_<Tval>(
CHECK_CUDA(argMax_<Tval>(
nullptr, nullptr, n,
nullptr, argmax,
nullptr));
......@@ -80,7 +77,7 @@ utils::Result<size_t> calculateWorkspace(size_t n_) {
size_random += align256(sizeof(Tidx) * n);
// cub device api
size_t size_radix_sort;
CHECK_CUB((radixSort<Tval, Tidx>(
CHECK_CUDA((radixSort<Tval, Tidx>(
nullptr, size_radix_sort,
nullptr, nullptr,
nullptr, nullptr,
......@@ -88,7 +85,7 @@ utils::Result<size_t> calculateWorkspace(size_t n_) {
nullptr)));
size_t size_inclusive_sum;
CHECK_CUB(inclusiveSum<Tval>(
CHECK_CUDA(inclusiveSum<Tval>(
nullptr, size_inclusive_sum,
nullptr, n,
nullptr));
......@@ -155,8 +152,8 @@ static __global__ void setSoftmaxMaxKernel(
template <class Tval, class Tidx>
static __global__ void randomSampleKernel(
Tidx *__restrict__ result,
Tval const *__restrict__ sorted,
Tidx const *__restrict__ indices_out,
const Tval *__restrict__ sorted,
const Tidx *__restrict__ indices_out,
size_t n,
float random, float topp, size_t topk) {
topk = cub::Min()(topk, n);
......@@ -177,7 +174,7 @@ struct Algo {
template <class Tidx, class Tval_>
infiniStatus_t argmax(
void *workspace, size_t workspace_size,
void *result, void const *probs, size_t n,
void *result, const void *probs, size_t n,
void *stream_) const {
using Tval = typename CudaTval<Tval_>::Type;
......@@ -202,7 +199,7 @@ struct Algo {
template <class Tidx, class Tval_>
infiniStatus_t random(
void *workspace_, size_t workspace_size,
void *result_, void const *probs, size_t n,
void *result_, const void *probs, size_t n,
float random_val, float topp, int topk, float temperature,
void *stream_) const {
......@@ -231,7 +228,7 @@ struct Algo {
auto grid = (n + block - 1) / block;
// sort
fillIndices<<<grid, block, 0, stream>>>(indices, n);
CHECK_CUB(radixSort(
CHECK_CUDA(radixSort(
workspace_, workspace_size,
logits, sorted,
indices, indices_out,
......@@ -241,7 +238,7 @@ struct Algo {
partialSoftmaxKernel<<<grid, block, 0, stream>>>(sorted, n, temperature);
setSoftmaxMaxKernel<<<1, 1, 0, stream>>>(sorted);
// sum
CHECK_CUB(inclusiveSum(
CHECK_CUDA(inclusiveSum(
workspace_, workspace,
sorted, n,
stream));
......
......@@ -17,17 +17,12 @@ struct RandomSampleInfo {
auto dt_i = result_desc->dtype();
auto dt_p = probs_desc->dtype();
CHECK_DTYPE(dt_i,
INFINI_DTYPE_U8, INFINI_DTYPE_U16, INFINI_DTYPE_U32, INFINI_DTYPE_U64,
INFINI_DTYPE_I8, INFINI_DTYPE_I16, INFINI_DTYPE_I32, INFINI_DTYPE_I64);
CHECK_DTYPE_ANY_INT(dt_i);
CHECK_DTYPE(dt_p, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_API_OR(result_desc->ndim(), 0,
return INFINI_STATUS_BAD_TENSOR_SHAPE);
CHECK_API_OR(probs_desc->ndim(), 1,
return INFINI_STATUS_BAD_TENSOR_SHAPE);
CHECK_API_OR(probs_desc->stride(0), 1,
return INFINI_STATUS_BAD_TENSOR_STRIDES);
CHECK_OR_RETURN(result_desc->ndim() == 0, INFINI_STATUS_BAD_TENSOR_SHAPE);
CHECK_OR_RETURN(probs_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE);
CHECK_OR_RETURN(probs_desc->stride(0) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
return utils::Result<RandomSampleInfo>({dt_i, dt_p, probs_desc->dim(0)});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment