Commit eeca63a7 authored by Tri Dao's avatar Tri Dao
Browse files

Bug fix: wrong smem_o write pointer for d=16

parent 765741c1
......@@ -101,7 +101,7 @@ void set_params(Fused_multihead_attention_fprop_params &params,
// Set this to probability of keeping an element to simplify things.
params.p_dropout = 1.f - p_dropout;
// Convert p from float to int so we don't have to convert the random uint to float to compare.
// [Minor] We want to round down since when we do the comparison we use <= instead <
// [Minor] We want to round down since when we do the comparison we use <= instead of <
params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
params.rp_dropout = 1.f / params.p_dropout;
......@@ -111,7 +111,7 @@ void set_params(Fused_multihead_attention_fprop_params &params,
params.is_causal = is_causal;
}
std::vector<at::Tensor>
std::vector<at::Tensor>
mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens, // b+1
const float p_dropout,
......
......@@ -1204,6 +1204,8 @@ struct Smem_tile_o {
this->smem_write_ ^= 7 * 32;
} else if( Mma_tile::MMAS_N >= 2 ) {
this->smem_write_ ^= 3 * 32;
} else {
this->smem_write_ ^= 3 * 32;
}
// this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment