#include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace gpu { shape hip_loop::compute_shape(std::vector inputs, std::vector mods) const { auto input_num = (inputs.size() - 2) / 2; inputs.erase(inputs.begin() + input_num, inputs.end()); return op.compute_shape(inputs, std::move(mods)); } struct gpu_loop { int64_t max_iterations = 0; template void copy(context& ctx, const argument& src, T& dst) const { argument arg_dst{src.get_shape(), &dst}; copy_from_gpu(ctx, src, arg_dst); } template void copy(context& ctx, T src, const argument& dst) const { argument arg_src{dst.get_shape(), &src}; copy_to_gpu(ctx, arg_src, dst); } void append(const std::vector&, const std::vector&, int) const {} void set_zero(context& ctx, const std::vector& concatenated_outputs, int iter) const { if(iter >= max_iterations) return; auto elem_num = max_iterations - iter; for(const auto& out : concatenated_outputs) { auto s = out.get_shape(); auto size = s.bytes() / max_iterations; auto lens = s.lens(); lens[0] = elem_num; shape ss{s.type(), lens}; assert(ss.bytes() + iter * size <= out.get_shape().bytes()); device::fill(ctx.get_stream().get(), argument(ss, out.data() + iter * size), 0); } } std::unordered_map get_output_params(const module& m) const { auto get_output_index = [](const std::string& name) { std::string out_prefix = "#output_"; auto loc = name.find(out_prefix); if(loc != std::string::npos) { int index = std::stoi(name.substr(loc + out_prefix.size())); return index; } return -1; }; const auto& param_names = m.get_parameter_names(); std::unordered_map result; for(const auto& name : param_names) { auto index = get_output_index(name); if(index == -1) continue; result[name] = index; } return result; } }; argument hip_loop::compute(context& ctx, const shape&, const std::vector& args, const std::vector& mods, const std::function( module_ref&, const std::unordered_map&)>& run) const { return run_loop(gpu_loop{op.max_iterations}, ctx, args, mods, run); } } // namespace gpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx