#include #include #include #include #include namespace migraph { namespace gpu { shape hip_mul::compute_shape(const std::vector& inputs) const { // check_shapes{inputs, *this}.has(3).standard(); check_shapes{inputs, *this}.has(3); return inputs.at(0); } argument hip_mul::compute(context& ctx, const shape&, const std::vector& args) const { device::mul(ctx.get_stream().get(), args[2], args[0], args[1]); return args[2]; } } // namespace gpu } // namespace migraph