Commit 7b18e6fd authored by wangshaojie6's avatar wangshaojie6
Browse files

attention with lower triangle mask with tile skipping

parent a614e299
......@@ -149,8 +149,8 @@ int main(int argc, char* argv[])
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck::index_t M = 256;
ck::index_t N = 256;
ck::index_t M = 512;
ck::index_t N = 512;
ck::index_t K = 64;
ck::index_t O = 128;
ck::index_t StrideA = -1;
......
......@@ -266,39 +266,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
__device__ static auto GetGemm0WaveIdx()
{
const index_t thread_id = get_thread_local_1d_id();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(Gemm0MWaves, Gemm0NWaves, warpSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto GetGemm0WaveMNIdx(const index_t thread_id)
{
constexpr auto wave_threadid_to_mn_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(warpSize / MPerXdl, MPerXdl))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return wave_threadid_to_mn_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto GetGemm0MNIdx(const index_t n4)
{
auto waveIdx = GetGemm0WaveIdx();
auto waveMNIdx = GetGemm0WaveMNIdx(waveIdx[I2]);
auto MIdx = waveIdx[I0] * MPerXdl + waveMNIdx[I1];
auto NIdx = waveIdx[I1] * NPerXdl + waveMNIdx[I0] * n4;
return make_tuple(NIdx, MIdx);
}
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n)
{
......@@ -600,37 +567,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr auto b1_block_slice_copy_step = make_multi_index(Gemm1KPerBlock / B1K1, 0, 0);
// get m/n id
//const auto wave_id = GetGemm0WaveIdx();
//const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
const auto m_n_id = GetGemm0MNIdx(n4);
#if 0
if(blockIdx.x == 0)
{
printf("tid=%d, wave mn id=[%d, %d], mn id=[%d, %d]\n",
static_cast<int>(threadIdx.x),
wave_m_n_id[I0],
wave_m_n_id[I1],
m_n_id[I0],
m_n_id[I1]);
}
if(blockIdx.x == 0 && threadIdx.x == 0)
{
printf("%d, %d, %d, %d, %d, %d, %d, %d\n",
static_cast<int>(m0),
static_cast<int>(n0),
static_cast<int>(m1),
static_cast<int>(n1),
static_cast<int>(m2),
static_cast<int>(n2),
static_cast<int>(n3),
static_cast<int>(n4));
}
#endif
// acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc_thread_desc_k0_m_k1
// n0_n1_n2_n3 -> k0
// m0_m1_m2 -> m
......@@ -659,17 +595,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
A1ThreadSlice_K0_M_K1,
make_tuple(A1ThreadSliceM * A1ThreadSliceK1, A1ThreadSliceK1, I1));
#if 0
if(threadIdx.x == 0)
{
printf("bid=%d, A1ThreadSliceK0=%d, A1ThreadSliceM=%d, A1ThreadSliceK1=%d\n",
static_cast<int>(blockIdx.x),
static_cast<int>(A1ThreadSliceK0),
static_cast<int>(A1ThreadSliceM),
static_cast<int>(A1ThreadSliceK1));
}
#endif
// B1 matrix in LDS memory, dst of blockwise copy
constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
......@@ -873,7 +798,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const index_t nstartxdl = nstart + n0_i * NPerRepeat;
const index_t acc_idx_n0 = acc_idx_m0 + n0_i * n2 * n4;
static_for<0, n2, 1>{}([&](auto n2_i) {
const index_t nstartgroup = nstartxdl + thread_n_cluster_id * n4 + n2_i * (warpSize / MPerXdl) * n4;
const index_t nstartgroup = nstartxdl + thread_n_cluster_id * n4 + n2_i * AccN3 * n4;
const index_t acc_idx_n2 = acc_idx_n0 + n2_i * n4;
static_for<0, n4, 1>{}([&](auto n4_i) {
const index_t n_global = nstartgroup + n4_i;
......@@ -897,32 +822,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
});
});
static_for<0, MXdlPerWave, 1>{}([&](auto i_m0){
static_for<0, NXdlPerWave, 1>{}([&](auto i_n0){
static_for<0, n2, 1>{}([&](auto i_n2){
static_for<0, n4, 1>{}([&](auto i_n4){
auto global_m_idx = m_block_data_idx_on_grid + m_n_id[I1] + i_m0 * Gemm0MWaves * MPerXdl;
auto global_n_idx = gemm1_k_block_outer_index * NPerBlock + m_n_id[I0] + i_n0 * Gemm0NWaves * NPerXdl + i_n4 + i_n2 * n4 * (warpSize / MPerXdl);
#if 0
if(blockIdx.x == 0 && i_m0 == 0 && i_n0 == 0)
{
printf("tid=%d, global_mn_idx=[%d, %d]\n",
static_cast<int>(threadIdx.x),
global_m_idx,
global_n_idx);
}
#endif
if(global_n_idx > global_m_idx)
{
acc_thread_buf(i_m0 * n0 * n2 * n4 + i_n0 * n2 * n4 + i_n2 * n4 + i_n4) = -ck::NumericLimits<float>::Infinity();
}
});
});
});
});
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
// softmax
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment