Commit 0ecbe1d5 authored by xgqdut2016's avatar xgqdut2016
Browse files

issue/342: modified loadsm

parent 79b3acc3
...@@ -155,8 +155,8 @@ __device__ void findTopOneLocal( ...@@ -155,8 +155,8 @@ __device__ void findTopOneLocal(
template <typename Tval, typename Tidx> template <typename Tval, typename Tidx>
__device__ void TopkKernel(__global_ptr__ Tval *values, __device__ void TopkKernel(__global_ptr__ Tval *values,
__global_ptr__ Tidx *indices, __global_ptr__ Tidx *indices,
__global_ptr__ Tidx *indices_global, __global_ptr__ Tidx *indices_global, // 长度为cluster_num() * core_num() * topk
__global_ptr__ Tval *values_global, __global_ptr__ Tval *values_global, // 把长度为voc的values的前topk元素集中倒values_global
__local__ Tval *values_local, __local__ Tval *values_local,
__local__ Tidx *indices_local, __local__ Tidx *indices_local,
int voc, int voc,
...@@ -270,13 +270,7 @@ __device__ Tcompute softmaxSum(__global_ptr__ const Tval *probs, ...@@ -270,13 +270,7 @@ __device__ Tcompute softmaxSum(__global_ptr__ const Tval *probs,
__shared__ Tcompute sum_; __shared__ Tcompute sum_;
if (core_id() == 0) { if (core_id() == 0) {
if constexpr (std::is_same_v<Tcompute, half>) { sum_ = to<Tcompute>(0.f);
sum_ = __float2half(0.0f);
} else if constexpr (std::is_same_v<Tcompute, bfloat16_t>) {
sum_ = __float2bfloat16(0.0f);
} else if constexpr (std::is_same_v<Tcompute, float>) {
sum_ = 0.0f;
}
} }
sync_cluster(); sync_cluster();
...@@ -292,7 +286,7 @@ __device__ Tcompute softmaxSum(__global_ptr__ const Tval *probs, ...@@ -292,7 +286,7 @@ __device__ Tcompute softmaxSum(__global_ptr__ const Tval *probs,
for (int index = core_id(); index < read_len; index += BLOCK_SIZE) { for (int index = core_id(); index < read_len; index += BLOCK_SIZE) {
if constexpr (std::is_same_v<Tval, half>) { if constexpr (std::is_same_v<Tval, half>) {
y_sm[index] = hexp((loadsm(x_sm + index) - to<half>(max_value)) / to<half>(temperature)); y_sm[index] = __float2half(exp((__half2float(x_sm[index]) - to<float>(max_value)) / temperature));
} else if constexpr (std::is_same_v<Tval, bfloat16_t>) { } else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
y_sm[index] = __float2bfloat16(exp((__bfloat162float(x_sm[index]) - to<float>(max_value)) / temperature)); y_sm[index] = __float2bfloat16(exp((__bfloat162float(x_sm[index]) - to<float>(max_value)) / temperature));
} else if constexpr (std::is_same_v<Tval, float>) { } else if constexpr (std::is_same_v<Tval, float>) {
...@@ -303,10 +297,8 @@ __device__ Tcompute softmaxSum(__global_ptr__ const Tval *probs, ...@@ -303,10 +297,8 @@ __device__ Tcompute softmaxSum(__global_ptr__ const Tval *probs,
Tcompute sum_0 = op::common_kunlun::reduce_op::sum<BLOCK_SIZE, Tval, Tcompute>(y_sm, read_len); Tcompute sum_0 = op::common_kunlun::reduce_op::sum<BLOCK_SIZE, Tval, Tcompute>(y_sm, read_len);
__shared__ Tcompute sum_tmp_0;
if (core_id() == 0) { if (core_id() == 0) {
sum_tmp_0 = sum_0; sum_ = sum_ + sum_0;
sum_ = loadsm(&sum_) + loadsm(&sum_tmp_0);
} }
sync_cluster(); sync_cluster();
} }
...@@ -330,7 +322,7 @@ __device__ Tcompute softmaxSum(__global_ptr__ const Tval *probs, ...@@ -330,7 +322,7 @@ __device__ Tcompute softmaxSum(__global_ptr__ const Tval *probs,
} }
sync_cluster(); sync_cluster();
return loadsm(&all_sum); return all_sum;
} }
template <typename Tval, typename Tcompute, typename Tidx> template <typename Tval, typename Tcompute, typename Tidx>
__device__ void sample(__global_ptr__ Tidx *result, __device__ void sample(__global_ptr__ Tidx *result,
......
...@@ -46,7 +46,6 @@ void launchKernel(void *workspace, ...@@ -46,7 +46,6 @@ void launchKernel(void *workspace,
indices_global, indices_global,
values_global, values_global,
sum_global); sum_global);
xpu_wait(stream);
} }
else{ else{
...@@ -55,7 +54,6 @@ void launchKernel(void *workspace, ...@@ -55,7 +54,6 @@ void launchKernel(void *workspace,
values, values,
indices_global, indices_global,
values_global); values_global);
xpu_wait(stream);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment