Commit 72aabeb5 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'layout-nhwc' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into gemm_to_conv

parents 197ca330 3c100748
...@@ -55,6 +55,7 @@ add_library(migraphx ...@@ -55,6 +55,7 @@ add_library(migraphx
insert_pad.cpp insert_pad.cpp
instruction.cpp instruction.cpp
json.cpp json.cpp
layout_nhwc.cpp
load_save.cpp load_save.cpp
make_op.cpp make_op.cpp
module.cpp module.cpp
...@@ -145,6 +146,7 @@ register_migraphx_ops( ...@@ -145,6 +146,7 @@ register_migraphx_ops(
if_op if_op
im2col im2col
isnan isnan
layout
leaky_relu leaky_relu
less less
load load
......
...@@ -59,6 +59,8 @@ void auto_contiguous::apply(module& m) const ...@@ -59,6 +59,8 @@ void auto_contiguous::apply(module& m) const
auto last = std::prev(m.end()); auto last = std::prev(m.end());
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
if(ins->name() == "layout")
continue;
// for last instruction that is NOT a return // for last instruction that is NOT a return
if(ins->outputs().empty() and ins != last) if(ins->outputs().empty() and ins != last)
continue; continue;
......
...@@ -56,6 +56,8 @@ static void create_pointwise_modules(module_pass_manager& mpm) ...@@ -56,6 +56,8 @@ static void create_pointwise_modules(module_pass_manager& mpm)
{ {
if(not ins->get_operator().attributes().get("pointwise", false)) if(not ins->get_operator().attributes().get("pointwise", false))
continue; continue;
if(ins->get_operator().name() == "layout")
continue;
assert(ins->get_operator().attributes().contains("point_op")); assert(ins->get_operator().attributes().contains("point_op"));
auto* pm = mpm.create_module(mpm.get_module().name() + ":pointwise" + std::to_string(n++)); auto* pm = mpm.create_module(mpm.get_module().name() + ":pointwise" + std::to_string(n++));
pm->set_bypass(); pm->set_bypass();
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP
#define MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP #define MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP
#include <migraphx/permutation.hpp>
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
...@@ -232,6 +233,19 @@ struct check_shapes ...@@ -232,6 +233,19 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes are packed with certain layouts
*/
const check_shapes&
packed_layouts(const std::initializer_list<std::vector<int64_t>>& layouts) const
{
if(not this->all_of([&](const shape& s) {
return s.packed() and contains(layouts, find_permutation(s));
}))
MIGRAPHX_THROW(prefix() + "Shapes are not packed with correct layout");
return *this;
}
/*! /*!
* Check all shapes are packed or broadcasted. * Check all shapes are packed or broadcasted.
*/ */
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module_pass_manager;
/**
* Transform convolutions to nhwc
*/
struct layout_nhwc
{
std::string name() const { return "layout_nhwc"; }
void apply(module_pass_manager& mpm) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OP_LAYOUT_HPP
#define MIGRAPHX_GUARD_OP_LAYOUT_HPP
#include <migraphx/config.hpp>
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/op/unary.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct layout : unary<layout>
{
std::vector<int64_t> permutation;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.permutation, "permutation"));
}
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).only_dims(permutation.size());
auto lens = inputs.at(0).lens();
auto t = inputs.at(0).type();
return shape::from_permutation(t, lens, permutation);
}
auto apply() const
{
return [](auto x) { return x; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_OP_LAYOUT_HPP
File mode changed from 100644 to 100755
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/layout_nhwc.hpp>
#include <migraphx/module.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class Predicate>
std::vector<instruction_ref> find_lasts(const module& m, Predicate pred)
{
std::vector<instruction_ref> result;
fix([&](auto self, auto ins) {
if(pred(ins))
{
result.push_back(ins);
return;
}
for(auto input : ins->inputs())
self(input);
})(std::prev(m.end()));
return result;
}
std::unordered_set<instruction_ref> preserve_output_layout(module& m)
{
std::unordered_set<instruction_ref> result;
std::vector<instruction_ref> outputs = find_lasts(m, [](auto ins) {
return ins->name() == "convolution" and ins->get_shape().lens().size() == 4;
});
for(auto output : outputs)
{
auto permutation = find_permutation(output->get_shape());
auto layout = m.insert_instruction(
std::next(output), make_op("layout", {{"permutation", permutation}}), output);
result.insert(m.replace_instruction(output, layout));
}
return result;
}
void transform_convolutions(module& m)
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "convolution")
continue;
if(ins->get_shape().lens().size() != 4)
continue;
auto v = ins->get_operator().to_value();
if(v.at("group").to<int>() > 1)
continue;
auto args = ins->inputs();
std::transform(args.begin(), args.end(), args.begin(), [&](const auto& i) {
return m.insert_instruction(ins, make_op("layout", {{"permutation", {0, 2, 3, 1}}}), i);
});
auto conv = m.insert_instruction(ins, ins->get_operator(), args);
auto c = m.insert_instruction(ins, make_op("contiguous"), conv);
m.replace_instruction(ins, c);
}
}
void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_layouts)
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "layout")
continue;
if(ins->get_shape() != ins->inputs().front()->get_shape())
continue;
if(contains(output_layouts, ins))
continue;
m.replace_instruction(ins, ins->inputs().front());
}
}
void layout_nhwc::apply(module_pass_manager& mpm) const
{
std::unordered_set<instruction_ref> output_layouts = preserve_output_layout(mpm.get_module());
transform_convolutions(mpm.get_module());
mpm.run_pass(dead_code_elimination{});
mpm.run_pass(eliminate_contiguous{"contiguous"});
mpm.run_pass(dead_code_elimination{});
remove_layout(mpm.get_module(), output_layouts);
mpm.run_pass(dead_code_elimination{});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -51,7 +51,18 @@ struct dnnl_binary : dnnl_op<dnnl_binary, dnnl::binary> ...@@ -51,7 +51,18 @@ struct dnnl_binary : dnnl_op<dnnl_binary, dnnl::binary>
auto r = s0; auto r = s0;
if(s0 != s1 or not s0.packed()) if(s0 != s1 or not s0.packed())
{ {
r = shape{s0.type(), s0.lens()}; if(s0.packed() != s1.packed())
{
r = s0.packed() ? s0 : s1;
}
else if(s0.broadcasted() != s1.broadcasted())
{
r = s0.broadcasted() ? s1.with_lens(s0.lens()) : s0.with_lens(s0.lens());
}
else
{
r = {s0.type(), s0.lens()};
}
} }
// Call to get_primitive to make sure an algo is available // Call to get_primitive to make sure an algo is available
this->get_primitive(this->to_memory_desc(r, inputs)); this->get_primitive(this->to_memory_desc(r, inputs));
......
...@@ -43,9 +43,9 @@ struct dnnl_convolution ...@@ -43,9 +43,9 @@ struct dnnl_convolution
return {MIGRAPHX_DNNL_PREFIX(ARG_SRC), MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS)}; return {MIGRAPHX_DNNL_PREFIX(ARG_SRC), MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS)};
} }
shape adjust_shape(const shape& x, int i) const shape adjust_shape(const shape& x, int i, const shape& output) const
{ {
auto s = base_adjust_shape(x); auto s = base_adjust_shape(x, output);
if(i == 1 and op.group > 1) if(i == 1 and op.group > 1)
{ {
// TODO: Add support for transposed weights // TODO: Add support for transposed weights
......
...@@ -37,9 +37,9 @@ struct dnnl_deconvolution ...@@ -37,9 +37,9 @@ struct dnnl_deconvolution
return {MIGRAPHX_DNNL_PREFIX(ARG_SRC), MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS)}; return {MIGRAPHX_DNNL_PREFIX(ARG_SRC), MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS)};
} }
shape adjust_shape(const shape& x, int i) const shape adjust_shape(const shape& x, int i, const shape& output) const
{ {
auto s = base_adjust_shape(x); auto s = base_adjust_shape(x, output);
if(i == 1) if(i == 1)
{ {
// The input and output channels are flipped for dnnl // The input and output channels are flipped for dnnl
......
...@@ -167,7 +167,7 @@ struct dnnl_op : auto_register_op<Derived> ...@@ -167,7 +167,7 @@ struct dnnl_op : auto_register_op<Derived>
std::iota(result.begin(), result.end(), MIGRAPHX_DNNL_PREFIX(ARG_SRC_0)); std::iota(result.begin(), result.end(), MIGRAPHX_DNNL_PREFIX(ARG_SRC_0));
return result; return result;
} }
shape base_adjust_shape(const shape& s) const shape base_adjust_shape(const shape& s, const shape& output) const
{ {
if(s.broadcasted()) if(s.broadcasted())
{ {
...@@ -183,7 +183,8 @@ struct dnnl_op : auto_register_op<Derived> ...@@ -183,7 +183,8 @@ struct dnnl_op : auto_register_op<Derived>
else else
return len; return len;
}); });
return shape{s.type(), lens}; // Use the permutation of the output
return output.with_lens(s.type(), lens);
} }
return s; return s;
} }
...@@ -204,7 +205,10 @@ struct dnnl_op : auto_register_op<Derived> ...@@ -204,7 +205,10 @@ struct dnnl_op : auto_register_op<Derived>
i++; i++;
} }
} }
shape adjust_shape(const shape& s, int) const { return base_adjust_shape(s); } shape adjust_shape(const shape& s, int, const shape& output) const
{
return base_adjust_shape(s, output);
}
std::vector<int> create_arg_map(std::size_t input_size) const std::vector<int> create_arg_map(std::size_t input_size) const
{ {
const auto& self = static_cast<const Derived&>(*this); const auto& self = static_cast<const Derived&>(*this);
...@@ -224,12 +228,12 @@ struct dnnl_op : auto_register_op<Derived> ...@@ -224,12 +228,12 @@ struct dnnl_op : auto_register_op<Derived>
const auto& self = static_cast<const Derived&>(*this); const auto& self = static_cast<const Derived&>(*this);
std::unordered_map<int, dnnl::memory::desc> result; std::unordered_map<int, dnnl::memory::desc> result;
result[MIGRAPHX_DNNL_PREFIX(ARG_DST)] = result[MIGRAPHX_DNNL_PREFIX(ARG_DST)] =
to_dnnl_memory_desc(self.adjust_shape(output_shape, inputs.size())); to_dnnl_memory_desc(self.adjust_shape(output_shape, inputs.size(), output_shape));
auto m = create_arg_map(inputs.size()); auto m = create_arg_map(inputs.size());
assert(m.size() >= inputs.size()); assert(m.size() >= inputs.size());
for(int i = 0; i < inputs.size(); i++) for(int i = 0; i < inputs.size(); i++)
{ {
result[m[i]] = to_dnnl_memory_desc(self.adjust_shape(inputs[i], i)); result[m[i]] = to_dnnl_memory_desc(self.adjust_shape(inputs[i], i, output_shape));
} }
return result; return result;
} }
......
...@@ -32,7 +32,7 @@ struct dnnl_reorder : dnnl_op<dnnl_reorder, dnnl::reorder> ...@@ -32,7 +32,7 @@ struct dnnl_reorder : dnnl_op<dnnl_reorder, dnnl::reorder>
{ {
std::string name() const { return "dnnl::reorder"; } std::string name() const { return "dnnl::reorder"; }
shape adjust_shape(const shape& x, int) const { return x; } shape adjust_shape(const shape& x, int, const shape&) const { return x; }
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
......
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
#include <migraphx/eliminate_data_type.hpp> #include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_identity.hpp> #include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp> #include <migraphx/eliminate_pad.hpp>
#include <migraphx/layout_nhwc.hpp>
#include <migraphx/memory_coloring.hpp> #include <migraphx/memory_coloring.hpp>
#include <migraphx/propagate_constant.hpp> #include <migraphx/propagate_constant.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
...@@ -82,6 +83,9 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -82,6 +83,9 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
simplify_algebra{}, simplify_algebra{},
simplify_reshapes{}, simplify_reshapes{},
layout_nhwc{},
dead_code_elimination{},
simplify_reshapes{},
simplify_algebra{}, simplify_algebra{},
auto_contiguous{}, auto_contiguous{},
simplify_reshapes{}, simplify_reshapes{},
......
...@@ -83,6 +83,7 @@ add_library(migraphx_gpu ...@@ -83,6 +83,7 @@ add_library(migraphx_gpu
compile_gen.cpp compile_gen.cpp
compile_hip.cpp compile_hip.cpp
compile_hip_code_object.cpp compile_hip_code_object.cpp
compile_miopen.cpp
compiler.cpp compiler.cpp
device_name.cpp device_name.cpp
fuse_mlir.cpp fuse_mlir.cpp
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compile_miopen.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/module.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/gpu/rocblas.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct miopen_op
{
operation op = op::identity{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"));
}
std::string name() const { return "gpu::miopen_op"; }
shape compute_shape(std::vector<shape> inputs) const
{
inputs.push_back(inputs.back());
return op.compute_shape(inputs);
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
MIGRAPHX_REGISTER_OP(miopen_op);
std::size_t compile_miopen::compile(operation& op, instruction_ref ins, bool format) const
{
op.from_value({{"int8_x4_format", format}});
auto v = op.compile(*ctx, ins->get_shape(), to_shapes(ins->inputs()));
return v.get("workspace", 0);
}
void compile_miopen::apply(module& m) const
{
assert(ctx);
const bool int8_x4_format = get_int8_x4_format(any_cast<migraphx::gpu::context>(*ctx));
for(auto ins : iterator_for(m))
{
if(ins->name() != "gpu::miopen_op")
continue;
auto op = any_cast<miopen_op>(ins->get_operator()).op;
std::size_t ws = 0;
try
{
// for the regular convolution and deconvolution, this try would always succeed
ws = compile(op, ins, int8_x4_format);
}
catch(migraphx::exception&)
{
// In case no solver supports the default format, retry using the other format.
ws = compile(op, ins, not int8_x4_format);
}
auto inputs = ins->inputs();
auto alloc = m.insert_instruction(
ins, make_op("allocate", {{"shape", to_value(shape{shape::int8_type, {ws}})}}));
inputs.insert(std::prev(inputs.end()), alloc);
m.replace_instruction(ins, op, inputs);
}
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_GPU_COMPILE_MIOPEN_HPP
#define MIGRAPHX_GUARD_GPU_COMPILE_MIOPEN_HPP
#include <migraphx/config.hpp>
#include <migraphx/instruction_ref.hpp>
#include <string>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
struct context;
struct operation;
namespace gpu {
struct compile_miopen
{
context* ctx = nullptr;
std::string name() const { return "gpu::compile_miopen"; }
void apply(module& m) const;
std::size_t compile(operation& op, instruction_ref ins, bool format) const;
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_COMPILE_MIOPEN_HPP
...@@ -83,9 +83,10 @@ struct miopen_convolution ...@@ -83,9 +83,10 @@ struct miopen_convolution
inline shape compute_shape(const std::vector<shape>& inputs) const inline shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, op}.has(4).standard(); check_shapes{inputs, op}.has(4);
std::vector<shape> conv_inputs(inputs.begin(), inputs.begin() + 2); std::vector<shape> conv_inputs(inputs.begin(), inputs.begin() + 2);
check_shapes{conv_inputs, op}.max_ndims(5); check_shapes{conv_inputs, *this}.max_ndims(5).packed_layouts(
{{0, 1, 2}, {0, 1, 2, 3}, {0, 2, 3, 1}, {0, 1, 2, 3, 4}});
return migraphx::compute_shape<Op>(op, conv_inputs); return migraphx::compute_shape<Op>(op, conv_inputs);
} }
...@@ -144,12 +145,9 @@ struct miopen_convolution ...@@ -144,12 +145,9 @@ struct miopen_convolution
#endif #endif
} }
inline void set_conv_descriptor() void set_conv_descriptor()
{ {
if(cd == nullptr) cd = (op.name() == "deconvolution") ? make_deconv(op) : make_conv(op);
{
cd = (op.name() == "deconvolution") ? make_deconv(op) : make_conv(op);
}
} }
value compile(migraphx::context& ctx, const shape& output, const std::vector<shape>& input) value compile(migraphx::context& ctx, const shape& output, const std::vector<shape>& input)
...@@ -239,7 +237,6 @@ struct miopen_convolution ...@@ -239,7 +237,6 @@ struct miopen_convolution
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen " + op.name() + " : find convolution failed"); MIGRAPHX_THROW("MIOpen " + op.name() + " : find convolution failed");
algo = perf.fwd_algo; algo = perf.fwd_algo;
size_t solution_count; size_t solution_count;
status = miopenConvolutionForwardGetSolutionCount(ctx.get_stream().get_miopen(), status = miopenConvolutionForwardGetSolutionCount(ctx.get_stream().get_miopen(),
......
...@@ -58,7 +58,7 @@ __global__ void ${kernel}(${params}) ...@@ -58,7 +58,7 @@ __global__ void ${kernel}(${params})
struct pointwise_compiler : compiler<pointwise_compiler> struct pointwise_compiler : compiler<pointwise_compiler>
{ {
std::vector<std::string> names() const { return {"pointwise", "contiguous"}; } std::vector<std::string> names() const { return {"pointwise", "contiguous", "layout"}; }
static std::size_t oversubscribe_if(bool b) static std::size_t oversubscribe_if(bool b)
{ {
...@@ -91,12 +91,12 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -91,12 +91,12 @@ struct pointwise_compiler : compiler<pointwise_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{ {
if(op.name() == "contiguous") if(contains({"layout", "contiguous"}, op.name()))
{ {
return replace(compile_op( return replace(compile_op(
ctx, ctx,
to_shapes(ins->inputs()), to_shapes(ins->inputs()),
{{"lambda", "[](auto x) { return x; }"}, {"kernel", "contiguous_kernel"}})); {{"lambda", "[](auto x) { return x; }"}, {"kernel", op.name() + "_kernel"}}));
} }
else else
{ {
......
...@@ -29,19 +29,14 @@ ...@@ -29,19 +29,14 @@
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/deconvolution.hpp>
#include <migraphx/op/dot.hpp> #include <migraphx/op/dot.hpp>
#include <migraphx/op/if_op.hpp> #include <migraphx/op/if_op.hpp>
#include <migraphx/op/reshape.hpp> #include <migraphx/op/reshape.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_dot.hpp> #include <migraphx/op/quant_dot.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/device_name.hpp> #include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/gemm.hpp> #include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/rocblas.hpp> #include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/compiler.hpp> #include <migraphx/gpu/compiler.hpp>
...@@ -109,9 +104,9 @@ struct miopen_apply ...@@ -109,9 +104,9 @@ struct miopen_apply
add_extend_op("scatter_none"); add_extend_op("scatter_none");
add_extend_op("topk"); add_extend_op("topk");
add_convolution_op<op::convolution>("convolution"); add_convolution_op("convolution");
add_convolution_op<op::deconvolution>("deconvolution"); add_convolution_op("deconvolution");
add_convolution_op<op::quant_convolution>("quant_convolution"); add_convolution_op("quant_convolution");
add_gemm_op<op::dot>("dot"); add_gemm_op<op::dot>("dot");
add_gemm_op<op::quant_dot>("quant_dot"); add_gemm_op<op::quant_dot>("quant_dot");
add_if_op(); add_if_op();
...@@ -238,34 +233,19 @@ struct miopen_apply ...@@ -238,34 +233,19 @@ struct miopen_apply
}); });
} }
template <typename Op>
void add_convolution_op(const std::string& name) void add_convolution_op(const std::string& name)
{ {
apply_map.emplace(name, [=](instruction_ref ins) { apply_map.emplace(name, [=](instruction_ref ins) {
operation conv = operation conv = make_op(
miopen_convolution<Op>{any_cast<Op>(ins->get_operator()), int8_x4_format}; "gpu::" + name,
migraphx::context ctx = get_context(); {{"op", ins->get_operator().to_value()}, {"int8_x4_format", int8_x4_format}});
size_t ws_bytes = 0; auto output = insert_allocation(ins, ins->get_shape());
auto compile_conv_with_format = [&](bool format) {
conv = miopen_convolution<Op>{any_cast<Op>(ins->get_operator()), format};
auto ws = conv.compile(ctx, ins->get_shape(), to_shapes(ins->inputs()));
ws_bytes = ws.get("workspace", 0);
};
try
{ // for the regular convolution and deconvolution, this try would always succeed
compile_conv_with_format(int8_x4_format);
}
catch(migraphx::exception&)
{
// In case no solver supports the default format, retry using the other format.
compile_conv_with_format(not int8_x4_format);
}
auto args = ins->inputs(); return mod->replace_instruction(ins,
auto output = insert_allocation(ins, ins->get_shape()); make_op("gpu::miopen_op", {{"op", to_value(conv)}}),
auto workspace = insert_allocation(ins, shape{shape::int8_type, {ws_bytes}}); ins->inputs().at(0),
return mod->replace_instruction(ins, conv, args[0], args[1], workspace, output); ins->inputs().at(1),
output);
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment