Commit e48e7f38 authored by Adam Osewski's avatar Adam Osewski
Browse files

Expose b2c_m01 parameter.

In order to pass it through cmd line.
parent 2d6fe2cd
...@@ -32,7 +32,8 @@ struct DeviceGemm : public BaseOperator ...@@ -32,7 +32,8 @@ struct DeviceGemm : public BaseOperator
ck::index_t StrideC, ck::index_t StrideC,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0; CElementwiseOperation c_element_op,
ck::index_t b2c_M01) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
......
...@@ -268,7 +268,8 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout, ...@@ -268,7 +268,8 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
index_t StrideE, index_t StrideE,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) CDEElementwiseOperation cde_element_op,
index_t b2c_M01)
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)}, : p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{static_cast<const BDataType*>(p_b_grid)}, p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_e_grid_{static_cast<EDataType*>(p_e_grid)}, p_e_grid_{static_cast<EDataType*>(p_e_grid)},
...@@ -280,7 +281,7 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout, ...@@ -280,7 +281,7 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
e_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_, b2c_M01)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op} cde_element_op_{cde_element_op}
...@@ -359,13 +360,13 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout, ...@@ -359,13 +360,13 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
} }
#if 0 #if 1
const index_t grid_size = const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
#else #else
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.e_grid_desc_m_n_); const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.e_grid_desc_m_n_);
#endif #endif
const auto K = arg.a_grid_desc_m_k_.GetLength(I1); const auto K = arg.a_grid_desc_m_k_.GetLength(I1);
auto launch_kernel = [&](auto has_main_k_block_loop) { auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value; constexpr bool has_main_loop = has_main_k_block_loop.value;
...@@ -449,7 +450,8 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout, ...@@ -449,7 +450,8 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
index_t StrideE, index_t StrideE,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) CDEElementwiseOperation cde_element_op,
index_t b2c_M01)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
...@@ -462,7 +464,8 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout, ...@@ -462,7 +464,8 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
StrideE, StrideE,
a_element_op, a_element_op,
b_element_op, b_element_op,
cde_element_op}; cde_element_op,
b2c_M01};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -480,7 +483,8 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout, ...@@ -480,7 +483,8 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
index_t StrideE, index_t StrideE,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override CDEElementwiseOperation cde_element_op,
index_t b2c_M01) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
...@@ -493,7 +497,8 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout, ...@@ -493,7 +497,8 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
StrideE, StrideE,
a_element_op, a_element_op,
b_element_op, b_element_op,
cde_element_op); cde_element_op,
b2c_M01);
} }
// polymorphic // polymorphic
......
...@@ -411,7 +411,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -411,7 +411,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
index_t StrideC, index_t StrideC,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op,
index_t /*b2c_M01 = 8*/)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
...@@ -618,7 +619,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -618,7 +619,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
index_t StrideC, index_t StrideC,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op,
index_t b2c_M01 = 8)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
...@@ -631,7 +633,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -631,7 +633,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
StrideC, StrideC,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op}; c_element_op,
b2c_M01};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -648,7 +651,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -648,7 +651,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
index_t StrideC, index_t StrideC,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override CElementwiseOperation c_element_op,
index_t b2c_M01) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
...@@ -661,7 +665,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -661,7 +665,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
StrideC, StrideC,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op); c_element_op,
b2c_M01);
} }
// polymorphic // polymorphic
......
...@@ -241,10 +241,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -241,10 +241,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
#if 0 #if 0
// return block_id to E matrix tile idx (m0, n0) mapping // return block_id to E matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n, index_t b2c_M01 = 8)
{ {
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>( return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
e_grid_desc_m_n); e_grid_desc_m_n, b2c_M01);
} }
#else #else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment