"...composable_kernel_rocm.git" did not exist on "0023f01ab02b9cc05a98ae1a7753df1481252e4d"
Commit 870a2482 authored by wangshaojie6's avatar wangshaojie6
Browse files

functionality right with lower triangle mask

parent 506c8eb3
...@@ -97,6 +97,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -97,6 +97,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{}; static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{}; static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{}; static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave);
static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave);
// Gemm1 // Gemm1
static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{}; static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{};
static constexpr auto B1K1 = Number<B1K1Value>{}; static constexpr auto B1K1 = Number<B1K1Value>{};
...@@ -262,6 +266,39 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -262,6 +266,39 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
} }
__device__ static auto GetGemm0WaveIdx()
{
const index_t thread_id = get_thread_local_1d_id();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(Gemm0MWaves, Gemm0NWaves, warpSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto GetGemm0WaveMNIdx(const index_t thread_id)
{
constexpr auto wave_threadid_to_mn_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(warpSize / MPerXdl, MPerXdl))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return wave_threadid_to_mn_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto GetGemm0MNIdx(const index_t n4)
{
auto waveIdx = GetGemm0WaveIdx();
auto waveMNIdx = GetGemm0WaveMNIdx(waveIdx[I2]);
auto MIdx = waveIdx[I0] * MPerXdl + waveMNIdx[I1];
auto NIdx = waveIdx[I1] * NPerXdl + waveMNIdx[I0] * n4;
return make_tuple(NIdx, MIdx);
}
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n)
{ {
...@@ -563,6 +600,37 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -563,6 +600,37 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr auto b1_block_slice_copy_step = make_multi_index(Gemm1KPerBlock / B1K1, 0, 0); constexpr auto b1_block_slice_copy_step = make_multi_index(Gemm1KPerBlock / B1K1, 0, 0);
// get m/n id
//const auto wave_id = GetGemm0WaveIdx();
//const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
const auto m_n_id = GetGemm0MNIdx(n4);
#if 0
if(blockIdx.x == 0)
{
printf("tid=%d, wave mn id=[%d, %d], mn id=[%d, %d]\n",
static_cast<int>(threadIdx.x),
wave_m_n_id[I0],
wave_m_n_id[I1],
m_n_id[I0],
m_n_id[I1]);
}
if(blockIdx.x == 0 && threadIdx.x == 0)
{
printf("%d, %d, %d, %d, %d, %d, %d, %d\n",
static_cast<int>(m0),
static_cast<int>(n0),
static_cast<int>(m1),
static_cast<int>(n1),
static_cast<int>(m2),
static_cast<int>(n2),
static_cast<int>(n3),
static_cast<int>(n4));
}
#endif
// acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc_thread_desc_k0_m_k1 // acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc_thread_desc_k0_m_k1
// n0_n1_n2_n3 -> k0 // n0_n1_n2_n3 -> k0
// m0_m1_m2 -> m // m0_m1_m2 -> m
...@@ -796,6 +864,32 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -796,6 +864,32 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}); });
#endif #endif
static_for<0, MXdlPerWave, 1>{}([&](auto i_m0){
static_for<0, NXdlPerWave, 1>{}([&](auto i_n0){
static_for<0, n2, 1>{}([&](auto i_n2){
static_for<0, n4, 1>{}([&](auto i_n4){
auto global_m_idx = m_block_data_idx_on_grid + m_n_id[I1] + i_m0 * Gemm0MWaves * MPerXdl;
auto global_n_idx = gemm1_k_block_outer_index * NPerBlock + m_n_id[I0] + i_n0 * Gemm0NWaves * NPerXdl + i_n4 + i_n2 * n4 * (warpSize / MPerXdl);
#if 0
if(blockIdx.x == 0 && i_m0 == 0 && i_n0 == 0)
{
printf("tid=%d, global_mn_idx=[%d, %d]\n",
static_cast<int>(threadIdx.x),
global_m_idx,
global_n_idx);
}
#endif
if(global_n_idx > global_m_idx)
{
acc_thread_buf(i_m0 * n0 * n2 * n4 + i_n0 * n2 * n4 + i_n2 * n4 + i_n4) = -ck::NumericLimits<float>::Infinity();
}
});
});
});
});
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
// softmax // softmax
......
...@@ -195,7 +195,7 @@ struct ReferenceBatchedGemmUpperTriangleMinusInf : public device::BaseOperator ...@@ -195,7 +195,7 @@ struct ReferenceBatchedGemmUpperTriangleMinusInf : public device::BaseOperator
AccDataType v_c; AccDataType v_c;
if(((n >> 7) << 7) <= ((m >> 7) << 7)) if(((n >> 0) << 0) <= ((m >> 0) << 0))
{ {
arg.c_element_op_(v_c, v_acc); arg.c_element_op_(v_c, v_acc);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment