"...protocols/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "057f8f470041364365090a93894b40296ee0bcb3"
Commit f6970ef8 authored by Luxios22's avatar Luxios22 Committed by binmakeswell
Browse files

[NFC] polish colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu code style (#954)

parent 0b86a634
#include <cooperative_groups.h>
#include <math.h> #include <math.h>
#include <cub/block/block_load.cuh> #include <cub/block/block_load.cuh>
...@@ -6,8 +7,6 @@ ...@@ -6,8 +7,6 @@
#include "block_reduce.h" #include "block_reduce.h"
#include "kernels.h" #include "kernels.h"
#include <cooperative_groups.h>
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
const float EPSILON = 1e-8f; const float EPSILON = 1e-8f;
...@@ -120,7 +119,7 @@ __global__ void ker_attn_softmax(T *inp, const T *attn_mask, int from_len, ...@@ -120,7 +119,7 @@ __global__ void ker_attn_softmax(T *inp, const T *attn_mask, int from_len,
BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i],
to_len); to_len);
} }
} // blockIdx.x } // blockIdx.x
} }
template <typename T, int block_dim, int ele_per_thread> template <typename T, int block_dim, int ele_per_thread>
...@@ -198,7 +197,7 @@ __global__ void ker_attn_softmax_lt32(T *inp, const T *attn_mask, int from_len, ...@@ -198,7 +197,7 @@ __global__ void ker_attn_softmax_lt32(T *inp, const T *attn_mask, int from_len,
BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i],
to_len); to_len);
} }
} // blockIdx.x } // blockIdx.x
} }
/* /*
...@@ -304,8 +303,7 @@ __global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) { ...@@ -304,8 +303,7 @@ __global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) {
cg::thread_block b = cg::this_thread_block(); cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b); cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
for (int i = 1; i < WARP_SIZE; i <<= 1) for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i);
sum += g.shfl_xor(sum, i);
#pragma unroll #pragma unroll
for (int i = 0; i < ITERATIONS; ++i) { for (int i = 0; i < ITERATIONS; ++i) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment