Commit 53963d64 authored by ltqin's avatar ltqin
Browse files

add double register

parent cf360b72
......@@ -86,10 +86,10 @@ int main(int argc, char* argv[])
#if NORMAL_CONFIG
ck::index_t M = 256;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t K = 64;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideA = 64;
ck::index_t StrideB = 64;
ck::index_t StrideC = 4096;
#else
ck::index_t M = 16;
......
......@@ -235,20 +235,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
__host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K()
{
return transform_tensor_descriptor(
BK0NK1BlockDesc{},
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(Number<B_K0>{}, Number<B_K1>{})),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerXDL>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K();
static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K();
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
......@@ -278,9 +265,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
b_thread_vec.template AsType<FloatAB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(0, 0, k / KPack, 0, n0, 0, 0, i))>{}];
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(n0, k + i))>{}];
});
using mfma_input_type =
......@@ -304,16 +290,18 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
// B[N0, N1, N2, KPerThread]
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1,
I1,
Number<K0PerThread>{}, // KPerThread
I1, // NBlockId
static constexpr auto b_thread_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<K0PerThread>{}, // KPerThread
Number<NRepeat>{}, // repeat
I1, // waves
I1, // NPerXdlops
Number<KPack>{}));
static constexpr auto b_thread_desc_ = transform_tensor_descriptor(
b_thread_desc,
make_tuple(make_pass_through_transform(NRepeat),
make_merge_transform_v3_division_mod(make_tuple(K0PerThread, KPack))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// C[M, N, NumRegXdlops]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
......
......@@ -576,11 +576,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
I1, // waves
I1, // NPerXdlops
Number<K1>{}));
auto b_thread_buf =
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAB,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetElementSpaceSize(),
true>{};
true>
b_thread_even_buf, b_thread_odd_buf;
const auto wave_id = GetWaveIdx();
const auto wave_k_n_id = GetWaveKNIdx(wave_id[I2]);
......@@ -662,7 +663,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf);
b_thread_even_buf);
// Move
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step);
......@@ -684,34 +685,53 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_thread_buf, c_thread_buf);
// read b after gemm
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf);
b_thread_odd_buf);
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
// move windows
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1,
a_block_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
block_sync_lds();
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_even_buf);
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1,
a_block_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
++i;
i += 2;
} while(i < (K0BlockMainLoop - 1));
}
// tail
{
block_sync_lds();
// block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_thread_buf, c_thread_buf);
// blockwise_gemm.Run(a_block_buf, b_thread_buf, c_thread_buf);
}
}
#else
ignore = b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3;
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment