"...git@developer.sourcefind.cn:OpenDAS/torch-scatter.git" did not exist on "eb438429196bf8afb8ea5195449e270fd91b97e9"
Unverified Commit c10a6e82 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer Committed by GitHub
Browse files

Add layout check to IsSupportedArgument (#627)



* Add layout check to IsSupportedArgument

* Format

---------
Co-authored-by: default avatarRosty Geyyer <rosty.geyyer@amd.com>
Co-authored-by: default avatarzjing14 <zhangjing14@gmail.com>
parent 14b3504d
...@@ -788,6 +788,20 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -788,6 +788,20 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
return true; return true;
} }
// check if DsLayout is supported
template <typename RefLayout, typename DsLayout, const index_t NumDTensor>
static bool CheckDLayout()
{
static bool valid = true;
// iterate over DLayout tuple
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
// if RefLayout and DLayout are same, keep valid true, otherwise false
valid = valid && is_same_v<RefLayout, DLayout>;
});
return valid;
}
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
...@@ -795,6 +809,26 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -795,6 +809,26 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
return false; return false;
} }
// Check supported layouts
// A0 - Row
// B0 - Col
// D0s - Rows
// B1 - Row or Col
// D1s - Rows
// E1 - Row
if(!(is_same_v<tensor_layout::gemm::RowMajor, A0Layout> &&
is_same_v<tensor_layout::gemm::ColumnMajor, B0Layout> &&
CheckDLayout<tensor_layout::gemm::RowMajor, D0sLayout, NumD0Tensor>() &&
(is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ||
is_same_v<tensor_layout::gemm::ColumnMajor,
B1Layout>)&&CheckDLayout<tensor_layout::gemm::RowMajor,
D1sLayout,
NumD1Tensor>() &&
is_same_v<tensor_layout::gemm::RowMajor, E1Layout>))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a0_grid_desc_m_k_, return GridwiseGemm::CheckValidity(arg.a0_grid_desc_m_k_,
arg.b0_grid_desc_n_k_, arg.b0_grid_desc_n_k_,
arg.b1_grid_desc_n_k_, arg.b1_grid_desc_n_k_,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment