"...composable_kernel.git" did not exist on "0b9fe84014931464bd821c56f7dfac6e3d675ba2"
Unverified Commit b1f8ae37 authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Fix contraction IsSupported checks (#1257)

parent 43879b89
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -627,7 +627,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -627,7 +627,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
arg.a_max_read_elems_ % ABlockTransferSrcScalarPerVector == 0; arg.a_max_read_elems_ % ABlockTransferSrcScalarPerVector == 0;
const bool valid_a_access_dim_m = ABlockTransferSrcVectorDim == 1 && arg.a_mz_consecutive_; const bool valid_a_access_dim_m = ABlockTransferSrcVectorDim == 1 && arg.a_mz_consecutive_;
const bool valid_a_access_dim_k = ABlockTransferSrcVectorDim == 2 && arg.a_kz_consecutive_; const bool valid_a_access_dim_k = ABlockTransferSrcVectorDim == 2 && arg.a_kz_consecutive_;
const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k; const bool valid_a_access_dim =
valid_a_access_dim_m || valid_a_access_dim_k || ABlockTransferSrcScalarPerVector == 1;
if(!(valid_a_vector_size && valid_a_access_dim)) if(!(valid_a_vector_size && valid_a_access_dim))
{ {
return false; return false;
...@@ -637,7 +638,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -637,7 +638,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
arg.b_max_read_elems_ % BBlockTransferSrcScalarPerVector == 0; arg.b_max_read_elems_ % BBlockTransferSrcScalarPerVector == 0;
const bool valid_b_access_dim_n = BBlockTransferSrcVectorDim == 1 && arg.b_nz_consecutive_; const bool valid_b_access_dim_n = BBlockTransferSrcVectorDim == 1 && arg.b_nz_consecutive_;
const bool valid_b_access_dim_k = BBlockTransferSrcVectorDim == 2 && arg.b_kz_consecutive_; const bool valid_b_access_dim_k = BBlockTransferSrcVectorDim == 2 && arg.b_kz_consecutive_;
const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k; const bool valid_b_access_dim =
valid_b_access_dim_n || valid_b_access_dim_k || BBlockTransferSrcScalarPerVector == 1;
if(!(valid_b_vector_size && valid_b_access_dim)) if(!(valid_b_vector_size && valid_b_access_dim))
{ {
return false; return false;
...@@ -648,7 +650,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -648,7 +650,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const bool valid_d_vector_size = const bool valid_d_vector_size =
arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0; arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector read of Ds is always on N dimension. // Vector read of Ds is always on N dimension.
const bool valid_d_access_dim = arg.ds_nz_consecutive_[i]; const bool valid_d_access_dim =
arg.ds_nz_consecutive_[i] || CDEBlockTransferScalarPerVector_NPerBlock == 1;
if(!(valid_d_vector_size && valid_d_access_dim)) if(!(valid_d_vector_size && valid_d_access_dim))
{ {
valid_ds_access = false; valid_ds_access = false;
...@@ -662,7 +665,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -662,7 +665,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const bool valid_e_vector_size = const bool valid_e_vector_size =
arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0; arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector write of E is always on N dimension. // Vector write of E is always on N dimension.
const bool valid_e_access_dim = arg.e_nz_consecutive_; const bool valid_e_access_dim =
arg.e_nz_consecutive_ || CDEBlockTransferScalarPerVector_NPerBlock == 1;
if(!(valid_e_vector_size && valid_e_access_dim)) if(!(valid_e_vector_size && valid_e_access_dim))
{ {
return false; return false;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment