Commit 1cadb2a1 authored by xgqdut2016's avatar xgqdut2016
Browse files

issue/342: delete to

parent 0ecbe1d5
...@@ -43,22 +43,6 @@ __device__ inline void loadsm(__shared_ptr__ const T *p, T *v, int len) { ...@@ -43,22 +43,6 @@ __device__ inline void loadsm(__shared_ptr__ const T *p, T *v, int len) {
__builtin_memcpy(v, p, len * sizeof(T)); __builtin_memcpy(v, p, len * sizeof(T));
} }
/**
* @brief Convert data type. All data is in local memory
* @param v: input value
* @return output value
*/
template <typename Tout, typename Tin>
__device__ inline Tout to(Tin v) {
if constexpr (std::is_same<Tin, half>::value) {
return __half2float(v);
} else if constexpr (std::is_same<Tin, bfloat16_t>::value) {
return __bfloat162float(v);
} else {
return static_cast<Tout>(v);
}
}
/** /**
* @brief atomicAdd for kunlun xpu * @brief atomicAdd for kunlun xpu
* @param ptr: pointer to shared memory * @param ptr: pointer to shared memory
......
...@@ -270,7 +270,7 @@ __device__ Tcompute softmaxSum(__global_ptr__ const Tval *probs, ...@@ -270,7 +270,7 @@ __device__ Tcompute softmaxSum(__global_ptr__ const Tval *probs,
__shared__ Tcompute sum_; __shared__ Tcompute sum_;
if (core_id() == 0) { if (core_id() == 0) {
sum_ = to<Tcompute>(0.f); sum_ = Tcompute(0.f);
} }
sync_cluster(); sync_cluster();
...@@ -286,9 +286,9 @@ __device__ Tcompute softmaxSum(__global_ptr__ const Tval *probs, ...@@ -286,9 +286,9 @@ __device__ Tcompute softmaxSum(__global_ptr__ const Tval *probs,
for (int index = core_id(); index < read_len; index += BLOCK_SIZE) { for (int index = core_id(); index < read_len; index += BLOCK_SIZE) {
if constexpr (std::is_same_v<Tval, half>) { if constexpr (std::is_same_v<Tval, half>) {
y_sm[index] = __float2half(exp((__half2float(x_sm[index]) - to<float>(max_value)) / temperature)); y_sm[index] = __float2half(exp((__half2float(x_sm[index]) - float(max_value)) / temperature));
} else if constexpr (std::is_same_v<Tval, bfloat16_t>) { } else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
y_sm[index] = __float2bfloat16(exp((__bfloat162float(x_sm[index]) - to<float>(max_value)) / temperature)); y_sm[index] = __float2bfloat16(exp((__bfloat162float(x_sm[index]) - float(max_value)) / temperature));
} else if constexpr (std::is_same_v<Tval, float>) { } else if constexpr (std::is_same_v<Tval, float>) {
y_sm[index] = exp((x_sm[index] - max_value) / temperature); y_sm[index] = exp((x_sm[index] - max_value) / temperature);
} }
...@@ -351,11 +351,11 @@ __device__ void sample(__global_ptr__ Tidx *result, ...@@ -351,11 +351,11 @@ __device__ void sample(__global_ptr__ Tidx *result,
GM2LM(values_global + r * buf_size, values_local, read_len * sizeof(Tval)); GM2LM(values_global + r * buf_size, values_local, read_len * sizeof(Tval));
for (int index = 0; index < read_len; index++) { for (int index = 0; index < read_len; index++) {
if constexpr (std::is_same_v<Tval, float>) { if constexpr (std::is_same_v<Tval, float>) {
cumsum += exp((values_local[index] - max_value) / temperature) / to<float>(all_sum); cumsum += exp((values_local[index] - max_value) / temperature) / float(all_sum);
} else if constexpr (std::is_same_v<Tval, bfloat16_t>) { } else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
cumsum += exp((to<float>(values_local[index]) - to<float>(max_value)) / temperature) / to<float>(all_sum); cumsum += exp((float(values_local[index]) - float(max_value)) / temperature) / float(all_sum);
} else if constexpr (std::is_same_v<Tval, half>) { } else if constexpr (std::is_same_v<Tval, half>) {
cumsum += exp((to<float>(values_local[index]) - to<float>(max_value)) / temperature) / to<float>(all_sum); cumsum += exp((float(values_local[index]) - float(max_value)) / temperature) / float(all_sum);
} }
if (cumsum >= topp) { if (cumsum >= topp) {
end = r * buf_size + index + 1; end = r * buf_size + index + 1;
...@@ -370,11 +370,11 @@ __device__ void sample(__global_ptr__ Tidx *result, ...@@ -370,11 +370,11 @@ __device__ void sample(__global_ptr__ Tidx *result,
GM2LM(values_global + r * buf_size, values_local, read_len * sizeof(Tval)); GM2LM(values_global + r * buf_size, values_local, read_len * sizeof(Tval));
for (int index = 0; index < read_len; index++) { for (int index = 0; index < read_len; index++) {
if constexpr (std::is_same_v<Tval, float>) { if constexpr (std::is_same_v<Tval, float>) {
cumsum += exp((values_local[index] - max_value) / temperature) / to<float>(all_sum); cumsum += exp((values_local[index] - max_value) / temperature) / float(all_sum);
} else if constexpr (std::is_same_v<Tval, bfloat16_t>) { } else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
cumsum += exp((to<float>(values_local[index]) - to<float>(max_value)) / temperature) / to<float>(all_sum); cumsum += exp((float(values_local[index]) - float(max_value)) / temperature) / float(all_sum);
} else if constexpr (std::is_same_v<Tval, half>) { } else if constexpr (std::is_same_v<Tval, half>) {
cumsum += exp((to<float>(values_local[index]) - to<float>(max_value)) / temperature) / to<float>(all_sum); cumsum += exp((float(values_local[index]) - float(max_value)) / temperature) / float(all_sum);
} }
if (random_val < cumsum) { if (random_val < cumsum) {
result[0] = indices_global[r * buf_size + index]; result[0] = indices_global[r * buf_size + index];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment