Commit b1dd76f3 authored by ltqin's avatar ltqin
Browse files

regular code

parent e8a71150
......@@ -118,11 +118,11 @@ int main(int argc, char* argv[])
#else
ck::index_t M = 16;
ck::index_t N = 16;
ck::index_t K = 32;
ck::index_t K = 8;
ck::index_t StrideA = 8;
ck::index_t StrideB = 8;
ck::index_t StrideC = 16;
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideC = N;
#endif
if(argc == 4)
......
......@@ -241,16 +241,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
a_thread_copy_.SetSrcCoord(CalculateAThreadOriginDataIndex());
}
static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K();
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_thread_buf,
CThreadBuffer& c_thread_buf) const
template <typename ABlockBuffer, typename AThreadBuffer>
__device__ void ReadAThreadData(const ABlockBuffer& a_block_buf, AThreadBuffer& a_thread_buf)
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
a_thread_desc_.GetElementSpaceSize());
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
......@@ -258,8 +251,35 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
a_thread_buf(Number<m0>{}));
});
}
__host__ __device__ static auto AlloCAThreadBuff()
{
return generate_tuple(
[&](auto i) {
ignore = i;
return StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAB,
a_thread_desc_.GetElementSpaceSize(),
true>{};
},
Number<MRepeat>{});
}
static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K();
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_thread_buf,
const BBlockBuffer& b_thread_buf,
CThreadBuffer& c_thread_buf) const
{
// auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
// a_thread_desc_.GetElementSpaceSize());
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
static_for<0, KPerThread, KPack>{}([&](auto k) {
......@@ -267,8 +287,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
vector_type<FloatAB, KPack> b_thread_vec;
constexpr index_t k0 = k / KPack;
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
a_thread_vec.template AsType<FloatAB>()(i) =
a_thread_buf[Number<m0>{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(0, 0, 0, k + i))>{}];
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(k0, n0, i))>{}];
});
......
......@@ -449,7 +449,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetElementSpaceSize(),
true>{};
},
Number<8>{});
Number<BaseMultK0>{});
const auto wave_id = GetWaveIdx();
const auto wave_k_n_id = GetWaveKNIdx(wave_id[I2]);
......@@ -516,6 +516,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
NXdlPerWave,
K1>{};
auto a_thread_buf = generate_tuple(
[&](auto i) {
ignore = i;
return blockwise_gemm.AlloCAThreadBuff();
},
Number<BaseMultK0/2>{});
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A
......@@ -531,38 +538,35 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<0>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<1>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<2>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<3>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
auto read_b_first_half_data = [&]() {
static_for<0, MultiK0 / 2, 1>{}([&](auto ii) {
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<ii>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
});
};
auto read_b_last_half_data = [&]() {
static_for<MultiK0 / 2, MultiK0, 1>{}([&](auto ii) {
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<ii>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
});
};
auto read_a_lds_data = [&]() {
static_for<0, MultiK0 / 2, 1>{}([&](auto ii) {
blockwise_gemm.ReadAThreadData(a_block_buf, a_thread_buf(Number<ii>{}));
blockwise_gemm.MoveABlockSliceWindow();
});
};
read_b_first_half_data();
// Initialize C
c_thread_buf.Clear();
// a data write to lds
......@@ -580,91 +584,27 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
block_sync_lds();
static_for<0, MultiK0, BaseMultK0>{}([&](auto) {
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<4>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<5>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<6>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<7>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
s_nop();
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<0>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<1>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<2>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<3>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<0>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<1>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<2>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<3>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
s_nop();
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<4>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<5>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<6>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<7>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
read_a_lds_data();
read_b_last_half_data();
s_barrier();
static_for<0, MultiK0 / 2, 1>{}([&](auto ii) {
blockwise_gemm.Run(a_thread_buf(Number<ii>{}),
b_thread_buf(Number<ii>{}),
c_thread_buf);
});
read_a_lds_data();
read_b_first_half_data();
s_barrier();
static_for<MultiK0 / 2, MultiK0, 1>{}([&](auto ii) {
blockwise_gemm.Run(a_thread_buf(Number<ii - 4>{}),
b_thread_buf(Number<ii>{}),
c_thread_buf);
});
});
block_sync_lds();
......@@ -683,94 +623,30 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
blockwise_gemm.ResetABlockStartWindow();
static_for<0, MultiK0, BaseMultK0>{}([&](auto i) {
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<4>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<5>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<6>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<7>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
s_nop();
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<0>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
read_a_lds_data();
read_b_last_half_data();
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<1>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
s_barrier();
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<2>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
static_for<0, MultiK0 / 2, 1>{}([&](auto ii) {
blockwise_gemm.Run(
a_thread_buf(Number<ii>{}), b_thread_buf(Number<ii>{}), c_thread_buf);
});
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<3>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
read_a_lds_data();
if constexpr(i < MultiK0 - BaseMultK0)
{
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<0>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<1>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<2>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<3>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
read_b_first_half_data();
}
s_nop();
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<4>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<5>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<6>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
s_barrier();
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<7>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
static_for<MultiK0 / 2, MultiK0, 1>{}([&](auto ii) {
blockwise_gemm.Run(a_thread_buf(Number<ii - 4>{}),
b_thread_buf(Number<ii>{}),
c_thread_buf);
});
});
}
}
......
......@@ -23,5 +23,12 @@ __device__ void s_nop()
" ::);
}
__device__ void s_barrier()
{
asm volatile("\
s_barrier \
" ::);
}
} // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment