Commit eeca63a7 authored by Tri Dao's avatar Tri Dao
Browse files

Bug fix: wrong smem_o write pointer for d=16

parent 765741c1
...@@ -101,7 +101,7 @@ void set_params(Fused_multihead_attention_fprop_params &params, ...@@ -101,7 +101,7 @@ void set_params(Fused_multihead_attention_fprop_params &params,
// Set this to probability of keeping an element to simplify things. // Set this to probability of keeping an element to simplify things.
params.p_dropout = 1.f - p_dropout; params.p_dropout = 1.f - p_dropout;
// Convert p from float to int so we don't have to convert the random uint to float to compare. // Convert p from float to int so we don't have to convert the random uint to float to compare.
// [Minor] We want to round down since when we do the comparison we use <= instead < // [Minor] We want to round down since when we do the comparison we use <= instead of <
params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
params.rp_dropout = 1.f / params.p_dropout; params.rp_dropout = 1.f / params.p_dropout;
...@@ -111,7 +111,7 @@ void set_params(Fused_multihead_attention_fprop_params &params, ...@@ -111,7 +111,7 @@ void set_params(Fused_multihead_attention_fprop_params &params,
params.is_causal = is_causal; params.is_causal = is_causal;
} }
std::vector<at::Tensor> std::vector<at::Tensor>
mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens, // b+1 const at::Tensor &cu_seqlens, // b+1
const float p_dropout, const float p_dropout,
......
...@@ -1204,6 +1204,8 @@ struct Smem_tile_o { ...@@ -1204,6 +1204,8 @@ struct Smem_tile_o {
this->smem_write_ ^= 7 * 32; this->smem_write_ ^= 7 * 32;
} else if( Mma_tile::MMAS_N >= 2 ) { } else if( Mma_tile::MMAS_N >= 2 ) {
this->smem_write_ ^= 3 * 32; this->smem_write_ ^= 3 * 32;
} else {
this->smem_write_ ^= 3 * 32;
} }
// this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32; // this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment