Commit 6e3b47c3 authored by ltqin's avatar ltqin
Browse files

init version, 4 buffer

parent 83e6a4b9
...@@ -117,12 +117,12 @@ int main(int argc, char* argv[]) ...@@ -117,12 +117,12 @@ int main(int argc, char* argv[])
ck::index_t StrideC = N; ck::index_t StrideC = N;
#else #else
ck::index_t M = 16; ck::index_t M = 16;
ck::index_t N = 16; ck::index_t N = 32;
ck::index_t K = 32; ck::index_t K = 16;
ck::index_t StrideA = 8; ck::index_t StrideA = K;
ck::index_t StrideB = 8; ck::index_t StrideB = K;
ck::index_t StrideC = 16; ck::index_t StrideC = N;
#endif #endif
if(argc == 4) if(argc == 4)
...@@ -189,7 +189,7 @@ int main(int argc, char* argv[]) ...@@ -189,7 +189,7 @@ int main(int argc, char* argv[])
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break; break;
default: default:
// a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1}); //a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_1<ADataType>{1}); b_k_n.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
} }
......
...@@ -32,9 +32,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1 ...@@ -32,9 +32,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
static constexpr index_t KPerBlock = K0PerBlock * KPack; static constexpr index_t KPerBlock = K0PerBlock * KPack;
static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack>{}; static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
...@@ -219,46 +216,13 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1 ...@@ -219,46 +216,13 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
c_grid_desc_g_m0_n0_m1_n1_m2_n2); c_grid_desc_g_m0_n0_m1_n1_m2_n2);
} }
__host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K()
{
return transform_tensor_descriptor(
AK0MK1BlockDesc{},
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(Number<A_K0>{}, Number<A_K1>{})),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerXDL>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
__device__ void MoveABlockSliceWindow()
{
a_thread_copy_.MoveSrcSliceWindow(a_block_desc_m0_m1_m2_k,
make_multi_index(0, 0, 0, K0PerBlock * KPack));
}
__device__ void ResetABlockStartWindow()
{
a_thread_copy_.SetSrcCoord(CalculateAThreadOriginDataIndex());
}
static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K();
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf, __device__ void Run(const ABlockBuffer& a_thread_buf,
const BBlockBuffer& b_thread_buf, const BBlockBuffer& b_thread_buf,
CThreadBuffer& c_thread_buf) const CThreadBuffer& c_thread_buf) const
{ {
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
a_thread_desc_.GetElementSpaceSize());
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A // read A
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B // read B
...@@ -268,7 +232,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1 ...@@ -268,7 +232,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
constexpr index_t k0 = k / KPack; constexpr index_t k0 = k / KPack;
static_for<0, KPack, 1>{}([&](auto i) { static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}]; [Number<a_thread_desc_.CalculateOffset(make_tuple(k0, m0, i))>{}];
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(k0, n0, i))>{}]; [Number<b_thread_desc_.CalculateOffset(make_tuple(k0, n0, i))>{}];
}); });
...@@ -291,7 +255,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1 ...@@ -291,7 +255,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
private: private:
// A[M0, M1, M2, KPerThread] // A[M0, M1, M2, KPerThread]
static constexpr auto a_thread_desc_ = static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<K0PerThread>{}, // KPerThread
Number<MRepeat>{}, // repeat
Number<KPack>{}));
// B[N0, N1, N2, KPerThread] // B[N0, N1, N2, KPerThread]
static constexpr auto b_thread_desc_ = static constexpr auto b_thread_desc_ =
...@@ -302,18 +268,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1 ...@@ -302,18 +268,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
// C[M, N, NumRegXdlops] // C[M, N, NumRegXdlops]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops())); make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPerThread>,
Sequence<0, 1, 2, 3>,
3,
A_K1,
A_K1>;
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
}; };
} // namespace ck } // namespace ck
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment