Commit 6d9425ec authored by Jianfeng yan's avatar Jianfeng yan
Browse files

added block2CTileMap, but results are not correct

parent f4f94f70
......@@ -7,6 +7,7 @@
#include "device_base.hpp"
#include "device_gemm.hpp"
#include "common_header.hpp"
#include "statically_indexed_array.hpp"
#include "tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
......@@ -181,6 +182,51 @@ struct DeviceGroupedGemmXdl
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
template <int GroupCount>
struct Block2CTileMap
{
Block2CTileMap() = default;
template <typename GemmDesc>
// Block2CTileMap(const StaticallyIndexedArray<GemmDesc, GroupCount>& gemm_desc)
Block2CTileMap(const std::vector<GemmDesc>& gemm_desc, const index_t N0) : N0_{N0}
{
for(index_t grp = 0; grp < GroupCount-1; ++grp)
{
assert(gemm_desc[grp].BlockEnd == gemm_desc[grp+1].BlockStart);
}
for(index_t grp = 0; grp < GroupCount; ++grp)
{
block_ptr_[grp] = gemm_desc[grp].BlockStart;
}
block_ptr_[GroupCount] = gemm_desc[GroupCount-1].BlockEnd;
}
template <typename Index>
__host__ __device__ constexpr auto CalculateBottomIndex(Index blockIdx) const
{
index_t block_id = blockIdx[Number<0>{}];
index_t local_block_id;
for(index_t grp = 0; grp < MaxGroupCount; ++grp)
{
if(block_id >= block_ptr_[grp] && block_id < block_ptr_[grp+1])
{
local_block_id = block_id - block_ptr_[grp];
}
}
return make_tuple(local_block_id / N0_, local_block_id % N0_);
// return make_tuple(local_block_id % N0_, local_block_id / N0_);
}
private:
index_t block_ptr_[GroupCount + 1];
index_t N0_;
};
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<
BlockSize,
......@@ -232,13 +278,14 @@ struct DeviceGroupedGemmXdl
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
const ADataType* a_ptr;
const BDataType* b_ptr;
CDataType* c_ptr;
ck::index_t BlockStart, BlockEnd;
// typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
Block2CTileMap<MaxGroupCount> block_2_ctile_map_;
};
// Argument
......@@ -301,15 +348,13 @@ struct DeviceGroupedGemmXdl
const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
const auto block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
gemm_desc_kernel_arg_.push_back(
GemmDescKernelArg{a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
block_2_ctile_map_,
// block_2_ctile_map_,
static_cast<const ADataType*>(p_a[i]),
static_cast<const BDataType*>(p_b[i]),
static_cast<CDataType*>(p_c[i]),
......@@ -317,6 +362,11 @@ struct DeviceGroupedGemmXdl
BlockEnd});
}
}
for(index_t i = 0; i < gemm_shapes.size(); i++)
{
gemm_desc_kernel_arg_[i].block_2_ctile_map_ = Block2CTileMap<MaxGroupCount>{gemm_desc_kernel_arg_, gemm_desc_kernel_arg_[i].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_.GetLength(Number<1>{})};
}
}
// private:
......@@ -328,6 +378,7 @@ struct DeviceGroupedGemmXdl
CElementwiseOperation c_element_op_;
std::vector<GemmDescKernelArg> gemm_desc_kernel_arg_;
// Block2CTileMap<MaxGroupCount> block_2_ctile_map_;
index_t grid_size_;
};
......
......@@ -73,6 +73,7 @@ __global__ void
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op)
{
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......@@ -84,7 +85,7 @@ __global__ void
i < group_count)
{
auto group_id = i;
const index_t block_id_grp = block_id - gemm_desc_[group_id].BlockStart;
// const index_t block_id_grp = block_id - gemm_desc_[group_id].BlockStart;
GridwiseGemm::template Run<HasMainK0BlockLoop>(
gemm_desc_[group_id].a_ptr,
......@@ -97,8 +98,8 @@ __global__ void
a_element_op,
b_element_op,
c_element_op,
gemm_desc_[group_id].block_2_ctile_map_,
block_id_grp);
gemm_desc_[group_id].block_2_ctile_map_);
// block_2_ctile_map);
}
});
#else
......@@ -426,8 +427,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const Block2CTileMap& block_2_ctile_map,
ck::index_t block_id = get_block_1d_id())
const Block2CTileMap& block_2_ctile_map)
// ck::index_t block_id = get_block_1d_id())
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
......@@ -440,7 +441,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// divide block work by [M, N]
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(block_id));
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment