Commit e1a0fb94 authored by Jing Zhang's avatar Jing Zhang
Browse files

pack half4_t

parent 3bbd5988
...@@ -9,8 +9,7 @@ ...@@ -9,8 +9,7 @@
namespace ck { namespace ck {
template <index_t BlockSize, template <index_t BlockSize,
typename FloatA, typename FloatAB,
typename FloatB,
class ABlockDesc, class ABlockDesc,
class BBlockDesc, class BBlockDesc,
index_t MPerWave, index_t MPerWave,
...@@ -34,7 +33,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -34,7 +33,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static constexpr index_t N0 = BBlockDesc{}.GetLength(I1); static constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
static constexpr index_t N1 = BBlockDesc{}.GetLength(I2); static constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatA, MPerWave, NPerWave, KPack>{}; static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, KPack>{};
static constexpr index_t MWaves = M1 / MPerWave; static constexpr index_t MWaves = M1 / MPerWave;
static constexpr index_t NWaves = N1 / NPerWave; static constexpr index_t NWaves = N1 / NPerWave;
...@@ -141,12 +140,16 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -141,12 +140,16 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
CThreadBuffer& c_thread_buf) const CThreadBuffer& c_thread_buf) const
{ {
auto a_thread_buf = auto a_thread_buf =
make_static_buffer<AddressSpace::Vgpr, FloatA>(a_thread_desc_.GetElementSpaceSize()); make_static_buffer<AddressSpace::Vgpr, FloatAB>(a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = auto b_thread_buf =
make_static_buffer<AddressSpace::Vgpr, FloatB>(b_thread_desc_.GetElementSpaceSize()); make_static_buffer<AddressSpace::Vgpr, FloatAB>(b_thread_desc_.GetElementSpaceSize());
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
vector_type<FloatAB, a_thread_desc_.GetElementSpaceSize()> a_thread_vec;
vector_type<FloatAB, b_thread_desc_.GetElementSpaceSize()> b_thread_vec;
static_for<0, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k) { static_for<0, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k) {
// read A // read A
a_thread_copy_.Run(ABlockDesc{}, a_thread_copy_.Run(ABlockDesc{},
...@@ -164,13 +167,23 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -164,13 +167,23 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_buf); b_thread_buf);
static_for<0, a_thread_desc_.GetElementSpaceSize(), 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(Number<i>{}) = a_thread_buf[Number<i>{}];
});
static_for<0, b_thread_desc_.GetElementSpaceSize(), 1>{}([&](auto i) {
b_thread_vec.template AsType<FloatAB>()(Number<i>{}) = b_thread_buf[Number<i>{}];
});
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
xdlops_gemm.template Run<decltype(a_thread_desc_), xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_), decltype(b_thread_desc_),
decltype(c_thread_desc_), decltype(c_thread_desc_),
m0, m0,
n0>(a_thread_buf, b_thread_buf, c_thread_buf); n0>(a_thread_vec.template AsType<half4_t>(),
b_thread_vec.template AsType<half4_t>(),
c_thread_buf);
}); });
}); });
}); });
...@@ -188,8 +201,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -188,8 +201,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{})); make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatA, using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB,
FloatA, FloatAB,
ABlockDesc, ABlockDesc,
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<1, MRepeat, 1, KPack>, Sequence<1, MRepeat, 1, KPack>,
...@@ -198,8 +211,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -198,8 +211,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
1, // KPack, 1, // KPack,
1>; 1>;
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatB, using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB,
FloatB, FloatAB,
BBlockDesc, BBlockDesc,
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<1, NRepeat, 1, KPack>, Sequence<1, NRepeat, 1, KPack>,
...@@ -213,8 +226,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -213,8 +226,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
}; };
template <index_t BlockSize, template <index_t BlockSize,
typename FloatA, typename FloatAB,
typename FloatB,
class ABlockDesc, class ABlockDesc,
class BBlockDesc, class BBlockDesc,
index_t MPerWave, index_t MPerWave,
...@@ -345,9 +357,9 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -345,9 +357,9 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
CThreadBuffer& c_thread_buf) const CThreadBuffer& c_thread_buf) const
{ {
auto a_thread_buf = auto a_thread_buf =
make_static_buffer<AddressSpace::Vgpr, FloatA>(a_thread_desc_.GetElementSpaceSize()); make_static_buffer<AddressSpace::Vgpr, FloatAB>(a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = auto b_thread_buf =
make_static_buffer<AddressSpace::Vgpr, FloatB>(b_thread_desc_.GetElementSpaceSize()); make_static_buffer<AddressSpace::Vgpr, FloatAB>(b_thread_desc_.GetElementSpaceSize());
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
...@@ -486,8 +498,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -486,8 +498,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{})); make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatA, using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB,
FloatA, FloatAB,
ABlockDesc, ABlockDesc,
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<1, 1, 1, KPack>, Sequence<1, 1, 1, KPack>,
...@@ -496,8 +508,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -496,8 +508,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
1, // KPack, 1, // KPack,
1>; 1>;
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatB, using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB,
FloatB, FloatAB,
BBlockDesc, BBlockDesc,
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<1, 1, 1, KPack>, Sequence<1, 1, 1, KPack>,
......
...@@ -307,7 +307,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -307,7 +307,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize, BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize,
FloatAB,
FloatAB, FloatAB,
decltype(a_k0_m0_m1_k1_block_desc), decltype(a_k0_m0_m1_k1_block_desc),
decltype(b_k0_n0_n1_k1_block_desc), decltype(b_k0_n0_n1_k1_block_desc),
......
...@@ -759,7 +759,7 @@ struct XdlopsGemm ...@@ -759,7 +759,7 @@ struct XdlopsGemm
constexpr index_t c_offset = CDesc{}.CalculateOffset(make_tuple(m0, n0)) * GetNumXdlops(); constexpr index_t c_offset = CDesc{}.CalculateOffset(make_tuple(m0, n0)) * GetNumXdlops();
static_for<0, KPack, mfma_type.k_base>{}([&](auto k) { static_for<0, KPack / mfma_type.k_base, 1>{}([&](auto k) {
constexpr index_t a_offset = ADesc{}.CalculateOffset(make_tuple(0, m0, 0, k)); constexpr index_t a_offset = ADesc{}.CalculateOffset(make_tuple(0, m0, 0, k));
constexpr index_t b_offset = BDesc{}.CalculateOffset(make_tuple(0, n0, 0, k)); constexpr index_t b_offset = BDesc{}.CalculateOffset(make_tuple(0, n0, 0, k));
......
...@@ -646,7 +646,7 @@ int main(int argc, char* argv[]) ...@@ -646,7 +646,7 @@ int main(int argc, char* argv[])
print_array("ConvStrides", to_multi_index(ConvStrides{})); print_array("ConvStrides", to_multi_index(ConvStrides{}));
print_array("ConvDilations", to_multi_index(ConvDilations{})); print_array("ConvDilations", to_multi_index(ConvDilations{}));
#if 1 #if 0
using in_data_t = float; using in_data_t = float;
constexpr index_t in_vector_size = 1; constexpr index_t in_vector_size = 1;
using acc_data_t = float; using acc_data_t = float;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment