Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
77070490
Commit
77070490
authored
May 25, 2025
by
crapromer
Browse files
issue/36 - Migrate cuda ramdom sample to metax, but compile and run too slow
parent
5a4e7a73
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
379 additions
and
0 deletions
+379
-0
src/infiniop/ops/random_sample/maca/random_sample_kernel.h
src/infiniop/ops/random_sample/maca/random_sample_kernel.h
+254
-0
src/infiniop/ops/random_sample/maca/random_sample_maca.h
src/infiniop/ops/random_sample/maca/random_sample_maca.h
+8
-0
src/infiniop/ops/random_sample/maca/random_sample_maca.maca
src/infiniop/ops/random_sample/maca/random_sample_maca.maca
+102
-0
src/infiniop/ops/random_sample/operator.cc
src/infiniop/ops/random_sample/operator.cc
+15
-0
No files found.
src/infiniop/ops/random_sample/maca/random_sample_kernel.h
0 → 100644
View file @
77070490
#include "../../../devices/maca/maca_kernel_common.h"
#include "infinicore.h"
#include <hccub/device/device_radix_sort.cuh>
#include <hccub/device/device_reduce.cuh>
#include <hccub/device/device_scan.cuh>
namespace
op
::
random_sample
::
maca
{
// ↓↓↓ 重新封装 cub api,减少模板参数,方便调用
template
<
class
T
>
static
hcError_t
argMax_
(
cub
::
KeyValuePair
<
int
,
T
>
*
kv_pair
,
const
T
*
logits
,
int
n
,
void
*
workspace_ptr
,
size_t
&
workspace_len
,
hcStream_t
stream
)
{
return
cub
::
DeviceReduce
::
ArgMax
(
workspace_ptr
,
workspace_len
,
logits
,
kv_pair
,
n
,
stream
);
}
template
<
class
Tval
,
class
Tidx
>
static
hcError_t
radixSort
(
void
*
workspace_ptr
,
size_t
&
workspace_len
,
const
Tval
*
key_in
,
Tval
*
key_out
,
const
Tidx
*
val_in
,
Tidx
*
val_out
,
int
n
,
hcStream_t
stream
)
{
return
cub
::
DeviceRadixSort
::
SortPairsDescending
(
workspace_ptr
,
workspace_len
,
key_in
,
key_out
,
val_in
,
val_out
,
n
,
0
,
sizeof
(
Tval
)
*
8
,
stream
);
}
template
<
class
T
>
static
hcError_t
inclusiveSum
(
void
*
workspace_ptr
,
size_t
&
workspace_len
,
T
*
data
,
int
n
,
hcStream_t
stream
)
{
return
cub
::
DeviceScan
::
InclusiveSum
(
workspace_ptr
,
workspace_len
,
data
,
data
,
n
,
stream
);
}
// ↑↑↑ 重新封装 cub api,减少模板参数,方便调用
// ↓↓↓ 计算 workspace
// 地址对齐到 256
static
constexpr
size_t
align256
(
size_t
size
)
{
return
(
size
+
255
)
&
(
~
255
);
}
template
<
class
Tidx
,
class
Tval
>
utils
::
Result
<
size_t
>
calculateWorkspace
(
size_t
n_
)
{
const
auto
n
=
static_cast
<
int
>
(
n_
);
size_t
argmax
;
CHECK_MACA
(
argMax_
<
Tval
>
(
nullptr
,
nullptr
,
n
,
nullptr
,
argmax
,
nullptr
));
// 前 256 字节用于 kv pair
argmax
+=
256
;
// indices
size_t
size_random
=
align256
(
sizeof
(
Tidx
)
*
n
);
// sorted
size_random
+=
align256
(
sizeof
(
Tval
)
*
n
);
// indices_out
size_random
+=
align256
(
sizeof
(
Tidx
)
*
n
);
// cub device api
size_t
size_radix_sort
;
CHECK_MACA
((
radixSort
<
Tval
,
Tidx
>
(
nullptr
,
size_radix_sort
,
nullptr
,
nullptr
,
nullptr
,
nullptr
,
n
,
nullptr
)));
size_t
size_inclusive_sum
;
CHECK_MACA
(
inclusiveSum
<
Tval
>
(
nullptr
,
size_inclusive_sum
,
nullptr
,
n
,
nullptr
));
size_random
+=
cub
::
Max
()(
size_radix_sort
,
size_inclusive_sum
);
return
utils
::
Result
<
size_t
>
(
cub
::
Max
()(
argmax
,
size_random
));
}
// ↑↑↑ 计算 workspace
// ↓↓↓ 通过特化将 fp16_t 转换为 half
template
<
class
Tval
>
struct
CudaTval
{
using
Type
=
Tval
;
};
template
<
>
struct
CudaTval
<
fp16_t
>
{
using
Type
=
half
;
};
// ↑↑↑ 通过特化将 fp16_t 转换为 half
// ↓↓↓ 用于采样过程的小型 kernel
// maca toolkit 11.x 带的 cub::DeviceReduce::ArgMax 只接受 cub::KeyValuePair<int, Tval> 输出。
// 这个 kernel 用于取出序号
template
<
class
Tidx
,
class
Tval
>
static
__global__
void
castIdx
(
Tidx
*
result
,
const
cub
::
KeyValuePair
<
int
,
Tval
>
*
kv_pair
)
{
*
result
=
kv_pair
->
key
;
}
// 填充排序要求的序号数组
template
<
class
Tidx
>
static
__global__
void
fillIndices
(
Tidx
*
indices
,
int
n
)
{
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
<
n
)
{
indices
[
i
]
=
i
;
}
}
// random sample 使用的 softmax 可以简化为一个基本的线性映射
// 由于已经排序,最大值就是第一个数字
// 第一个数字需要被多个 block 读取,不能写
template
<
class
T
>
static
__global__
void
partialSoftmaxKernel
(
T
*
__restrict__
data
,
int
n
,
float
temperature
)
{
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
0
<
i
&&
i
<
n
)
{
float
max
=
__ldg
(
data
);
data
[
i
]
=
(
T
)
expf
(((
float
)
data
[
i
]
-
max
)
/
temperature
);
}
}
// 将第一个数字写成 1,即 exp(0)
template
<
class
T
>
static
__global__
void
setSoftmaxMaxKernel
(
T
*
__restrict__
data
)
{
*
data
=
1
;
}
// 直接 for 循环遍历采样
// 这个 kernel 仅用于避免将数据拷贝到 cpu
template
<
class
Tval
,
class
Tidx
>
static
__global__
void
randomSampleKernel
(
Tidx
*
__restrict__
result
,
const
Tval
*
__restrict__
sorted
,
const
Tidx
*
__restrict__
indices_out
,
size_t
n
,
float
random
,
float
topp
,
size_t
topk
)
{
topk
=
cub
::
Min
()(
topk
,
n
);
auto
p
=
(
Tval
)(
random
*
cub
::
Min
()(
topp
*
(
float
)
sorted
[
n
-
1
],
(
float
)
sorted
[
topk
-
1
]));
for
(
size_t
i
=
0
;;
++
i
)
{
if
((
sorted
[
i
])
>=
p
)
{
*
result
=
indices_out
[
i
];
return
;
}
}
}
// ↑↑↑ 用于采样过程的小型 kernel
struct
Algo
{
int
block_size
;
template
<
class
Tidx
,
class
Tval_
>
infiniStatus_t
argmax
(
void
*
workspace
,
size_t
workspace_size
,
void
*
result
,
const
void
*
probs
,
size_t
n
,
void
*
stream_
)
const
{
using
Tval
=
typename
CudaTval
<
Tval_
>::
Type
;
auto
stream
=
(
hcStream_t
)
stream_
;
auto
logits
=
(
Tval
*
)
probs
;
auto
kv_pair
=
(
cub
::
KeyValuePair
<
int
,
Tval
>
*
)
workspace
;
workspace
=
(
void
*
)((
char
*
)
workspace
+
256
);
workspace_size
-=
256
;
argMax_
(
kv_pair
,
logits
,
n
,
workspace
,
workspace_size
,
stream
);
castIdx
<<<
1
,
1
,
0
,
stream
>>>
((
Tidx
*
)
result
,
kv_pair
);
return
INFINI_STATUS_SUCCESS
;
}
template
<
class
Tidx
,
class
Tval_
>
infiniStatus_t
random
(
void
*
workspace_
,
size_t
workspace_size
,
void
*
result_
,
const
void
*
probs
,
size_t
n
,
float
random_val
,
float
topp
,
int
topk
,
float
temperature
,
void
*
stream_
)
const
{
using
Tval
=
typename
CudaTval
<
Tval_
>::
Type
;
auto
stream
=
(
hcStream_t
)
stream_
;
auto
logits
=
(
Tval
*
)
probs
;
auto
result
=
(
Tidx
*
)
result_
;
auto
workspace
=
reinterpret_cast
<
size_t
>
(
workspace_
);
auto
workspace_end
=
workspace
+
workspace_size
;
auto
indices
=
reinterpret_cast
<
Tidx
*>
(
workspace
);
workspace
+=
align256
(
sizeof
(
Tidx
)
*
n
);
auto
sorted
=
reinterpret_cast
<
Tval
*>
(
workspace
);
workspace
+=
align256
(
sizeof
(
Tval
)
*
n
);
auto
indices_out
=
reinterpret_cast
<
Tidx
*>
(
workspace
);
workspace
+=
align256
(
sizeof
(
Tidx
)
*
n
);
workspace_
=
reinterpret_cast
<
void
*>
(
workspace
);
workspace_size
=
workspace_end
-
workspace
;
auto
block
=
cub
::
Min
()((
size_t
)
block_size
,
n
);
auto
grid
=
(
n
+
block
-
1
)
/
block
;
// sort
fillIndices
<<<
grid
,
block
,
0
,
stream
>>>
(
indices
,
n
);
CHECK_MACA
(
radixSort
(
workspace_
,
workspace_size
,
logits
,
sorted
,
indices
,
indices_out
,
n
,
stream
));
// softmax
partialSoftmaxKernel
<<<
grid
,
block
,
0
,
stream
>>>
(
sorted
,
n
,
temperature
);
setSoftmaxMaxKernel
<<<
1
,
1
,
0
,
stream
>>>
(
sorted
);
// sum
CHECK_MACA
(
inclusiveSum
(
workspace_
,
workspace
,
sorted
,
n
,
stream
));
// sample
randomSampleKernel
<<<
1
,
1
,
0
,
stream
>>>
(
result
,
sorted
,
indices_out
,
n
,
random_val
,
topp
,
topk
);
return
INFINI_STATUS_SUCCESS
;
}
};
}
// namespace op::random_sample::maca
src/infiniop/ops/random_sample/maca/random_sample_maca.h
0 → 100644
View file @
77070490
#ifndef __RANDOM_SAMPLE_MACA_H__
#define __RANDOM_SAMPLE_MACA_H__
#include "../random_sample.h"
DESCRIPTOR
(
maca
)
#endif // __RANDOM_SAMPLE_MACA_H__
src/infiniop/ops/random_sample/maca/random_sample_maca.maca
0 → 100644
View file @
77070490
#include "../../../devices/maca/common_maca.h"
#include "../../../devices/maca/maca_handle.h"
#include "../info.h"
#include "random_sample_kernel.h"
#include "random_sample_maca.h"
namespace op::random_sample::maca {
struct Descriptor::Opaque {
std::shared_ptr<device::maca::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t result_desc,
infiniopTensorDescriptor_t probs_desc) {
auto handle = reinterpret_cast<device::maca::Handle *>(handle_);
auto result = RandomSampleInfo::create(result_desc, probs_desc);
CHECK_RESULT(result);
auto info = result.take();
size_t workspace_size;
#define CASE_P(CASE, Tidx, Tval) \
case CASE: { \
auto workspace_result = calculateWorkspace<Tidx, Tval>(info.n); \
CHECK_RESULT(workspace_result); \
workspace_size = workspace_result.take(); \
} break
#define CASE_I(CASE, Tidx) \
case CASE: \
switch (info.dt_p) { \
CASE_P(INFINI_DTYPE_F16, Tidx, half); \
CASE_P(INFINI_DTYPE_F32, Tidx, float); \
CASE_P(INFINI_DTYPE_F64, Tidx, double); \
default: \
abort(); \
} \
break
switch (info.dt_i) {
CASE_I(INFINI_DTYPE_I8, int8_t);
CASE_I(INFINI_DTYPE_I16, int16_t);
CASE_I(INFINI_DTYPE_I32, int32_t);
CASE_I(INFINI_DTYPE_I64, int64_t);
CASE_I(INFINI_DTYPE_U8, uint8_t);
CASE_I(INFINI_DTYPE_U16, uint16_t);
CASE_I(INFINI_DTYPE_U32, uint32_t);
CASE_I(INFINI_DTYPE_U64, uint64_t);
default:
abort();
}
#undef CASE_I
#undef CASE_P
*desc_ptr = new Descriptor(
info,
workspace_size,
new Opaque{handle->internal()},
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
size_t Descriptor::minWorkspaceSize() const {
return _min_workspace_size;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *result,
const void *probs,
float random_val,
float topp,
int topk,
float temperature,
void *stream) const {
if (workspace_size < _min_workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
auto block_size = _opaque->internal->blockSizeX();
Calculate::calculate<Algo>(
Algo{block_size}, _info, workspace, workspace_size,
result, probs,
random_val, topp, topk, temperature,
stream);
return INFINI_STATUS_SUCCESS;
}
} // namespace op::random_sample::maca
src/infiniop/ops/random_sample/operator.cc
View file @
77070490
...
@@ -8,6 +8,9 @@
...
@@ -8,6 +8,9 @@
#ifdef ENABLE_CUDA_API
#ifdef ENABLE_CUDA_API
#include "cuda/random_sample_cuda.cuh"
#include "cuda/random_sample_cuda.cuh"
#endif
#endif
#ifdef ENABLE_METAX_API
#include "maca/random_sample_maca.h"
#endif
__C
infiniStatus_t
infiniopCreateRandomSampleDescriptor
(
__C
infiniStatus_t
infiniopCreateRandomSampleDescriptor
(
infiniopHandle_t
handle
,
infiniopHandle_t
handle
,
...
@@ -31,6 +34,9 @@ __C infiniStatus_t infiniopCreateRandomSampleDescriptor(
...
@@ -31,6 +34,9 @@ __C infiniStatus_t infiniopCreateRandomSampleDescriptor(
#ifdef ENABLE_CUDA_API
#ifdef ENABLE_CUDA_API
CREATE
(
INFINI_DEVICE_NVIDIA
,
cuda
);
CREATE
(
INFINI_DEVICE_NVIDIA
,
cuda
);
#endif
#endif
#ifdef ENABLE_METAX_API
CREATE
(
INFINI_DEVICE_METAX
,
maca
);
#endif
default:
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
@@ -58,6 +64,9 @@ __C infiniStatus_t infiniopGetRandomSampleWorkspaceSize(
...
@@ -58,6 +64,9 @@ __C infiniStatus_t infiniopGetRandomSampleWorkspaceSize(
#ifdef ENABLE_CUDA_API
#ifdef ENABLE_CUDA_API
GET
(
INFINI_DEVICE_NVIDIA
,
cuda
);
GET
(
INFINI_DEVICE_NVIDIA
,
cuda
);
#endif
#endif
#ifdef ENABLE_METAX_API
GET
(
INFINI_DEVICE_METAX
,
maca
);
#endif
default:
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
@@ -95,6 +104,9 @@ __C infiniStatus_t infiniopRandomSample(
...
@@ -95,6 +104,9 @@ __C infiniStatus_t infiniopRandomSample(
#ifdef ENABLE_CUDA_API
#ifdef ENABLE_CUDA_API
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
cuda
);
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
cuda
);
#endif
#endif
#ifdef ENABLE_METAX_API
CALCULATE
(
INFINI_DEVICE_METAX
,
maca
);
#endif
default:
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
@@ -119,6 +131,9 @@ __C infiniStatus_t infiniopDestroyRandomSampleDescriptor(
...
@@ -119,6 +131,9 @@ __C infiniStatus_t infiniopDestroyRandomSampleDescriptor(
#ifdef ENABLE_CUDA_API
#ifdef ENABLE_CUDA_API
DELETE
(
INFINI_DEVICE_NVIDIA
,
cuda
);
DELETE
(
INFINI_DEVICE_NVIDIA
,
cuda
);
#endif
#endif
#ifdef ENABLE_METAX_API
DELETE
(
INFINI_DEVICE_METAX
,
maca
);
#endif
default:
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment