Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
d741ee7d
Commit
d741ee7d
authored
Sep 02, 2025
by
xgqdut2016
Browse files
issue/342: success random_sample all
parent
c5bc6628
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
28 deletions
+35
-28
src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu
...nfiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu
+33
-25
test/infiniop/random_sample.py
test/infiniop/random_sample.py
+2
-3
No files found.
src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu
View file @
d741ee7d
...
...
@@ -162,7 +162,7 @@ __device__ void findTopOne_local(
result[0] = indices_a;
}
template <unsigned int BLOCK_SIZE, typename Tval, typename Tcompute, typename Tidx>
template <unsigned int
CLUSTER_SIZE, unsigned int
BLOCK_SIZE, typename Tval, typename Tcompute, typename Tidx>
__global__ void random_sampleKernel(Tidx *result,
const Tval *probs,
float random_val,
...
...
@@ -263,7 +263,7 @@ __global__ void random_sampleKernel(Tidx *result,
findTopk(values_global, indices_global, nthreads * topk, topk);
}
}
sync_cluster();
//上面这部分是计算topk,数据分别存储在values_global,indices_global里面
__global_ptr__ Tval *values_global_ = values_global;
__shared__ Tval max_value;
...
...
@@ -290,7 +290,8 @@ __global__ void random_sampleKernel(Tidx *result,
if(cid == 0){
if constexpr (std::is_same_v<Tcompute, half>) {
sum_ = __float2half(0.0f);
} else if constexpr (std::is_same_v<Tcompute, bfloat16_t>) {
}
else if constexpr (std::is_same_v<Tcompute, bfloat16_t>) {
sum_ = __float2bfloat16(0.0f);
}
else if constexpr (std::is_same_v<Tcompute, float>) {
...
...
@@ -302,14 +303,15 @@ __global__ void random_sampleKernel(Tidx *result,
for (int r = 0; r < sm_repeat; r++) {
if (cid == 0) {
GM2SM
_ASYNC
(probs_ + r * all_sm_size + cluster_id() * sm_size, x_sm, sm_size * sizeof(Tval));
GM2SM(probs_ + r * all_sm_size + cluster_id() * sm_size, x_sm, sm_size * sizeof(Tval));
}
sync_cluster();
for (int index = cid; index < sm_size; index += BLOCK_SIZE) {
if constexpr (std::is_same_v<Tval, half>) {
y_sm[index] = hexp((loadsm(x_sm + index) - loadsm(&max_value)) / to<half>(temperature));
} else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
}
else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
y_sm[index] = __float2bfloat16(exp((__bfloat162float(x_sm[index]) - __bfloat162float(max_value)) / temperature));
}
else if constexpr (std::is_same_v<Tval, float>) {
...
...
@@ -332,13 +334,14 @@ __global__ void random_sampleKernel(Tidx *result,
if (sm_step) {
if (cid == 0) {
GM2SM
_ASYNC
(probs_ + sm_repeat * all_sm_size + sm_ind_start, x_sm, sm_step * sizeof(Tval));
GM2SM(probs_ + sm_repeat * all_sm_size + sm_ind_start, x_sm, sm_step * sizeof(Tval));
}
sync_cluster();
for (int index = cid; index < sm_step; index += BLOCK_SIZE) {
if constexpr (std::is_same_v<Tval, half>) {
y_sm[index] = hexp((loadsm(x_sm + index) - loadsm(&max_value)) / to<half>(temperature));
} else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
}
else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
y_sm[index] = __float2bfloat16(exp((__bfloat162float(x_sm[index]) - __bfloat162float(max_value)) / temperature));
}
else if constexpr (std::is_same_v<Tval, float>) {
...
...
@@ -358,17 +361,18 @@ __global__ void random_sampleKernel(Tidx *result,
__global_ptr__ Tcompute *sum_global_ = sum_global;
if (core_id() == 0) {
SM2GM
_ASYNC
(&sum_, sum_global_ + cluster_id(), sizeof(Tcompute));
SM2GM(&sum_, sum_global_ + cluster_id(), sizeof(Tcompute));
}
sync_cluster();
__shared__ Tcompute all_sum;
__shared__ Tcompute z_sm[CLUSTER_SIZE];
if(cid == 0){
GM2SM
_ASYNC
(sum_global_,
x
_sm, cluster_num() * sizeof(Tcompute));
GM2SM(sum_global_,
z
_sm, cluster_num() * sizeof(Tcompute));
}
sync_cluster();
Tcompute all_sum_0 = sum<BLOCK_SIZE, Tcompute, Tcompute>(
x
_sm, cluster_num());
Tcompute all_sum_0 = sum<BLOCK_SIZE, Tcompute, Tcompute>(
z
_sm, cluster_num());
if (cid == 0) {
all_sum = all_sum_0;
}
...
...
@@ -377,19 +381,19 @@ __global__ void random_sampleKernel(Tidx *result,
if (thread_id == 0) {
int end = topk;
float cumsum = 0.0f;
for(int r = 0; r < topk / buf_size + (topk % buf_size > 0 ? 1 : 0); r++){
int read_len = (r < topk / buf_size ? buf_size : topk % buf_size);
GM2LM(values_global + r * buf_size, values_local, read_len * sizeof(Tval));
for (int index = 0; index < read_len; index++) {
if constexpr (std::is_same_v<Tval, float>) {
cumsum += exp((values_local[index] - max_value) / temperature) / to<float>(loadsm(&all_sum));
}
else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
cumsum += exp(
(
to<float>(values_local[index]) - to<float>(loadsm(&max_value))
)
/ temperature) / to<float>(loadsm(&all_sum));
cumsum += exp((values_local[index] - max_value) / temperature) / to<float>(loadsm(&all_sum));
}
else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
cumsum += exp(to<float>(values_local[index]) - to<float>(loadsm(&max_value))/ temperature)
/ to<float>(loadsm(&all_sum));
}
else if constexpr (std::is_same_v<Tval, half>) {
cumsum += exp(
(
to<float>(values_local[index]) - to<float>(loadsm(&max_value))
)
/ temperature) / to<float>(loadsm(&all_sum));
cumsum += exp(to<float>(values_local[index]) - to<float>(loadsm(&max_value))/ temperature)
/ to<float>(loadsm(&all_sum));
}
if (cumsum >= topp) {
end = r * buf_size + index + 1;
...
...
@@ -405,11 +409,12 @@ __global__ void random_sampleKernel(Tidx *result,
for (int index = 0; index < read_len; index++) {
if constexpr (std::is_same_v<Tval, float>) {
cumsum += exp((values_local[index] - max_value) / temperature)/ to<float>(loadsm(&all_sum));
} else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
cumsum += exp((to<float>(values_local[index]) - to<float>(loadsm(&max_value))) / temperature) / to<float>(loadsm(&all_sum));
}
else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
cumsum += exp(to<float>(values_local[index]) - to<float>(loadsm(&max_value))/ temperature) / to<float>(loadsm(&all_sum));
}
else if constexpr (std::is_same_v<Tval, half>) {
cumsum += exp(
(
to<float>(values_local[index]) - to<float>(loadsm(&max_value))
)
/ temperature)/ to<float>(loadsm(&all_sum));
cumsum += exp(to<float>(values_local[index]) - to<float>(loadsm(&max_value))/ temperature)
/ to<float>(loadsm(&all_sum));
}
if (random_val < cumsum) {
result[0] = indices_global[r * buf_size + index];
...
...
@@ -505,12 +510,13 @@ void random_sampleFunction(void *workspace,
Tval *values = (Tval *)workspace_value;
xpu_memcpy(values, (Tval *)probs, n * sizeof(Tval), XPU_DEVICE_TO_DEVICE);
Tval *values_global = values + n;
Tval *sum_global = values_global + cluster_num * core_num * topk_;
char *workspace_index = workspace_value + (n + cluster_num * core_num * topk_ + cluster_num) * sizeof(Tval);
char *workspace_sum = workspace_value + (n + cluster_num * core_num * topk_) * sizeof(Tval);
float *sum_global = (float *)workspace_sum;
char *workspace_index = workspace_sum + cluster_num * sizeof(float);
Tidx *indices = (Tidx *)workspace_index;
Tidx *indices_global = indices + n;
if (dosample){
random_sampleKernel<core_num, Tval,
Tval
, Tidx><<<cluster_num, core_num, stream>>>((Tidx *)result,
random_sampleKernel<
cluster_num,
core_num, Tval,
float
, Tidx><<<cluster_num, core_num, stream>>>((Tidx *)result,
(Tval *)probs,
random_val,
topp,
...
...
@@ -560,10 +566,12 @@ infiniStatus_t Descriptor::create(
CHECK_RESULT(result);
auto info = result.take();
// size_t workspace_size = 3 * probs_desc->numel() * infiniSizeOf(probs_desc->dtype()) + probs_desc->numel() * infiniSizeOf(infiniDtype_t::INFINI_DTYPE_I32);
int cluster_num =
256
;
int cluster_num =
8
;
int core_num = 64;
size_t workspace_size = (probs_desc->numel() + cluster_num * core_num * probs_desc->numel() + cluster_num) * infiniSizeOf(probs_desc->dtype()) + (probs_desc->numel() + cluster_num * core_num * probs_desc->numel()) * infiniSizeOf(result_desc->dtype());
int n = probs_desc->numel();
int topk = 50;//必须想办法控制workspace大小,如果topk太大会导致无法申请进而结果报错
size_t workspace_size = (n + cluster_num * core_num * topk) * (infiniSizeOf(probs_desc->dtype()) + infiniSizeOf(result_desc->dtype())) + cluster_num * sizeof(float);
*desc_ptr = new Descriptor(
info,
workspace_size,
...
...
test/infiniop/random_sample.py
View file @
d741ee7d
...
...
@@ -54,8 +54,7 @@ NUM_ITERATIONS = 1000
def
random_sample
(
data
,
random_val
,
topp
,
topk
,
voc
,
temperature
):
if
topp
>
0
and
topk
>
1
:
sorted_vals
,
sorted_indices
=
torch
.
sort
(
data
,
descending
=
True
)
print
(
sorted_vals
[:
topk
])
print
(
sorted_indices
[:
topk
])
scaled_vals
=
(
sorted_vals
-
sorted_vals
[
0
])
/
temperature
try
:
probs
=
torch
.
softmax
(
scaled_vals
,
dim
=
0
)
...
...
@@ -158,7 +157,7 @@ def test(
if
sync
is
not
None
:
sync
()
print
(
indices
.
actual_tensor
(),
ans
)
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
:
debug_all
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment