"src/op/gemm.cc" did not exist on "64f17c2f369e612cc297d358f607307a615bbb59"
Commit 5b526236 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

checks in compute_shape()

parent 865e71c3
...@@ -48,17 +48,11 @@ struct multinomial ...@@ -48,17 +48,11 @@ struct multinomial
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this, true}.has(2).only_dims(2); check_shapes{inputs, *this, true}.has(2).only_dims(2);
size_t sample_size = inputs.back().max_lens().back();
if(not contains({shape::int32_type, shape::int64_type}, dtype)) if(not contains({shape::int32_type, shape::int64_type}, dtype))
MIGRAPHX_THROW( MIGRAPHX_THROW(
"Multinomial: Invalid output type. Valid types are int32_type and int64_type."); "Multinomial: Invalid output type. Valid types are int32_type and int64_type.");
<<<<<<< HEAD
return inputs.front().normalize_standard(); return inputs.front().normalize_standard();
=======
return {dtype, {inputs.front().max_lens().front(), sample_size}};
>>>>>>> da4fa01ff583c14d23b9e10c5fc178a4e3fe3bc2
} }
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment