"src/lib/vscode:/vscode.git/clone" did not exist on "3ae1a5b14d81564f348525f27dac4e821ba6dbdc"
Commit 45567a25 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani
Browse files

only 1 thread writes to global mem in fprop


Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent a0997bc7
...@@ -667,6 +667,8 @@ inline __device__ void device_1xN_loop(const Params &params) { ...@@ -667,6 +667,8 @@ inline __device__ void device_1xN_loop(const Params &params) {
const int bidb = blockIdx.x; const int bidb = blockIdx.x;
// The block index for the head. // The block index for the head.
const int bidh = blockIdx.y; const int bidh = blockIdx.y;
// The block index.
const int bidx = gridDim.x * bidh + bidb;
// The thread index. // The thread index.
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
...@@ -678,8 +680,10 @@ inline __device__ void device_1xN_loop(const Params &params) { ...@@ -678,8 +680,10 @@ inline __device__ void device_1xN_loop(const Params &params) {
// the attention matrix. This way, as long as we have the batch, head, and the location of // the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 16 block within the attention matrix, we can generate the exact same dropout pattern. // the 16 x 16 block within the attention matrix, we can generate the exact same dropout pattern.
auto seeds = at::cuda::philox::unpack(params.philox_args); auto seeds = at::cuda::philox::unpack(params.philox_args);
params.rng_state[0] = std::get<0>(seeds); if (bidx == 0 && tidx == 0) {
params.rng_state[1] = std::get<1>(seeds); params.rng_state[0] = std::get<0>(seeds);
params.rng_state[1] = std::get<1>(seeds);
}
Philox ph(std::get<0>(seeds), 0, std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32); Philox ph(std::get<0>(seeds), 0, std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32);
constexpr int M = Kernel_traits::Cta_tile_p::M; constexpr int M = Kernel_traits::Cta_tile_p::M;
const int STEPS = (params.seqlen_q + M - 1) / M; const int STEPS = (params.seqlen_q + M - 1) / M;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment