Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
5c0cb198
Commit
5c0cb198
authored
May 12, 2025
by
YdrMaster
Browse files
issue/191/fix: fix review change request
Signed-off-by:
YdrMaster
<
ydrml@hotmail.com
>
parent
dafe0ae5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
26 deletions
+18
-26
src/infiniop/ops/random_sample/cuda/random_sample_kernel.cuh
src/infiniop/ops/random_sample/cuda/random_sample_kernel.cuh
+14
-17
src/infiniop/ops/random_sample/info.h
src/infiniop/ops/random_sample/info.h
+4
-9
No files found.
src/infiniop/ops/random_sample/cuda/random_sample_kernel.cuh
View file @
5c0cb198
#
include
"../../../
../utils.
h"
#
include
"../../../
devices/cuda/cuda_kernel_common.cu
h"
#include "infinicore.h"
#include <cstddef>
#include <cub/device/device_radix_sort.cuh>
#include <cub/device/device_reduce.cuh>
#include <cub/device/device_scan.cuh>
...
...
@@ -12,7 +11,7 @@ namespace op::random_sample::cuda {
template
<
class
T
>
static
cudaError
argMax_
(
cub
::
KeyValuePair
<
int
,
T
>
*
kv_pair
,
T
const
*
logits
,
const
T
*
logits
,
int
n
,
void
*
workspace_ptr
,
size_t
&
workspace_len
,
...
...
@@ -26,8 +25,8 @@ static cudaError argMax_(
template
<
class
Tval
,
class
Tidx
>
static
cudaError
radixSort
(
void
*
workspace_ptr
,
size_t
&
workspace_len
,
Tval
const
*
key_in
,
Tval
*
key_out
,
Tidx
const
*
val_in
,
Tidx
*
val_out
,
const
Tval
*
key_in
,
Tval
*
key_out
,
const
Tidx
*
val_in
,
Tidx
*
val_out
,
int
n
,
cudaStream_t
stream
)
{
return
cub
::
DeviceRadixSort
::
SortPairsDescending
(
...
...
@@ -53,8 +52,6 @@ static cudaError inclusiveSum(
// ↑↑↑ 重新封装 cub api,减少模板参数,方便调用
// ↓↓↓ 计算 workspace
#define CHECK_CUB(API) CHECK_INTERNAL(API, cudaSuccess)
// 地址对齐到 256
static
constexpr
size_t
align256
(
size_t
size
)
{
return
(
size
+
255
)
&
(
~
255
);
...
...
@@ -62,10 +59,10 @@ static constexpr size_t align256(size_t size) {
template
<
class
Tidx
,
class
Tval
>
utils
::
Result
<
size_t
>
calculateWorkspace
(
size_t
n_
)
{
auto
const
n
=
static_cast
<
int
>
(
n_
);
const
auto
n
=
static_cast
<
int
>
(
n_
);
size_t
argmax
;
CHECK_CU
B
(
argMax_
<
Tval
>
(
CHECK_CU
DA
(
argMax_
<
Tval
>
(
nullptr
,
nullptr
,
n
,
nullptr
,
argmax
,
nullptr
));
...
...
@@ -80,7 +77,7 @@ utils::Result<size_t> calculateWorkspace(size_t n_) {
size_random
+=
align256
(
sizeof
(
Tidx
)
*
n
);
// cub device api
size_t
size_radix_sort
;
CHECK_CU
B
((
radixSort
<
Tval
,
Tidx
>
(
CHECK_CU
DA
((
radixSort
<
Tval
,
Tidx
>
(
nullptr
,
size_radix_sort
,
nullptr
,
nullptr
,
nullptr
,
nullptr
,
...
...
@@ -88,7 +85,7 @@ utils::Result<size_t> calculateWorkspace(size_t n_) {
nullptr
)));
size_t
size_inclusive_sum
;
CHECK_CU
B
(
inclusiveSum
<
Tval
>
(
CHECK_CU
DA
(
inclusiveSum
<
Tval
>
(
nullptr
,
size_inclusive_sum
,
nullptr
,
n
,
nullptr
));
...
...
@@ -155,8 +152,8 @@ static __global__ void setSoftmaxMaxKernel(
template
<
class
Tval
,
class
Tidx
>
static
__global__
void
randomSampleKernel
(
Tidx
*
__restrict__
result
,
Tval
const
*
__restrict__
sorted
,
Tidx
const
*
__restrict__
indices_out
,
const
Tval
*
__restrict__
sorted
,
const
Tidx
*
__restrict__
indices_out
,
size_t
n
,
float
random
,
float
topp
,
size_t
topk
)
{
topk
=
cub
::
Min
()(
topk
,
n
);
...
...
@@ -177,7 +174,7 @@ struct Algo {
template
<
class
Tidx
,
class
Tval_
>
infiniStatus_t
argmax
(
void
*
workspace
,
size_t
workspace_size
,
void
*
result
,
void
const
*
probs
,
size_t
n
,
void
*
result
,
const
void
*
probs
,
size_t
n
,
void
*
stream_
)
const
{
using
Tval
=
typename
CudaTval
<
Tval_
>::
Type
;
...
...
@@ -202,7 +199,7 @@ struct Algo {
template
<
class
Tidx
,
class
Tval_
>
infiniStatus_t
random
(
void
*
workspace_
,
size_t
workspace_size
,
void
*
result_
,
void
const
*
probs
,
size_t
n
,
void
*
result_
,
const
void
*
probs
,
size_t
n
,
float
random_val
,
float
topp
,
int
topk
,
float
temperature
,
void
*
stream_
)
const
{
...
...
@@ -231,7 +228,7 @@ struct Algo {
auto
grid
=
(
n
+
block
-
1
)
/
block
;
// sort
fillIndices
<<<
grid
,
block
,
0
,
stream
>>>
(
indices
,
n
);
CHECK_CU
B
(
radixSort
(
CHECK_CU
DA
(
radixSort
(
workspace_
,
workspace_size
,
logits
,
sorted
,
indices
,
indices_out
,
...
...
@@ -241,7 +238,7 @@ struct Algo {
partialSoftmaxKernel
<<<
grid
,
block
,
0
,
stream
>>>
(
sorted
,
n
,
temperature
);
setSoftmaxMaxKernel
<<<
1
,
1
,
0
,
stream
>>>
(
sorted
);
// sum
CHECK_CU
B
(
inclusiveSum
(
CHECK_CU
DA
(
inclusiveSum
(
workspace_
,
workspace
,
sorted
,
n
,
stream
));
...
...
src/infiniop/ops/random_sample/info.h
View file @
5c0cb198
...
...
@@ -17,17 +17,12 @@ struct RandomSampleInfo {
auto
dt_i
=
result_desc
->
dtype
();
auto
dt_p
=
probs_desc
->
dtype
();
CHECK_DTYPE
(
dt_i
,
INFINI_DTYPE_U8
,
INFINI_DTYPE_U16
,
INFINI_DTYPE_U32
,
INFINI_DTYPE_U64
,
INFINI_DTYPE_I8
,
INFINI_DTYPE_I16
,
INFINI_DTYPE_I32
,
INFINI_DTYPE_I64
);
CHECK_DTYPE_ANY_INT
(
dt_i
);
CHECK_DTYPE
(
dt_p
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_F32
,
INFINI_DTYPE_F64
);
CHECK_API_OR
(
result_desc
->
ndim
(),
0
,
return
INFINI_STATUS_BAD_TENSOR_SHAPE
);
CHECK_API_OR
(
probs_desc
->
ndim
(),
1
,
return
INFINI_STATUS_BAD_TENSOR_SHAPE
);
CHECK_API_OR
(
probs_desc
->
stride
(
0
),
1
,
return
INFINI_STATUS_BAD_TENSOR_STRIDES
);
CHECK_OR_RETURN
(
result_desc
->
ndim
()
==
0
,
INFINI_STATUS_BAD_TENSOR_SHAPE
);
CHECK_OR_RETURN
(
probs_desc
->
ndim
()
==
1
,
INFINI_STATUS_BAD_TENSOR_SHAPE
);
CHECK_OR_RETURN
(
probs_desc
->
stride
(
0
)
==
1
,
INFINI_STATUS_BAD_TENSOR_STRIDES
);
return
utils
::
Result
<
RandomSampleInfo
>
({
dt_i
,
dt_p
,
probs_desc
->
dim
(
0
)});
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment