Commit 870a396b authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

manual merge

parents 228b665c d309e02f
...@@ -33,38 +33,6 @@ ...@@ -33,38 +33,6 @@
namespace migraphx { namespace migraphx {
template <class T>
struct implicit_conversion_op
{
T x;
template <index_int N, class U>
constexpr operator vec<U, N>() const
{
if constexpr(vec_size<T>() == 0)
{
return x;
}
else
{
static_assert(vec_size<T>() == N, "Vector mismatch size");
return __builtin_convertvector(x, vec<U, N>);
}
}
template <class U>
constexpr operator U() const
{
return x;
}
};
template <class T>
constexpr implicit_conversion_op<T> implicit_conversion(T x)
{
return {x};
}
template <class F, class T, class... Ts> template <class F, class T, class... Ts>
__device__ void pointwise_tensor(index idx, F f, T out, Ts... xs) __device__ void pointwise_tensor(index idx, F f, T out, Ts... xs)
{ {
......
...@@ -21,28 +21,29 @@ ...@@ -21,28 +21,29 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_RTGLIB_INT_DIVIDE_HPP #ifndef MIGRAPHX_GUARD_KERNELS_RANGES_HPP
#define MIGRAPHX_GUARD_RTGLIB_INT_DIVIDE_HPP #define MIGRAPHX_GUARD_KERNELS_RANGES_HPP
#include <migraphx/config.hpp> #include <migraphx/kernels/iota_iterator.hpp>
#include <cmath>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class R, class T, class U> template <class Iterator>
R floor_divide(T x, U y) struct iterator_range
{ {
return R(std::floor(double(x) / double(y))); Iterator start;
} Iterator last;
constexpr Iterator begin() const { return start; }
template <class R, class T, class U> constexpr Iterator end() const { return last; }
R ceil_divide(T x, U y) };
constexpr iterator_range<iota_iterator> range(diff_int start, diff_int last)
{ {
return R(std::ceil(double(x) / double(y))); return {{start, {}}, {last, {}}};
} }
constexpr iterator_range<iota_iterator> range(diff_int last) { return range(0, last); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_RANGES_HPP
#endif
...@@ -94,16 +94,17 @@ MIGRAPHX_DPP_REDUCE(op::max, v_max) ...@@ -94,16 +94,17 @@ MIGRAPHX_DPP_REDUCE(op::max, v_max)
MIGRAPHX_DPP_REDUCE(op::min, v_min) MIGRAPHX_DPP_REDUCE(op::min, v_min)
MIGRAPHX_DPP_REDUCE(op::product, v_mul) MIGRAPHX_DPP_REDUCE(op::product, v_mul)
template <class Op, class T, class F> template <class Op, class T, class Index, class F>
__device__ auto block_reduce(index idx, Op op, T init, index_int n, F f) __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
{ {
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
#if __AMDGCN_WAVEFRONT_SIZE == 32 #if __AMDGCN_WAVEFRONT_SIZE == 32
constexpr index_int lanes_per_thread = 16; constexpr index_int lanes_per_thread = 16;
#else #else
constexpr index_int lanes_per_thread = 64; constexpr index_int lanes_per_thread = 64;
#endif #endif
using type = decltype(f(0)); using type = decltype(f(0));
__shared__ type buffer[idx.nlocal() / lanes_per_thread]; __shared__ type buffer[idx.max_nlocal() / lanes_per_thread];
type x = init; type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); }); idx.local_stride(n, [&](auto i) { x = op(x, f(i)); });
dpp_reduce(x, op); dpp_reduce(x, op);
...@@ -123,12 +124,12 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f) ...@@ -123,12 +124,12 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
return y; return y;
} }
#else #else
template <class Op, class T, class F> template <class Op, class T, class Index, class F>
__device__ auto block_reduce(index idx, Op op, T init, index_int n, F f) __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
{ {
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
using type = decltype(f(0)); using type = decltype(f(0));
__shared__ type buffer[idx.nlocal()]; __shared__ type buffer[idx.max_nlocal()];
type x = init; type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); }); idx.local_stride(n, [&](auto i) { x = op(x, f(i)); });
buffer[idx.local] = x; buffer[idx.local] = x;
...@@ -196,17 +197,14 @@ struct block ...@@ -196,17 +197,14 @@ struct block
struct reducer struct reducer
{ {
index idx; index idx;
Slicer slicer; Slicer slice;
template <class Op, class T, class Read> template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const __device__ auto reduce(Op op, T init, Read read) const
{ {
return sliced(slicer, [=](auto x, auto... xs) { return sliced(slice, [=](auto x, auto... xs) {
return vec_reduce(block_reduce(idx, return block_reduce(idx, op, init, x.get_shape().elements(), [&](auto j) {
op, return vec_reduce(read(x[j], xs[j]...), op);
init, });
x.get_shape().elements(),
[&](auto j) { return read(x[j], xs[j]...); }),
op);
}); });
} }
...@@ -220,7 +218,7 @@ struct block ...@@ -220,7 +218,7 @@ struct block
template <class F> template <class F>
__device__ auto inner(F f) const __device__ auto inner(F f) const
{ {
return sliced(slicer, [=](auto x, auto... xs) { return sliced(slice, [=](auto x, auto... xs) {
idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); }); idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); });
}); });
} }
...@@ -228,7 +226,7 @@ struct block ...@@ -228,7 +226,7 @@ struct block
template <class Input> template <class Input>
constexpr auto elements() const constexpr auto elements() const
{ {
using reduce_type = decltype(slicer(Input{})); using reduce_type = decltype(slice(Input{}));
using value_type = typename Input::type; using value_type = typename Input::type;
constexpr auto relements = get_shape_c<reduce_type>{}.elements(); constexpr auto relements = get_shape_c<reduce_type>{}.elements();
if constexpr(vec_size<value_type>() > 1) if constexpr(vec_size<value_type>() > 1)
...@@ -262,11 +260,11 @@ struct lane ...@@ -262,11 +260,11 @@ struct lane
struct reducer struct reducer
{ {
index idx; index idx;
Slicer slicer; Slicer slice;
template <class Op, class T, class Read> template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const __device__ auto reduce(Op op, T init, Read read) const
{ {
return sliced(slicer, [=](auto x, auto... xs) { return sliced(slice, [=](auto x, auto... xs) {
using type = typename decltype(x)::type; using type = typename decltype(x)::type;
type r = init; type r = init;
for(index_int j = 0; j < x.get_shape().elements(); j++) for(index_int j = 0; j < x.get_shape().elements(); j++)
...@@ -286,7 +284,7 @@ struct lane ...@@ -286,7 +284,7 @@ struct lane
template <class F> template <class F>
__device__ auto inner(F f) const __device__ auto inner(F f) const
{ {
return sliced(slicer, [=](auto x, auto... xs) { return sliced(slice, [=](auto x, auto... xs) {
for(index_int j = 0; j < x.get_shape().elements(); j++) for(index_int j = 0; j < x.get_shape().elements(); j++)
{ {
f(x[j], xs[j]...); f(x[j], xs[j]...);
...@@ -297,7 +295,7 @@ struct lane ...@@ -297,7 +295,7 @@ struct lane
template <class Input> template <class Input>
constexpr auto elements() const constexpr auto elements() const
{ {
using reduce_type = decltype(slicer(Input{})); using reduce_type = decltype(slice(Input{}));
return get_shape_c<reduce_type>{}.elements(); return get_shape_c<reduce_type>{}.elements();
} }
}; };
......
...@@ -128,6 +128,7 @@ struct shape ...@@ -128,6 +128,7 @@ struct shape
result[0] = tidx; result[0] = tidx;
return result; return result;
} }
/// Convert multi-index into a single index /// Convert multi-index into a single index
constexpr index_int single(index_array idx) const constexpr index_int single(index_array idx) const
{ {
......
...@@ -33,11 +33,15 @@ template <index_int Axis, class Input, class Output> ...@@ -33,11 +33,15 @@ template <index_int Axis, class Input, class Output>
__device__ void softmax(Input input, Output output) __device__ void softmax(Input input, Output output)
{ {
reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) { reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) {
auto batch_max = r.reduce(op::max{}, lowest{}, op::id{})(input); #ifdef MIGRAPHX_USE_FAST_SOFTMAX
auto batch_sum = const auto c = vec_at(r.slice(input)[0], 0);
r.reduce(op::sum{}, 0, [&](auto x) { return migraphx::exp(x - batch_max); })(input); #else
r.inner([&](auto& y, auto x) { y = migraphx::exp(x - batch_max) / batch_sum; })(output, const auto c = r.reduce(op::max{}, lowest{}, op::id{})(input);
input); #endif
auto batch_sum = r.reduce(op::sum{}, 0, [&](auto x) {
return migraphx::convert<float>(migraphx::exp(x - c));
})(input);
r.inner([&](auto& y, auto x) { y = migraphx::exp(x - c) / batch_sum; })(output, input);
}); });
} }
......
...@@ -185,5 +185,37 @@ constexpr auto vec_reduce(T x, Op op) ...@@ -185,5 +185,37 @@ constexpr auto vec_reduce(T x, Op op)
} }
} }
template <class T>
struct implicit_conversion_op
{
T x;
template <index_int N, class U>
constexpr operator vec<U, N>() const
{
if constexpr(vec_size<T>() == 0)
{
return x;
}
else
{
static_assert(vec_size<T>() == N, "Vector mismatch size");
return __builtin_convertvector(x, vec<U, N>);
}
}
template <class U>
constexpr operator U() const
{
return x;
}
};
template <class T>
constexpr implicit_conversion_op<T> implicit_conversion(T x)
{
return {x};
}
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP #endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/leaky_relu.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/miopen.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape miopen_leaky_relu::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).not_broadcasted();
return inputs.at(1);
}
argument miopen_leaky_relu::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
float alpha = 1;
float beta = 0;
auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape);
miopenActivationForward(ctx.get_stream().get_miopen(),
ad.get(),
&alpha,
x_desc.get(),
args[0].implicit(),
&beta,
y_desc.get(),
args[1].implicit());
return args[1];
}
void miopen_leaky_relu::finalize(context&, const shape&, const std::vector<shape>&)
{
ad = make_leaky_relu(op.alpha);
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -26,24 +26,18 @@ ...@@ -26,24 +26,18 @@
#include <migraphx/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/deconvolution.hpp>
#include <migraphx/op/dot.hpp> #include <migraphx/op/dot.hpp>
#include <migraphx/op/if_op.hpp> #include <migraphx/op/if_op.hpp>
#include <migraphx/op/reshape.hpp> #include <migraphx/op/reshape.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_dot.hpp> #include <migraphx/op/quant_dot.hpp>
#include <migraphx/gpu/batch_norm_inference.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/deconvolution.hpp>
#include <migraphx/gpu/device_name.hpp> #include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/gemm.hpp> #include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/quant_convolution.hpp>
#include <migraphx/gpu/rocblas.hpp> #include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/compiler.hpp> #include <migraphx/gpu/compiler.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
...@@ -81,85 +75,25 @@ struct miopen_apply ...@@ -81,85 +75,25 @@ struct miopen_apply
(void)i; (void)i;
} }
const std::unordered_set<std::string>& get_rocblas_fp32_archs()
{
static std::unordered_set<std::string> supported_archs{"gfx908", "gfx90a"};
return supported_archs;
}
void init() void init()
{ {
assert(mod != nullptr); assert(mod != nullptr);
assert(pass != nullptr); assert(pass != nullptr);
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38 auto& ctx = get_context();
auto& ctx = get_context(); int8_x4_format = get_int8_x4_format(ctx);
const auto device_name = trim(split_string(get_device_name(), ':').front()); compute_fp32 = get_compute_fp32_flag();
if(contains(get_rocblas_fp32_archs(), device_name))
compute_fp32 = true;
rocblas_gemm_flags flag;
rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag);
int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4);
#endif
offload_copy = (mod->name() == "main") ? pass->offload_copy : false; offload_copy = (mod->name() == "main") ? pass->offload_copy : false;
add_generic_op("acos");
add_generic_op("acosh");
add_generic_op("add");
add_generic_op("asin");
add_generic_op("asinh");
add_generic_op("atan");
add_generic_op("atanh");
add_generic_op("ceil");
add_generic_op("contiguous"); add_generic_op("contiguous");
add_generic_op("cos");
add_generic_op("cosh");
add_generic_op("div");
add_generic_op("equal");
add_generic_op("erf");
add_generic_op("exp");
add_generic_op("floor");
add_generic_op("greater");
add_generic_op("less");
add_generic_op("log");
add_generic_op("logical_and");
add_generic_op("logical_or");
add_generic_op("logical_xor");
add_generic_op("max");
add_generic_op("min");
add_generic_op("mul");
add_generic_op("not");
add_generic_op("pow");
add_generic_op("prelu");
add_generic_op("recip");
add_generic_op("relu");
add_generic_op("round");
add_generic_op("rsqrt");
add_generic_op("sigmoid");
add_generic_op("sign");
add_generic_op("sin");
add_generic_op("sinh");
add_generic_op("sqdiff");
add_generic_op("sqrt");
add_generic_op("sub");
add_generic_op("tan");
add_generic_op("tanh");
add_generic_op("where");
add_extend_op("abs");
add_extend_op("argmax"); add_extend_op("argmax");
add_extend_op("argmin"); add_extend_op("argmin");
add_extend_op("clip");
add_extend_op("convert");
add_extend_op("elu");
add_extend_op("gather");
add_extend_op("leaky_relu");
add_extend_op("logsoftmax"); add_extend_op("logsoftmax");
add_extend_op("lrn"); add_extend_op("lrn");
add_extend_op("multinomial"); add_extend_op("multinomial");
add_extend_op("nonzero"); add_extend_op("nonzero");
add_extend_op("pad");
add_extend_op("pooling"); add_extend_op("pooling");
add_extend_op("prefix_scan_sum"); add_extend_op("prefix_scan_sum");
add_extend_op("reverse"); add_extend_op("reverse");
...@@ -169,16 +103,15 @@ struct miopen_apply ...@@ -169,16 +103,15 @@ struct miopen_apply
add_extend_op("scatter_none"); add_extend_op("scatter_none");
add_extend_op("topk"); add_extend_op("topk");
add_batch_norm_inference_op(); add_convolution_op("convolution");
add_convolution_op(); add_convolution_op("deconvolution");
add_deconvolution_op(); add_convolution_op("quant_convolution");
add_gemm_op<op::dot>("dot"); add_gemm_op<op::dot>("dot");
add_gemm_op<op::quant_dot>("quant_dot"); add_gemm_op<op::quant_dot>("quant_dot");
add_if_op(); add_if_op();
add_loop_op(); add_loop_op();
add_neg_op(); add_neg_op();
add_nms_op(); add_nms_op();
add_quant_convolution_op();
} }
void copy_params() const void copy_params() const
...@@ -227,7 +160,8 @@ struct miopen_apply ...@@ -227,7 +160,8 @@ struct miopen_apply
init(); init();
for(auto it = mod->begin(); it != mod->end(); it++) for(auto it = mod->begin(); it != mod->end(); it++)
{ {
auto s = it->get_shape(); auto s = it->get_shape();
auto attrs = it->get_operator().attributes();
if(apply_map.count(it->name()) > 0) if(apply_map.count(it->name()) > 0)
{ {
check_shape(s, apply_map.at(it->name())(it)); check_shape(s, apply_map.at(it->name())(it));
...@@ -236,11 +170,37 @@ struct miopen_apply ...@@ -236,11 +170,37 @@ struct miopen_apply
{ {
check_shape(s, insert_precompile_op(it)); check_shape(s, insert_precompile_op(it));
} }
else if(attrs.contains("target"))
{
check_shape(s, insert_custom_op(it, attrs));
}
} }
copy_params(); copy_params();
} }
instruction_ref insert_custom_op(instruction_ref ins, const value& attrs) const
{
const auto& custom_op = ins->get_operator();
if(attrs.at("target") == "cpu")
{
auto s = ins->get_shape();
std::vector<instruction_ref> cpu_inputs;
auto inputs = ins->inputs();
auto output = inputs.back();
std::transform(
inputs.begin(), inputs.end(), std::back_inserter(cpu_inputs), [&](auto in) {
return mod->insert_instruction(ins, make_op("hip::copy_from_gpu"), in);
});
cpu_inputs.front() =
mod->insert_instruction(ins, make_op("hip::sync_stream"), cpu_inputs);
auto cpu_out = mod->insert_instruction(ins, custom_op, cpu_inputs);
auto gpu_out =
mod->insert_instruction(ins, make_op("hip::copy_to_gpu"), cpu_out, output);
return mod->replace_instruction(ins, gpu_out);
}
return ins;
}
instruction_ref insert_precompile_op(instruction_ref ins) const instruction_ref insert_precompile_op(instruction_ref ins) const
{ {
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
...@@ -259,38 +219,6 @@ struct miopen_apply ...@@ -259,38 +219,6 @@ struct miopen_apply
return mod->insert_instruction(ins, make_op("allocate", {{"shape", to_value(s)}})); return mod->insert_instruction(ins, make_op("allocate", {{"shape", to_value(s)}}));
} }
void add_convolution_op()
{
apply_map.emplace("convolution", [=](instruction_ref ins) {
auto&& op = any_cast<op::convolution>(ins->get_operator());
auto conv = miopen_convolution{op, make_conv(op)};
auto ws = conv.find(get_context(), ins->get_shape(), to_shapes(ins->inputs()));
auto workspace = insert_allocation(ins, ws);
auto output = insert_allocation(ins, ins->get_shape());
return mod->replace_instruction(
ins, conv, ins->inputs().at(0), ins->inputs().at(1), workspace, output);
});
}
void add_deconvolution_op()
{
apply_map.emplace("deconvolution", [=](instruction_ref ins) {
auto&& op = any_cast<op::deconvolution>(ins->get_operator());
auto conv = miopen_deconvolution{op, make_deconv(op)};
auto ws = conv.find(get_context(), ins->get_shape(), to_shapes(ins->inputs()));
auto workspace = insert_allocation(ins, ws);
auto output = insert_allocation(ins, ins->get_shape());
return mod->replace_instruction(
ins, conv, ins->inputs().at(0), ins->inputs().at(1), workspace, output);
});
}
template <typename Op> template <typename Op>
void add_gemm_op(const std::string& name) void add_gemm_op(const std::string& name)
{ {
...@@ -304,32 +232,19 @@ struct miopen_apply ...@@ -304,32 +232,19 @@ struct miopen_apply
}); });
} }
void add_quant_convolution_op() void add_convolution_op(const std::string& name)
{ {
apply_map.emplace("quant_convolution", [=](instruction_ref ins) { apply_map.emplace(name, [=](instruction_ref ins) {
auto&& op = any_cast<op::quant_convolution>(ins->get_operator()); operation conv = make_op(
shape ws; "gpu::" + name,
miopen_quant_convolution conv; {{"op", ins->get_operator().to_value()}, {"int8_x4_format", int8_x4_format}});
auto compile_quant_conv_with_format = [&](bool format) { auto output = insert_allocation(ins, ins->get_shape());
conv = miopen_quant_convolution{op, format, make_conv(op)};
ws = conv.find(get_context(), ins->get_shape(), to_shapes(ins->inputs()));
};
try
{
compile_quant_conv_with_format(int8_x4_format);
}
catch(migraphx::exception&)
{
// In case no solver supports the default format, retry using the other format.
compile_quant_conv_with_format(not int8_x4_format);
}
auto args = ins->inputs();
auto workspace = insert_allocation(ins, ws);
auto output = insert_allocation(ins, ins->get_shape());
return mod->replace_instruction(ins, conv, args[0], args[1], workspace, output); return mod->replace_instruction(ins,
make_op("gpu::miopen_op", {{"op", to_value(conv)}}),
ins->inputs().at(0),
ins->inputs().at(1),
output);
}); });
} }
...@@ -363,43 +278,6 @@ struct miopen_apply ...@@ -363,43 +278,6 @@ struct miopen_apply
}); });
} }
void add_batch_norm_inference_op()
{
apply_map.emplace("batch_norm_inference", [=](instruction_ref ins) {
auto&& op = any_cast<op::batch_norm_inference>(ins->get_operator());
auto output = insert_allocation(ins, ins->get_shape());
shape old_shape = ins->inputs().at(1)->get_shape();
auto input = ins->inputs()[0];
auto input_lens = input->get_shape().lens();
std::vector<int64_t> rsp_lens(input_lens.size(), 1);
// for per_activation case, also need to reshape input
if(op.bn_mode == op::batch_norm_inference::per_activation)
{
std::copy(input_lens.begin() + 1, input_lens.end(), rsp_lens.begin() + 1);
}
else
{
rsp_lens[1] = static_cast<int64_t>(old_shape.elements());
}
auto reshape_op = op::reshape{rsp_lens};
std::vector<instruction_ref> reshapes;
std::transform(ins->inputs().begin() + 1,
ins->inputs().end(),
std::back_inserter(reshapes),
[&](auto i) { return mod->insert_instruction(ins, reshape_op, i); });
return mod->replace_instruction(ins,
miopen_batch_norm_inference{op},
input,
reshapes[0],
reshapes[1],
reshapes[2],
reshapes[3],
output);
});
}
// use 0 - input to represent neg // use 0 - input to represent neg
void add_neg_op() void add_neg_op()
{ {
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include "migraphx/make_op.hpp"
#include <migraphx/gpu/mlir.hpp> #include <migraphx/gpu/mlir.hpp>
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
...@@ -31,7 +32,13 @@ ...@@ -31,7 +32,13 @@
#include <mlir-c/Dialect/MIGraphX.h> #include <mlir-c/Dialect/MIGraphX.h>
#include <mlir-c/IntegerSet.h> #include <mlir-c/IntegerSet.h>
#include <mlir-c/Pass.h> #include <mlir-c/Pass.h>
#include <mlir-c/Registration.h> #include <mutex>
#if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION != 3
#warning "Incompatible version of rocMLIR library used, disabling"
#undef MIGRAPHX_MLIR
#else
#include <mlir-c/RegisterRocMLIR.h>
#endif
#endif #endif
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
...@@ -43,15 +50,12 @@ ...@@ -43,15 +50,12 @@
#include <migraphx/gpu/code_object_op.hpp> #include <migraphx/gpu/code_object_op.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device_name.hpp> #include <migraphx/gpu/device_name.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/perfdb.hpp> #include <migraphx/gpu/perfdb.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/permutation.hpp>
#include <deque> #include <deque>
#include <variant> #include <variant>
#if defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) && MLIR_MIGRAPHX_DIALECT_API_VERSION >= 2
#define MIGRAPHX_MLIR_BARE_POINTER
#endif
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
...@@ -99,7 +103,10 @@ struct mlir_handle ...@@ -99,7 +103,10 @@ struct mlir_handle
mlir_handle(T p) : handle(ptr{p}) {} mlir_handle(T p) : handle(ptr{p}) {}
T get() const { return handle.get().get(); } T get() const
{
return handle.get().get(); // NOLINT(readability-redundant-smartptr-get)
}
T release() { return handle.release().get(); } T release() { return handle.release().get(); }
...@@ -163,9 +170,11 @@ struct mlir_program ...@@ -163,9 +170,11 @@ struct mlir_program
location(mlirLocationUnknownGet(ctx.get())), location(mlirLocationUnknownGet(ctx.get())),
mmodule(mlirModuleCreateEmpty(location)) mmodule(mlirModuleCreateEmpty(location))
{ {
MlirDialectHandle mixr_handle = mlirGetDialectHandle__migraphx__(); MlirDialectRegistry registry = mlirDialectRegistryCreate();
mlirDialectHandleRegisterDialect(mixr_handle, ctx.get()); mlirRegisterRocMLIRDialects(registry);
mlirRegisterAllDialects(ctx.get()); mlirContextAppendDialectRegistry(ctx.get(), registry);
mlirContextLoadAllAvailableDialects(ctx.get());
mlirDialectRegistryDestroy(registry);
mlirContextSetAllowUnregisteredDialects(ctx.get(), true /*allow*/); mlirContextSetAllowUnregisteredDialects(ctx.get(), true /*allow*/);
} }
...@@ -370,7 +379,11 @@ struct mlir_program ...@@ -370,7 +379,11 @@ struct mlir_program
mlir_operation_state& add_results(const std::vector<shape>& outputs) mlir_operation_state& add_results(const std::vector<shape>& outputs)
{ {
auto x = prog->make_tensors(outputs); std::vector<shape> reshaped(outputs.size());
std::transform(outputs.begin(), outputs.end(), reshaped.begin(), [](const shape& r) {
return shape{r.type(), r.lens()};
});
auto x = prog->make_tensors(reshaped);
mlirOperationStateAddResults(&op_state, x.size(), x.data()); mlirOperationStateAddResults(&op_state, x.size(), x.data());
return *this; return *this;
} }
...@@ -443,7 +456,8 @@ struct mlir_program ...@@ -443,7 +456,8 @@ struct mlir_program
auto ops = create_operation_state("func.func"); auto ops = create_operation_state("func.func");
ops.add_attributes({{"function_type", make_function_type(inputs, outputs)}, ops.add_attributes({{"function_type", make_function_type(inputs, outputs)},
{"sym_name", std::string("main")}, {"sym_name", std::string("main")},
{"kernel", std::string("mixr")}}); {"kernel", std::string("mixr")},
{"arch", target_arch}});
ops.add_region(std::move(region)); ops.add_region(std::move(region));
insert(body, std::move(ops)); insert(body, std::move(ops));
...@@ -502,11 +516,13 @@ struct mlir_program ...@@ -502,11 +516,13 @@ struct mlir_program
{ {
pp = pp =
problem_params{ins->get_operator(), to_shapes(ins->inputs()), ins->get_shape()}; problem_params{ins->get_operator(), to_shapes(ins->inputs()), ins->get_shape()};
std::string tuned = get_tune_params(); // check if HW supports xdlops
auto target_chip = trim(split_string(target_arch, ':').front());
bool xdlops = contains(get_xdlops_archs(), target_chip);
std::string tuned = get_tune_params(xdlops);
if(not tuned.empty()) if(not tuned.empty())
ops.add_attributes({{"perf_config", tuned}}); ops.add_attributes({{"perf_config", tuned}});
// check if HW supports xdlops if(xdlops)
if(contains(get_xdlops_archs(), target_name))
ops.add_attributes({{"xdlopsV2", true}}); ops.add_attributes({{"xdlopsV2", true}});
} }
...@@ -530,7 +546,7 @@ struct mlir_program ...@@ -530,7 +546,7 @@ struct mlir_program
// 1st pipeline to call // 1st pipeline to call
mlirMIGraphXAddHighLevelPipeline(pm.get()); mlirMIGraphXAddHighLevelPipeline(pm.get());
// 2nd pipeline to call // 2nd pipeline to call
mlirMIGraphXAddBackendPipeline(pm.get(), target_name.c_str(), "amdgcn-amd-amdhsa", ""); mlirMIGraphXAddBackendPipeline(pm.get(), target_arch.c_str());
mlirPassManagerRun(pm.get(), mmodule.get()); mlirPassManagerRun(pm.get(), mmodule.get());
code_object_op op{}; code_object_op op{};
...@@ -540,16 +556,7 @@ struct mlir_program ...@@ -540,16 +556,7 @@ struct mlir_program
return op; return op;
} }
void find_target() void find_target() { target_arch = get_device_name(); }
{
std::string tname = get_device_name();
// HACK: Since MLIR can't handle the full target name
target_name = trim(split_string(tname, ':').front());
if(tname.size() != target_name.size())
std::cout
<< "*************** WARNING: MLIR may not compile the correct target features for: "
<< tname << std::endl;
}
std::pair<std::size_t, std::size_t> get_launch_params() const std::pair<std::size_t, std::size_t> get_launch_params() const
{ {
...@@ -571,14 +578,14 @@ struct mlir_program ...@@ -571,14 +578,14 @@ struct mlir_program
MIGRAPHX_THROW("Failed to compile mlir program"); MIGRAPHX_THROW("Failed to compile mlir program");
} }
std::string get_tune_params() { return get_mlir_perf_for_conv(pp); } std::string get_tune_params(bool xdlops) { return get_mlir_perf_for_conv(pp, xdlops); }
mlir_context ctx; mlir_context ctx;
MlirLocation location; MlirLocation location;
mlir_module mmodule; mlir_module mmodule;
problem_params pp; problem_params pp;
std::deque<std::string> strings{}; std::deque<std::string> strings{};
std::string target_name; std::string target_arch;
}; };
std::string dump_mlir(const module& m) std::string dump_mlir(const module& m)
...@@ -589,11 +596,61 @@ std::string dump_mlir(const module& m) ...@@ -589,11 +596,61 @@ std::string dump_mlir(const module& m)
return mlir_print(&mlirOperationPrint, mod_op); return mlir_print(&mlirOperationPrint, mod_op);
} }
code_object_op compile_mlir(const context&, const module& m) void adjust_param_shapes(module& m, const std::vector<instruction_ref>& inputs)
{ {
auto names = m.get_parameter_names();
std::sort(names.begin(), names.end());
for(auto i : range(names.size()))
{
const auto& name = names[i];
const auto& input = inputs[i]->get_shape();
auto param = m.get_parameter(name);
if(input.standard())
continue;
auto lens = input.lens();
auto strides = input.strides();
std::vector<operation> ops;
if(input.transposed())
{
auto perm = find_permutation(input);
auto iperm = invert_permutation(perm);
lens = reorder_dims(lens, iperm);
strides = reorder_dims(strides, iperm);
ops.push_back(make_op("transpose", {{"permutation", perm}}));
}
if(input.broadcasted())
{
std::transform(lens.begin(),
lens.end(),
strides.begin(),
lens.begin(),
[](auto len, auto stride) -> std::size_t {
if(stride == 0)
return 1;
return len;
});
ops.push_back(make_op("multibroadcast", {{"out_lens", input.lens()}}));
}
auto new_param =
std::accumulate(ops.begin(),
ops.end(),
m.add_parameter(name + ".0", shape{input.type(), lens}),
[&](auto x, auto op) { return m.insert_instruction(param, op, x); });
m.replace_instruction(param, new_param);
m.remove_instruction(param);
}
}
code_object_op compile_mlir(const context&, module m, const std::vector<instruction_ref>& inputs)
{
adjust_param_shapes(m, inputs);
const bool trace = enabled(MIGRAPHX_TRACE_MLIR{}); const bool trace = enabled(MIGRAPHX_TRACE_MLIR{});
if(trace) if(trace)
std::cout << m << std::endl; std::cout << m << std::endl;
// set mutex while llvm thread support is disabled.
static std::mutex g_mlirc_mutex; // NOLINT
const std::lock_guard<std::mutex> lock(g_mlirc_mutex);
mlir_program mp; mlir_program mp;
mp.find_target(); mp.find_target();
mp.parse(m); mp.parse(m);
...@@ -613,46 +670,9 @@ instruction_ref insert_mlir(module& m, ...@@ -613,46 +670,9 @@ instruction_ref insert_mlir(module& m,
std::vector<instruction_ref> refs; std::vector<instruction_ref> refs;
std::size_t last = 0; std::size_t last = 0;
#ifdef MIGRAPHX_MLIR_BARE_POINTER
refs.reserve(inputs.size()); refs.reserve(inputs.size());
std::copy(inputs.begin(), inputs.end(), std::back_inserter(refs)); std::copy(inputs.begin(), inputs.end(), std::back_inserter(refs));
last = refs.size() - 1; last = refs.size() - 1;
#else
refs.reserve(inputs.size() * 15);
std::unordered_map<uint64_t, instruction_ref> literal_map{};
auto get_literal = [&](uint64_t value) {
auto fi = literal_map.find(value);
if(fi != literal_map.end())
return fi->second;
auto lit = m.add_literal(value);
literal_map.emplace(value, lit);
return lit;
};
for(auto input : inputs)
{
const size_t offset = 0;
auto s = input->get_shape();
last = refs.size();
refs.push_back(input);
refs.push_back(input);
refs.push_back(get_literal(offset)); // offset
// dim sizes
std::transform(s.lens().begin(),
s.lens().end(),
std::back_inserter(refs),
[&](const auto& lval) { return get_literal(lval); });
// refs.push_back(get_literal(1)); // G
// dim strides
std::transform(s.strides().begin(),
s.strides().end(),
std::back_inserter(refs),
[&](const auto& lval) { return get_literal(lval); });
// refs.push_back(get_literal(1)); // G
}
#endif
co.expected_inputs = to_shapes(refs); co.expected_inputs = to_shapes(refs);
co.output_arg = last; co.output_arg = last;
return m.insert_instruction(ins, co, refs); return m.insert_instruction(ins, co, refs);
...@@ -662,13 +682,19 @@ instruction_ref insert_mlir(module& m, ...@@ -662,13 +682,19 @@ instruction_ref insert_mlir(module& m,
std::string dump_mlir(const module&) { return {}; } std::string dump_mlir(const module&) { return {}; }
code_object_op compile_mlir(const context&, const module&) { return {}; }
template <class T> template <class T>
void use(T&) void use(T&)
{ {
} }
// Disabling clang-tidy warning on non-real useage.
// NOLINTBEGIN(performance-unnecessary-value-param)
code_object_op compile_mlir(const context&, module, const std::vector<instruction_ref>&)
{
return {};
}
// NOLINTEND(performance-unnecessary-value-param)
instruction_ref instruction_ref
// cppcheck-suppress funcArgNamesDifferent // cppcheck-suppress funcArgNamesDifferent
insert_mlir(module& m, instruction_ref, code_object_op co, const std::vector<instruction_ref>&) insert_mlir(module& m, instruction_ref, code_object_op co, const std::vector<instruction_ref>&)
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/permutation.hpp> #include <migraphx/permutation.hpp>
#include <fstream> #include <fstream>
#include <mutex>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -88,6 +89,9 @@ std::string generate_miopen_config(const problem_params& pp) ...@@ -88,6 +89,9 @@ std::string generate_miopen_config(const problem_params& pp)
auto query_miopen_db(const std::string& query) auto query_miopen_db(const std::string& query)
{ {
static std::mutex g_db_mutex; // NOLINT
const std::lock_guard<std::mutex> lock(g_db_mutex);
// TODO: Store db as a static variable // TODO: Store db as a static variable
const auto dbpath = fs::path{"/opt"} / "rocm" / "share" / "miopen" / "db" / "miopen.db"; const auto dbpath = fs::path{"/opt"} / "rocm" / "share" / "miopen" / "db" / "miopen.db";
// Check if db file exists. // Check if db file exists.
...@@ -108,16 +112,17 @@ auto query_miopen_db(const std::string& query) ...@@ -108,16 +112,17 @@ auto query_miopen_db(const std::string& query)
} // namespace } // namespace
std::string get_mlir_perf_for_conv(const problem_params& pp) std::string get_mlir_perf_for_conv(const problem_params& pp, bool xdlops)
{ {
std::string query = "select P.* \ std::string solver = xdlops ? "ConvMlirIgemmFwdXdlops" : "ConvMlirIgemmFwd";
std::string query = "select P.* \
from perf_db P, config C \ from perf_db P, config C \
where P.config = C.id AND \ where P.config = C.id AND \
P.solver = 'ConvMlirIgemmFwdXdlops' AND \ P.solver = '${solver}' AND \
${config}"; ${config}";
auto results = auto results = query_miopen_db(
query_miopen_db(interpolate_string(query, {{"config", generate_miopen_config(pp)}})); interpolate_string(query, {{"config", generate_miopen_config(pp)}, {"solver", solver}}));
if(results.empty()) if(results.empty())
return ""; return "";
return results.front().at("params"); return results.front().at("params");
......
...@@ -35,6 +35,12 @@ namespace { ...@@ -35,6 +35,12 @@ namespace {
template <class Derived, std::size_t N> template <class Derived, std::size_t N>
struct layernorm_base struct layernorm_base
{ {
float epsilon = 1e-12f;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.epsilon, "epsilon"));
}
shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const
{ {
std::size_t nargs = 1; std::size_t nargs = 1;
...@@ -45,23 +51,27 @@ struct layernorm_base ...@@ -45,23 +51,27 @@ struct layernorm_base
} }
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(nargs + N); check_shapes{inputs, static_cast<const Derived&>(*this)}.has(nargs + N);
auto s = inputs.at(0); auto s = inputs.at(0);
auto t = s.type();
if(not mods.empty())
t = mods.front()->get_output_shapes().front().type();
if(s.scalar()) if(s.scalar())
{ {
return s; return s;
} }
else if(s.broadcasted()) else if(s.broadcasted())
{ {
return {s.type(), s.lens()}; return {t, s.lens()};
} }
else else
{ {
return s.with_lens(s.lens()); return s.with_lens(t, s.lens());
} }
} }
}; };
struct layernorm : layernorm_base<layernorm, 0> struct layernorm : layernorm_base<layernorm, 0>
{ {
std::string name() const { return "gpu::prelayernorm"; } std::string name() const { return "gpu::prelayernorm"; }
}; };
MIGRAPHX_REGISTER_OP(layernorm); MIGRAPHX_REGISTER_OP(layernorm);
...@@ -80,8 +90,9 @@ struct find_layernorm ...@@ -80,8 +90,9 @@ struct find_layernorm
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
auto eps = r.instructions["eps"]->eval().at<float>();
m.replace_instruction(ins, layernorm{}, x_ins); m.replace_instruction(ins, layernorm{eps}, x_ins);
} }
}; };
...@@ -89,15 +100,17 @@ struct find_add_layernorm ...@@ -89,15 +100,17 @@ struct find_add_layernorm
{ {
auto matcher() const auto matcher() const
{ {
return match::layernorm()(match::var("x")(match::name("add").bind("add"))); return match::layernorm()(
match::var("x")(match::name("add")(match::used_once()).bind("add")));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto eps = r.instructions["eps"]->eval().at<float>();
m.replace_instruction(ins, add_layernorm{}, add_ins->inputs()); m.replace_instruction(ins, add_layernorm{eps}, add_ins->inputs());
} }
}; };
} // namespace } // namespace
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/quant_convolution.hpp>
#include <migraphx/gpu/device/convert.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/generate.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape miopen_quant_convolution::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4).standard();
return op.normalize_compute_shape({inputs.at(0), inputs.at(1)});
}
argument miopen_quant_convolution::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
auto x_desc = make_tensor(args[0].get_shape(), int8_x4_format);
auto w_desc = make_tensor(args[1].get_shape(), int8_x4_format);
auto y_desc = make_tensor(output_shape);
float alpha = 1;
float beta = 0;
auto status = miopenConvolutionForward(ctx.get_stream().get_miopen(),
&alpha,
x_desc.get(),
args[0].implicit(),
w_desc.get(),
args[1].implicit(),
cd.get(),
algo,
&beta,
y_desc.get(),
args[3].implicit(),
args[2].implicit(),
args[2].get_shape().bytes());
if(status != miopenStatusSuccess)
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: run convolution forward failed");
}
return args[3];
}
shape miopen_quant_convolution::find(context& ctx,
const shape& output_shape,
std::vector<shape> inputs)
{
shape workspace_shape{};
auto x_desc = make_tensor(inputs[0], int8_x4_format);
auto w_desc = make_tensor(inputs[1], int8_x4_format);
auto y_desc = make_tensor(output_shape);
std::size_t workspace_size = 0;
miopenConvolutionForwardGetWorkSpaceSize(ctx.get_stream().get_miopen(),
w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
&workspace_size);
workspace_shape = shape{shape::int8_type, {workspace_size}};
auto x_shape = inputs[0];
auto w_shape = inputs[1];
if(int8_x4_format)
{
x_shape = pack_int8_shape(x_shape);
w_shape = pack_int8_shape(w_shape);
}
auto x = to_gpu(generate_argument(x_shape));
auto w = to_gpu(generate_argument(w_shape));
auto y = allocate_gpu(output_shape);
auto workspace = allocate_gpu(workspace_shape);
int algo_count = 1;
miopenConvAlgoPerf_t perf;
auto status = miopenFindConvolutionForwardAlgorithm(ctx.get_stream().get_miopen(),
x_desc.get(),
x.implicit(),
w_desc.get(),
w.implicit(),
cd.get(),
y_desc.get(),
y.implicit(),
1,
&algo_count,
&perf,
workspace.implicit(),
workspace_size,
false);
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Quant Convolution: find convolution failed");
algo = perf.fwd_algo;
size_t solution_count;
status = miopenConvolutionForwardGetSolutionCount(ctx.get_stream().get_miopen(),
w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
&solution_count);
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Quant Convolution: get solution count failed");
std::vector<miopenConvSolution_t> solutions(solution_count);
status = miopenConvolutionForwardGetSolution(ctx.get_stream().get_miopen(),
w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
solution_count,
&solution_count,
solutions.data());
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Quant Convolution: get solution failed");
solution_id = solutions.front().solution_id;
return shape{shape::int8_type, {perf.memory}};
}
void miopen_quant_convolution::finalize(context& ctx,
const shape& output_shape,
std::vector<shape> inputs)
{
if(cd == nullptr)
cd = make_conv(op);
if(solution_id == 0)
{
// Check that workspace hasn't changed
auto size = inputs.at(2).bytes();
auto ws = find(ctx, output_shape, inputs);
if(ws.bytes() > size)
MIGRAPHX_THROW("MIOpen Quant Convolution: workspace has changed during finalization.");
}
auto x_desc = make_tensor(inputs[0], int8_x4_format);
auto w_desc = make_tensor(inputs[1], int8_x4_format);
auto y_desc = make_tensor(output_shape);
auto status = miopenConvolutionForwardCompileSolution(ctx.get_stream().get_miopen(),
w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
solution_id);
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Quant Convolution: compile solution failed");
}
shape miopen_quant_convolution::pack_int8_shape(const shape& s) const
{
if(s.type() != shape::int8_type)
{
MIGRAPHX_THROW("PACK_INT8_SHAPE: only process int8_type");
}
auto lens = s.lens();
auto strides = s.strides();
lens[1] = (lens[1] + 3) / 4 * 4;
strides[0] = strides[1] * lens[1];
return {s.type(), lens, strides};
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -21,7 +21,13 @@ ...@@ -21,7 +21,13 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <unordered_set>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/rocblas.hpp> #include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -41,6 +47,33 @@ rocblas_handle_ptr create_rocblas_handle_ptr(hipStream_t s) ...@@ -41,6 +47,33 @@ rocblas_handle_ptr create_rocblas_handle_ptr(hipStream_t s)
return rb; return rb;
} }
const std::unordered_set<std::string>& get_rocblas_fp32_archs()
{
static std::unordered_set<std::string> supported_archs{"gfx908", "gfx90a"};
return supported_archs;
}
bool get_compute_fp32_flag()
{
bool compute_fp32 = false;
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
const auto device_name = trim(split_string(get_device_name(), ':').front());
if(contains(get_rocblas_fp32_archs(), device_name))
compute_fp32 = true;
#endif
return compute_fp32;
}
bool get_int8_x4_format(context& ctx)
{
bool int8_x4_format = true;
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags flag;
rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag);
int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4);
#endif
return int8_x4_format;
}
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/softmax.hpp>
#include <migraphx/gpu/device/softmax.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/tune_axis.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_softmax::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).standard();
return op.normalize_compute_shape({inputs.at(0)});
}
argument hip_softmax::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
auto n_dim = args.front().get_shape().lens().size();
auto tuned_axis = tune_axis(n_dim, op.axis, op.name());
device::softmax(ctx.get_stream().get(), args.back(), args.front(), tuned_axis);
return args.back();
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -42,7 +42,6 @@ ...@@ -42,7 +42,6 @@
#include <migraphx/propagate_constant.hpp> #include <migraphx/propagate_constant.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/replace_allocate.hpp> #include <migraphx/replace_allocate.hpp>
#include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_gelu.hpp> #include <migraphx/rewrite_gelu.hpp>
#include <migraphx/rewrite_pooling.hpp> #include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_quantization.hpp> #include <migraphx/rewrite_quantization.hpp>
...@@ -52,6 +51,7 @@ ...@@ -52,6 +51,7 @@
#include <migraphx/simplify_qdq.hpp> #include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <migraphx/gpu/allocation_model.hpp> #include <migraphx/gpu/allocation_model.hpp>
#include <migraphx/gpu/compile_miopen.hpp>
#include <migraphx/gpu/compile_ops.hpp> #include <migraphx/gpu/compile_ops.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp> #include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
...@@ -112,8 +112,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -112,8 +112,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
insert_pad{}, insert_pad{},
dead_code_elimination{}, dead_code_elimination{},
rewrite_batchnorm{},
dead_code_elimination{},
rewrite_rnn{}, rewrite_rnn{},
dead_code_elimination{}, dead_code_elimination{},
inline_module{}, inline_module{},
...@@ -145,14 +143,16 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -145,14 +143,16 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
eliminate_concat{concat_gpu_optimization{}}, eliminate_concat{concat_gpu_optimization{}},
dead_code_elimination{}, dead_code_elimination{},
pack_int8_args{}, compile_miopen{&gctx},
dead_code_elimination{}, dead_code_elimination{},
adjust_allocation{gpu_allocation_model{}}, pack_int8_args{},
dead_code_elimination{}, dead_code_elimination{},
fuse_ops{&ctx, options.fast_math}, fuse_ops{&ctx, options.fast_math},
dead_code_elimination{}, dead_code_elimination{},
replace_allocate{gpu_allocation_model{}, options.offload_copy}, replace_allocate{gpu_allocation_model{}, options.offload_copy},
dead_code_elimination{}, dead_code_elimination{},
adjust_allocation{gpu_allocation_model{}},
dead_code_elimination{},
compile_ops{&ctx}, compile_ops{&ctx},
dead_code_elimination{}, dead_code_elimination{},
write_literals{&ctx}, write_literals{&ctx},
......
...@@ -26,15 +26,12 @@ ...@@ -26,15 +26,12 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
#include <migraphx/op/deconvolution.hpp> #include <migraphx/op/deconvolution.hpp>
#include <migraphx/op/quant_convolution.hpp> #include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/dot.hpp> #include <migraphx/op/dot.hpp>
#include <migraphx/op/quant_dot.hpp> #include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/elu.hpp>
#include <migraphx/op/im2col.hpp> #include <migraphx/op/im2col.hpp>
#include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/logsoftmax.hpp> #include <migraphx/op/logsoftmax.hpp>
#include <migraphx/op/loop.hpp> #include <migraphx/op/loop.hpp>
#include <migraphx/op/lrn.hpp> #include <migraphx/op/lrn.hpp>
...@@ -75,84 +72,6 @@ typename std::conditional_t<std::is_integral<T>{}, std::make_signed<T>, std::ena ...@@ -75,84 +72,6 @@ typename std::conditional_t<std::is_integral<T>{}, std::make_signed<T>, std::ena
return x; return x;
} }
//
// ref implemenataion of batch norm for inference
//
// inputs are:
// args[0] -> input data buffer
// args[1] -> mini batch mean
// args[2] -> mini batch variance
// args[3] -> gamma
// args[4] -> bias
//
// The equation to compute batch norm for inference is:
//
// output[i] = bias + gamma * (input[i] + mean) / sqrt(variance + epsilon)
//
// the input data format should be nchw
//
struct ref_batch_norm_inference
{
op::batch_norm_inference op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "ref::batch_norm_inference"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument output{output_shape};
double epsilon = op.epsilon;
auto input = args[0];
auto arg_gamma = args[1];
auto arg_bias = args[2];
auto mini_batch_mean = args[3];
auto mini_batch_variance = args[4];
if(op.bn_mode == op::batch_norm_inference::spatial)
{
visit_all(output, input, mini_batch_mean, mini_batch_variance, arg_gamma, arg_bias)(
[&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) {
par_for(output_shape.elements(), [&](auto i) {
auto idx = output_shape.multi(i);
auto c = idx[1];
assert((variance[c] + epsilon) > 0);
result[i] =
gamma[c] * (buffer[i] - mean[c]) / std::sqrt(variance[c] + epsilon) +
bias[c];
});
});
}
if(op.bn_mode == op::batch_norm_inference::per_activation)
{
visit_all(output, input, mini_batch_mean, mini_batch_variance, arg_gamma, arg_bias)(
[&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) {
par_for(output_shape.elements(), [&](auto i) {
auto idx = output_shape.multi(i);
idx[0] = 0;
auto index = output_shape.index(idx);
assert((variance[index] + epsilon) > 0);
result[i] = gamma[index] * (buffer[i] - mean[index]) /
std::sqrt(variance[index] + epsilon) +
bias[index];
});
});
}
return output;
}
};
MIGRAPHX_REGISTER_OP(ref_batch_norm_inference)
struct ref_lrn struct ref_lrn
{ {
op::lrn op; op::lrn op;
...@@ -237,15 +156,16 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>> ...@@ -237,15 +156,16 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>>
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
std::vector<std::size_t> padding; std::vector<std::size_t> padding;
if(op.use_dynamic_same_auto_pad) if(op.padding_mode != op::padding_mode_t::default_)
{ {
auto input_lens = args[0].get_shape().lens(); auto input_lens = args[0].get_shape().lens();
std::vector<std::size_t> img_lens{input_lens.begin() + 2, input_lens.end()};
auto weights_lens = args[1].get_shape().lens(); auto weights_lens = args[1].get_shape().lens();
std::vector<std::size_t> k_lens{weights_lens.begin() + 2, weights_lens.end()}; padding =
padding = calc_dyn_auto_pad(img_lens, k_lens, op.stride, op.dilation); op.padding_mode == op::same_upper
output_shape = ? calc_dyn_auto_pad(input_lens, weights_lens, op.stride, op.dilation, true)
compute_padded_shape({args.at(0).get_shape(), args.at(1).get_shape()}, padding); : calc_dyn_auto_pad(input_lens, weights_lens, op.stride, op.dilation, false);
output_shape = compute_padded_shape(
args[0].get_shape(), args[1].get_shape(), padding, op.stride, op.dilation);
} }
else else
{ {
...@@ -313,34 +233,6 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>> ...@@ -313,34 +233,6 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>>
}); });
return result; return result;
} }
private:
/*!
* Used for dynamic auto padding since padding needs to be computed at evaulation time.
* \param inputs two fixed shape inputs [input_tensor, weights]
* \param padding from auto_pad calculation
*/
shape compute_padded_shape(const std::vector<shape>& inputs,
const std::vector<std::size_t>& padding) const
{
const shape& input = inputs.at(0);
const shape& weights = inputs.at(1);
const size_t num_spatial_dims = input.lens().size() - 2;
std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]};
// calculate the output shape of the convolution: ((W - K + 2P) / S) + 1
for(size_t i = 0; i < num_spatial_dims; i++)
{
auto padding_factor = padding[i] + padding[i + num_spatial_dims];
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[i + 2] - (1 + op.dilation[i] * (weights.lens()[i + 2] - 1)) +
padding_factor) /
op.stride[i] +
1)));
}
return inputs[0].with_lens(output_lens);
}
}; };
struct ref_im2col struct ref_im2col
...@@ -454,10 +346,10 @@ struct ref_pad ...@@ -454,10 +346,10 @@ struct ref_pad
std::string name() const { return "ref::pad"; } std::string name() const { return "ref::pad"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const dyn_output& dyn_out, std::vector<argument> args) const
{ {
assert(output_shape.standard()); assert(dyn_out.computed_shape.standard());
argument result{output_shape}; argument result{dyn_out.computed_shape};
result.visit([&](auto output) { result.visit([&](auto output) {
using type = typename decltype(output)::value_type; using type = typename decltype(output)::value_type;
std::fill(output.begin(), output.end(), pad_clamp<type>(op.value)); std::fill(output.begin(), output.end(), pad_clamp<type>(op.value));
...@@ -491,9 +383,9 @@ struct ref_gemm ...@@ -491,9 +383,9 @@ struct ref_gemm
std::string name() const { return "ref::dot"; } std::string name() const { return "ref::dot"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{dyn_out.computed_shape};
migemm(result, args[0], args[1], 1.0f, 0.0f); migemm(result, args[0], args[1], 1.0f, 0.0f);
return result; return result;
...@@ -537,65 +429,6 @@ struct ref_quant_gemm ...@@ -537,65 +429,6 @@ struct ref_quant_gemm
}; };
MIGRAPHX_REGISTER_OP(ref_gemm) MIGRAPHX_REGISTER_OP(ref_gemm)
struct leaky_relu_op
{
op::leaky_relu op;
std::string name() const { return "ref::leaky_relu"; }
auto fcn() const
{
auto a = op.alpha;
return [a](auto x) { return x > 0 ? x : x * a; };
}
};
struct elu_op
{
op::elu op;
std::string name() const { return "ref::elu"; }
auto fcn() const
{
auto a = op.alpha;
return [a](auto x) { return x > 0 ? x : a * std::expm1(x); };
}
};
template <typename Op>
struct ref_unary : auto_register_op<ref_unary<Op>>
{
ref_unary() = default;
template <class T>
ref_unary(T pop) : op(Op{std::move(pop)})
{
}
Op op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op.op, f);
}
std::string name() const { return op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(1);
const auto& s = inputs.at(0);
return {s.type(), s.lens()};
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
assert(input.get_shape().standard());
std::transform(input.begin(), input.end(), output.begin(), op.fcn());
});
return result;
}
};
template <class Op> template <class Op>
struct ref_softmax : auto_register_op<ref_softmax<Op>> struct ref_softmax : auto_register_op<ref_softmax<Op>>
{ {
...@@ -616,10 +449,10 @@ struct ref_softmax : auto_register_op<ref_softmax<Op>> ...@@ -616,10 +449,10 @@ struct ref_softmax : auto_register_op<ref_softmax<Op>>
{ {
return op.normalize_compute_shape(inputs); return op.normalize_compute_shape(inputs);
} }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{dyn_out.computed_shape};
auto batch_lens = output_shape.lens(); auto batch_lens = dyn_out.computed_shape.lens();
int64_t tuned_axis = tune_axis(args[0].get_shape().lens().size(), op.axis, op.name()); int64_t tuned_axis = tune_axis(args[0].get_shape().lens().size(), op.axis, op.name());
std::size_t n_dims = batch_lens[tuned_axis]; std::size_t n_dims = batch_lens[tuned_axis];
batch_lens[tuned_axis] = 1; batch_lens[tuned_axis] = 1;
...@@ -642,7 +475,7 @@ struct ref_softmax : auto_register_op<ref_softmax<Op>> ...@@ -642,7 +475,7 @@ struct ref_softmax : auto_register_op<ref_softmax<Op>>
for(std::size_t j = 0; j < n_dims; ++j) for(std::size_t j = 0; j < n_dims; ++j)
{ {
idx[tuned_axis] = j; idx[tuned_axis] = j;
std::size_t index = output_shape.index(idx); std::size_t index = dyn_out.computed_shape.index(idx);
output[index] = std::exp(input[index] - batch_max[i]); output[index] = std::exp(input[index] - batch_max[i]);
} }
...@@ -731,16 +564,12 @@ struct ref_apply ...@@ -731,16 +564,12 @@ struct ref_apply
void init() void init()
{ {
apply_map["batch_norm_inference"] =
extend_op<ref_batch_norm_inference, op::batch_norm_inference>();
apply_map["convolution"] = extend_op<ref_convolution<op::convolution>, op::convolution>(); apply_map["convolution"] = extend_op<ref_convolution<op::convolution>, op::convolution>();
apply_map["dot"] = extend_op<ref_gemm, op::dot>(); apply_map["dot"] = extend_op<ref_gemm, op::dot>();
apply_map["quant_dot"] = extend_op<ref_quant_gemm, op::quant_dot>(); apply_map["quant_dot"] = extend_op<ref_quant_gemm, op::quant_dot>();
apply_map["quant_convolution"] = apply_map["quant_convolution"] =
extend_op<ref_convolution<op::quant_convolution>, op::quant_convolution>(); extend_op<ref_convolution<op::quant_convolution>, op::quant_convolution>();
apply_map["elu"] = extend_op<ref_unary<elu_op>, op::elu>();
apply_map["im2col"] = extend_op<ref_im2col, op::im2col>(); apply_map["im2col"] = extend_op<ref_im2col, op::im2col>();
apply_map["leaky_relu"] = extend_op<ref_unary<leaky_relu_op>, op::leaky_relu>();
apply_map["logsoftmax"] = extend_op<ref_softmax<op::logsoftmax>, op::logsoftmax>(); apply_map["logsoftmax"] = extend_op<ref_softmax<op::logsoftmax>, op::logsoftmax>();
apply_map["lrn"] = extend_op<ref_lrn, op::lrn>(); apply_map["lrn"] = extend_op<ref_lrn, op::lrn>();
apply_map["pad"] = extend_op<ref_pad, op::pad>(); apply_map["pad"] = extend_op<ref_pad, op::pad>();
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
*/ */
#include <migraphx/tf/op_parser.hpp> #include <migraphx/tf/op_parser.hpp>
#include <migraphx/tf/tf_parser.hpp> #include <migraphx/tf/tf_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
...@@ -38,16 +39,37 @@ struct parse_batchnorm : op_parser<parse_batchnorm> ...@@ -38,16 +39,37 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
instruction_ref parse(const op_desc& /*opd*/, instruction_ref parse(const op_desc& /*opd*/,
const tf_parser& /*parser*/, const tf_parser& /*parser*/,
tf_parser::node_info info, tf_parser::node_info info,
const std::vector<instruction_ref>& args) const std::vector<instruction_ref> args) const
{ {
float epsilon = 1e-5f; // different default epsilon than from ONNX
float momentum = 0.9f; float epsilon = 1e-4f;
if(contains(info.attributes, "epsilon")) if(contains(info.attributes, "epsilon"))
{ {
epsilon = info.attributes.at("epsilon").f(); epsilon = info.attributes.at("epsilon").f();
} }
auto op = make_op("batch_norm_inference", {{"epsilon", epsilon}, {"momentum", momentum}});
return info.add_instruction(op, args); auto x_lens = args[0]->get_shape().lens();
auto x_type = args[0]->get_shape().type();
// unsqueeze tensors of shape (C) to broadcast correctly
auto rt = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0.5}});
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}});
auto scale_unsqueeze =
info.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), args[1]);
auto bias_unsqueeze =
info.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), args[2]);
auto mean_unsqueeze =
info.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), args[3]);
auto var_unsqueeze =
info.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), args[4]);
auto numer = info.add_broadcastable_binary_op("sub", args[0], mean_unsqueeze);
auto var_eps = info.add_broadcastable_binary_op("add", var_unsqueeze, eps);
auto denom = info.add_broadcastable_binary_op("pow", var_eps, rt);
auto div0 = info.add_broadcastable_binary_op("div", numer, denom);
auto r0 = info.add_broadcastable_binary_op("mul", div0, scale_unsqueeze);
return info.add_broadcastable_binary_op("add", r0, bias_unsqueeze);
} }
}; };
......
...@@ -75,7 +75,6 @@ struct parse_conv : op_parser<parse_conv> ...@@ -75,7 +75,6 @@ struct parse_conv : op_parser<parse_conv>
const std::string& pad_mode = info.attributes.at("padding").s(); const std::string& pad_mode = info.attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos) if(pad_mode.find("SAME") != std::string::npos)
{ {
op.padding_mode = op::padding_mode_t::same;
std::vector<size_t> weight_dims = weights->get_shape().lens(); std::vector<size_t> weight_dims = weights->get_shape().lens();
size_t weight_h = weight_dims[2]; size_t weight_h = weight_dims[2];
size_t weight_w = weight_dims[3]; size_t weight_w = weight_dims[3];
...@@ -87,10 +86,6 @@ struct parse_conv : op_parser<parse_conv> ...@@ -87,10 +86,6 @@ struct parse_conv : op_parser<parse_conv>
op.padding = std::vector<size_t>(pads.begin(), pads.end()); op.padding = std::vector<size_t>(pads.begin(), pads.end());
} }
else if(pad_mode.find("VALID") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::valid;
}
else if(pad_mode.find("EXPLICIT") != std::string::npos) else if(pad_mode.find("EXPLICIT") != std::string::npos)
{ {
std::vector<size_t> padding; std::vector<size_t> padding;
......
...@@ -80,7 +80,6 @@ struct parse_depthwiseconv : op_parser<parse_depthwiseconv> ...@@ -80,7 +80,6 @@ struct parse_depthwiseconv : op_parser<parse_depthwiseconv>
if(pad_mode.find("SAME") != std::string::npos) if(pad_mode.find("SAME") != std::string::npos)
{ {
op.padding_mode = op::padding_mode_t::same;
std::vector<size_t> weight_dims = weights->get_shape().lens(); std::vector<size_t> weight_dims = weights->get_shape().lens();
size_t weight_h = weight_dims[2]; size_t weight_h = weight_dims[2];
size_t weight_w = weight_dims[3]; size_t weight_w = weight_dims[3];
...@@ -101,10 +100,6 @@ struct parse_depthwiseconv : op_parser<parse_depthwiseconv> ...@@ -101,10 +100,6 @@ struct parse_depthwiseconv : op_parser<parse_depthwiseconv>
op.padding[1] = pads[1]; op.padding[1] = pads[1];
} }
} }
else if(pad_mode.find("VALID") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::valid;
}
} }
std::vector<int64_t> new_weights_shape; std::vector<int64_t> new_weights_shape;
......
...@@ -347,7 +347,7 @@ void tf_parser::parse_node(const std::string& name) ...@@ -347,7 +347,7 @@ void tf_parser::parse_node(const std::string& name)
// input was from a node with multiple outputs // input was from a node with multiple outputs
if(contains(input_name, ':')) if(contains(input_name, ':'))
{ {
input_name = input_name.substr(0, input.find(':')); input_name.resize(input.find(':'));
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment