"vscode:/vscode.git/clone" did not exist on "f656ffe70e9f7755dc612f50c6cc3354e0b169e7"
Commit 224b7aa0 authored by charlie's avatar charlie
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_model_test

parents c4b1102e 4420ccbd
......@@ -26,6 +26,9 @@
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -55,12 +58,12 @@ struct parse_split : op_parser<parse_split>
{
literal s = parser.parse_value(info.attributes.at("split"));
s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); });
if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) !=
static_cast<int64_t>(lens[tuned_axis]))
{
MIGRAPHX_THROW("PARSE_SPLIT: sum of split attribute unequal to dim size of axis!");
}
}
else if(args.size() == 2)
{
auto s = args[1]->eval();
check_arg_empty(s, "Split: dynamic shape is not supported");
s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); });
}
// no split attribute, input is equally divided
else
......@@ -74,6 +77,15 @@ struct parse_split : op_parser<parse_split>
vec_splits.resize(info.num_outputs, dl);
}
if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) !=
static_cast<int64_t>(lens[tuned_axis]))
{
MIGRAPHX_THROW(
"PARSE_SPLIT: sum of split attribute unequal to dim size of axis! tuned axis:" +
std::to_string(lens[tuned_axis]) + " Output " + to_string_range(vec_splits) +
" Rank " + std::to_string(n_rank) + " Len outs " + to_string_range(lens));
}
std::vector<instruction_ref> ret_ins;
int64_t start = 0;
for(auto sl : vec_splits)
......
......@@ -47,7 +47,7 @@ struct parse_transpose : op_parser<parse_transpose>
}
// if perm is empty, use the default value
auto n_dim = args.front()->get_shape().lens().size();
auto n_dim = args.front()->get_shape().ndim();
if(perm.empty())
{
perm.resize(n_dim);
......
......@@ -94,11 +94,19 @@ struct module_pm : module_pass_manager
virtual void run_pass(const pass& p) override
{
assert(mod);
timer ts{};
using seconds = std::chrono::duration<double>;
trace("Module: ", mod->name(), ", Pass: ", p.name());
const double t1 = ts.record<seconds>();
assert(mod->validate() == mod->end());
p.apply(*this);
trace(*mod);
validate_pass(*mod, p, *t);
const double t2 = ts.record<seconds>();
trace("Pass: ", p.name(), " completed in (s): ", (t2 - t1));
}
};
......
......@@ -51,7 +51,18 @@ struct dnnl_binary : dnnl_op<dnnl_binary, dnnl::binary>
auto r = s0;
if(s0 != s1 or not s0.packed())
{
r = shape{s0.type(), s0.lens()};
if(s0.packed() != s1.packed())
{
r = s0.packed() ? s0 : s1;
}
else if(s0.broadcasted() != s1.broadcasted())
{
r = s0.broadcasted() ? s1.with_lens(s0.lens()) : s0.with_lens(s0.lens());
}
else
{
r = {s0.type(), s0.lens()};
}
}
// Call to get_primitive to make sure an algo is available
this->get_primitive(this->to_memory_desc(r, inputs));
......
......@@ -43,9 +43,9 @@ struct dnnl_convolution
return {MIGRAPHX_DNNL_PREFIX(ARG_SRC), MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS)};
}
shape adjust_shape(const shape& x, int i) const
shape adjust_shape(const shape& x, int i, const shape& output) const
{
auto s = base_adjust_shape(x);
auto s = base_adjust_shape(x, output);
if(i == 1 and op.group > 1)
{
// TODO: Add support for transposed weights
......
......@@ -37,9 +37,9 @@ struct dnnl_deconvolution
return {MIGRAPHX_DNNL_PREFIX(ARG_SRC), MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS)};
}
shape adjust_shape(const shape& x, int i) const
shape adjust_shape(const shape& x, int i, const shape& output) const
{
auto s = base_adjust_shape(x);
auto s = base_adjust_shape(x, output);
if(i == 1)
{
// The input and output channels are flipped for dnnl
......
......@@ -167,7 +167,7 @@ struct dnnl_op : auto_register_op<Derived>
std::iota(result.begin(), result.end(), MIGRAPHX_DNNL_PREFIX(ARG_SRC_0));
return result;
}
shape base_adjust_shape(const shape& s) const
shape base_adjust_shape(const shape& s, const shape& output) const
{
if(s.broadcasted())
{
......@@ -183,7 +183,8 @@ struct dnnl_op : auto_register_op<Derived>
else
return len;
});
return shape{s.type(), lens};
// Use the permutation of the output
return output.with_lens(s.type(), lens);
}
return s;
}
......@@ -204,7 +205,10 @@ struct dnnl_op : auto_register_op<Derived>
i++;
}
}
shape adjust_shape(const shape& s, int) const { return base_adjust_shape(s); }
shape adjust_shape(const shape& s, int, const shape& output) const
{
return base_adjust_shape(s, output);
}
std::vector<int> create_arg_map(std::size_t input_size) const
{
const auto& self = static_cast<const Derived&>(*this);
......@@ -224,12 +228,12 @@ struct dnnl_op : auto_register_op<Derived>
const auto& self = static_cast<const Derived&>(*this);
std::unordered_map<int, dnnl::memory::desc> result;
result[MIGRAPHX_DNNL_PREFIX(ARG_DST)] =
to_dnnl_memory_desc(self.adjust_shape(output_shape, inputs.size()));
to_dnnl_memory_desc(self.adjust_shape(output_shape, inputs.size(), output_shape));
auto m = create_arg_map(inputs.size());
assert(m.size() >= inputs.size());
for(int i = 0; i < inputs.size(); i++)
{
result[m[i]] = to_dnnl_memory_desc(self.adjust_shape(inputs[i], i));
result[m[i]] = to_dnnl_memory_desc(self.adjust_shape(inputs[i], i, output_shape));
}
return result;
}
......
......@@ -32,7 +32,7 @@ struct dnnl_reorder : dnnl_op<dnnl_reorder, dnnl::reorder>
{
std::string name() const { return "dnnl::reorder"; }
shape adjust_shape(const shape& x, int) const { return x; }
shape adjust_shape(const shape& x, int, const shape&) const { return x; }
shape compute_shape(const std::vector<shape>& inputs) const
{
......
......@@ -33,6 +33,7 @@
#include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/layout_nhwc.hpp>
#include <migraphx/memory_coloring.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/register_target.hpp>
......@@ -82,6 +83,9 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
simplify_algebra{},
simplify_reshapes{},
layout_nhwc{},
dead_code_elimination{},
simplify_reshapes{},
simplify_algebra{},
auto_contiguous{},
simplify_reshapes{},
......
......@@ -83,6 +83,7 @@ add_library(migraphx_gpu
compile_gen.cpp
compile_hip.cpp
compile_hip_code_object.cpp
compile_miopen.cpp
compiler.cpp
device_name.cpp
fuse_mlir.cpp
......@@ -232,11 +233,14 @@ get_target_property(MIOPEN_LOCATION MIOpen LOCATION)
check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API)
check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API)
if(HAS_FIND_2_API)
# TODO: Set default to HAS_FIND_2_API
set(MIGRAPHX_USE_FIND_2_API OFF CACHE BOOL "")
if(MIGRAPHX_USE_FIND_2_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API)
message(STATUS "MIGraphx is using Find-2.0 API of MIOpen")
else()
message(STATUS "MIOpen does not have Find-2.0 API")
message(STATUS "MIGraphx is using legacy Find API in MIOpen")
endif()
if(HAS_FIND_MODE_API)
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compile_miopen.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/module.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/gpu/rocblas.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct miopen_op
{
operation op = op::identity{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"));
}
std::string name() const { return "gpu::miopen_op"; }
shape compute_shape(std::vector<shape> inputs) const
{
inputs.push_back(inputs.back());
return op.compute_shape(inputs);
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
MIGRAPHX_REGISTER_OP(miopen_op);
std::size_t compile_miopen::compile(operation& op, instruction_ref ins, bool format) const
{
op.from_value({{"int8_x4_format", format}});
auto v = op.compile(*ctx, ins->get_shape(), to_shapes(ins->inputs()));
return v.get<std::size_t>("workspace", 0);
}
void compile_miopen::apply(module& m) const
{
assert(ctx);
const bool int8_x4_format = get_int8_x4_format(any_cast<migraphx::gpu::context>(*ctx));
for(auto ins : iterator_for(m))
{
if(ins->name() != "gpu::miopen_op")
continue;
auto op = any_cast<miopen_op>(ins->get_operator()).op;
std::size_t ws = 0;
try
{
// for the regular convolution and deconvolution, this try would always succeed
ws = compile(op, ins, int8_x4_format);
}
catch(migraphx::exception&)
{
// In case no solver supports the default format, retry using the other format.
ws = compile(op, ins, not int8_x4_format);
}
auto inputs = ins->inputs();
auto alloc = m.insert_instruction(
ins, make_op("allocate", {{"shape", to_value(shape{shape::int8_type, {ws}})}}));
inputs.insert(std::prev(inputs.end()), alloc);
m.replace_instruction(ins, op, inputs);
}
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -39,19 +39,26 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_COMPILE_PARALLEL);
struct precompile_op
{
operation op = op::identity{};
operation op = op::identity{};
std::size_t additional_args = 1;
bool ignore_modules = false;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"));
return pack(f(self.op, "op"),
f(self.additional_args, "additional_args"),
f(self.ignore_modules, "ignore_modules"));
}
std::string name() const { return "gpu::precompile_op"; }
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{
inputs.pop_back();
// Pop off additional args
inputs.resize(inputs.size() - additional_args);
if(ignore_modules)
return op.compute_shape(inputs);
return op.compute_shape(inputs, mods);
}
......
......@@ -772,11 +772,9 @@ struct find_layernorm_pointwise
{
auto ins = r.result;
auto layernorm = r.instructions["layernorm"];
auto* pm = ins->module_inputs().front();
if(not layernorm->module_inputs().empty())
return;
auto* pm = ins->module_inputs().front();
auto inputs = layernorm->inputs();
inputs.pop_back();
inputs.insert(inputs.end(), ins->inputs().begin() + 1, ins->inputs().end());
......@@ -785,6 +783,37 @@ struct find_layernorm_pointwise
}
};
struct find_concat_pointwise
{
auto matcher() const
{
return precompile_name("pointwise")(
match::arg(0)(precompile_name("concat").bind("concat")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto concat = r.instructions["concat"];
if(not concat->module_inputs().empty())
return;
// TODO: Handle type conversions
if(ins->get_shape().type() != concat->get_shape().type())
return;
auto* pm = ins->module_inputs().front();
auto inputs = concat->inputs();
inputs.pop_back();
inputs.insert(inputs.end(), ins->inputs().begin() + 1, ins->inputs().end());
auto op = concat->get_operator();
op.from_value({{"additional_args", ins->inputs().size() - 1}, {"ignore_modules", true}});
m.replace_instruction(ins, op, inputs, {pm});
}
};
void fuse_ops::apply(module& m) const
{
match::find_matches(m, find_contiguous_pointwise{});
......@@ -793,6 +822,7 @@ void fuse_ops::apply(module& m) const
run_passes(m, {dead_code_elimination{}});
match::find_matches(m,
find_layernorm_pointwise{},
find_concat_pointwise{},
find_gemm_pointwise{},
find_contiguous_tranpose_gemm{},
find_commutative_broadcast{});
......
......@@ -21,7 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <rocblas.h>
#include <rocblas/rocblas.h>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_GPU_COMPILE_MIOPEN_HPP
#define MIGRAPHX_GUARD_GPU_COMPILE_MIOPEN_HPP
#include <migraphx/config.hpp>
#include <migraphx/instruction_ref.hpp>
#include <string>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
struct context;
struct operation;
namespace gpu {
struct compile_miopen
{
context* ctx = nullptr;
std::string name() const { return "gpu::compile_miopen"; }
void apply(module& m) const;
std::size_t compile(operation& op, instruction_ref ins, bool format) const;
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_COMPILE_MIOPEN_HPP
......@@ -83,9 +83,10 @@ struct miopen_convolution
inline shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, op}.has(4).standard();
check_shapes{inputs, op}.has(4);
std::vector<shape> conv_inputs(inputs.begin(), inputs.begin() + 2);
check_shapes{conv_inputs, op}.max_ndims(5);
check_shapes{conv_inputs, *this}.max_ndims(5).packed_layouts(
{{0, 1, 2}, {0, 1, 2, 3}, {0, 2, 3, 1}, {0, 1, 2, 3, 4}});
return migraphx::compute_shape<Op>(op, conv_inputs);
}
......@@ -144,12 +145,9 @@ struct miopen_convolution
#endif
}
inline void set_conv_descriptor()
void set_conv_descriptor()
{
if(cd == nullptr)
{
cd = (op.name() == "deconvolution") ? make_deconv(op) : make_conv(op);
}
cd = (op.name() == "deconvolution") ? make_deconv(op) : make_conv(op);
}
value compile(migraphx::context& ctx, const shape& output, const std::vector<shape>& input)
......@@ -239,7 +237,6 @@ struct miopen_convolution
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen " + op.name() + " : find convolution failed");
algo = perf.fwd_algo;
size_t solution_count;
status = miopenConvolutionForwardGetSolutionCount(ctx.get_stream().get_miopen(),
......
......@@ -25,7 +25,7 @@
#define MIGRAPHX_GUARD_MIGRAPHLIB_ROCBLAS_HPP
#include <migraphx/manage_ptr.hpp>
#include <migraphx/config.hpp>
#include <rocblas.h>
#include <rocblas/rocblas.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -38,16 +38,19 @@ using namespace migraphx::gpu::gen; // NOLINT
static const char* const concat_kernel = R"__migraphx__(
#include <migraphx/kernels/concat.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/ops.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
__global__ void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, auto... xs) {
concat<${axis}>(y, xs...);
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, ${concat_params}, auto... xs) {
concat<${axis}>(${concat_args})(${post}, y, xs...);
});
}
......@@ -68,28 +71,42 @@ struct concat_compiler : compiler<concat_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
// TODO: Use reduce_dims
auto num_of_concat_inputs = v.get("concat_inputs", inputs.size() - 1);
hip_compile_options options;
options.inputs = inputs;
options.output = inputs.back();
options.params = "-Wno-float-equal";
options.kernel_name = v.get("kernel", "concat_kernel");
auto axis = find_fast_axis(options.inputs);
auto vec = vectorize::elements(ctx, axis, options.inputs);
options.kernel_name = v.get("kernel", "concat_kernel");
options.set_launch_params(
v, compute_global_for(ctx, get_concat_elements(options.inputs) / vec.size, 256));
auto src = interpolate_string(concat_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"transformers", make_transformer_args(vec)},
{"axis", v.at("axis").to<std::string>()}});
auto src = interpolate_string(
concat_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"concat_params", enum_params(num_of_concat_inputs, "auto concat_x")},
{"concat_args", enum_params(num_of_concat_inputs, "concat_x")},
{"post", v.get("post", std::string{"op::id{}"})},
{"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})},
{"axis", v.at("axis").to<std::string>()}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
auto v = op.to_value();
if(not ins->module_inputs().empty())
{
auto* pm = ins->module_inputs().front();
v["concat_inputs"] = ins->inputs().size() - pm->get_parameter_names().size();
v["preamble"] = generate_pointwise(*pm, "post_concat");
v["post"] = "MIGRAPHX_LIFT(post_concat)";
v["kernel"] = "concat_" + generate_name_from_ops(*pm) + "_kernel";
}
return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
}
};
......
......@@ -58,7 +58,7 @@ __global__ void ${kernel}(${params})
struct pointwise_compiler : compiler<pointwise_compiler>
{
std::vector<std::string> names() const { return {"pointwise", "contiguous"}; }
std::vector<std::string> names() const { return {"pointwise", "contiguous", "layout"}; }
static std::size_t oversubscribe_if(bool b)
{
......@@ -91,12 +91,12 @@ struct pointwise_compiler : compiler<pointwise_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
if(op.name() == "contiguous")
if(contains({"layout", "contiguous"}, op.name()))
{
return replace(compile_op(
ctx,
to_shapes(ins->inputs()),
{{"lambda", "[](auto x) { return x; }"}, {"kernel", "contiguous_kernel"}}));
{{"lambda", "[](auto x) { return x; }"}, {"kernel", op.name() + "_kernel"}}));
}
else
{
......
......@@ -41,7 +41,15 @@ constexpr auto concat_slice(Output out, Input, Start)
return Start{} * output_shape.strides[Axis];
});
constexpr auto s = make_shape(lens, strides);
return make_tensor_view(&out[offset], s);
MIGRAPHX_ASSERT(offset < out.get_shape().element_space());
MIGRAPHX_ASSERT((s.element_space() + offset) <= out.get_shape().element_space());
return make_tensor_view(out.data() + offset, s);
}
template <index_int Axis, class Input, class Start, class... Ts>
constexpr auto concat_slices(Input input, Start start, Ts... xs)
{
return [=](auto f) { f(concat_slice<Axis>(xs, input, start)...); };
}
template <index_int Axis, class Input>
......@@ -51,15 +59,19 @@ constexpr auto concat_ends(Input)
return _c<lens[Axis]>;
}
template <index_int Axis, class Output, class... Inputs>
__device__ void concat(Output output, Inputs... inputs)
template <index_int Axis, class... Inputs>
__device__ auto concat(Inputs... inputs)
{
auto idx = make_index();
fold([&](auto start, auto input) {
auto y = concat_slice<Axis>(output, input, start);
idx.global_stride(input.get_shape().elements(), [&](auto i) { y[i] = input[i]; });
return start + concat_ends<Axis>(input);
})(_c<0>, inputs...);
return [=](auto f, auto... ts) {
auto idx = make_index();
fold([&](auto start, auto input) {
concat_slices<Axis>(input, start, ts...)([&](auto y, auto... xs) {
idx.global_stride(input.get_shape().elements(),
[&](auto i) { y[i] = f(input[i], xs[i]...); });
});
return start + concat_ends<Axis>(input);
})(_c<0>, inputs...);
};
}
} // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment