Commit 4a5b2257 authored by Rosty Geyyer's avatar Rosty Geyyer
Browse files

Format

parent 3e002b60
...@@ -789,13 +789,13 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -789,13 +789,13 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
} }
// check if DsLayout is supported // check if DsLayout is supported
template<typename RefLayout, typename DsLayout, const index_t NumDTensor> template <typename RefLayout, typename DsLayout, const index_t NumDTensor>
static bool CheckDLayout() static bool CheckDLayout()
{ {
static bool valid = true; static bool valid = true;
// iterate over DLayout tuple // iterate over DLayout tuple
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
// if RefLayout and DLayout are same, keep valid true, otherwise false // if RefLayout and DLayout are same, keep valid true, otherwise false
valid = valid && is_same_v<RefLayout, DLayout>; valid = valid && is_same_v<RefLayout, DLayout>;
}); });
...@@ -816,12 +816,14 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -816,12 +816,14 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
// B1 - Row or Col // B1 - Row or Col
// D1s - Rows // D1s - Rows
// E1 - Row // E1 - Row
if(!(is_same_v<tensor_layout::gemm::RowMajor, A0Layout> && if(!(is_same_v<tensor_layout::gemm::RowMajor, A0Layout> &&
is_same_v<tensor_layout::gemm::ColumnMajor, B0Layout> && is_same_v<tensor_layout::gemm::ColumnMajor, B0Layout> &&
CheckDLayout<tensor_layout::gemm::RowMajor, D0sLayout, NumD0Tensor>() && CheckDLayout<tensor_layout::gemm::RowMajor, D0sLayout, NumD0Tensor>() &&
(is_same_v<tensor_layout::gemm::RowMajor, B1Layout> || (is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ||
is_same_v<tensor_layout::gemm::ColumnMajor, B1Layout>) && is_same_v<tensor_layout::gemm::ColumnMajor,
CheckDLayout<tensor_layout::gemm::RowMajor, D1sLayout, NumD1Tensor>() && B1Layout>)&&CheckDLayout<tensor_layout::gemm::RowMajor,
D1sLayout,
NumD1Tensor>() &&
is_same_v<tensor_layout::gemm::RowMajor, E1Layout>)) is_same_v<tensor_layout::gemm::RowMajor, E1Layout>))
{ {
return false; return false;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment