Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
c5bc6628
Commit
c5bc6628
authored
Sep 01, 2025
by
xgqdut2016
Browse files
issue/342: F16 success but BF16 failed
parent
a4b897d9
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
614 additions
and
251 deletions
+614
-251
src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.cc
...infiniop/ops/random_sample/kunlun/random_sample_kunlun.cc
+0
-216
src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu
...nfiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu
+611
-33
test/infiniop/random_sample.py
test/infiniop/random_sample.py
+3
-2
No files found.
src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.cc
deleted
100644 → 0
View file @
a4b897d9
#include "random_sample_kunlun.h"
#include "../../../devices/kunlun/kunlun_common.h"
#include "../../../devices/kunlun/kunlun_handle.h"
#include "../info.h"
#include <assert.h>
void
sample_I64
(
void
*
result
,
float
*
destination
,
int
*
topk_indices
,
float
random_val
,
float
topp
,
int
topk_
,
XPUStream
stream
);
void
sample_I32
(
void
*
result
,
float
*
destination
,
int
*
topk_indices
,
float
random_val
,
float
topp
,
int
topk_
,
XPUStream
stream
);
namespace
op
::
random_sample
::
kunlun
{
struct
Descriptor
::
Opaque
{
std
::
shared_ptr
<
device
::
kunlun
::
Handle
::
Internal
>
internal
;
};
Descriptor
::~
Descriptor
()
{
delete
_opaque
;
}
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle_
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
result_desc
,
infiniopTensorDescriptor_t
probs_desc
)
{
auto
handle
=
reinterpret_cast
<
device
::
kunlun
::
Handle
*>
(
handle_
);
auto
result
=
RandomSampleInfo
::
create
(
result_desc
,
probs_desc
);
CHECK_RESULT
(
result
);
auto
info
=
result
.
take
();
size_t
workspace_size
=
3
*
probs_desc
->
numel
()
*
infiniSizeOf
(
probs_desc
->
dtype
())
+
probs_desc
->
numel
()
*
infiniSizeOf
(
infiniDtype_t
::
INFINI_DTYPE_I32
);
*
desc_ptr
=
new
Descriptor
(
info
,
workspace_size
,
new
Opaque
{
handle
->
internal
()},
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
size_t
Descriptor
::
minWorkspaceSize
()
const
{
return
_min_workspace_size
;
}
infiniStatus_t
random_sample_kernel
(
void
*
workspace
,
size_t
workspace_size
,
std
::
shared_ptr
<
device
::
kunlun
::
Handle
::
Internal
>
internal
,
infiniDtype_t
dt_p
,
infiniDtype_t
dt_i
,
void
*
result
,
const
void
*
probs
,
float
random_val
,
float
topp
,
int
topk
,
float
temperature
,
int64_t
n
,
void
*
stream
)
{
int
topk_
=
topk
<=
(
int
)
n
?
topk
:
(
int
)
n
;
bool
dosample
=
topk_
>
1
&&
temperature
!=
0.0
f
&&
topp
!=
0.0
f
&&
random_val
!=
0.0
f
;
char
*
workspace_value
=
reinterpret_cast
<
char
*>
(
workspace
);
if
(
dosample
)
{
float
*
topk_values
=
(
float
*
)
workspace_value
;
//(topk_, )
float
*
probs_F32
=
topk_values
+
topk_
;
//(n, )
float
*
destination
=
probs_F32
+
n
;
//(n, )
char
*
workspace_index
=
workspace_value
+
(
2
*
n
+
topk_
)
*
sizeof
(
float
);
int
*
topk_indices
=
(
int
*
)
workspace_index
;
//(topk_)
switch
(
dt_p
)
{
case
INFINI_DTYPE_F16
:
CHECK_STATUS
(
internal
->
useXdnn
(
(
kunlunStream_t
)
stream
,
[
&
](
xdnnHandle_t
handle
)
{
CHECK_KUNLUN
((
xdnn
::
cast
<
float16
,
float
>
(
handle
,
(
float16
*
)
probs
,
probs_F32
,
n
)));
CHECK_KUNLUN
((
xdnn
::
sorted_topk
<
float
>
(
handle
,
probs_F32
,
topk_values
,
topk_indices
,
1
,
n
,
topk_
,
true
,
true
)));
float
max_value
=
0.0
f
;
xpu_memcpy
(
&
max_value
,
topk_values
,
sizeof
(
float
),
XPUMemcpyKind
::
XPU_DEVICE_TO_HOST
);
CHECK_KUNLUN
((
xdnn
::
add_scalar
<
float
>
(
handle
,
probs_F32
,
destination
,
max_value
,
-
1.0
f
,
n
)));
CHECK_KUNLUN
((
xdnn
::
mul_scalar
<
float
>
(
handle
,
destination
,
destination
,
1.0
/
temperature
,
n
)));
CHECK_KUNLUN
((
xdnn
::
softmax
<
float
>
(
handle
,
destination
,
destination
,
{
n
},
0
)));
CHECK_KUNLUN
((
xdnn
::
cumsum
<
float
>
(
handle
,
destination
,
destination
,
{
n
},
false
,
false
,
0
)));
return
INFINI_STATUS_SUCCESS
;
}));
if
(
dt_i
==
INFINI_DTYPE_I64
)
{
sample_I64
(
result
,
destination
,
topk_indices
,
random_val
,
topp
,
topk_
,
reinterpret_cast
<
kunlunStream_t
>
(
stream
));
return
INFINI_STATUS_SUCCESS
;
}
else
if
(
dt_i
==
INFINI_DTYPE_I32
)
{
sample_I32
(
result
,
destination
,
topk_indices
,
random_val
,
topp
,
topk_
,
reinterpret_cast
<
kunlunStream_t
>
(
stream
));
return
INFINI_STATUS_SUCCESS
;
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
break
;
case
INFINI_DTYPE_F32
:
CHECK_STATUS
(
internal
->
useXdnn
(
(
kunlunStream_t
)
stream
,
[
&
](
xdnnHandle_t
handle
)
{
CHECK_KUNLUN
((
xdnn
::
sorted_topk
<
float
>
(
handle
,
(
float
*
)
probs
,
topk_values
,
topk_indices
,
1
,
n
,
topk_
,
true
,
true
)));
float
max_value
=
0.0
f
;
xpu_memcpy
(
&
max_value
,
topk_values
,
sizeof
(
float
),
XPUMemcpyKind
::
XPU_DEVICE_TO_HOST
);
CHECK_KUNLUN
((
xdnn
::
add_scalar
<
float
>
(
handle
,
(
float
*
)
probs
,
probs_F32
,
max_value
,
-
1.0
f
,
n
)));
CHECK_KUNLUN
((
xdnn
::
mul_scalar
<
float
>
(
handle
,
probs_F32
,
probs_F32
,
1.0
/
temperature
,
n
)));
CHECK_KUNLUN
((
xdnn
::
softmax
<
float
>
(
handle
,
probs_F32
,
destination
,
{
n
},
0
)));
CHECK_KUNLUN
((
xdnn
::
cumsum
<
float
>
(
handle
,
destination
,
destination
,
{
n
},
false
,
false
,
0
)));
return
INFINI_STATUS_SUCCESS
;
}));
if
(
dt_i
==
INFINI_DTYPE_I64
)
{
sample_I64
(
result
,
destination
,
topk_indices
,
random_val
,
topp
,
topk_
,
reinterpret_cast
<
kunlunStream_t
>
(
stream
));
return
INFINI_STATUS_SUCCESS
;
}
else
if
(
dt_i
==
INFINI_DTYPE_I32
)
{
sample_I32
(
result
,
destination
,
topk_indices
,
random_val
,
topp
,
topk_
,
reinterpret_cast
<
kunlunStream_t
>
(
stream
));
return
INFINI_STATUS_SUCCESS
;
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
break
;
default:
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
else
{
int64_t
*
output
=
(
int64_t
*
)
workspace_value
;
switch
(
dt_p
)
{
case
INFINI_DTYPE_F32
:
if
(
dt_i
==
INFINI_DTYPE_I64
)
{
CHECK_STATUS
(
internal
->
useXdnn
(
(
kunlunStream_t
)
stream
,
[
&
](
xdnnHandle_t
handle
)
{
CHECK_KUNLUN
((
xdnn
::
argmax
<
float
>
(
handle
,
(
float
*
)
probs
,
(
int64_t
*
)
result
,
{
n
},
0
)));
return
INFINI_STATUS_SUCCESS
;
}));
return
INFINI_STATUS_SUCCESS
;
}
else
if
(
dt_i
==
INFINI_DTYPE_I32
)
{
CHECK_STATUS
(
internal
->
useXdnn
(
(
kunlunStream_t
)
stream
,
[
&
](
xdnnHandle_t
handle
)
{
CHECK_KUNLUN
((
xdnn
::
argmax
<
float
>
(
handle
,
(
float
*
)
probs
,
output
,
{
n
},
0
)));
CHECK_KUNLUN
((
xdnn
::
cast
<
int64_t
,
int
>
(
handle
,
output
,
(
int
*
)
result
,
1
)));
return
INFINI_STATUS_SUCCESS
;
}));
return
INFINI_STATUS_SUCCESS
;
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
case
INFINI_DTYPE_F16
:
if
(
dt_i
==
INFINI_DTYPE_I64
)
{
CHECK_STATUS
(
internal
->
useXdnn
(
(
kunlunStream_t
)
stream
,
[
&
](
xdnnHandle_t
handle
)
{
CHECK_KUNLUN
((
xdnn
::
argmax
<
float16
>
(
handle
,
(
float16
*
)
probs
,
(
int64_t
*
)
result
,
{
n
},
0
)));
return
INFINI_STATUS_SUCCESS
;
}));
return
INFINI_STATUS_SUCCESS
;
}
else
if
(
dt_i
==
INFINI_DTYPE_I32
)
{
CHECK_STATUS
(
internal
->
useXdnn
(
(
kunlunStream_t
)
stream
,
[
&
](
xdnnHandle_t
handle
)
{
CHECK_KUNLUN
((
xdnn
::
argmax
<
float16
>
(
handle
,
(
float16
*
)
probs
,
output
,
{
n
},
0
)));
CHECK_KUNLUN
((
xdnn
::
cast
<
int64_t
,
int
>
(
handle
,
output
,
(
int
*
)
result
,
1
)));
return
INFINI_STATUS_SUCCESS
;
}));
return
INFINI_STATUS_SUCCESS
;
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
default:
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
result
,
const
void
*
probs
,
float
random_val
,
float
topp
,
int
topk
,
float
temperature
,
void
*
stream
)
const
{
if
(
workspace_size
<
_min_workspace_size
)
{
return
INFINI_STATUS_INSUFFICIENT_WORKSPACE
;
}
CHECK_STATUS
(
random_sample_kernel
(
workspace
,
workspace_size
,
_opaque
->
internal
,
_info
.
dt_p
,
_info
.
dt_i
,
result
,
probs
,
random_val
,
topp
,
topk
,
temperature
,
_info
.
n
,
stream
));
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::random_sample::kunlun
src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu
View file @
c5bc6628
#ifndef __RANDOM_SAMPLE_KUNLUN_H__
#define __RANDOM_SAMPLE_KUNLUN_H__
#include "random_sample_kunlun.h"
#include "../../../devices/kunlun/kunlun_common.h"
#include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "../../../devices/kunlun/kunlun_handle.h"
#include "../../../reduce/kunlun/reduce_kunlun.h"
#include "../info.h"
#include <assert.h>
#include "xpu/kernel/xtdk_io.h"
using namespace device::kunlun::kernel;
using namespace op::common_kunlun::reduce_op;
template <typename Tval>
__device__ void swap_local(__local__ Tval &a, __local__ Tval &b) {
__local__ Tval tmp = a;
a = b;
b = tmp;
}
template <typename Tval, typename Tidx>
__device__ void findTopk(
__global_ptr__ Tval *values,
__global_ptr__ Tidx *indices,
int size,
int topk) {
__local__ Tval values_a;
__local__ Tval values_b;
__local__ Tidx indices_a;
__local__ Tidx indices_b;
for (int i = 0; i < topk; ++i) {
for (int j = i + 1; j < size; ++j) {
GM2LM(values + i, &values_a, sizeof(Tval));
GM2LM(values + j, &values_b, sizeof(Tval));
GM2LM(indices + i, &indices_a, sizeof(Tidx));
GM2LM(indices + j, &indices_b, sizeof(Tidx));
if constexpr(std::is_same_v<Tval, float>){
if (values_a < values_b) {
swap_local(values_a, values_b);
swap_local(indices_a, indices_b);
}
}
else if constexpr(std::is_same_v<Tval, half>){
if (__half2float(values_a) < __half2float(values_b)) {
swap_local(values_a, values_b);
swap_local(indices_a, indices_b);
}
}
else if constexpr(std::is_same_v<Tval, bfloat16_t>){
if (__bfloat162float(values_a) < __bfloat162float(values_b)) {
swap_local(values_a, values_b);
swap_local(indices_a, indices_b);
}
}
LM2GM(&values_a, values + i, sizeof(Tval));
LM2GM(&values_b, values + j, sizeof(Tval));
LM2GM(&indices_a, indices + i, sizeof(Tidx));
LM2GM(&indices_b, indices + j, sizeof(Tidx));
}
}
}
template <typename Tval, typename Tidx>
__device__ void findTopk_local(
__local__ Tval *values,
__local__ Tidx *result,
int size,
int topk) {
for (int i = 0; i < topk; ++i) {
for (int j = i + 1; j < size; ++j) {
if constexpr(std::is_same_v<Tval, float>){
if (values[i] < values[j]) {
swap_local(values[i], values[j]);
swap_local(result[i], result[j]);
}
}
else if constexpr(std::is_same_v<Tval, half>){
if (__half2float(values[i]) < __half2float(values[j])) {
swap_local(values[i], values[j]);
swap_local(result[i], result[j]);
}
}
else if constexpr(std::is_same_v<Tval, bfloat16_t>){
if (__bfloat162float(values[i]) < __bfloat162float(values[j])) {
swap_local(values[i], values[j]);
swap_local(result[i], result[j]);
}
}
}
}
}
template<class Tidx>
__global__ void sampleKernel(Tidx *result, float *destination, int *topk_indices, float random_val,
template <typename Tval, typename Tidx>
__device__ void findTopOne(
__global_ptr__ Tval *values,
__global_ptr__ Tidx *indices,
int size) {
__local__ Tval values_a = (Tval)(-INFINITY);
__local__ Tval values_b;
__local__ Tidx indices_a = 0;
__local__ Tidx indices_b;
for (int j = 0; j < size; ++j) {
GM2LM(values + j, &values_b, sizeof(Tval));
GM2LM(indices + j, &indices_b, sizeof(Tidx));
if constexpr(std::is_same_v<Tval, float>){
if (values_a < values_b) {
values_a = values_b;
indices_a = indices_b;
}
}
else if constexpr(std::is_same_v<Tval, half>){
if (__half2float(values_a) < __half2float(values_b)) {
values_a = values_b;
indices_a = indices_b;
}
}
else if constexpr(std::is_same_v<Tval, bfloat16_t>){
if (__bfloat162float(values_a) < __bfloat162float(values_b)) {
values_a = values_b;
indices_a = indices_b;
}
}
LM2GM(&values_a, values, sizeof(Tval)); //把最大值存储在0号位置
LM2GM(&indices_a, indices, sizeof(Tidx));
}
}
template <typename Tval, typename Tidx>
__device__ void findTopOne_local(
__local__ Tval *values,
__local__ Tidx *result,
int size) {
__local__ Tval values_a = (Tval)(-INFINITY);
__local__ Tidx indices_a = 0;
for (int j = 0; j < size; ++j) {
if constexpr(std::is_same_v<Tval, float>){
if (values_a < values[j]) {
values_a = values[j];
indices_a = result[j];
}
}
else if constexpr(std::is_same_v<Tval, half>){
if (__half2float(values_a) < __half2float(values[j])) {
values_a = values[j];
indices_a = result[j];
}
}
else if constexpr(std::is_same_v<Tval, bfloat16_t>){
if (__bfloat162float(values_a) < __bfloat162float(values[j])) {
values_a = values[j];
indices_a = result[j];
}
}
}
values[0] = values_a;
result[0] = indices_a;
}
template <unsigned int BLOCK_SIZE, typename Tval, typename Tcompute, typename Tidx>
__global__ void random_sampleKernel(Tidx *result,
const Tval *probs,
float random_val,
float topp,
int topk){
int voc,
int topk,
float temperature,
Tidx *indices,
Tval *values,
Tidx *indices_global,
Tval *values_global,
Tcompute *sum_global) {
int cid = core_id();
int ncores = core_num();
if (cid >= ncores) {
if (cid >= BLOCK_SIZE) {
return;
}
int thread_id = ncores * cluster_id() + cid;
if(thread_id == 0){
int end = 0;
for (end = 0; end < topk; end++) {
int thread_id = BLOCK_SIZE * cluster_id() + cid;
int nthreads = BLOCK_SIZE * cluster_num();
if (destination[end] >= topp) {
break;
// 每个coreId分配step个元素
int remain = voc % nthreads;
int step_easy = (voc - remain) / nthreads;
int step_hard = step_easy + 1;
int step = (thread_id < remain ? step_hard : step_easy);
int ind_start = (thread_id < remain ? thread_id * step_hard : remain * step_hard + (thread_id - remain) * step_easy);
for (int index = ind_start; index < ind_start + step; index++) {
indices[index] = index;
}
constexpr int buf_size = 128;
__local__ Tval values_local[2 * buf_size];
__local__ Tidx indices_local[2 * buf_size];
for (int i = 0; i < 2 * buf_size; i++) {
values_local[i] = (Tval)(-INFINITY);
indices_local[i] = 0;
}
int remainTask = step % buf_size;
int repeat = (step - remainTask) / buf_size;
if (topk >= step_easy) {
if (thread_id == 0) {
findTopk(values, indices, voc, topk);
}
sync_cluster();
for(int index = thread_id; index < topk; index += nthreads){
GM2LM(values + index, values_local, sizeof(Tval));
GM2LM(indices + index, indices_local, sizeof(Tidx));
LM2GM(values_local, values_global + index, sizeof(Tval));
LM2GM(indices_local, indices_global + index, sizeof(Tidx));
}
sync_cluster();
} else { // topk < step_easy
if (buf_size > step_easy) { // buf_size >= step_hard > step_easy > topk
GM2LM(values + ind_start, values_local, step * sizeof(Tval));
GM2LM(indices + ind_start, indices_local, step * sizeof(Tidx));
findTopk_local(values_local, indices_local, step, topk);
LM2GM(values_local, values_global + thread_id * topk, topk * sizeof(Tval)); // values_global前面nthreads * topk存储不同core的topk元素
LM2GM(indices_local, indices_global + thread_id * topk, topk * sizeof(Tidx));
} else { // buf_size <= step_easy
if (topk > buf_size) { // step_easy > topk > buf_size
findTopk(&values[ind_start], &indices[ind_start], step, topk);
for(int r = 0; r < topk / buf_size + (topk % buf_size > 0 ? 1 : 0); r++){
int read_len = (r < topk / buf_size ? buf_size : topk % buf_size);
GM2LM(values + ind_start + r * buf_size, values_local, read_len * sizeof(Tval));
GM2LM(indices + ind_start + r * buf_size, indices_local, read_len * sizeof(Tidx));
LM2GM(values_local, values_global + thread_id * topk + r * buf_size, read_len * sizeof(Tval));
LM2GM(indices_local, indices_global + thread_id * topk + r * buf_size, read_len * sizeof(Tidx));
}
} else { // step_easy >= buf_size >= topk
for (int r = 0; r < repeat; r++) {
GM2LM(values + ind_start + r * buf_size, values_local, buf_size * sizeof(Tval));
GM2LM(indices + ind_start + r * buf_size, indices_local, buf_size * sizeof(Tidx));
findTopk_local(values_local, indices_local, buf_size + topk, topk); // 每次循环把上次的前topk也加入对比
for (int i = buf_size; i < buf_size + topk; i++) { // 把上一轮循环的topk加载到后半部分
values_local[i] = values_local[i - buf_size];
indices_local[i] = indices_local[i - buf_size];
}
}
if (remainTask) {
//此时repeat一定大于0,且values_local[buf_size:buf_size + topk]存储上次的前topk数据
for(int i = 0; i < topk; i++){
values_local[i] = values_local[i + buf_size];
indices_local[i] = indices_local[i + buf_size];
}
GM2LM(values + ind_start + repeat * buf_size, values_local + topk, remainTask * sizeof(Tval));
GM2LM(indices + ind_start + repeat * buf_size, indices_local + topk, remainTask * sizeof(Tidx));
findTopk_local(values_local, indices_local, remainTask + topk, topk);
}
LM2GM(values_local, values_global + thread_id * topk, topk * sizeof(Tval));
LM2GM(indices_local, indices_global + thread_id * topk, topk * sizeof(Tidx));
}
}
if (thread_id == 0) {
findTopk(values_global, indices_global, nthreads * topk, topk);
}
}
//上面这部分是计算topk,数据分别存储在values_global,indices_global里面
__global_ptr__ Tval *values_global_ = values_global;
__shared__ Tval max_value;
if(core_id() == 0){
GM2SM(values_global, &max_value, sizeof(Tval));
}
sync_cluster();
__shared__ Tval x_sm[SM_SIZE / sizeof(Tval)];
__shared__ Tval y_sm[SM_SIZE / sizeof(Tval)];
int sm_size = SM_SIZE / sizeof(Tval);
int all_sm_size = cluster_num() * sm_size;
int sm_remain = voc % all_sm_size;
int sm_repeat = (voc - sm_remain) / all_sm_size;
int sm_remain_cluster = sm_remain % cluster_num();
int sm_step_easy = (sm_remain - sm_remain_cluster) / cluster_num();
int sm_step_hard = sm_step_easy + 1;
int sm_step = (cluster_id() < sm_remain_cluster ? sm_step_hard : sm_step_easy);
int sm_ind_start = (cluster_id() < sm_remain_cluster ? cluster_id() * sm_step_hard : sm_remain_cluster * sm_step_hard + (cluster_id() - sm_remain_cluster) * sm_step_easy);
__shared__ Tcompute sum_;
if(cid == 0){
if constexpr (std::is_same_v<Tcompute, half>) {
sum_ = __float2half(0.0f);
} else if constexpr (std::is_same_v<Tcompute, bfloat16_t>) {
sum_ = __float2bfloat16(0.0f);
}
else if constexpr (std::is_same_v<Tcompute, float>) {
sum_ = 0.0f;
}
}
sync_cluster();
__global_ptr__ Tval const *probs_ = probs;
for (int r = 0; r < sm_repeat; r++) {
if (cid == 0) {
GM2SM_ASYNC(probs_ + r * all_sm_size + cluster_id() * sm_size, x_sm, sm_size * sizeof(Tval));
}
sync_cluster();
for (int index = cid; index < sm_size; index += BLOCK_SIZE) {
if constexpr (std::is_same_v<Tval, half>) {
y_sm[index] = hexp((loadsm(x_sm + index) - loadsm(&max_value)) / to<half>(temperature));
} else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
y_sm[index] = __float2bfloat16(exp((__bfloat162float(x_sm[index]) - __bfloat162float(max_value)) / temperature));
}
else if constexpr (std::is_same_v<Tval, float>) {
y_sm[index] = exp((x_sm[index] - max_value) / temperature);
}
}
sync_cluster();
Tcompute sum_0 = sum<BLOCK_SIZE, Tval, Tcompute>(y_sm, sm_size);
__shared__ Tcompute sum_tmp_0;
if (cid == 0) {
sum_tmp_0 = sum_0;
sum_ = loadsm(&sum_) + loadsm(&sum_tmp_0);
}
sync_cluster();
}
if (sm_step) {
if (cid == 0) {
GM2SM_ASYNC(probs_ + sm_repeat * all_sm_size + sm_ind_start, x_sm, sm_step * sizeof(Tval));
}
sync_cluster();
for (int index = cid; index < sm_step; index += BLOCK_SIZE) {
if constexpr (std::is_same_v<Tval, half>) {
y_sm[index] = hexp((loadsm(x_sm + index) - loadsm(&max_value)) / to<half>(temperature));
} else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
y_sm[index] = __float2bfloat16(exp((__bfloat162float(x_sm[index]) - __bfloat162float(max_value)) / temperature));
}
else if constexpr (std::is_same_v<Tval, float>) {
y_sm[index] = exp((x_sm[index] - max_value) / temperature);
}
}
sync_cluster();
Tcompute sum_0 = sum<BLOCK_SIZE, Tval, Tcompute>(y_sm, sm_step);
__shared__ Tcompute sum_tmp_0;
if (cid == 0) {
sum_tmp_0 = sum_0;
sum_ = loadsm(&sum_) + loadsm(&sum_tmp_0);
}
sync_cluster();
}
__global_ptr__ Tcompute *sum_global_ = sum_global;
if (core_id() == 0) {
SM2GM_ASYNC(&sum_, sum_global_ + cluster_id(), sizeof(Tcompute));
}
sync_cluster();
__shared__ Tcompute all_sum;
if(cid == 0){
GM2SM_ASYNC(sum_global_, x_sm, cluster_num() * sizeof(Tcompute));
}
sync_cluster();
if (end < topk - 1) {
end += 1;
} else {
end = topk;
Tcompute all_sum_0 = sum<BLOCK_SIZE, Tcompute, Tcompute>(x_sm, cluster_num());
if (cid == 0) {
all_sum = all_sum_0;
}
sync_cluster();
if (thread_id == 0) {
int end = topk;
float cumsum = 0.0f;
random_val *= destination[end - 1];
for(int r = 0; r < topk / buf_size + (topk % buf_size > 0 ? 1 : 0); r++){
int read_len = (r < topk / buf_size ? buf_size : topk % buf_size);
GM2LM(values_global + r * buf_size, values_local, read_len * sizeof(Tval));
for (int index = 0; index < read_len; index++) {
if constexpr (std::is_same_v<Tval, float>) {
cumsum += exp((values_local[index] - max_value) / temperature) / to<float>(loadsm(&all_sum));
for (int i = 0; i < end; i++) {
if (random_val < destination[i]) {
result[0] = static_cast<Tidx>(topk_indices[i]);
} else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
cumsum += exp((to<float>(values_local[index]) - to<float>(loadsm(&max_value))) / temperature) / to<float>(loadsm(&all_sum));
}
else if constexpr (std::is_same_v<Tval, half>) {
cumsum += exp((to<float>(values_local[index]) - to<float>(loadsm(&max_value))) / temperature) / to<float>(loadsm(&all_sum));
}
if (cumsum >= topp) {
end = r * buf_size + index + 1;
break;
}
}
}
random_val *= cumsum;
cumsum = 0.0f;
for(int r = 0; r < end / buf_size + (end % buf_size > 0 ? 1 : 0); r++){
int read_len = (r < end / buf_size ? buf_size : end % buf_size);
GM2LM(values_global + r * buf_size, values_local, read_len * sizeof(Tval));
for (int index = 0; index < read_len; index++) {
if constexpr (std::is_same_v<Tval, float>) {
cumsum += exp((values_local[index] - max_value) / temperature)/ to<float>(loadsm(&all_sum));
} else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
cumsum += exp((to<float>(values_local[index]) - to<float>(loadsm(&max_value))) / temperature) / to<float>(loadsm(&all_sum));
}
else if constexpr (std::is_same_v<Tval, half>) {
cumsum += exp((to<float>(values_local[index]) - to<float>(loadsm(&max_value))) / temperature)/ to<float>(loadsm(&all_sum));
}
if (random_val < cumsum) {
result[0] = indices_global[r * buf_size + index];
break;
}
}
}
}
}
void sample_I64(void *result, float *destination, int *topk_indices, float random_val,
template <typename Tval, typename Tidx>
__global__ void argmaxKernel(Tidx *result, const Tval *probs, int voc,
Tidx *indices,
Tval *values,
Tidx *indices_global,
Tval *values_global){
int cid = core_id();
if (cid >= core_num()) {
return;
}
int thread_id = core_num() * cluster_id() + cid;
int nthreads = core_num() * cluster_num();
// 每个coreId分配step个元素
int remain = voc % nthreads;
int step_easy = (voc - remain) / nthreads;
int step_hard = step_easy + 1;
int step = (thread_id < remain ? step_hard : step_easy);
int ind_start = (thread_id < remain ? thread_id * step_hard : remain * step_hard + (thread_id - remain) * step_easy);
for (int index = ind_start; index < ind_start + step; index++) {
indices[index] = index;
}
constexpr int buf_size = 128;
__local__ Tval values_local[2 * buf_size];
__local__ Tidx indices_local[2 * buf_size];
for (int i = 0; i < 2 * buf_size; i++) {
values_local[i] = (Tval)(-INFINITY);
indices_local[i] = 0;
}
int remainTask = step % buf_size;
int repeat = (step - remainTask) / buf_size;
if (buf_size > step_easy) { // buf_size >= step_hard > step_easy
GM2LM(values + ind_start, values_local, step * sizeof(Tval));
GM2LM(indices + ind_start, indices_local, step * sizeof(Tidx));
findTopOne_local(values_local, indices_local, step);
LM2GM(values_local, values_global + thread_id, sizeof(Tval));
LM2GM(indices_local, indices_global + thread_id, sizeof(Tidx));
} else { // buf_size <= step_easy
for (int r = 0; r < repeat; r++) {
GM2LM(values + ind_start + r * buf_size, values_local, buf_size * sizeof(Tval));
GM2LM(indices + ind_start + r * buf_size, indices_local, buf_size * sizeof(Tidx));
findTopOne_local(values_local, indices_local, buf_size + 1);
values_local[buf_size] = values_local[0];
indices_local[buf_size] = indices_local[0];
}
if (remainTask) {
GM2LM(values + ind_start + repeat * buf_size, values_local, remainTask * sizeof(Tval));
GM2LM(indices + ind_start + repeat * buf_size, indices_local, remainTask * sizeof(Tidx));
//此时repeat一定大于0
values_local[remainTask] = values_local[buf_size];
indices_local[remainTask] = indices_local[buf_size];
findTopOne_local(values_local, indices_local, remainTask + 1);
}
LM2GM(values_local, values_global + thread_id, sizeof(Tval));
LM2GM(indices_local, indices_global + thread_id, sizeof(Tidx));
}
if (thread_id == 0) {
findTopOne(values_global, indices_global, nthreads);
result[0] = indices_global[0];
}
}
template <typename Tval, typename Tidx>
void random_sampleFunction(void *workspace,
void *result,
const void *probs,
float random_val,
float topp,
int topk_, XPUStream stream){
sampleKernel<int64_t><<<1, 1, stream>>>((int64_t *)result, destination, topk_indices, random_val, topp, topk_);
int topk,
float temperature,
int64_t n,
XPUStream stream) {
constexpr unsigned int cluster_num = 8;
constexpr unsigned int core_num = 64;
char *workspace_value = reinterpret_cast<char *>(workspace);
int topk_ = topk <= (int)n ? topk : (int)n;
bool dosample = topk_ > 1 && temperature != 0.0f && topp != 0.0f && random_val != 0.0f;
Tval *values = (Tval *)workspace_value;
xpu_memcpy(values, (Tval *)probs, n * sizeof(Tval), XPU_DEVICE_TO_DEVICE);
Tval *values_global = values + n;
Tval *sum_global = values_global + cluster_num * core_num * topk_;
char *workspace_index = workspace_value + (n + cluster_num * core_num * topk_ + cluster_num) * sizeof(Tval);
Tidx *indices = (Tidx *)workspace_index;
Tidx *indices_global = indices + n;
if (dosample){
random_sampleKernel<core_num, Tval, Tval, Tidx><<<cluster_num, core_num, stream>>>((Tidx *)result,
(Tval *)probs,
random_val,
topp,
n,
topk_,
temperature,
indices,
values,
indices_global,
values_global,
sum_global);
xpu_wait(stream);
}
else{
argmaxKernel<Tval, Tidx><<<cluster_num, core_num, stream>>>((Tidx *)result, (Tval *)probs, n,
indices,
values,
indices_global,
values_global);
xpu_wait(stream);
}
}
void sample_I32(void *result, float *destination, int *topk_indices, float random_val,
#define LAUNCH_KERNEL(Tval, Tidx) \
random_sampleFunction<Tval, Tidx>(workspace, result, probs, random_val, topp, topk, temperature, n, reinterpret_cast<kunlunStream_t>(stream));
namespace op::random_sample::kunlun {
struct Descriptor::Opaque {
std::shared_ptr<device::kunlun::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t result_desc,
infiniopTensorDescriptor_t probs_desc) {
auto handle = reinterpret_cast<device::kunlun::Handle *>(handle_);
auto result = RandomSampleInfo::create(result_desc, probs_desc);
CHECK_RESULT(result);
auto info = result.take();
// size_t workspace_size = 3 * probs_desc->numel() * infiniSizeOf(probs_desc->dtype()) + probs_desc->numel() * infiniSizeOf(infiniDtype_t::INFINI_DTYPE_I32);
int cluster_num = 256;
int core_num = 64;
size_t workspace_size = (probs_desc->numel() + cluster_num * core_num * probs_desc->numel() + cluster_num) * infiniSizeOf(probs_desc->dtype()) + (probs_desc->numel() + cluster_num * core_num * probs_desc->numel()) * infiniSizeOf(result_desc->dtype());
*desc_ptr = new Descriptor(
info,
workspace_size,
new Opaque{handle->internal()},
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
size_t Descriptor::minWorkspaceSize() const {
return _min_workspace_size;
}
infiniStatus_t
Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *result,
const void *probs,
float random_val,
float topp,
int topk_, XPUStream stream){
sampleKernel<int32_t><<<1, 1, stream>>>((int32_t *)result, destination, topk_indices, random_val, topp, topk_);
int topk,
float temperature,
void *stream) const {
if (workspace_size < _min_workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
int n = (int)_info.n;
if (_info.dt_i == INFINI_DTYPE_I32){
switch (_info.dt_p) {
case INFINI_DTYPE_F16:
LAUNCH_KERNEL(half, int32_t);
return INFINI_STATUS_SUCCESS;
case INFINI_DTYPE_BF16:
LAUNCH_KERNEL(bfloat16_t, int32_t);
return INFINI_STATUS_SUCCESS;
case INFINI_DTYPE_F32:
LAUNCH_KERNEL(float, int32_t);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
else if (_info.dt_i == INFINI_DTYPE_I64){
switch (_info.dt_p) {
case INFINI_DTYPE_F16:
LAUNCH_KERNEL(half, int64_t);
return INFINI_STATUS_SUCCESS;
case INFINI_DTYPE_BF16:
LAUNCH_KERNEL(bfloat16_t, int64_t);
return INFINI_STATUS_SUCCESS;
case INFINI_DTYPE_F32:
LAUNCH_KERNEL(float, int64_t);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
#endif // __RANDOM_SAMPLE_KUNLUN_H__
} // namespace op::random_sample::kunlun
test/infiniop/random_sample.py
View file @
c5bc6628
...
...
@@ -54,7 +54,8 @@ NUM_ITERATIONS = 1000
def
random_sample
(
data
,
random_val
,
topp
,
topk
,
voc
,
temperature
):
if
topp
>
0
and
topk
>
1
:
sorted_vals
,
sorted_indices
=
torch
.
sort
(
data
,
descending
=
True
)
print
(
sorted_vals
[:
topk
])
print
(
sorted_indices
[:
topk
])
scaled_vals
=
(
sorted_vals
-
sorted_vals
[
0
])
/
temperature
try
:
probs
=
torch
.
softmax
(
scaled_vals
,
dim
=
0
)
...
...
@@ -157,7 +158,7 @@ def test(
if
sync
is
not
None
:
sync
()
print
(
indices
.
actual_tensor
(),
ans
)
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
:
debug_all
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment