"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "ebc3edc6da23c99806882a3551c0758869484256"
Commit d3341a67 authored by Jing Zhang's avatar Jing Zhang
Browse files

xdlops refactor

parent b62bf8c3
...@@ -9,82 +9,99 @@ namespace ck { ...@@ -9,82 +9,99 @@ namespace ck {
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatAB,
class ABlockDesc, typename AK0MK1BlockDesc,
class BBlockDesc, typename BK0NK1BlockDesc,
index_t MPerWave, index_t MPerXDL,
index_t NPerWave, index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t K1> index_t K1>
struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{ {
using CIndex = MultiIndex<2>;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
using CIndex = MultiIndex<2>;
static constexpr index_t WaveSize = 64; static constexpr index_t WaveSize = 64;
static constexpr index_t M0 = ABlockDesc{}.GetLength(I1); static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
static constexpr index_t M1 = ABlockDesc{}.GetLength(I2); static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
static constexpr index_t N0 = BBlockDesc{}.GetLength(I1); static constexpr index_t K0 = BK0NK1BlockDesc{}.GetLength(I0);
static constexpr index_t N1 = BBlockDesc{}.GetLength(I2); static constexpr index_t KPerBlock = K0;
static constexpr index_t KPack = K1;
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, K1>{}; static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, K1>{};
static constexpr index_t MWaves = M1 / MPerWave; static constexpr auto CXdlopsLayout = xdlops_gemm.GetCXdlopsLayout();
static constexpr index_t NWaves = N1 / NPerWave;
static constexpr index_t MRepeat = M0; static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NRepeat = N0; static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
__device__ constexpr auto GetCLayout() const { return xdlops_gemm.GetCLayout(); } __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2ThreadDesc()
{
constexpr auto M0 = Number<CXdlopsLayout.M1()>{};
constexpr auto M2 = Number<CXdlopsLayout.M0()>{};
__device__ constexpr auto GetNumBlks() const { return xdlops_gemm.GetCLayout().GetNumBlks(); } return make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, I1, M0, I1, M2, I1));
}
__device__ constexpr auto GetBlkSize() const { return xdlops_gemm.GetCLayout().GetBlkSize(); } __device__ static auto GetWaveIdx()
{
const index_t thread_id = get_thread_local_1d_id();
const auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto CalculateAThreadOriginDataIndex() __device__ static auto CalculateAThreadOriginDataIndex()
{ {
const index_t thread_id = get_thread_local_1d_id(); const auto wave_idx = GetWaveIdx();
const index_t waveId = thread_id / WaveSize;
const index_t laneId = thread_id % WaveSize; const auto waveId_m = wave_idx[I0];
const index_t waveId_m = waveId / NWaves; const auto laneId = wave_idx[I2];
const auto blk_idx = xdlops_gemm.GetBlkIdx();
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
if constexpr(xdlops_gemm.IsKReduction) if constexpr(xdlops_gemm.IsKReduction)
{ {
const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId); return make_tuple(blk_id, 0, waveId_m, blk_td, 0);
const index_t k_offset = xdlops_gemm.GetBlkId(laneId);
return make_tuple(k_offset, 0, m_offset, 0);
} }
else else
{ {
const index_t m_offset = waveId_m * MPerWave + laneId; return make_tuple(0, 0, waveId_m, laneId, 0);
const index_t k_offset = 0;
return make_tuple(k_offset, 0, m_offset, 0);
} }
} }
__device__ static auto CalculateBThreadOriginDataIndex() __device__ static auto CalculateBThreadOriginDataIndex()
{ {
const index_t thread_id = get_thread_local_1d_id(); const auto wave_idx = GetWaveIdx();
const index_t waveId = thread_id / WaveSize;
const index_t laneId = thread_id % WaveSize; const auto waveId_n = wave_idx[I1];
const index_t waveId_n = waveId % NWaves; const auto laneId = wave_idx[I2];
const auto blk_idx = xdlops_gemm.GetBlkIdx();
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
if constexpr(xdlops_gemm.IsKReduction) if constexpr(xdlops_gemm.IsKReduction)
{ {
const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId); return make_tuple(blk_id, 0, waveId_n, blk_td, 0);
const index_t k_offset = xdlops_gemm.GetBlkId(laneId);
return make_tuple(k_offset, 0, n_offset, 0);
} }
else else
{ {
const index_t n_offset = waveId_n * NPerWave + laneId; return make_tuple(0, 0, waveId_n, laneId, 0);
const index_t k_offset = 0;
return make_tuple(k_offset, 0, n_offset, 0);
} }
} }
...@@ -92,263 +109,117 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -92,263 +109,117 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
__device__ static CIndex __device__ static CIndex
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>) CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{ {
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const index_t waveId = get_thread_local_1d_id() / WaveSize; const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); constexpr auto mrepeat_mwave_mperxdl_to_m = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerXDL>{}));
const index_t waveId_m = waveId / NWaves; constexpr auto nrepeat_nwave_nperxdl_to_n = make_naive_tensor_descriptor_packed(
const index_t waveId_n = waveId % NWaves; make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerXDL>{}));
const index_t m_offset = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk[I0]; const index_t c_thread_m =
const index_t n_offset = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk[I1]; mrepeat_mwave_mperxdl_to_m.CalculateOffset(make_tuple(m0, waveId_m, blk_idx[I0]));
const index_t c_thread_n =
nrepeat_nwave_nperxdl_to_n.CalculateOffset(make_tuple(n0, waveId_n, blk_idx[I1]));
return CIndex{m_offset, n_offset}; return CIndex{c_thread_m, c_thread_n};
} }
__device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1() __host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1()
: a_thread_copy_{CalculateAThreadOriginDataIndex()},
b_thread_copy_{CalculateBThreadOriginDataIndex()}
{ {
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
BK0NK1BlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0), static_assert(AK0MK1BlockDesc{}.GetLength(I0) == BK0NK1BlockDesc{}.GetLength(I0),
"wrong! K dimension not consistent"); "wrong! K0 dimension not consistent");
static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3), static_assert(AK0MK1BlockDesc{}.GetLength(I2) == BK0NK1BlockDesc{}.GetLength(I2),
"wrong! K1 dimension not consistent"); "wrong! K1 dimension not consistent");
static_assert(BlockSize == MWaves * NWaves * WaveSize, static_assert(BlockSize == MWaves * NWaves * WaveSize,
"BlockSize != MWaves * NWaves * WaveSize\n"); "BlockSize != MWaves * NWaves * WaveSize\n");
static_assert(K1 == BBlockDesc{}.GetLength(I3), "K1 is wrong!");
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!"); static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!");
static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!"); static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!");
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize());
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
vector_type<FloatAB, a_thread_desc_.GetElementSpaceSize()> a_thread_vec;
vector_type<FloatAB, b_thread_desc_.GetElementSpaceSize()> b_thread_vec;
static_for<0, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k) {
// read A
a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
// read B static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
b_thread_copy_.Run(BBlockDesc{}, "wrong!");
make_tuple(k, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
using mfma_input_type = constexpr index_t NumBlks = CXdlopsLayout.GetNumBlks();
typename vector_type<FloatAB, xdlops_gemm.mfma_type.k_base>::type; constexpr index_t NumXdlops = CXdlopsLayout.GetNumXdlops();
static_for<0, a_thread_desc_.GetElementSpaceSize(), 1>{}([&](auto i) { static_assert(NumBlks == 1 && NumXdlops == 1, "K Reduction Mfma only");
a_thread_vec.template AsType<FloatAB>()(Number<i>{}) = a_thread_buf[Number<i>{}];
});
static_for<0, b_thread_desc_.GetElementSpaceSize(), 1>{}([&](auto i) {
b_thread_vec.template AsType<FloatAB>()(Number<i>{}) = b_thread_buf[Number<i>{}];
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
m0,
n0>(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf);
});
});
});
} }
private: __host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2BlockDescriptor()
// A[K, M]
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{}));
// B[K, N]
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{}));
static constexpr auto c_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
ABlockDesc,
decltype(a_thread_desc_),
Sequence<1, MRepeat, 1, K1>,
Sequence<0, 1, 2, 3>,
3,
K1,
1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
BBlockDesc,
decltype(b_thread_desc_),
Sequence<1, NRepeat, 1, K1>,
Sequence<0, 1, 2, 3>,
3,
K1,
1>;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
template <index_t BlockSize,
typename FloatAB,
class ABlockDesc,
class BBlockDesc,
index_t MPerWave,
index_t NPerWave,
index_t K1>
struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
{
using CIndex = MultiIndex<2>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto xdlops_gemm = XdlopsGemm<float, MPerWave, NPerWave, K1>{};
static constexpr index_t WaveSize = 64;
static constexpr index_t M0 = ABlockDesc{}.GetLength(I1);
static constexpr index_t M1 = ABlockDesc{}.GetLength(I2);
static constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
static constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
static constexpr index_t MWaves = M1 / MPerWave;
static constexpr index_t NWaves = N1 / NPerWave;
static constexpr index_t MRepeat = M0;
static constexpr index_t NRepeat = N0;
__device__ constexpr auto GetCLayout() const { return xdlops_gemm.GetCLayout(); }
__device__ constexpr auto GetNumBlks() const { return xdlops_gemm.GetCLayout().GetNumBlks(); }
__device__ constexpr auto GetBlkSize() const { return xdlops_gemm.GetCLayout().GetBlkSize(); }
__device__ static auto CalculateAThreadOriginDataIndex()
{ {
const index_t thread_id = get_thread_local_1d_id(); constexpr auto M2 = Number<CXdlopsLayout.M1()>{};
const index_t waveId = thread_id / WaveSize; constexpr auto M3 = Number<CXdlopsLayout.N1()>{};
const index_t laneId = thread_id % WaveSize; constexpr auto M4 = Number<CXdlopsLayout.M0()>{};
const index_t waveId_m = waveId / NWaves; constexpr auto N2 = Number<CXdlopsLayout.N0()>{};
if constexpr(xdlops_gemm.IsKReduction) return make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
{ Number<NRepeat>{},
const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId); Number<MWaves>{},
const index_t k_offset = xdlops_gemm.GetBlkId(laneId); Number<NWaves>{},
return make_tuple(k_offset, 0, m_offset, 0); Number<M2>{},
} Number<M3>{},
else Number<M4>{},
{ Number<N2>{}));
const index_t m_offset = waveId_m * MPerWave + laneId;
const index_t k_offset = 0;
return make_tuple(k_offset, 0, m_offset, 0);
}
} }
__device__ static auto CalculateBThreadOriginDataIndex() template <typename CMNGridDesc>
__host__ __device__ static constexpr auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
{ {
const index_t thread_id = get_thread_local_1d_id(); ///\To-do: pass CGrid desc transform deep inside xdlops gemm
const index_t waveId = thread_id / WaveSize; constexpr auto M2 = Number<CXdlopsLayout.M1()>{};
const index_t laneId = thread_id % WaveSize; constexpr auto M3 = Number<CXdlopsLayout.N1()>{};
const index_t waveId_n = waveId % NWaves; constexpr auto M4 = Number<CXdlopsLayout.M0()>{};
constexpr auto N2 = Number<CXdlopsLayout.N0()>{};
if constexpr(xdlops_gemm.IsKReduction)
{ return transform_tensor_descriptor(
const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId); c_m_n_grid_desc,
const index_t k_offset = xdlops_gemm.GetBlkId(laneId); make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, M2, M3, M4)),
return make_tuple(k_offset, 0, n_offset, 0); make_unmerge_transform(make_tuple(NRepeat, NWaves, N2))),
} make_tuple(Sequence<0>{}, Sequence<1>{}),
else make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
{
const index_t n_offset = waveId_n * NPerWave + laneId;
const index_t k_offset = 0;
return make_tuple(k_offset, 0, n_offset, 0);
}
} }
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i> __host__ __device__ static constexpr auto MakeAK0M0M1M2K1BlockDescriptor()
__device__ static CIndex
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{ {
return transform_tensor_descriptor(
const index_t waveId = get_thread_local_1d_id() / WaveSize; AK0MK1BlockDesc{},
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerXDL>{})),
const index_t waveId_m = waveId / NWaves; make_pass_through_transform(Number<K1>{})),
const index_t waveId_n = waveId % NWaves; make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
const index_t m_offset = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk[I0];
const index_t n_offset = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk[I1];
return CIndex{m_offset, n_offset};
} }
__device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline() __host__ __device__ static constexpr auto MakeBK0N0N1N2K1BlockDescriptor()
: a_thread_copy_{CalculateAThreadOriginDataIndex()},
b_thread_copy_{CalculateBThreadOriginDataIndex()}
{ {
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), return transform_tensor_descriptor(
"wrong! Desc should be known at compile-time"); BK0NK1BlockDesc{},
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0), make_unmerge_transform(
"wrong! K dimension not consistent"); make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerXDL>{})),
make_pass_through_transform(Number<K1>{})),
static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
"wrong! K1 dimension not consistent"); make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
static_assert(BlockSize == MWaves * NWaves * WaveSize,
"BlockSize != MWaves * NWaves * WaveSize\n");
static_assert(K1 == BBlockDesc{}.GetLength(I3), "K1 is wrong!");
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!");
static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!");
} }
static constexpr auto a_k0_m0_m1_m2_k1_block_desc = MakeAK0M0M1M2K1BlockDescriptor();
static constexpr auto b_k0_n0_n1_n2_k1_block_desc = MakeBK0N0N1N2K1BlockDescriptor();
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf, __device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
...@@ -359,165 +230,88 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -359,165 +230,88 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); vector_type<FloatAB, K1> a_thread_vec;
// read A_sub_0 vector_type<FloatAB, K1> b_thread_vec;
a_thread_copy_.Run(ABlockDesc{},
make_tuple(I0, I0, I0, I0), static_for<0, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k) {
a_block_buf, // read A
a_thread_desc_, a_thread_copy_.Run(a_k0_m0_m1_m2_k1_block_desc,
make_tuple(I0, I0, I0, I0), make_tuple(k, I0, I0, I0, I0),
a_thread_buf);
// read B_sub_0
b_thread_copy_.Run(BBlockDesc{},
make_tuple(I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
// read B_sub_1
b_thread_copy_.Run(BBlockDesc{},
make_tuple(I0, I1, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I1, I0, I0),
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(ABlockDesc{},
make_tuple(I0, I1, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I1, I0, I0),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
0,
0>(a_thread_buf, b_thread_buf, c_thread_buf);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
0,
1>(a_thread_buf, b_thread_buf, c_thread_buf);
static_for<xdlops_gemm.KPerXdlops, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k) {
// read A_sub_0
a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I0, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
a_thread_buf); a_thread_buf);
// C_sub_10 += transpose(A_sub_1) * B_sub_0 // read B
xdlops_gemm.template Run<decltype(a_thread_desc_), b_thread_copy_.Run(b_k0_n0_n1_n2_k1_block_desc,
decltype(b_thread_desc_), make_tuple(k, I0, I0, I0, I0),
decltype(c_thread_desc_),
1,
0>(a_thread_buf, b_thread_buf, c_thread_buf);
// read B_sub_0
b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
1,
1>(a_thread_buf, b_thread_buf, c_thread_buf);
// read B_sub_1
b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I1, I0, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(I0, I1, I0, I0), make_tuple(I0, I0, I0, I0, I0),
b_thread_buf); b_thread_buf);
// read A_sub_1 using mfma_input_type =
a_thread_copy_.Run(ABlockDesc{}, typename vector_type<FloatAB, xdlops_gemm.mfma_type.k_base>::type;
make_tuple(k, I1, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I1, I0, I0),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0 static_for<0, MRepeat, 1>{}([&](auto m0) {
xdlops_gemm.template Run<decltype(a_thread_desc_), static_for<0, NRepeat, 1>{}([&](auto n0) {
decltype(b_thread_desc_), static_for<0, K1, 1>{}([&](auto i) {
decltype(c_thread_desc_), a_thread_vec.template AsType<FloatAB>()(Number<i>{}) = a_thread_buf
0, [Number<a_thread_desc_.CalculateOffset(make_tuple(0, m0, 0, 0, i))>{}];
0>(a_thread_buf, b_thread_buf, c_thread_buf); });
// C_sub_01 += transpose(A_sub_0) * B_sub_1 static_for<0, K1, 1>{}([&](auto i) {
xdlops_gemm.template Run<decltype(a_thread_desc_), b_thread_vec.template AsType<FloatAB>()(Number<i>{}) = b_thread_buf
decltype(b_thread_desc_), [Number<b_thread_desc_.CalculateOffset(make_tuple(0, n0, 0, 0, i))>{}];
decltype(c_thread_desc_), });
0,
1>(a_thread_buf, b_thread_buf, c_thread_buf); constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run<c_offset>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf);
});
});
}); });
// C_sub_10 += transpose(A_sub_1) * B_sub_0
xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
1,
0>(a_thread_buf, b_thread_buf, c_thread_buf);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
1,
1>(a_thread_buf, b_thread_buf, c_thread_buf);
} }
private: private:
// A[K, M] // A[K, M]
static constexpr auto a_thread_desc_ = static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{})); make_tuple(I1, Number<MRepeat>{}, I1, I1, Number<K1>{}));
// B[K, N] // B[K, N]
static constexpr auto b_thread_desc_ = static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{})); make_tuple(I1, Number<NRepeat>{}, I1, I1, Number<K1>{}));
static constexpr auto c_thread_desc_ = static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{})); make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, Number<xdlops_gemm.GetNumXdlops()>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB, using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB, FloatAB,
ABlockDesc, decltype(a_k0_m0_m1_m2_k1_block_desc),
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<1, 1, 1, K1>, Sequence<1, MRepeat, 1, 1, K1>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3, 4>,
3, 4,
1, // K1, K1,
1>; 1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB, using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB, FloatAB,
BBlockDesc, decltype(b_k0_n0_n1_n2_k1_block_desc),
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<1, 1, 1, K1>, Sequence<1, NRepeat, 1, 1, K1>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3, 4>,
3, 4,
1, // K1, K1,
1>; 1>;
AThreadCopy a_thread_copy_; AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
BThreadCopy b_thread_copy_; BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
}; };
} // namespace ck } // namespace ck
......
...@@ -18,7 +18,7 @@ template <typename GridwiseGemm, ...@@ -18,7 +18,7 @@ template <typename GridwiseGemm,
typename FloatC, typename FloatC,
typename AK0MK1GridDesc, typename AK0MK1GridDesc,
typename BK0NK1GridDesc, typename BK0NK1GridDesc,
typename CM0M1M2NGridDesc, typename CM0N0M1N1M2M3M4N2GridDesc,
typename CBlockClusterAdaptor> typename CBlockClusterAdaptor>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
...@@ -29,7 +29,7 @@ __global__ void ...@@ -29,7 +29,7 @@ __global__ void
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AK0MK1GridDesc a_k0_m_k1_grid_desc, const AK0MK1GridDesc a_k0_m_k1_grid_desc,
const BK0NK1GridDesc b_k0_n_k1_grid_desc, const BK0NK1GridDesc b_k0_n_k1_grid_desc,
const CM0M1M2NGridDesc c_m0_m1_m2_n_grid_desc, const CM0N0M1N1M2M3M4N2GridDesc c_m0_m1_m2_n_grid_desc,
const CBlockClusterAdaptor c_block_cluster_adaptor) const CBlockClusterAdaptor c_block_cluster_adaptor)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
...@@ -43,7 +43,7 @@ __global__ void ...@@ -43,7 +43,7 @@ __global__ void
p_shared_block, p_shared_block,
a_k0_m_k1_grid_desc, a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc, b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor); c_block_cluster_adaptor);
} }
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
...@@ -52,7 +52,7 @@ template <typename GridwiseGemm, ...@@ -52,7 +52,7 @@ template <typename GridwiseGemm,
typename FloatC, typename FloatC,
typename AK0MK1GridDesc, typename AK0MK1GridDesc,
typename BK0NK1GridDesc, typename BK0NK1GridDesc,
typename CM0M1M2NGridDesc, typename CM0N0M1N1M2M3M4N2GridDesc,
typename CBlockClusterAdaptor> typename CBlockClusterAdaptor>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
...@@ -63,7 +63,7 @@ __global__ void ...@@ -63,7 +63,7 @@ __global__ void
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_k0_m_k1_grid_desc, const void CONSTANT* p_a_k0_m_k1_grid_desc,
const void CONSTANT* p_b_k0_n_k1_grid_desc, const void CONSTANT* p_b_k0_n_k1_grid_desc,
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc, const void CONSTANT* p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
const void CONSTANT* p_c_block_cluster_adaptor) const void CONSTANT* p_c_block_cluster_adaptor)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
...@@ -73,8 +73,9 @@ __global__ void ...@@ -73,8 +73,9 @@ __global__ void
cast_pointer_to_generic_address_space(p_a_k0_m_k1_grid_desc)); cast_pointer_to_generic_address_space(p_a_k0_m_k1_grid_desc));
const auto b_k0_n_k1_grid_desc = *reinterpret_cast<const BK0NK1GridDesc*>( const auto b_k0_n_k1_grid_desc = *reinterpret_cast<const BK0NK1GridDesc*>(
cast_pointer_to_generic_address_space(p_b_k0_n_k1_grid_desc)); cast_pointer_to_generic_address_space(p_b_k0_n_k1_grid_desc));
const auto c_m0_m1_m2_n_grid_desc = *reinterpret_cast<const CM0M1M2NGridDesc*>( const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
cast_pointer_to_generic_address_space(p_c_m0_m1_m2_n_grid_desc)); *reinterpret_cast<const CM0N0M1N1M2M3M4N2GridDesc*>(
cast_pointer_to_generic_address_space(p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc));
const auto c_block_cluster_adaptor = *reinterpret_cast<const CBlockClusterAdaptor*>( const auto c_block_cluster_adaptor = *reinterpret_cast<const CBlockClusterAdaptor*>(
cast_pointer_to_generic_address_space(p_c_block_cluster_adaptor)); cast_pointer_to_generic_address_space(p_c_block_cluster_adaptor));
...@@ -86,7 +87,7 @@ __global__ void ...@@ -86,7 +87,7 @@ __global__ void
p_shared_block, p_shared_block,
a_k0_m_k1_grid_desc, a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc, b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor); c_block_cluster_adaptor);
} }
#endif #endif
...@@ -138,6 +139,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -138,6 +139,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{}; static constexpr auto K1 = Number<K1Value>{};
...@@ -201,29 +205,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -201,29 +205,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCM0M1M2NGridDescriptor(const CMNGridDesc& c_m_n_grid_desc) MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
{ {
constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, K1>{}; constexpr auto max_lds_align = K1;
constexpr auto CLayout = xdlops_gemm.GetCLayout();
constexpr auto M0 = Number<CLayout.M1()>{};
constexpr auto M1 = Number<CLayout.N1()>{};
constexpr auto M2 = Number<CLayout.M0()>{};
constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat);
constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat);
constexpr auto N1 = Number<CLayout.N0()>{}; constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
const auto c_m0_m1_m2_n_grid_desc = transform_tensor_descriptor( constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
c_m_n_grid_desc, make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, M0, M1, M2)),
make_unmerge_transform(make_tuple(NRepeat, NWaves, N1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
return c_m0_m1_m2_n_grid_desc; using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
MPerWave,
NPerWave,
MRepeat,
NRepeat,
K1>;
return BlockwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc);
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -253,8 +256,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -253,8 +256,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return c_blockid_to_m0_n0_block_cluster_adaptor; return c_blockid_to_m0_n0_block_cluster_adaptor;
} }
using CM0M1M2NGridDesc = decltype(MakeCM0M1M2NGridDescriptor(CMNGridDesc{})); using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{})); using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}));
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
...@@ -262,7 +265,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -262,7 +265,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
const AK0MK1GridDesc& a_k0_m_k1_grid_desc, const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc, const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
const CM0M1M2NGridDesc& c_m0_m1_m2_n_grid_desc, const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
const CBlockClusterAdaptor& c_block_cluster_adaptor) const CBlockClusterAdaptor& c_block_cluster_adaptor)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
...@@ -270,7 +273,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -270,7 +273,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize()); p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_grid, c_m0_m1_m2_n_grid_desc.GetElementSpaceSize()); p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize());
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
...@@ -358,50 +361,26 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -358,50 +361,26 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// register // register
// sanity check // sanity check
static_assert(MPerBlock % (MPerWave * MRepeat) == 0 &&
NPerBlock % (NPerWave * NRepeat) == 0,
"wrong!");
constexpr auto a_k0_m0_m1_k1_block_desc = transform_tensor_descriptor(
a_k0_m_k1_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MPerBlock / MRepeat>{})),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
constexpr auto b_k0_n0_n1_k1_block_desc = transform_tensor_descriptor(
b_k0_n_k1_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NPerBlock / NRepeat>{})),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB, FloatAB,
decltype(a_k0_m0_m1_k1_block_desc), decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n0_n1_k1_block_desc), decltype(b_k0_n_k1_block_desc),
MPerWave, MPerWave,
NPerWave, NPerWave,
K1>{}; MRepeat,
NRepeat,
constexpr auto CLayout = blockwise_gemm.GetCLayout(); K1>{};
constexpr index_t BlkSize = CLayout.GetBlkSize();
constexpr index_t NumBlks = CLayout.GetNumBlks();
constexpr index_t NumXdlops = CLayout.GetNumXdlops();
static_assert(NumBlks == 1 && NumXdlops == 1, "K Reduction Mfma only");
constexpr auto c_mr_nr_blk_desc = constexpr auto c_mr_nr_blk_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
blockwise_gemm.GetCM0N0M1N1M2M3M4N2ThreadDesc();
constexpr auto CBlkSize = c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc.GetElementSpaceSize();
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum_t::Vgpr,
vector_type<FloatAcc, BlkSize>, vector_type<FloatAcc, CBlkSize>,
c_mr_nr_blk_desc.GetElementSpaceSize(), c_mr_nr_blk_desc.GetElementSpaceSize(),
true> true>
c_thread_buf; c_thread_buf;
...@@ -474,94 +453,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -474,94 +453,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
} }
#if 0
// output: register to global memory // output: register to global memory
{ {
constexpr index_t M0 = CLayout.M1(); constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
constexpr index_t M1 = CLayout.N1(); blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor();
constexpr index_t M2 = CLayout.M0();
constexpr index_t N0 = CLayout.N1();
constexpr index_t N1 = CLayout.N0();
constexpr auto c_m0_m1_m2_n_thread_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<1>{},
Number<1>{},
Number<M0>{},
Number<1>{},
Number<M2>{},
Number<1>{}));
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, c_m0_m1_m2_n_thread_desc.GetElementSpaceSize(), true>
c_blk_buf_;
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
static_for<0, NRepeat, 1>{}([&](auto nr_i) {
constexpr auto blk_off =
c_mr_nr_blk_desc.CalculateOffset(make_tuple(mr_i, nr_i));
static_for<0, BlkSize, 1>{}([&](auto j) {
c_blk_buf_(Number<blk_off * BlkSize + j>{}) =
c_thread_buf[Number<blk_off>{}]
.template AsType<FloatAcc>()[Number<j>{}];
});
});
});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_grid =
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{};
constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat);
constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat);
ThreadwiseTensorSliceTransfer_v1r3<
FloatC,
FloatC,
decltype(c_m0_m1_m2_n_thread_desc),
decltype(c_m0_m1_m2_n_grid_desc),
Sequence<MRepeat, NRepeat, 1, 1, M0, 1, M2, 1>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{
c_m0_m1_m2_n_grid_desc,
make_multi_index(m_thread_data_on_grid / (M2 * M1 * M0 * MWaves),
n_thread_data_on_grid / (N1 * NWaves),
m_thread_data_on_grid % (M2 * M1 * M0 * MWaves) / (M2 * M1 * M0),
n_thread_data_on_grid % (N1 * NWaves) / N1,
m_thread_data_on_grid % (M2 * M1 * M0) / (M2 * M1),
m_thread_data_on_grid % (M2 * M1) / M2,
m_thread_data_on_grid % M2,
n_thread_data_on_grid % N1)}
.Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_blk_buf_,
c_m0_m1_m2_n_grid_desc,
c_grid_buf,
c_m0_m1_m2_n_grid_tensor_step_hacks);
}
#else
{
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
constexpr auto c_m0_m1_m2_n_thread_desc = make_naive_tensor_descriptor_packed( constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
make_tuple(I1, I1, I1, I1, Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{})); constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
...@@ -574,92 +473,96 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -574,92 +473,96 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const index_t n_thread_data_on_grid = const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{}; constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_iterator_hacks = CGridStepHacks{};
auto c_thread_copy = auto c_thread_copy =
ThreadwiseTensorSliceTransfer_v1r3<FloatC, ThreadwiseTensorSliceTransfer_v1r3<FloatC,
FloatC, FloatC,
decltype(c_m0_m1_m2_n_thread_desc), decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
decltype(c_m0_m1_m2_n_grid_desc), decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc),
Sequence<1, 1, 1, 1, M0, 1, M2, 1>, Sequence<I1, I1, I1, I1, M2, I1, M4, I1>,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
1, 1,
true>{ true>{
c_m0_m1_m2_n_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
make_multi_index(0, make_multi_index(0,
0, 0,
0, 0,
0, 0,
m_thread_data_on_grid / (M2 * M1), m_thread_data_on_grid / (M3 * M4),
m_thread_data_on_grid % (M2 * M1) / M2, m_thread_data_on_grid % (M3 * M4) / M4,
m_thread_data_on_grid % M2, m_thread_data_on_grid % M4,
n_thread_data_on_grid)}; n_thread_data_on_grid)};
auto init_copy = [&](auto c_thread_idx_) { auto init_copy = [&](auto c_thread_idx_) {
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf, c_grid_buf,
c_m0_m1_m2_n_grid_tensor_step_hacks); c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_iterator_hacks);
return c_thread_idx_; return c_thread_idx_;
}; };
auto mrepeat_plus_copy = [&](auto c_thread_idx_) { auto mrepeat_plus_copy = [&](auto c_thread_idx_) {
constexpr auto mrepeat_step_plus = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0); constexpr auto mrepeat_step_plus = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, mrepeat_step_plus); c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
mrepeat_step_plus);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf, c_grid_buf,
c_m0_m1_m2_n_grid_tensor_step_hacks); c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_iterator_hacks);
}; };
auto nrepeat_plus_copy = [&](auto c_thread_idx_) { auto nrepeat_plus_copy = [&](auto c_thread_idx_) {
constexpr auto nrepeat_step_plus = make_multi_index(0, 1, 0, 0, 0, 0, 0, 0); constexpr auto nrepeat_step_plus = make_multi_index(0, 1, 0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, nrepeat_step_plus); c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
nrepeat_step_plus);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf, c_grid_buf,
c_m0_m1_m2_n_grid_tensor_step_hacks); c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_iterator_hacks);
}; };
auto mrepeat_minus_copy = [&](auto c_thread_idx_) { auto mrepeat_minus_copy = [&](auto c_thread_idx_) {
constexpr auto mrepeat_step_plus = make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0); constexpr auto mrepeat_step_plus = make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, mrepeat_step_plus); c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
mrepeat_step_plus);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf, c_grid_buf,
c_m0_m1_m2_n_grid_tensor_step_hacks); c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_iterator_hacks);
}; };
auto nrepeat_minus_copy = [&](auto c_thread_idx_) { auto nrepeat_minus_copy = [&](auto c_thread_idx_) {
constexpr auto nrepeat_step_minus = make_multi_index(0, -1, 0, 0, 0, 0, 0, 0); constexpr auto nrepeat_step_minus = make_multi_index(0, -1, 0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, nrepeat_step_minus); c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
nrepeat_step_minus);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf, c_grid_buf,
c_m0_m1_m2_n_grid_tensor_step_hacks); c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_iterator_hacks);
}; };
static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or
...@@ -791,7 +694,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -791,7 +694,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
init_copy(make_tuple(I0, I0)); init_copy(make_tuple(I0, I0));
} }
} }
#endif
} }
}; // namespace ck }; // namespace ck
......
...@@ -709,19 +709,59 @@ struct XdlopsGemm ...@@ -709,19 +709,59 @@ struct XdlopsGemm
static_assert(mfma_type.k % mfma_type.k_base == 0, "k % kbase != 0!"); static_assert(mfma_type.k % mfma_type.k_base == 0, "k % kbase != 0!");
} }
template <typename CM0N0M1N1M2N2GridDesc>
__host__ __device__ static constexpr auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CM0N0M1N1M2N2GridDesc& c_m0_n0_m1_n1_m2_n2_grid_desc)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto M0 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I0);
constexpr auto N0 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I1);
constexpr auto M1 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I2);
constexpr auto N1 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I3);
constexpr auto M2 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I4);
constexpr auto N2 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I5);
static_assert(N2 == mfma_type.num_threads_blk, "");
static_assert(
M2 == (mfma_type.num_groups_blk * mfma_type.num_output_blks * mfma_type.group_size),
"");
return transform_dynamic_tensor_descriptor(
c_m0_n0_m1_n1_m2_n2_grid_desc,
make_tuple(make_pass_through_transform(M0),
make_pass_through_transform(N0),
make_pass_through_transform(M1),
make_pass_through_transform(N1),
make_unmerge_transform(make_tuple(mfma_type.num_groups_blk,
mfma_type.num_input_blks,
mfma_type.group_size)),
make_pass_through_transform(mfma_type.num_threads_blk)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4, 5, 6>{},
Sequence<7>{}));
}
__device__ static constexpr index_t GetRegSizePerXdlops() __device__ static constexpr index_t GetRegSizePerXdlops()
{ {
return MPerXdlops * NPerXdlops / mfma_type.wave_size; return MPerXdlops * NPerXdlops / mfma_type.wave_size;
} }
template <class ADesc, template <index_t c_offset, class FloatA, class FloatB, class FloatC>
class BDesc,
class CDesc,
index_t m0,
index_t n0,
class FloatA,
class FloatB,
class FloatC>
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{ {
static_assert(is_same<base_type, float>::value || is_same<base_type, half_t>::value || static_assert(is_same<base_type, float>::value || is_same<base_type, half_t>::value ||
...@@ -730,24 +770,35 @@ struct XdlopsGemm ...@@ -730,24 +770,35 @@ struct XdlopsGemm
static_assert(KPack % mfma_type.k_base == 0, "KPack cannot be divided by k_base"); static_assert(KPack % mfma_type.k_base == 0, "KPack cannot be divided by k_base");
constexpr index_t c_offset = CDesc{}.CalculateOffset(make_tuple(m0, n0)) * GetNumXdlops(); static_for<0, KPack / mfma_type.k_base, 1>{}([&](auto k) {
static_for<0, KPack, mfma_type.k_base>{}([&](auto k) {
constexpr index_t a_offset = ADesc{}.CalculateOffset(make_tuple(0, m0, 0, k));
constexpr index_t b_offset = BDesc{}.CalculateOffset(make_tuple(0, n0, 0, k));
mfma_type.template run<MPerXdlops, NPerXdlops, c_offset>( mfma_type.template run<MPerXdlops, NPerXdlops, c_offset>(
p_a_wave[Number<a_offset / mfma_type.k_base>{}], p_a_wave[k], p_b_wave[k], p_c_thread);
p_b_wave[Number<b_offset / mfma_type.k_base>{}],
p_c_thread);
}); });
} }
static constexpr auto GetBlkIdx()
{
const auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(
make_tuple(1, mfma_type.num_input_blks, mfma_type.num_threads_blk))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto blk_idx = threadidx_to_blk_idx_adaptor.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto blk_id = blk_idx[Number<1>{}];
const auto blk_td = blk_idx[Number<2>{}];
return make_tuple(blk_id, blk_td);
}
__device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i) __device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
{ {
const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size; const auto blk_idx = GetBlkIdx();
const index_t blk_id = laneId / mfma_type.num_threads_blk;
const index_t blk_td = laneId % mfma_type.num_threads_blk; const auto blk_id = blk_idx[Number<0>{}];
const auto blk_td = blk_idx[Number<1>{}];
index_t n_offset = blk_i * mfma_type.n + blk_td; index_t n_offset = blk_i * mfma_type.n + blk_td;
index_t m_offset = xdlops_i * mfma_type.m + blk_id * mfma_type.group_size; index_t m_offset = xdlops_i * mfma_type.m + blk_id * mfma_type.group_size;
...@@ -755,24 +806,12 @@ struct XdlopsGemm ...@@ -755,24 +806,12 @@ struct XdlopsGemm
return CIndex{m_offset, n_offset}; return CIndex{m_offset, n_offset};
} }
static constexpr index_t MRepeats = GetXdlopsInfo().MRepeats;
static constexpr index_t NRepeats = GetXdlopsInfo().NRepeats;
static constexpr index_t MPerXdlops = GetXdlopsInfo().MPerXdlops; static constexpr index_t MPerXdlops = GetXdlopsInfo().MPerXdlops;
static constexpr index_t NPerXdlops = GetXdlopsInfo().NPerXdlops; static constexpr index_t NPerXdlops = GetXdlopsInfo().NPerXdlops;
static constexpr bool IsKReduction = GetXdlopsInfo().IsKReduction();
static constexpr bool IsABroadcast = GetXdlopsInfo().IsABroadcast();
static constexpr index_t KPerXdlops = GetXdlopsInfo().GetKPerXdlops(); static constexpr index_t KPerXdlops = GetXdlopsInfo().GetKPerXdlops();
static constexpr auto GetBlkId(const index_t lane_id) static constexpr bool IsKReduction = GetXdlopsInfo().IsKReduction();
{ static constexpr bool IsABroadcast = GetXdlopsInfo().IsABroadcast();
return lane_id / mfma_type.num_threads_blk;
}
static constexpr auto GetBlkTd(const index_t lane_id)
{
return lane_id % mfma_type.num_threads_blk;
}
static constexpr auto mfma_type = GetXdlopsInfo().mfma_type; static constexpr auto mfma_type = GetXdlopsInfo().mfma_type;
...@@ -794,7 +833,7 @@ struct XdlopsGemm ...@@ -794,7 +833,7 @@ struct XdlopsGemm
} }
}; };
__host__ __device__ static constexpr auto GetCLayout() { return CLayout{}; } __host__ __device__ static constexpr auto GetCXdlopsLayout() { return CLayout{}; }
}; };
} // namespace ck } // namespace ck
......
...@@ -129,9 +129,10 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -129,9 +129,10 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
} }
const auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc);
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc); using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
...@@ -144,7 +145,7 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -144,7 +145,7 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
FloatC, FloatC,
remove_reference_t<AK0MK1GridDesc>, remove_reference_t<AK0MK1GridDesc>,
remove_reference_t<BK0NK1GridDesc>, remove_reference_t<BK0NK1GridDesc>,
remove_reference_t<CM0M1M2NGridDesc>, remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
remove_reference_t<CBlockClusterAdaptor>>; remove_reference_t<CBlockClusterAdaptor>>;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
...@@ -158,18 +159,18 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -158,18 +159,18 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
p_c_grid, p_c_grid,
a_k0_m_k1_grid_desc, a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc, b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor); c_block_cluster_adaptor);
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc)); DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc));
DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc)); DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc));
DeviceMem c_m0_m1_m2_n_grid_desc_dev_buf(sizeof(CM0M1M2NGridDesc)); DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CM0N0M1N1M2M3M4N2GridDesc));
DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor)); DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor));
a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc); a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc);
b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc); b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc);
c_m0_m1_m2_n_grid_desc_dev_buf.ToDevice(&c_m0_m1_m2_n_grid_desc); c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor); c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
float ave_time = launch_and_time_kernel( float ave_time = launch_and_time_kernel(
...@@ -183,7 +184,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -183,7 +184,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
p_c_grid, p_c_grid,
cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer()), cast_pointer_to_constant_address_space(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
#endif #endif
return ave_time; return ave_time;
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#define USE_CONV_FWD_V4R4R2_NHWC 1 #define USE_CONV_FWD_V4R4R2_NHWC 1
#define USE_CONV_FWD_V6R1_NCHW 0 #define USE_CONV_FWD_V6R1_NCHW 0
#define USE_CONV_FWD_V5R1_NCHW 0 #define USE_CONV_FWD_V5R1_NCHW 0
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0 #define USE_CONV_FWD_V4R4R2_XDL_NCHW 1
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0 #define USE_CONV_FWD_V4R4R4_XDL_NHWC 0
enum ConvForwardAlgo enum ConvForwardAlgo
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment