Commit e48e7f38 authored by Adam Osewski's avatar Adam Osewski
Browse files

Expose b2c_m01 parameter.

In order to pass it through cmd line.
parent 2d6fe2cd
......@@ -32,7 +32,8 @@ struct DeviceGemm : public BaseOperator
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
CElementwiseOperation c_element_op,
ck::index_t b2c_M01) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
......
......@@ -268,7 +268,8 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
CDEElementwiseOperation cde_element_op,
index_t b2c_M01)
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
......@@ -280,7 +281,7 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
b_grid_desc_bk0_n_bk1_{
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_, b2c_M01)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
......@@ -359,7 +360,7 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
#if 0
#if 1
const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
#else
......@@ -449,7 +450,8 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
CDEElementwiseOperation cde_element_op,
index_t b2c_M01)
{
return Argument{p_a,
p_b,
......@@ -462,7 +464,8 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
StrideE,
a_element_op,
b_element_op,
cde_element_op};
cde_element_op,
b2c_M01};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -480,7 +483,8 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
CDEElementwiseOperation cde_element_op,
index_t b2c_M01) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
......@@ -493,7 +497,8 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
StrideE,
a_element_op,
b_element_op,
cde_element_op);
cde_element_op,
b2c_M01);
}
// polymorphic
......
......@@ -411,7 +411,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
CElementwiseOperation c_element_op,
index_t /*b2c_M01 = 8*/)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
......@@ -618,7 +619,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
CElementwiseOperation c_element_op,
index_t b2c_M01 = 8)
{
return Argument{p_a,
p_b,
......@@ -631,7 +633,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
StrideC,
a_element_op,
b_element_op,
c_element_op};
c_element_op,
b2c_M01};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -648,7 +651,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
CElementwiseOperation c_element_op,
index_t b2c_M01) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
......@@ -661,7 +665,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
StrideC,
a_element_op,
b_element_op,
c_element_op);
c_element_op,
b2c_M01);
}
// polymorphic
......
......@@ -241,10 +241,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
#if 0
// return block_id to E matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n, index_t b2c_M01 = 8)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
e_grid_desc_m_n);
e_grid_desc_m_n, b2c_M01);
}
#else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment