Unverified Commit 226da497 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Remove where op workaround for ck (#1854)


Co-authored-by: default avatarkahmed10 <15948690+kahmed10@users.noreply.github.com>
parent 0802c19e
...@@ -31,8 +31,6 @@ namespace migraphx { ...@@ -31,8 +31,6 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace onnx { namespace onnx {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK)
struct parse_where : op_parser<parse_where> struct parse_where : op_parser<parse_where>
{ {
std::vector<op_desc> operators() const { return {{"Where"}}; } std::vector<op_desc> operators() const { return {{"Where"}}; }
...@@ -59,13 +57,6 @@ struct parse_where : op_parser<parse_where> ...@@ -59,13 +57,6 @@ struct parse_where : op_parser<parse_where>
compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens()); compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens());
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens()); lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens());
if(enabled(MIGRAPHX_ENABLE_CK{}))
{
// Convert condition tensor to int32 to work around CK not supporting bool type
args[0] = info.add_instruction(
make_op("convert", {{"target_type", shape::int32_type}}), args[0]);
}
if(args[0]->get_shape().lens() != lens) if(args[0]->get_shape().lens() != lens)
{ {
args[0] = args[0] =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment