Commit e1a0fb94 authored by Jing Zhang's avatar Jing Zhang
Browse files

pack half4_t

parent 3bbd5988
......@@ -9,8 +9,7 @@
namespace ck {
template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatAB,
class ABlockDesc,
class BBlockDesc,
index_t MPerWave,
......@@ -34,7 +33,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
static constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatA, MPerWave, NPerWave, KPack>{};
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, KPack>{};
static constexpr index_t MWaves = M1 / MPerWave;
static constexpr index_t NWaves = N1 / NPerWave;
......@@ -141,12 +140,16 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf =
make_static_buffer<AddressSpace::Vgpr, FloatA>(a_thread_desc_.GetElementSpaceSize());
make_static_buffer<AddressSpace::Vgpr, FloatAB>(a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf =
make_static_buffer<AddressSpace::Vgpr, FloatB>(b_thread_desc_.GetElementSpaceSize());
make_static_buffer<AddressSpace::Vgpr, FloatAB>(b_thread_desc_.GetElementSpaceSize());
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
vector_type<FloatAB, a_thread_desc_.GetElementSpaceSize()> a_thread_vec;
vector_type<FloatAB, b_thread_desc_.GetElementSpaceSize()> b_thread_vec;
static_for<0, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k) {
// read A
a_thread_copy_.Run(ABlockDesc{},
......@@ -164,13 +167,23 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
make_tuple(I0, I0, I0, I0),
b_thread_buf);
static_for<0, a_thread_desc_.GetElementSpaceSize(), 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(Number<i>{}) = a_thread_buf[Number<i>{}];
});
static_for<0, b_thread_desc_.GetElementSpaceSize(), 1>{}([&](auto i) {
b_thread_vec.template AsType<FloatAB>()(Number<i>{}) = b_thread_buf[Number<i>{}];
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
m0,
n0>(a_thread_buf, b_thread_buf, c_thread_buf);
n0>(a_thread_vec.template AsType<half4_t>(),
b_thread_vec.template AsType<half4_t>(),
c_thread_buf);
});
});
});
......@@ -188,8 +201,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
FloatA,
using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB,
FloatAB,
ABlockDesc,
decltype(a_thread_desc_),
Sequence<1, MRepeat, 1, KPack>,
......@@ -198,8 +211,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
1, // KPack,
1>;
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatB,
FloatB,
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB,
FloatAB,
BBlockDesc,
decltype(b_thread_desc_),
Sequence<1, NRepeat, 1, KPack>,
......@@ -213,8 +226,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
};
template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatAB,
class ABlockDesc,
class BBlockDesc,
index_t MPerWave,
......@@ -345,9 +357,9 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf =
make_static_buffer<AddressSpace::Vgpr, FloatA>(a_thread_desc_.GetElementSpaceSize());
make_static_buffer<AddressSpace::Vgpr, FloatAB>(a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf =
make_static_buffer<AddressSpace::Vgpr, FloatB>(b_thread_desc_.GetElementSpaceSize());
make_static_buffer<AddressSpace::Vgpr, FloatAB>(b_thread_desc_.GetElementSpaceSize());
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
......@@ -486,8 +498,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
FloatA,
using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB,
FloatAB,
ABlockDesc,
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPack>,
......@@ -496,8 +508,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
1, // KPack,
1>;
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatB,
FloatB,
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB,
FloatAB,
BBlockDesc,
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPack>,
......
......@@ -307,7 +307,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
const auto blockwise_gemm =
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize,
FloatAB,
FloatAB,
decltype(a_k0_m0_m1_k1_block_desc),
decltype(b_k0_n0_n1_k1_block_desc),
......
......@@ -759,7 +759,7 @@ struct XdlopsGemm
constexpr index_t c_offset = CDesc{}.CalculateOffset(make_tuple(m0, n0)) * GetNumXdlops();
static_for<0, KPack, mfma_type.k_base>{}([&](auto k) {
static_for<0, KPack / mfma_type.k_base, 1>{}([&](auto k) {
constexpr index_t a_offset = ADesc{}.CalculateOffset(make_tuple(0, m0, 0, k));
constexpr index_t b_offset = BDesc{}.CalculateOffset(make_tuple(0, n0, 0, k));
......
......@@ -646,7 +646,7 @@ int main(int argc, char* argv[])
print_array("ConvStrides", to_multi_index(ConvStrides{}));
print_array("ConvDilations", to_multi_index(ConvDilations{}));
#if 1
#if 0
using in_data_t = float;
constexpr index_t in_vector_size = 1;
using acc_data_t = float;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment