Commit 7c77c682 authored by Paul's avatar Paul
Browse files

Make the config larger

parent 0881cee8
......@@ -156,6 +156,15 @@ template <typename ALayout,
ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler()>
struct CK_DeviceGemmMultipleD
{
static constexpr auto I0 = ck::Number<0>{};
static constexpr auto I1 = ck::Number<1>{};
// static constexpr auto I2 = ck::Number<2>{};
// static constexpr auto I3 = ck::Number<3>{};
// static constexpr auto I4 = ck::Number<4>{};
// static constexpr auto I5 = ck::Number<5>{};
// static constexpr auto I6 = ck::Number<6>{};
// static constexpr auto I7 = ck::Number<7>{};
ck::tensor_operation::device::MatrixPadder<GemmSpec, ck::index_t, ck::index_t, ck::index_t>
matrix_padder{MPerBlock, NPerBlock, KPerBlock};
......@@ -212,6 +221,34 @@ struct CK_DeviceGemmMultipleD
e_grid_desc_m_n_);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename AGridDesc_M_K,
typename BGridDesc_N_K,
typename DsGridDesc_M_N,
typename EGridDesc_M_N,
typename Block2ETileMap>
static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
const BGridDesc_N_K& b_grid_desc_n_k,
const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map)
{
constexpr auto M = a_grid_desc_m_k.GetLength(I0);
constexpr auto N = b_grid_desc_n_k.GetLength(I0);
constexpr auto K = a_grid_desc_m_k.GetLength(I1);
// check consistency of desc
static_assert(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1));
// check tile size
static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0);
// check block-to-E-tile
static_assert(block_2_etile_map.CheckValidity(e_grid_desc_m_n));
return GridwiseGemm::CheckValidity(a_grid_desc_m_k, b_grid_desc_n_k, ds_grid_desc_m_n, e_grid_desc_m_n, block_2_etile_map);
}
AElementwiseOperation a_element_op{};
BElementwiseOperation b_element_op{};
CDEElementwiseOperation cde_element_op{};
......
......@@ -33,9 +33,9 @@ struct gemm_add_relu : verify_program<gemm_add_relu>
{
migraphx::program p;
auto* mm = p.get_main_module();
auto a = mm->add_parameter("1", {migraphx::shape::half_type, {16, 8}});
auto b = mm->add_parameter("2", {migraphx::shape::half_type, {8, 32}});
auto c = mm->add_parameter("3", {migraphx::shape::half_type, {16, 32}});
auto a = mm->add_parameter("1", {migraphx::shape::half_type, {256, 512}});
auto b = mm->add_parameter("2", {migraphx::shape::half_type, {512, 1024}});
auto c = mm->add_parameter("3", {migraphx::shape::half_type, {256, 1024}});
auto dot = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto add = mm->add_instruction(migraphx::make_op("add"), dot, c);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment