Commit 506c8eb3 authored by wangshaojie6's avatar wangshaojie6
Browse files

init code for tile skipping

parent 3f9100cc
...@@ -164,8 +164,8 @@ int main(int argc, char* argv[]) ...@@ -164,8 +164,8 @@ int main(int argc, char* argv[])
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape // Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) // C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3]) // C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t G0 = 7; ck::index_t G0 = 2;
ck::index_t G1 = 13; ck::index_t G1 = 3;
if(argc == 1) if(argc == 1)
{ {
......
...@@ -591,6 +591,17 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -591,6 +591,17 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
A1ThreadSlice_K0_M_K1, A1ThreadSlice_K0_M_K1,
make_tuple(A1ThreadSliceM * A1ThreadSliceK1, A1ThreadSliceK1, I1)); make_tuple(A1ThreadSliceM * A1ThreadSliceK1, A1ThreadSliceK1, I1));
#if 0
if(threadIdx.x == 0)
{
printf("bid=%d, A1ThreadSliceK0=%d, A1ThreadSliceM=%d, A1ThreadSliceK1=%d\n",
static_cast<int>(blockIdx.x),
static_cast<int>(A1ThreadSliceK0),
static_cast<int>(A1ThreadSliceM),
static_cast<int>(A1ThreadSliceK1));
}
#endif
// B1 matrix in LDS memory, dst of blockwise copy // B1 matrix in LDS memory, dst of blockwise copy
constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
...@@ -753,6 +764,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -753,6 +764,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
index_t gemm1_k_block_outer_index = 0; index_t gemm1_k_block_outer_index = 0;
do do
{ {
if((m_block_data_idx_on_grid < gemm1_k_block_outer_index * NPerBlock) && ((m_block_data_idx_on_grid + MPerBlock - 1) < (gemm1_k_block_outer_index * NPerBlock + NPerBlock - 1)))
{
continue;
}
// gemm0 // gemm0
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1, gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1,
......
...@@ -195,7 +195,7 @@ struct ReferenceBatchedGemmUpperTriangleMinusInf : public device::BaseOperator ...@@ -195,7 +195,7 @@ struct ReferenceBatchedGemmUpperTriangleMinusInf : public device::BaseOperator
AccDataType v_c; AccDataType v_c;
if(n <= m) if(((n >> 7) << 7) <= ((m >> 7) << 7))
{ {
arg.c_element_op_(v_c, v_acc); arg.c_element_op_(v_c, v_acc);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment