Commit b41b7699 authored by Paul's avatar Paul
Browse files

Skip if any input is not a supported ck type

parent 8a2837c4
......@@ -31,8 +31,6 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK)
struct parse_where : op_parser<parse_where>
{
std::vector<op_desc> operators() const { return {{"Where"}}; }
......@@ -59,12 +57,6 @@ struct parse_where : op_parser<parse_where>
compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens());
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens());
if (enabled(MIGRAPHX_ENABLE_CK{}))
{
// Convert condition tensor to int32 to work around CK not supporting bool type
args[0] = info.add_instruction(make_op("convert", {{"target_type", shape::int32_type}}), args[0]);
}
if(args[0]->get_shape().lens() != lens)
{
args[0] =
......
......@@ -72,12 +72,16 @@ MIGRAPHX_REGISTER_OP(ck_gemm);
namespace {
static bool is_ck_supported_type(shape::type_t t)
{
return contains({shape::half_type, shape::int8_type, shape::int32_type}, t);
}
MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
{
if(ins->name() != "dot" and ins->name() != "quant_dot")
return false;
if(not contains({shape::half_type, shape::int8_type, shape::int32_type},
ins->get_shape().type()))
if(not is_ck_supported_type(ins->get_shape().type()))
return false;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
......@@ -107,6 +111,10 @@ struct find_ck_gemm_pointwise
auto gemm_idx = gemm_it - inputs.begin();
if(ins->get_shape().type() != gemm_ins->get_shape().type())
return;
if (std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) {
return not is_ck_supported_type(input->get_shape().type());
}))
return;
assert(gemm_it != inputs.end());
if(gemm_idx != 0)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment