Commit 1cadb2a1 authored by xgqdut2016's avatar xgqdut2016
Browse files

issue/342: delete to

parent 0ecbe1d5
......@@ -43,22 +43,6 @@ __device__ inline void loadsm(__shared_ptr__ const T *p, T *v, int len) {
__builtin_memcpy(v, p, len * sizeof(T));
}
/**
* @brief Convert data type. All data is in local memory
* @param v: input value
* @return output value
*/
template <typename Tout, typename Tin>
__device__ inline Tout to(Tin v) {
if constexpr (std::is_same<Tin, half>::value) {
return __half2float(v);
} else if constexpr (std::is_same<Tin, bfloat16_t>::value) {
return __bfloat162float(v);
} else {
return static_cast<Tout>(v);
}
}
/**
* @brief atomicAdd for kunlun xpu
* @param ptr: pointer to shared memory
......
......@@ -270,7 +270,7 @@ __device__ Tcompute softmaxSum(__global_ptr__ const Tval *probs,
__shared__ Tcompute sum_;
if (core_id() == 0) {
sum_ = to<Tcompute>(0.f);
sum_ = Tcompute(0.f);
}
sync_cluster();
......@@ -286,9 +286,9 @@ __device__ Tcompute softmaxSum(__global_ptr__ const Tval *probs,
for (int index = core_id(); index < read_len; index += BLOCK_SIZE) {
if constexpr (std::is_same_v<Tval, half>) {
y_sm[index] = __float2half(exp((__half2float(x_sm[index]) - to<float>(max_value)) / temperature));
y_sm[index] = __float2half(exp((__half2float(x_sm[index]) - float(max_value)) / temperature));
} else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
y_sm[index] = __float2bfloat16(exp((__bfloat162float(x_sm[index]) - to<float>(max_value)) / temperature));
y_sm[index] = __float2bfloat16(exp((__bfloat162float(x_sm[index]) - float(max_value)) / temperature));
} else if constexpr (std::is_same_v<Tval, float>) {
y_sm[index] = exp((x_sm[index] - max_value) / temperature);
}
......@@ -351,11 +351,11 @@ __device__ void sample(__global_ptr__ Tidx *result,
GM2LM(values_global + r * buf_size, values_local, read_len * sizeof(Tval));
for (int index = 0; index < read_len; index++) {
if constexpr (std::is_same_v<Tval, float>) {
cumsum += exp((values_local[index] - max_value) / temperature) / to<float>(all_sum);
cumsum += exp((values_local[index] - max_value) / temperature) / float(all_sum);
} else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
cumsum += exp((to<float>(values_local[index]) - to<float>(max_value)) / temperature) / to<float>(all_sum);
cumsum += exp((float(values_local[index]) - float(max_value)) / temperature) / float(all_sum);
} else if constexpr (std::is_same_v<Tval, half>) {
cumsum += exp((to<float>(values_local[index]) - to<float>(max_value)) / temperature) / to<float>(all_sum);
cumsum += exp((float(values_local[index]) - float(max_value)) / temperature) / float(all_sum);
}
if (cumsum >= topp) {
end = r * buf_size + index + 1;
......@@ -370,11 +370,11 @@ __device__ void sample(__global_ptr__ Tidx *result,
GM2LM(values_global + r * buf_size, values_local, read_len * sizeof(Tval));
for (int index = 0; index < read_len; index++) {
if constexpr (std::is_same_v<Tval, float>) {
cumsum += exp((values_local[index] - max_value) / temperature) / to<float>(all_sum);
cumsum += exp((values_local[index] - max_value) / temperature) / float(all_sum);
} else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
cumsum += exp((to<float>(values_local[index]) - to<float>(max_value)) / temperature) / to<float>(all_sum);
cumsum += exp((float(values_local[index]) - float(max_value)) / temperature) / float(all_sum);
} else if constexpr (std::is_same_v<Tval, half>) {
cumsum += exp((to<float>(values_local[index]) - to<float>(max_value)) / temperature) / to<float>(all_sum);
cumsum += exp((float(values_local[index]) - float(max_value)) / temperature) / float(all_sum);
}
if (random_val < cumsum) {
result[0] = indices_global[r * buf_size + index];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment