Commit 20488ee7 authored by zhangyue's avatar zhangyue
Browse files

issue/418: 解决 p800 上手写算子引用 sm 上指针的报错问题

parent b3170335
...@@ -37,22 +37,6 @@ inline __device__ float lowerBitMask(int i) { ...@@ -37,22 +37,6 @@ inline __device__ float lowerBitMask(int i) {
return (1 << (i + 1)) - 1; return (1 << (i + 1)) - 1;
} }
/**
* @brief Load data from shared memory
* @param p: pointer to shared memory
* @return loaded value
*/
template <typename T>
__device__ inline T loadsm(__shared_ptr__ const T *p) {
T v;
if constexpr (std::is_same<T, half>::value
|| std::is_same<T, bfloat16_t>::value) {
__builtin_memcpy(&v, p, sizeof(T));
} else {
v = *p;
}
return v;
}
// Load len data from shared memory // Load len data from shared memory
template <typename T> template <typename T>
__device__ inline void loadsm(__shared_ptr__ const T *p, T *v, int len) { __device__ inline void loadsm(__shared_ptr__ const T *p, T *v, int len) {
...@@ -89,7 +73,7 @@ inline __device__ T atomicAdd(__shared_ptr__ T *ptr, T value) { ...@@ -89,7 +73,7 @@ inline __device__ T atomicAdd(__shared_ptr__ T *ptr, T value) {
template <> template <>
inline __device__ half atomicAdd<half>(__shared_ptr__ half *ptr, half value) { inline __device__ half atomicAdd<half>(__shared_ptr__ half *ptr, half value) {
ticket_lock_mix(); ticket_lock_mix();
__half old = loadsm(ptr); half old = *ptr;
float of = __half2float(old); float of = __half2float(old);
float vf = __half2float(value); float vf = __half2float(value);
float sumf = of + vf; float sumf = of + vf;
...@@ -103,7 +87,7 @@ inline __device__ half atomicAdd<half>(__shared_ptr__ half *ptr, half value) { ...@@ -103,7 +87,7 @@ inline __device__ half atomicAdd<half>(__shared_ptr__ half *ptr, half value) {
template <> template <>
inline __device__ bfloat16_t atomicAdd<bfloat16_t>(__shared_ptr__ bfloat16_t *ptr, bfloat16_t value) { inline __device__ bfloat16_t atomicAdd<bfloat16_t>(__shared_ptr__ bfloat16_t *ptr, bfloat16_t value) {
ticket_lock_mix(); ticket_lock_mix();
bfloat16_t old = loadsm(ptr); bfloat16_t old = *ptr;
float of = __bfloat162float(old); float of = __bfloat162float(old);
float vf = __bfloat162float(value); float vf = __bfloat162float(value);
float sumf = of + vf; float sumf = of + vf;
...@@ -122,7 +106,7 @@ inline __device__ bfloat16_t atomicAdd<bfloat16_t>(__shared_ptr__ bfloat16_t *pt ...@@ -122,7 +106,7 @@ inline __device__ bfloat16_t atomicAdd<bfloat16_t>(__shared_ptr__ bfloat16_t *pt
template <typename T> template <typename T>
inline __device__ T atomicMax(__shared_ptr__ T *ptr, T value) { inline __device__ T atomicMax(__shared_ptr__ T *ptr, T value) {
ticket_lock_mix(); ticket_lock_mix();
T old = loadsm(ptr); T old = *ptr;
if constexpr (std::is_same<T, bfloat16_t>::value) { if constexpr (std::is_same<T, bfloat16_t>::value) {
float of = __bfloat162float(old); float of = __bfloat162float(old);
float vf = __bfloat162float(value); float vf = __bfloat162float(value);
......
...@@ -31,7 +31,7 @@ __device__ void causalSoftmaxBlock( ...@@ -31,7 +31,7 @@ __device__ void causalSoftmaxBlock(
// height: 3 col_id-> // height: 3 col_id->
if (width + size_t(row_id) >= col + height) { if (width + size_t(row_id) >= col + height) {
if constexpr (std::is_same_v<Tdata, half>) { if constexpr (std::is_same_v<Tdata, half>) {
y[col] = hexp(loadsm(x + col) - loadsm(&max_)); y[col] = hexp(x[col] - max_);
} else if constexpr (std::is_same_v<Tdata, bfloat16_t>) { } else if constexpr (std::is_same_v<Tdata, bfloat16_t>) {
y[col] = __float2bfloat16(exp(__bfloat162float(x[col]) - __bfloat162float(max_))); y[col] = __float2bfloat16(exp(__bfloat162float(x[col]) - __bfloat162float(max_)));
} else { } else {
...@@ -54,7 +54,7 @@ __device__ void causalSoftmaxBlock( ...@@ -54,7 +54,7 @@ __device__ void causalSoftmaxBlock(
// Apply softmax // Apply softmax
for (size_t col = core_id(); col < width; col += BLOCK_SIZE) { for (size_t col = core_id(); col < width; col += BLOCK_SIZE) {
if (sum_ != 0) { if (sum_ != 0) {
y[col] = to<Tdata>(to<Tcompute>(loadsm(y + col)) / sum_); y[col] = to<Tdata>(to<Tcompute>(y[col]) / sum_);
} else { } else {
y[col] = Tdata(0); y[col] = Tdata(0);
} }
......
...@@ -25,8 +25,8 @@ __device__ void rmsnormBlock( ...@@ -25,8 +25,8 @@ __device__ void rmsnormBlock(
// Copy contiguous x, w into local mem (load from shared memory safely) // Copy contiguous x, w into local mem (load from shared memory safely)
for (size_t i = core_id(); i < dim; i += BLOCK_SIZE) { for (size_t i = core_id(); i < dim; i += BLOCK_SIZE) {
Tdata xi = loadsm(x + i); Tdata xi = x[i];
Tweight wi = loadsm(w + i); Tweight wi = w[i];
y[i] = static_cast<Tdata>(to<Tcompute>(xi) * to<Tcompute>(wi) * rms); y[i] = static_cast<Tdata>(to<Tcompute>(xi) * to<Tcompute>(wi) * rms);
} }
sync_cluster(); sync_cluster();
......
...@@ -13,20 +13,20 @@ __device__ inline Tcompute sumSquared(__shared_ptr__ const Tdata *data_ptr, size ...@@ -13,20 +13,20 @@ __device__ inline Tcompute sumSquared(__shared_ptr__ const Tdata *data_ptr, size
Tcompute ss = 0; Tcompute ss = 0;
for (size_t i = core_id(); i < count; i += BLOCK_SIZE) { for (size_t i = core_id(); i < count; i += BLOCK_SIZE) {
Tdata xi = loadsm(data_ptr + i); Tdata xi = data_ptr[i];
ss += to<Tcompute>(xi) * to<Tcompute>(xi); ss += to<Tcompute>(xi) * to<Tcompute>(xi);
} }
__shared__ Tcompute temp_storage; __shared__ Tcompute temp_storage;
if (core_id() == 0) { if (core_id() == 0) {
temp_storage = 0; temp_storage = to<Tcompute>(0.f);
} }
sync_cluster(); sync_cluster();
atomicAdd(&temp_storage, ss); atomicAdd(&temp_storage, ss);
sync_cluster(); sync_cluster();
return loadsm(&temp_storage); return temp_storage;
} }
// Sum(x) on contiguous data of length count // Sum(x) on contiguous data of length count
...@@ -35,43 +35,42 @@ __device__ inline Tcompute sum(__shared_ptr__ const Tdata *data_ptr, size_t coun ...@@ -35,43 +35,42 @@ __device__ inline Tcompute sum(__shared_ptr__ const Tdata *data_ptr, size_t coun
Tcompute ss = 0; Tcompute ss = 0;
for (size_t i = core_id(); i < count; i += BLOCK_SIZE) { for (size_t i = core_id(); i < count; i += BLOCK_SIZE) {
Tdata xi = loadsm(data_ptr + i); Tdata xi = data_ptr[i];
ss += to<Tcompute>(xi); ss += to<Tcompute>(xi);
} }
__shared__ Tcompute temp_storage; __shared__ Tcompute temp_storage;
if (core_id() == 0) { if (core_id() == 0) {
temp_storage = 0; temp_storage = to<Tcompute>(0.f);
} }
sync_cluster(); sync_cluster();
atomicAdd(&temp_storage, ss); atomicAdd(&temp_storage, ss);
sync_cluster(); sync_cluster();
return loadsm(&temp_storage); return temp_storage;
} }
// Max(x) on contiguous data of length count // Max(x) on contiguous data of length count
template <unsigned int BLOCK_SIZE, typename Tdata> template <unsigned int BLOCK_SIZE, typename Tdata>
__device__ inline Tdata max(__shared_ptr__ const Tdata *data_ptr, size_t count) { __device__ inline Tdata max(__shared_ptr__ const Tdata *data_ptr, size_t count) {
Tdata max_val = loadsm(data_ptr); Tdata max_val = data_ptr[0];
for (size_t i = core_id(); i < count; i += BLOCK_SIZE) { for (size_t i = core_id(); i < count; i += BLOCK_SIZE) {
// Tdata xi = loadsm(data_ptr + i); Tdata xi = data_ptr[i];
Tdata xi = loadsm(data_ptr + i);
max_val = fmax(max_val, to<Tdata>(xi)); max_val = fmax(max_val, to<Tdata>(xi));
} }
__shared__ Tdata temp_storage; __shared__ Tdata temp_storage;
if (core_id() == 0) { if (core_id() == 0) {
temp_storage = loadsm(data_ptr); temp_storage = data_ptr[0];
} }
sync_cluster(); sync_cluster();
atomicMax(&temp_storage, max_val); atomicMax(&temp_storage, max_val);
sync_cluster(); sync_cluster();
return loadsm(&temp_storage); return temp_storage;
} }
} // namespace op::common_kunlun::reduce_op } // namespace op::common_kunlun::reduce_op
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment