Commit 6d57f78d authored by Alan Turner's avatar Alan Turner
Browse files

Add name() to CK compute shape throws; enforce mul has 2 args

parent 6fe8b43c
......@@ -55,7 +55,7 @@ struct ck_gemm
{
check_shapes{inputs, *this}.same_ndims();
if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs.");
MIGRAPHX_THROW(name() + ": should have at least two inputs.");
auto a = inputs[0];
auto b = inputs[1];
for(const auto& input : inputs)
......@@ -96,7 +96,7 @@ struct ck_gemm_softmax_gemm
{
check_shapes{inputs, *this}.same_ndims();
if(inputs.size() < 3)
MIGRAPHX_THROW("Expected 3 inputs but got " + to_string(inputs.size()));
MIGRAPHX_THROW(name() + ": Expected 3 inputs but got " + to_string(inputs.size()));
auto a = inputs[0];
auto b = inputs[1];
auto b1 = inputs[2];
......@@ -225,7 +225,7 @@ struct find_ck_gemm_softmax_gemm
{
auto gemm1 =
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
auto mul = match::name("pointwise")(match::either_arg(0, 1)(
auto mul = match::name("pointwise")(match::nargs(2), match::either_arg(0, 1)(
match::is_constant().bind("scale"), gemm1))(is_pointwise_scale());
auto softmax = match::name("softmax")(match::arg(0)(mul)).bind("softmax");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment