Commit a9815bf4 authored by charlie's avatar charlie
Browse files

Merge branch 'dyn_ref_multibroadcast' of...

Merge branch 'dyn_ref_multibroadcast' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_transpose
parents b80e2db1 2fa68ded
...@@ -39,19 +39,26 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_COMPILE_PARALLEL); ...@@ -39,19 +39,26 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_COMPILE_PARALLEL);
struct precompile_op struct precompile_op
{ {
operation op = op::identity{}; operation op = op::identity{};
std::size_t additional_args = 1;
bool ignore_modules = false;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.op, "op")); return pack(f(self.op, "op"),
f(self.additional_args, "additional_args"),
f(self.ignore_modules, "ignore_modules"));
} }
std::string name() const { return "gpu::precompile_op"; } std::string name() const { return "gpu::precompile_op"; }
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{ {
inputs.pop_back(); // Pop off additional args
inputs.resize(inputs.size() - additional_args);
if(ignore_modules)
return op.compute_shape(inputs);
return op.compute_shape(inputs, mods); return op.compute_shape(inputs, mods);
} }
......
...@@ -772,11 +772,9 @@ struct find_layernorm_pointwise ...@@ -772,11 +772,9 @@ struct find_layernorm_pointwise
{ {
auto ins = r.result; auto ins = r.result;
auto layernorm = r.instructions["layernorm"]; auto layernorm = r.instructions["layernorm"];
auto* pm = ins->module_inputs().front();
if(not layernorm->module_inputs().empty()) if(not layernorm->module_inputs().empty())
return; return;
auto* pm = ins->module_inputs().front();
auto inputs = layernorm->inputs(); auto inputs = layernorm->inputs();
inputs.pop_back(); inputs.pop_back();
inputs.insert(inputs.end(), ins->inputs().begin() + 1, ins->inputs().end()); inputs.insert(inputs.end(), ins->inputs().begin() + 1, ins->inputs().end());
...@@ -785,6 +783,37 @@ struct find_layernorm_pointwise ...@@ -785,6 +783,37 @@ struct find_layernorm_pointwise
} }
}; };
struct find_concat_pointwise
{
auto matcher() const
{
return precompile_name("pointwise")(
match::arg(0)(precompile_name("concat").bind("concat")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto concat = r.instructions["concat"];
if(not concat->module_inputs().empty())
return;
// TODO: Handle type conversions
if(ins->get_shape().type() != concat->get_shape().type())
return;
auto* pm = ins->module_inputs().front();
auto inputs = concat->inputs();
inputs.pop_back();
inputs.insert(inputs.end(), ins->inputs().begin() + 1, ins->inputs().end());
auto op = concat->get_operator();
op.from_value({{"additional_args", ins->inputs().size() - 1}, {"ignore_modules", true}});
m.replace_instruction(ins, op, inputs, {pm});
}
};
void fuse_ops::apply(module& m) const void fuse_ops::apply(module& m) const
{ {
match::find_matches(m, find_contiguous_pointwise{}); match::find_matches(m, find_contiguous_pointwise{});
...@@ -793,6 +822,7 @@ void fuse_ops::apply(module& m) const ...@@ -793,6 +822,7 @@ void fuse_ops::apply(module& m) const
run_passes(m, {dead_code_elimination{}}); run_passes(m, {dead_code_elimination{}});
match::find_matches(m, match::find_matches(m,
find_layernorm_pointwise{}, find_layernorm_pointwise{},
find_concat_pointwise{},
find_gemm_pointwise{}, find_gemm_pointwise{},
find_contiguous_tranpose_gemm{}, find_contiguous_tranpose_gemm{},
find_commutative_broadcast{}); find_commutative_broadcast{});
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <rocblas.h> #include <rocblas/rocblas.h>
#include <migraphx/gpu/gemm_impl.hpp> #include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp> #include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp> #include <migraphx/permutation.hpp>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_GPU_COMPILE_MIOPEN_HPP
#define MIGRAPHX_GUARD_GPU_COMPILE_MIOPEN_HPP
#include <migraphx/config.hpp>
#include <migraphx/instruction_ref.hpp>
#include <string>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
struct context;
struct operation;
namespace gpu {
struct compile_miopen
{
context* ctx = nullptr;
std::string name() const { return "gpu::compile_miopen"; }
void apply(module& m) const;
std::size_t compile(operation& op, instruction_ref ins, bool format) const;
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_COMPILE_MIOPEN_HPP
...@@ -83,9 +83,10 @@ struct miopen_convolution ...@@ -83,9 +83,10 @@ struct miopen_convolution
inline shape compute_shape(const std::vector<shape>& inputs) const inline shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, op}.has(4).standard(); check_shapes{inputs, op}.has(4);
std::vector<shape> conv_inputs(inputs.begin(), inputs.begin() + 2); std::vector<shape> conv_inputs(inputs.begin(), inputs.begin() + 2);
check_shapes{conv_inputs, op}.max_ndims(5); check_shapes{conv_inputs, *this}.max_ndims(5).packed_layouts(
{{0, 1, 2}, {0, 1, 2, 3}, {0, 2, 3, 1}, {0, 1, 2, 3, 4}});
return migraphx::compute_shape<Op>(op, conv_inputs); return migraphx::compute_shape<Op>(op, conv_inputs);
} }
...@@ -144,12 +145,9 @@ struct miopen_convolution ...@@ -144,12 +145,9 @@ struct miopen_convolution
#endif #endif
} }
inline void set_conv_descriptor() void set_conv_descriptor()
{ {
if(cd == nullptr) cd = (op.name() == "deconvolution") ? make_deconv(op) : make_conv(op);
{
cd = (op.name() == "deconvolution") ? make_deconv(op) : make_conv(op);
}
} }
value compile(migraphx::context& ctx, const shape& output, const std::vector<shape>& input) value compile(migraphx::context& ctx, const shape& output, const std::vector<shape>& input)
...@@ -239,7 +237,6 @@ struct miopen_convolution ...@@ -239,7 +237,6 @@ struct miopen_convolution
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen " + op.name() + " : find convolution failed"); MIGRAPHX_THROW("MIOpen " + op.name() + " : find convolution failed");
algo = perf.fwd_algo; algo = perf.fwd_algo;
size_t solution_count; size_t solution_count;
status = miopenConvolutionForwardGetSolutionCount(ctx.get_stream().get_miopen(), status = miopenConvolutionForwardGetSolutionCount(ctx.get_stream().get_miopen(),
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#define MIGRAPHX_GUARD_MIGRAPHLIB_ROCBLAS_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_ROCBLAS_HPP
#include <migraphx/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <rocblas.h> #include <rocblas/rocblas.h>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -38,16 +38,19 @@ using namespace migraphx::gpu::gen; // NOLINT ...@@ -38,16 +38,19 @@ using namespace migraphx::gpu::gen; // NOLINT
static const char* const concat_kernel = R"__migraphx__( static const char* const concat_kernel = R"__migraphx__(
#include <migraphx/kernels/concat.hpp> #include <migraphx/kernels/concat.hpp>
#include <migraphx/kernels/vectorize.hpp> #include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/ops.hpp>
#include <args.hpp> #include <args.hpp>
namespace migraphx { namespace migraphx {
${preamble}
extern "C" { extern "C" {
__global__ void ${kernel}(${params}) __global__ void ${kernel}(${params})
{ {
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, auto... xs) { transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, ${concat_params}, auto... xs) {
concat<${axis}>(y, xs...); concat<${axis}>(${concat_args})(${post}, y, xs...);
}); });
} }
...@@ -68,28 +71,42 @@ struct concat_compiler : compiler<concat_compiler> ...@@ -68,28 +71,42 @@ struct concat_compiler : compiler<concat_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
// TODO: Use reduce_dims auto num_of_concat_inputs = v.get("concat_inputs", inputs.size() - 1);
hip_compile_options options; hip_compile_options options;
options.inputs = inputs; options.inputs = inputs;
options.output = inputs.back(); options.output = inputs.back();
options.params = "-Wno-float-equal"; options.params = "-Wno-float-equal";
options.kernel_name = v.get("kernel", "concat_kernel");
auto axis = find_fast_axis(options.inputs); auto axis = find_fast_axis(options.inputs);
auto vec = vectorize::elements(ctx, axis, options.inputs); auto vec = vectorize::elements(ctx, axis, options.inputs);
options.kernel_name = v.get("kernel", "concat_kernel");
options.set_launch_params( options.set_launch_params(
v, compute_global_for(ctx, get_concat_elements(options.inputs) / vec.size, 256)); v, compute_global_for(ctx, get_concat_elements(options.inputs) / vec.size, 256));
auto src = interpolate_string(concat_kernel, auto src = interpolate_string(
{{"kernel", options.kernel_name}, concat_kernel,
{"params", enum_params(inputs.size(), "void * private_p")}, {{"kernel", options.kernel_name},
{"args", enum_params(inputs.size(), "private_p")}, {"params", enum_params(inputs.size(), "void * private_p")},
{"transformers", make_transformer_args(vec)}, {"args", enum_params(inputs.size(), "private_p")},
{"axis", v.at("axis").to<std::string>()}}); {"concat_params", enum_params(num_of_concat_inputs, "auto concat_x")},
{"concat_args", enum_params(num_of_concat_inputs, "concat_x")},
{"post", v.get("post", std::string{"op::id{}"})},
{"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})},
{"axis", v.at("axis").to<std::string>()}});
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{ {
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value())); auto v = op.to_value();
if(not ins->module_inputs().empty())
{
auto* pm = ins->module_inputs().front();
v["concat_inputs"] = ins->inputs().size() - pm->get_parameter_names().size();
v["preamble"] = generate_pointwise(*pm, "post_concat");
v["post"] = "MIGRAPHX_LIFT(post_concat)";
v["kernel"] = "concat_" + generate_name_from_ops(*pm) + "_kernel";
}
return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
} }
}; };
......
...@@ -58,7 +58,7 @@ __global__ void ${kernel}(${params}) ...@@ -58,7 +58,7 @@ __global__ void ${kernel}(${params})
struct pointwise_compiler : compiler<pointwise_compiler> struct pointwise_compiler : compiler<pointwise_compiler>
{ {
std::vector<std::string> names() const { return {"pointwise", "contiguous"}; } std::vector<std::string> names() const { return {"pointwise", "contiguous", "layout"}; }
static std::size_t oversubscribe_if(bool b) static std::size_t oversubscribe_if(bool b)
{ {
...@@ -91,12 +91,12 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -91,12 +91,12 @@ struct pointwise_compiler : compiler<pointwise_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{ {
if(op.name() == "contiguous") if(contains({"layout", "contiguous"}, op.name()))
{ {
return replace(compile_op( return replace(compile_op(
ctx, ctx,
to_shapes(ins->inputs()), to_shapes(ins->inputs()),
{{"lambda", "[](auto x) { return x; }"}, {"kernel", "contiguous_kernel"}})); {{"lambda", "[](auto x) { return x; }"}, {"kernel", op.name() + "_kernel"}}));
} }
else else
{ {
......
...@@ -41,7 +41,15 @@ constexpr auto concat_slice(Output out, Input, Start) ...@@ -41,7 +41,15 @@ constexpr auto concat_slice(Output out, Input, Start)
return Start{} * output_shape.strides[Axis]; return Start{} * output_shape.strides[Axis];
}); });
constexpr auto s = make_shape(lens, strides); constexpr auto s = make_shape(lens, strides);
return make_tensor_view(&out[offset], s); MIGRAPHX_ASSERT(offset < out.get_shape().element_space());
MIGRAPHX_ASSERT((s.element_space() + offset) <= out.get_shape().element_space());
return make_tensor_view(out.data() + offset, s);
}
template <index_int Axis, class Input, class Start, class... Ts>
constexpr auto concat_slices(Input input, Start start, Ts... xs)
{
return [=](auto f) { f(concat_slice<Axis>(xs, input, start)...); };
} }
template <index_int Axis, class Input> template <index_int Axis, class Input>
...@@ -51,15 +59,19 @@ constexpr auto concat_ends(Input) ...@@ -51,15 +59,19 @@ constexpr auto concat_ends(Input)
return _c<lens[Axis]>; return _c<lens[Axis]>;
} }
template <index_int Axis, class Output, class... Inputs> template <index_int Axis, class... Inputs>
__device__ void concat(Output output, Inputs... inputs) __device__ auto concat(Inputs... inputs)
{ {
auto idx = make_index(); return [=](auto f, auto... ts) {
fold([&](auto start, auto input) { auto idx = make_index();
auto y = concat_slice<Axis>(output, input, start); fold([&](auto start, auto input) {
idx.global_stride(input.get_shape().elements(), [&](auto i) { y[i] = input[i]; }); concat_slices<Axis>(input, start, ts...)([&](auto y, auto... xs) {
return start + concat_ends<Axis>(input); idx.global_stride(input.get_shape().elements(),
})(_c<0>, inputs...); [&](auto i) { y[i] = f(input[i], xs[i]...); });
});
return start + concat_ends<Axis>(input);
})(_c<0>, inputs...);
};
} }
} // namespace migraphx } // namespace migraphx
......
...@@ -29,19 +29,14 @@ ...@@ -29,19 +29,14 @@
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/deconvolution.hpp>
#include <migraphx/op/dot.hpp> #include <migraphx/op/dot.hpp>
#include <migraphx/op/if_op.hpp> #include <migraphx/op/if_op.hpp>
#include <migraphx/op/reshape.hpp> #include <migraphx/op/reshape.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_dot.hpp> #include <migraphx/op/quant_dot.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/device_name.hpp> #include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/gemm.hpp> #include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/rocblas.hpp> #include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/compiler.hpp> #include <migraphx/gpu/compiler.hpp>
...@@ -109,9 +104,9 @@ struct miopen_apply ...@@ -109,9 +104,9 @@ struct miopen_apply
add_extend_op("scatter_none"); add_extend_op("scatter_none");
add_extend_op("topk"); add_extend_op("topk");
add_convolution_op<op::convolution>("convolution"); add_convolution_op("convolution");
add_convolution_op<op::deconvolution>("deconvolution"); add_convolution_op("deconvolution");
add_convolution_op<op::quant_convolution>("quant_convolution"); add_convolution_op("quant_convolution");
add_gemm_op<op::dot>("dot"); add_gemm_op<op::dot>("dot");
add_gemm_op<op::quant_dot>("quant_dot"); add_gemm_op<op::quant_dot>("quant_dot");
add_if_op(); add_if_op();
...@@ -238,34 +233,19 @@ struct miopen_apply ...@@ -238,34 +233,19 @@ struct miopen_apply
}); });
} }
template <typename Op>
void add_convolution_op(const std::string& name) void add_convolution_op(const std::string& name)
{ {
apply_map.emplace(name, [=](instruction_ref ins) { apply_map.emplace(name, [=](instruction_ref ins) {
operation conv = operation conv = make_op(
miopen_convolution<Op>{any_cast<Op>(ins->get_operator()), int8_x4_format}; "gpu::" + name,
migraphx::context ctx = get_context(); {{"op", ins->get_operator().to_value()}, {"int8_x4_format", int8_x4_format}});
size_t ws_bytes = 0; auto output = insert_allocation(ins, ins->get_shape());
auto compile_conv_with_format = [&](bool format) {
conv = miopen_convolution<Op>{any_cast<Op>(ins->get_operator()), format};
auto ws = conv.compile(ctx, ins->get_shape(), to_shapes(ins->inputs()));
ws_bytes = ws.get("workspace", 0);
};
try
{ // for the regular convolution and deconvolution, this try would always succeed
compile_conv_with_format(int8_x4_format);
}
catch(migraphx::exception&)
{
// In case no solver supports the default format, retry using the other format.
compile_conv_with_format(not int8_x4_format);
}
auto args = ins->inputs(); return mod->replace_instruction(ins,
auto output = insert_allocation(ins, ins->get_shape()); make_op("gpu::miopen_op", {{"op", to_value(conv)}}),
auto workspace = insert_allocation(ins, shape{shape::int8_type, {ws_bytes}}); ins->inputs().at(0),
return mod->replace_instruction(ins, conv, args[0], args[1], workspace, output); ins->inputs().at(1),
output);
}); });
} }
......
...@@ -101,7 +101,10 @@ struct mlir_handle ...@@ -101,7 +101,10 @@ struct mlir_handle
mlir_handle(T p) : handle(ptr{p}) {} mlir_handle(T p) : handle(ptr{p}) {}
T get() const { return handle.get().get(); } T get() const
{
return handle.get().get(); // NOLINT(readability-redundant-smartptr-get)
}
T release() { return handle.release().get(); } T release() { return handle.release().get(); }
......
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include <migraphx/fuse_pointwise.hpp> #include <migraphx/fuse_pointwise.hpp>
#include <migraphx/inline_module.hpp> #include <migraphx/inline_module.hpp>
#include <migraphx/insert_pad.hpp> #include <migraphx/insert_pad.hpp>
#include <migraphx/layout_nhwc.hpp>
#include <migraphx/memory_coloring.hpp> #include <migraphx/memory_coloring.hpp>
#include <migraphx/normalize_ops.hpp> #include <migraphx/normalize_ops.hpp>
#include <migraphx/preallocate_param.hpp> #include <migraphx/preallocate_param.hpp>
...@@ -50,6 +51,7 @@ ...@@ -50,6 +51,7 @@
#include <migraphx/simplify_qdq.hpp> #include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <migraphx/gpu/allocation_model.hpp> #include <migraphx/gpu/allocation_model.hpp>
#include <migraphx/gpu/compile_miopen.hpp>
#include <migraphx/gpu/compile_ops.hpp> #include <migraphx/gpu/compile_ops.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp> #include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
...@@ -70,6 +72,7 @@ namespace gpu { ...@@ -70,6 +72,7 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC)
struct id_pass struct id_pass
{ {
...@@ -120,6 +123,9 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -120,6 +123,9 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
simplify_algebra{}, simplify_algebra{},
simplify_reshapes{}, simplify_reshapes{},
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), layout_nhwc{}),
dead_code_elimination{},
simplify_reshapes{},
simplify_algebra{}, simplify_algebra{},
prefuse_ops{}, prefuse_ops{},
dead_code_elimination{}, dead_code_elimination{},
...@@ -136,8 +142,12 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -136,8 +142,12 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
eliminate_concat{concat_gpu_optimization{}}, eliminate_concat{concat_gpu_optimization{}},
dead_code_elimination{}, dead_code_elimination{},
compile_miopen{&gctx},
dead_code_elimination{},
pack_int8_args{}, pack_int8_args{},
dead_code_elimination{}, dead_code_elimination{},
adjust_allocation{gpu_allocation_model{}},
dead_code_elimination{},
fuse_ops{&ctx, options.fast_math}, fuse_ops{&ctx, options.fast_math},
dead_code_elimination{}, dead_code_elimination{},
replace_allocate{gpu_allocation_model{}, options.offload_copy}, replace_allocate{gpu_allocation_model{}, options.offload_copy},
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/layout_nhwc.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
void run_pass(migraphx::module& m)
{
migraphx::run_passes(m, {migraphx::layout_nhwc{}, migraphx::dead_code_elimination{}});
}
migraphx::operation layout(std::vector<int64_t> permutation = {0, 1, 2, 3})
{
return migraphx::make_op("layout", {{"permutation", permutation}});
}
migraphx::instruction_ref add_layout_nhwc(migraphx::module& m, migraphx::instruction_ref ins)
{
return m.add_instruction(layout({0, 2, 3, 1}), ins);
}
TEST_CASE(conv_relu)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {1, 8, 16, 16}});
auto w = m1.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {16, 8, 3, 3}}));
auto conv = m1.add_instruction(
migraphx::make_op("convolution",
{{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
x,
w);
m1.add_instruction(migraphx::make_op("relu"), conv);
}
run_pass(m1);
migraphx::module m2;
{
auto x = add_layout_nhwc(
m2, m2.add_parameter("x", {migraphx::shape::float_type, {1, 8, 16, 16}}));
auto w = add_layout_nhwc(m2,
m2.add_literal(migraphx::generate_literal(
{migraphx::shape::float_type, {16, 8, 3, 3}})));
auto conv = m2.add_instruction(
migraphx::make_op("convolution",
{{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
x,
w);
auto conv_layout = m2.add_instruction(layout(), conv);
m2.add_instruction(migraphx::make_op("relu"), conv_layout);
}
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(conv_add)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {1, 8, 16, 16}});
auto w = m1.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {16, 8, 3, 3}}));
auto y = m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {16}}));
auto conv = m1.add_instruction(
migraphx::make_op("convolution",
{{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
x,
w);
auto b = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", conv->get_shape().lens()}}),
y);
m1.add_instruction(migraphx::make_op("add"), conv, b);
}
run_pass(m1);
migraphx::module m2;
{
auto x = add_layout_nhwc(
m2, m2.add_parameter("x", {migraphx::shape::float_type, {1, 8, 16, 16}}));
auto w = add_layout_nhwc(m2,
m2.add_literal(migraphx::generate_literal(
{migraphx::shape::float_type, {16, 8, 3, 3}})));
auto y = m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {16}}));
auto conv = m2.add_instruction(
migraphx::make_op("convolution",
{{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
x,
w);
auto conv_layout = m2.add_instruction(layout(), conv);
auto b = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", conv->get_shape().lens()}}),
y);
m2.add_instruction(migraphx::make_op("add"), conv_layout, b);
}
EXPECT(m1.sort() == m2.sort());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -454,6 +454,23 @@ def binary_dyn_brcst_add_test(): ...@@ -454,6 +454,23 @@ def binary_dyn_brcst_add_test():
return ([node], [arg0, arg1], [arg_out]) return ([node], [arg0, arg1], [arg_out])
@onnx_test
def binary_dyn_brcst_attr_error_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT16, [4, 5])
arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT,
[None, 3, 4, 5])
arg_out = helper.make_tensor_value_info('out', TensorProto.FLOAT,
[None, 3, 4, 5])
node = onnx.helper.make_node('Add',
inputs=['0', '1'],
outputs=['out'],
broadcast=1,
axis=1)
return ([node], [arg0, arg1], [arg_out])
@onnx_test @onnx_test
def binary_dyn_brcst_mul_test(): def binary_dyn_brcst_mul_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT,
...@@ -5738,6 +5755,92 @@ def split_test_default(): ...@@ -5738,6 +5755,92 @@ def split_test_default():
return ([node], [x], [y1, y2]) return ([node], [x], [y1, y2])
@onnx_test
def split_test_no_attribute():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [300, 15])
y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [75, 15])
y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [75, 15])
y3 = helper.make_tensor_value_info('y3', TensorProto.FLOAT, [75, 15])
y4 = helper.make_tensor_value_info('y4', TensorProto.FLOAT, [75, 15])
split = np.ones(4) * 75
split_tensor = helper.make_tensor(name="split",
data_type=TensorProto.INT64,
dims=split.shape,
vals=split.astype(np.int64))
const_node = helper.make_node("Constant",
inputs=[],
outputs=['split'],
value=split_tensor)
node = onnx.helper.make_node(
'Split',
inputs=['x', 'split'],
outputs=['y1', 'y2', 'y3', 'y4'],
)
return ([const_node, node], [x], [y1, y2, y3, y4])
@onnx_test
def split_test_no_attribute_invalid_split():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [300, 15])
y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [75, 15])
y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [75, 15])
y3 = helper.make_tensor_value_info('y3', TensorProto.FLOAT, [75, 15])
y4 = helper.make_tensor_value_info('y4', TensorProto.FLOAT, [75, 15])
split = np.ones(4)
split_tensor = helper.make_tensor(name="split",
data_type=TensorProto.INT64,
dims=split.shape,
vals=split.astype(np.int64))
const_node = helper.make_node("Constant",
inputs=[],
outputs=['split'],
value=split_tensor)
node = onnx.helper.make_node(
'Split',
inputs=['x', 'split'],
outputs=['y1', 'y2', 'y3', 'y4'],
)
return ([const_node, node], [x], [y1, y2, y3, y4])
@onnx_test
def split_test_invalid_split():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15])
y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [10, 7])
y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [10, 4])
y3 = helper.make_tensor_value_info('y3', TensorProto.FLOAT, [10, 4])
node = onnx.helper.make_node('Split',
inputs=['x'],
outputs=['y1', 'y2', 'y3'],
axis=1,
split=[1, 1, 1])
return ([node], [x], [y1, y2, y3])
@onnx_test
def split_test_no_attribute_invalid_input_split():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15])
y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [10, 7])
y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [10, 4])
y3 = helper.make_tensor_value_info('y3', TensorProto.FLOAT, [10, 4])
node = onnx.helper.make_node('Split',
inputs=['x'],
outputs=['y1', 'y2', 'y3'],
axis=1,
split=[])
return ([node], [x], [y1, y2, y3])
@onnx_test @onnx_test
def sqrt_test(): def sqrt_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15]) x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15])
......
...@@ -559,6 +559,14 @@ TEST_CASE(binary_dyn_brcst_add_test) ...@@ -559,6 +559,14 @@ TEST_CASE(binary_dyn_brcst_add_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(binary_dyn_brcst_attr_error_test)
{
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
EXPECT(test::throws(
[&] { migraphx::parse_onnx("binary_dyn_brcst_attr_error_test.onnx", options); }));
}
TEST_CASE(binary_dyn_brcst_mul_test) TEST_CASE(binary_dyn_brcst_mul_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -5599,6 +5607,31 @@ TEST_CASE(split_test) ...@@ -5599,6 +5607,31 @@ TEST_CASE(split_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(split_test_no_attribute)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape si{migraphx::shape::int64_type, {4}, {1}};
std::vector<int> ind = {75, 75, 75, 75};
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {300, 15}});
mm->add_literal(migraphx::literal(si, ind));
auto r1 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {75}}}), input);
auto r2 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {75}}, {"ends", {150}}}), input);
auto r3 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {150}}, {"ends", {225}}}), input);
auto r4 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {225}}, {"ends", {300}}}), input);
mm->add_return({r1, r2, r3, r4});
auto prog = migraphx::parse_onnx("split_test_no_attribute.onnx");
EXPECT(p == prog);
}
TEST_CASE(split_test_default) TEST_CASE(split_test_default)
{ {
migraphx::program p; migraphx::program p;
...@@ -5614,6 +5647,23 @@ TEST_CASE(split_test_default) ...@@ -5614,6 +5647,23 @@ TEST_CASE(split_test_default)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(split_test_no_attribute_invalid_split)
{
EXPECT(
test::throws([&] { migraphx::parse_onnx("split_test_no_attribute_invalid_split.onnx"); }));
}
TEST_CASE(split_test_invalid_split)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("split_test_invalid_split.onnx"); }));
}
TEST_CASE(split_test_no_attribute_invalid_input_split)
{
EXPECT(test::throws(
[&] { migraphx::parse_onnx("split_test_no_attribute_invalid_input_split.onnx"); }));
}
TEST_CASE(sqrt_test) TEST_CASE(sqrt_test)
{ {
migraphx::program p; migraphx::program p;
......
split_test_invalid_split:
5
xy1y2y3"Split*
axis*
split@@@split_test_invalid_splitZ
x


b
y1


b
y2


b
y3


B
\ No newline at end of file
split_test_no_attribute:
0split"Constant*
value*:KKKKBsplit
!
x
splity1y2y3y4"Splitsplit_test_no_attributeZ
x


b
y1

K
b
y2

K
b
y3

K
b
y4

K
B
\ No newline at end of file
+split_test_no_attribute_invalid_input_split:
/
xy1y2y3"Split*
axis*
split+split_test_no_attribute_invalid_input_splitZ
x


b
y1


b
y2


b
y3


B
\ No newline at end of file
%split_test_no_attribute_invalid_split:
0split"Constant*
value*:Bsplit
!
x
splity1y2y3y4"Split%split_test_no_attribute_invalid_splitZ
x


b
y1

K
b
y2

K
b
y3

K
b
y4

K
B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment