Commit a4fdef4c authored by zhanghj2's avatar zhanghj2
Browse files

优化softmax计算

parent 3a477917
...@@ -55,12 +55,13 @@ __device__ __forceinline__ void warp_allreduce_(Tensor<Engine0, Layout0> &dst, T ...@@ -55,12 +55,13 @@ __device__ __forceinline__ void warp_allreduce_(Tensor<Engine0, Layout0> &dst, T
// smem_reduce(row, col) = dst(0); // smem_reduce(row, col) = dst(0);
} }
__syncthreads(); __syncthreads();
if (tidx < 16) // if (tidx < 16)
{ // {
smem_reduce(row + 64) = op(op(smem_reduce(row * 4), smem_reduce(row * 4 + 1)), op(smem_reduce(row * 4 + 2), smem_reduce(row * 4 + 3))); // smem_reduce(row + 64) = op(op(smem_reduce(row * 4), smem_reduce(row * 4 + 1)), op(smem_reduce(row * 4 + 2), smem_reduce(row * 4 + 3)));
} // }
__syncthreads(); // __syncthreads();
dst(0) = smem_reduce(row + 64); // dst(0) = smem_reduce(row + 64);
dst(0) = op(op(smem_reduce(row * 4), smem_reduce(row * 4 + 1)), op(smem_reduce(row * 4 + 2), smem_reduce(row * 4 + 3)));
} }
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator> template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
...@@ -75,12 +76,13 @@ __device__ __forceinline__ void warp_allreduce_tp1(Tensor<Engine0, Layout0> &dst ...@@ -75,12 +76,13 @@ __device__ __forceinline__ void warp_allreduce_tp1(Tensor<Engine0, Layout0> &dst
smem_reduce[row * 2 + (warp_id / 4)] = dst[0]; smem_reduce[row * 2 + (warp_id / 4)] = dst[0];
} }
__syncthreads(); __syncthreads();
if (col == 0 && warp_id < 4) { // if (col == 0 && warp_id < 4) {
// printf("sum %d %d %d %.2f %.2f \n", row, tidx, warp_id, smem_reduce[row * 2], smem_reduce[row * 2 + warp_id / 4]); // // printf("sum %d %d %d %.2f %.2f \n", row, tidx, warp_id, smem_reduce[row * 2], smem_reduce[row * 2 + warp_id / 4]);
smem_reduce[128 + row] = op(smem_reduce[row * 2], smem_reduce[row * 2 + 1]); // smem_reduce[128 + row] = op(smem_reduce[row * 2], smem_reduce[row * 2 + 1]);
} // }
__syncthreads(); // __syncthreads();
dst(0) = smem_reduce(128 + row); // dst(0) = smem_reduce(128 + row);
dst(0) = op(smem_reduce[row * 2], smem_reduce[row * 2 + 1]);
} }
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator> template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void warp_allreduce_tp4(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &smem_reduce, Operator &op) { __device__ __forceinline__ void warp_allreduce_tp4(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &smem_reduce, Operator &op) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment