Commit 5c0cb198 authored by YdrMaster's avatar YdrMaster
Browse files

issue/191/fix: fix review change request


Signed-off-by: default avatarYdrMaster <ydrml@hotmail.com>
parent dafe0ae5
#include "../../../../utils.h" #include "../../../devices/cuda/cuda_kernel_common.cuh"
#include "infinicore.h" #include "infinicore.h"
#include <cstddef>
#include <cub/device/device_radix_sort.cuh> #include <cub/device/device_radix_sort.cuh>
#include <cub/device/device_reduce.cuh> #include <cub/device/device_reduce.cuh>
#include <cub/device/device_scan.cuh> #include <cub/device/device_scan.cuh>
...@@ -12,7 +11,7 @@ namespace op::random_sample::cuda { ...@@ -12,7 +11,7 @@ namespace op::random_sample::cuda {
template <class T> template <class T>
static cudaError argMax_( static cudaError argMax_(
cub::KeyValuePair<int, T> *kv_pair, cub::KeyValuePair<int, T> *kv_pair,
T const *logits, const T *logits,
int n, int n,
void *workspace_ptr, void *workspace_ptr,
size_t &workspace_len, size_t &workspace_len,
...@@ -26,8 +25,8 @@ static cudaError argMax_( ...@@ -26,8 +25,8 @@ static cudaError argMax_(
template <class Tval, class Tidx> template <class Tval, class Tidx>
static cudaError radixSort( static cudaError radixSort(
void *workspace_ptr, size_t &workspace_len, void *workspace_ptr, size_t &workspace_len,
Tval const *key_in, Tval *key_out, const Tval *key_in, Tval *key_out,
Tidx const *val_in, Tidx *val_out, const Tidx *val_in, Tidx *val_out,
int n, int n,
cudaStream_t stream) { cudaStream_t stream) {
return cub::DeviceRadixSort::SortPairsDescending( return cub::DeviceRadixSort::SortPairsDescending(
...@@ -53,8 +52,6 @@ static cudaError inclusiveSum( ...@@ -53,8 +52,6 @@ static cudaError inclusiveSum(
// ↑↑↑ 重新封装 cub api,减少模板参数,方便调用 // ↑↑↑ 重新封装 cub api,减少模板参数,方便调用
// ↓↓↓ 计算 workspace // ↓↓↓ 计算 workspace
#define CHECK_CUB(API) CHECK_INTERNAL(API, cudaSuccess)
// 地址对齐到 256 // 地址对齐到 256
static constexpr size_t align256(size_t size) { static constexpr size_t align256(size_t size) {
return (size + 255) & (~255); return (size + 255) & (~255);
...@@ -62,10 +59,10 @@ static constexpr size_t align256(size_t size) { ...@@ -62,10 +59,10 @@ static constexpr size_t align256(size_t size) {
template <class Tidx, class Tval> template <class Tidx, class Tval>
utils::Result<size_t> calculateWorkspace(size_t n_) { utils::Result<size_t> calculateWorkspace(size_t n_) {
auto const n = static_cast<int>(n_); const auto n = static_cast<int>(n_);
size_t argmax; size_t argmax;
CHECK_CUB(argMax_<Tval>( CHECK_CUDA(argMax_<Tval>(
nullptr, nullptr, n, nullptr, nullptr, n,
nullptr, argmax, nullptr, argmax,
nullptr)); nullptr));
...@@ -80,7 +77,7 @@ utils::Result<size_t> calculateWorkspace(size_t n_) { ...@@ -80,7 +77,7 @@ utils::Result<size_t> calculateWorkspace(size_t n_) {
size_random += align256(sizeof(Tidx) * n); size_random += align256(sizeof(Tidx) * n);
// cub device api // cub device api
size_t size_radix_sort; size_t size_radix_sort;
CHECK_CUB((radixSort<Tval, Tidx>( CHECK_CUDA((radixSort<Tval, Tidx>(
nullptr, size_radix_sort, nullptr, size_radix_sort,
nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr, nullptr, nullptr,
...@@ -88,7 +85,7 @@ utils::Result<size_t> calculateWorkspace(size_t n_) { ...@@ -88,7 +85,7 @@ utils::Result<size_t> calculateWorkspace(size_t n_) {
nullptr))); nullptr)));
size_t size_inclusive_sum; size_t size_inclusive_sum;
CHECK_CUB(inclusiveSum<Tval>( CHECK_CUDA(inclusiveSum<Tval>(
nullptr, size_inclusive_sum, nullptr, size_inclusive_sum,
nullptr, n, nullptr, n,
nullptr)); nullptr));
...@@ -155,8 +152,8 @@ static __global__ void setSoftmaxMaxKernel( ...@@ -155,8 +152,8 @@ static __global__ void setSoftmaxMaxKernel(
template <class Tval, class Tidx> template <class Tval, class Tidx>
static __global__ void randomSampleKernel( static __global__ void randomSampleKernel(
Tidx *__restrict__ result, Tidx *__restrict__ result,
Tval const *__restrict__ sorted, const Tval *__restrict__ sorted,
Tidx const *__restrict__ indices_out, const Tidx *__restrict__ indices_out,
size_t n, size_t n,
float random, float topp, size_t topk) { float random, float topp, size_t topk) {
topk = cub::Min()(topk, n); topk = cub::Min()(topk, n);
...@@ -177,7 +174,7 @@ struct Algo { ...@@ -177,7 +174,7 @@ struct Algo {
template <class Tidx, class Tval_> template <class Tidx, class Tval_>
infiniStatus_t argmax( infiniStatus_t argmax(
void *workspace, size_t workspace_size, void *workspace, size_t workspace_size,
void *result, void const *probs, size_t n, void *result, const void *probs, size_t n,
void *stream_) const { void *stream_) const {
using Tval = typename CudaTval<Tval_>::Type; using Tval = typename CudaTval<Tval_>::Type;
...@@ -202,7 +199,7 @@ struct Algo { ...@@ -202,7 +199,7 @@ struct Algo {
template <class Tidx, class Tval_> template <class Tidx, class Tval_>
infiniStatus_t random( infiniStatus_t random(
void *workspace_, size_t workspace_size, void *workspace_, size_t workspace_size,
void *result_, void const *probs, size_t n, void *result_, const void *probs, size_t n,
float random_val, float topp, int topk, float temperature, float random_val, float topp, int topk, float temperature,
void *stream_) const { void *stream_) const {
...@@ -231,7 +228,7 @@ struct Algo { ...@@ -231,7 +228,7 @@ struct Algo {
auto grid = (n + block - 1) / block; auto grid = (n + block - 1) / block;
// sort // sort
fillIndices<<<grid, block, 0, stream>>>(indices, n); fillIndices<<<grid, block, 0, stream>>>(indices, n);
CHECK_CUB(radixSort( CHECK_CUDA(radixSort(
workspace_, workspace_size, workspace_, workspace_size,
logits, sorted, logits, sorted,
indices, indices_out, indices, indices_out,
...@@ -241,7 +238,7 @@ struct Algo { ...@@ -241,7 +238,7 @@ struct Algo {
partialSoftmaxKernel<<<grid, block, 0, stream>>>(sorted, n, temperature); partialSoftmaxKernel<<<grid, block, 0, stream>>>(sorted, n, temperature);
setSoftmaxMaxKernel<<<1, 1, 0, stream>>>(sorted); setSoftmaxMaxKernel<<<1, 1, 0, stream>>>(sorted);
// sum // sum
CHECK_CUB(inclusiveSum( CHECK_CUDA(inclusiveSum(
workspace_, workspace, workspace_, workspace,
sorted, n, sorted, n,
stream)); stream));
......
...@@ -17,17 +17,12 @@ struct RandomSampleInfo { ...@@ -17,17 +17,12 @@ struct RandomSampleInfo {
auto dt_i = result_desc->dtype(); auto dt_i = result_desc->dtype();
auto dt_p = probs_desc->dtype(); auto dt_p = probs_desc->dtype();
CHECK_DTYPE(dt_i, CHECK_DTYPE_ANY_INT(dt_i);
INFINI_DTYPE_U8, INFINI_DTYPE_U16, INFINI_DTYPE_U32, INFINI_DTYPE_U64,
INFINI_DTYPE_I8, INFINI_DTYPE_I16, INFINI_DTYPE_I32, INFINI_DTYPE_I64);
CHECK_DTYPE(dt_p, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); CHECK_DTYPE(dt_p, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_API_OR(result_desc->ndim(), 0, CHECK_OR_RETURN(result_desc->ndim() == 0, INFINI_STATUS_BAD_TENSOR_SHAPE);
return INFINI_STATUS_BAD_TENSOR_SHAPE); CHECK_OR_RETURN(probs_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE);
CHECK_API_OR(probs_desc->ndim(), 1, CHECK_OR_RETURN(probs_desc->stride(0) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
return INFINI_STATUS_BAD_TENSOR_SHAPE);
CHECK_API_OR(probs_desc->stride(0), 1,
return INFINI_STATUS_BAD_TENSOR_STRIDES);
return utils::Result<RandomSampleInfo>({dt_i, dt_p, probs_desc->dim(0)}); return utils::Result<RandomSampleInfo>({dt_i, dt_p, probs_desc->dim(0)});
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment