Commit 267e7c7a authored by Alan Turner's avatar Alan Turner
Browse files

Allow for mixed types with int8 gemms

parent d028080a
...@@ -28,4 +28,4 @@ ROCmSoftwarePlatform/half@rocm-5.4.2 ...@@ -28,4 +28,4 @@ ROCmSoftwarePlatform/half@rocm-5.4.2
pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/composable_kernel@01305bf28afd5d4f37702e6722f2441bea99b6f2 -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCmSoftwarePlatform/composable_kernel@84c5bec1d66a633802fd977bd61e0aada7a6f153 -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
...@@ -110,7 +110,7 @@ struct find_ck_gemm_pointwise ...@@ -110,7 +110,7 @@ struct find_ck_gemm_pointwise
auto inputs = ins->inputs(); auto inputs = ins->inputs();
auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins); auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins);
auto gemm_idx = gemm_it - inputs.begin(); auto gemm_idx = gemm_it - inputs.begin();
if(ins->get_shape().type() != gemm_ins->get_shape().type()) if(gemm_ins->get_shape().type() != shape::int32_type and ins->get_shape().type() != gemm_ins->get_shape().type())
return; return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) { if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) {
return not is_ck_supported_type(input->get_shape().type()); return not is_ck_supported_type(input->get_shape().type());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment