Commit 0ecbe1d5 authored by xgqdut2016's avatar xgqdut2016
Browse files

issue/342: modified loadsm

parent 79b3acc3
......@@ -155,8 +155,8 @@ __device__ void findTopOneLocal(
template <typename Tval, typename Tidx>
__device__ void TopkKernel(__global_ptr__ Tval *values,
__global_ptr__ Tidx *indices,
__global_ptr__ Tidx *indices_global,
__global_ptr__ Tval *values_global,
__global_ptr__ Tidx *indices_global, // 长度为cluster_num() * core_num() * topk
__global_ptr__ Tval *values_global, // 把长度为voc的values的前topk元素集中倒values_global
__local__ Tval *values_local,
__local__ Tidx *indices_local,
int voc,
......@@ -270,13 +270,7 @@ __device__ Tcompute softmaxSum(__global_ptr__ const Tval *probs,
__shared__ Tcompute sum_;
if (core_id() == 0) {
if constexpr (std::is_same_v<Tcompute, half>) {
sum_ = __float2half(0.0f);
} else if constexpr (std::is_same_v<Tcompute, bfloat16_t>) {
sum_ = __float2bfloat16(0.0f);
} else if constexpr (std::is_same_v<Tcompute, float>) {
sum_ = 0.0f;
}
sum_ = to<Tcompute>(0.f);
}
sync_cluster();
......@@ -292,7 +286,7 @@ __device__ Tcompute softmaxSum(__global_ptr__ const Tval *probs,
for (int index = core_id(); index < read_len; index += BLOCK_SIZE) {
if constexpr (std::is_same_v<Tval, half>) {
y_sm[index] = hexp((loadsm(x_sm + index) - to<half>(max_value)) / to<half>(temperature));
y_sm[index] = __float2half(exp((__half2float(x_sm[index]) - to<float>(max_value)) / temperature));
} else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
y_sm[index] = __float2bfloat16(exp((__bfloat162float(x_sm[index]) - to<float>(max_value)) / temperature));
} else if constexpr (std::is_same_v<Tval, float>) {
......@@ -303,10 +297,8 @@ __device__ Tcompute softmaxSum(__global_ptr__ const Tval *probs,
Tcompute sum_0 = op::common_kunlun::reduce_op::sum<BLOCK_SIZE, Tval, Tcompute>(y_sm, read_len);
__shared__ Tcompute sum_tmp_0;
if (core_id() == 0) {
sum_tmp_0 = sum_0;
sum_ = loadsm(&sum_) + loadsm(&sum_tmp_0);
sum_ = sum_ + sum_0;
}
sync_cluster();
}
......@@ -330,7 +322,7 @@ __device__ Tcompute softmaxSum(__global_ptr__ const Tval *probs,
}
sync_cluster();
return loadsm(&all_sum);
return all_sum;
}
template <typename Tval, typename Tcompute, typename Tidx>
__device__ void sample(__global_ptr__ Tidx *result,
......
......@@ -46,7 +46,6 @@ void launchKernel(void *workspace,
indices_global,
values_global,
sum_global);
xpu_wait(stream);
}
else{
......@@ -55,7 +54,6 @@ void launchKernel(void *workspace,
values,
indices_global,
values_global);
xpu_wait(stream);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment