Commit 2bb9444f authored by Paul's avatar Paul
Browse files

Allow storing stateful lambdas

parent 7838a9a8
......@@ -580,87 +580,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return false;
}
// check vector load/store
{
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// check vector load of A
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
{
if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
{
// FIXME: not rigorous
if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector laod of B
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
{
if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
{
// FIXME: not rigorous
if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector load of Ds
// only support RowMajor for now
bool all_valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
if constexpr(!is_same_v<DLayout, Row>)
{
all_valid = false;
}
});
if(!all_valid)
{
return false;
}
// check vector store of E
// only support RowMajor for now
if constexpr(is_same_v<ELayout, Row>)
{
if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
else
{
return false;
}
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
return IsSupported(arg.MRaw_, arg.NRaw_, arg.KRaw_) and
GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
......@@ -812,10 +733,27 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock;
Block2ETileMap block_2_etile_map;
// for checking vector load/store
index_t MRaw;
index_t NRaw;
index_t KRaw;
// element-wise op
AElementwiseOperation a_element_op;
BElementwiseOperation b_element_op;
CDEElementwiseOperation cde_element_op;
bool has_main_k_block_loop = true;
bool is_valid = false;
constexpr Descriptor(ADesc a, BDesc b, DsDesc ds, EDesc e)
constexpr Descriptor(ADesc a,
BDesc b,
DsDesc ds,
EDesc e,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CDEElementwiseOperation cde_element_op_)
: a_grid_desc_ak0_m_ak1{GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
DeviceOp::matrix_padder.PadADescriptor_M_K(a))},
b_grid_desc_bk0_n_bk1{GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
......@@ -834,24 +772,44 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
DeviceOp::matrix_padder.PadCDescriptor_M_N(e))},
has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))},
MRaw{e.GetLength(I0)},
NRaw{e.GetLength(I1)},
KRaw{a.GetLength(I1)},
a_element_op{a_element_op_},
b_element_op{b_element_op_},
cde_element_op{cde_element_op_},
is_valid{GridwiseGemm::CheckValidity(
(DeviceOp::matrix_padder.PadADescriptor_M_K(a)),
DeviceOp::matrix_padder.PadBDescriptor_N_K(b),
transform_tuples(
[&](auto d) constexpr {
return DeviceOp::matrix_padder.PadCDescriptor_M_N(d);
},
ds),
DeviceOp::matrix_padder.PadCDescriptor_M_N(e),
block_2_etile_map) and IsSupported(e.GetLength(I0), e.GetLength(I1), a.GetLength(I1))}
(DeviceOp::matrix_padder.PadADescriptor_M_K(a)),
DeviceOp::matrix_padder.PadBDescriptor_N_K(b),
transform_tuples(
[&](auto d) constexpr {
return DeviceOp::matrix_padder.PadCDescriptor_M_N(d);
},
ds),
DeviceOp::matrix_padder.PadCDescriptor_M_N(e),
block_2_etile_map) and
IsSupported(MRaw, NRaw, KRaw)}
{
}
constexpr bool IsValid() const
{
return is_valid;
}
};
template <class ADesc, class BDesc, class DsDesc, class EDesc>
static constexpr auto make_descriptor(ADesc a, BDesc b, DsDesc ds, EDesc e)
static constexpr auto
make_descriptor(ADesc a,
BDesc b,
DsDesc ds,
EDesc e,
AElementwiseOperation a_element_op = AElementwiseOperation{},
BElementwiseOperation b_element_op = BElementwiseOperation{},
CDEElementwiseOperation cde_element_op = CDEElementwiseOperation{})
{
return Descriptor<ADesc, BDesc, DsDesc, EDesc>(a, b, ds, e);
return Descriptor<ADesc, BDesc, DsDesc, EDesc>(
a, b, ds, e, a_element_op, b_element_op, cde_element_op);
}
template <class Desc, class DsPointer>
......@@ -870,9 +828,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
p_ds_grid,
p_e_grid,
p_shared_block,
AElementwiseOperation{},
BElementwiseOperation{},
CDEElementwiseOperation{},
desc.a_element_op,
desc.b_element_op,
desc.cde_element_op,
desc.a_grid_desc_ak0_m_ak1,
desc.b_grid_desc_bk0_n_bk1,
desc.ds_grid_desc_mblock_mperblock_nblock_nperblock,
......@@ -886,9 +844,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
p_ds_grid,
p_e_grid,
p_shared_block,
AElementwiseOperation{},
BElementwiseOperation{},
CDEElementwiseOperation{},
desc.a_element_op,
desc.b_element_op,
desc.cde_element_op,
desc.a_grid_desc_ak0_m_ak1,
desc.b_grid_desc_bk0_n_bk1,
desc.ds_grid_desc_mblock_mperblock_nblock_nperblock,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment