Commit 6be3baa1 authored by Alan Turner's avatar Alan Turner
Browse files

Merge

parents 3d4fb6ae 214b313f
...@@ -5,14 +5,14 @@ on: ...@@ -5,14 +5,14 @@ on:
branches: [develop] branches: [develop]
types: [opened, synchronize, closed] types: [opened, synchronize, closed]
schedule: schedule:
- cron: "0 5 * * 1-6" - cron: "0 6 * * 1-6"
workflow_dispatch: workflow_dispatch:
inputs: inputs:
rocm_release: rocm_release:
description: ROCm Version description: ROCm Version
required: true required: true
default: '5.2' default: '5.3'
performance_reports_repo: performance_reports_repo:
description: Result repository description: Result repository
required: true required: true
...@@ -30,9 +30,9 @@ concurrency: "perftest-${{ github.head_ref || github.base_ref || 'schedule' }}" ...@@ -30,9 +30,9 @@ concurrency: "perftest-${{ github.head_ref || github.base_ref || 'schedule' }}"
jobs: jobs:
release: release:
uses: rocmsoftwareplatform/migraphx-benchmark/.github/workflows/perf-test.yml@main uses: ROCmSoftwarePlatform/migraphx-benchmark/.github/workflows/perf-test.yml@main
with: with:
rocm_release: ${{ github.event.inputs.rocm_release || '5.2' }} rocm_release: ${{ github.event.inputs.rocm_release || '5.3' }}
result_number: ${{ github.event.inputs.result_number || '10' }} result_number: ${{ github.event.inputs.result_number || '10' }}
flags: ${{ github.event.inputs.flags || '-s' }} flags: ${{ github.event.inputs.flags || '-s' }}
performance_reports_repo: ${{ github.event.inputs.performance_reports_repo || 'ROCmSoftwarePlatform/migraphx-reports' }} performance_reports_repo: ${{ github.event.inputs.performance_reports_repo || 'ROCmSoftwarePlatform/migraphx-reports' }}
......
...@@ -29,6 +29,7 @@ See below for a comprehensive list of commands and option arguments, as well as ...@@ -29,6 +29,7 @@ See below for a comprehensive list of commands and option arguments, as well as
| --tf | Load file as a tensorflow graph | | --tf | Load file as a tensorflow graph |
| --migraphx | Load file as a migraphx graph | | --migraphx | Load file as a migraphx graph |
| --migraphx-json | Load file as a migraphx JSON graph | | --migraphx-json | Load file as a migraphx JSON graph |
| --batch | Set batch size for the model |
| --nhwc | Treat tensorflow format as nhwc | | --nhwc | Treat tensorflow format as nhwc |
| --nchw | Treat tensorflow format as nchw | | --nchw | Treat tensorflow format as nchw |
| --skip-unknown-operators | Skip unknown operators when parsing and continue to parse | | --skip-unknown-operators | Skip unknown operators when parsing and continue to parse |
......
...@@ -21,6 +21,6 @@ ...@@ -21,6 +21,6 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE. # THE SOFTWARE.
##################################################################################### #####################################################################################
tensorflow==2.7.2 tensorflow==2.9.3
onnxruntime onnxruntime
tokenizers tokenizers
\ No newline at end of file
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_INT_DIVIDE_HPP
#define MIGRAPHX_GUARD_RTGLIB_INT_DIVIDE_HPP
#include <migraphx/config.hpp>
#include <cmath>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class R, class T, class U>
R floor_divide(T x, U y)
{
return R(std::floor(double(x) / double(y)));
}
template <class R, class T, class U>
R ceil_divide(T x, U y)
{
return R(std::ceil(double(x) / double(y)));
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -80,6 +80,7 @@ struct literal : raw_data<literal> ...@@ -80,6 +80,7 @@ struct literal : raw_data<literal>
fill(start, end); fill(start, end);
} }
// Directly copies buffer of x
template <class T, MIGRAPHX_REQUIRES(sizeof(T) == 1)> template <class T, MIGRAPHX_REQUIRES(sizeof(T) == 1)>
literal(const shape& s, T* x) : buffer(make_shared_array<char>(s.bytes())), m_shape(s) literal(const shape& s, T* x) : buffer(make_shared_array<char>(s.bytes())), m_shape(s)
{ {
...@@ -107,25 +108,15 @@ struct literal : raw_data<literal> ...@@ -107,25 +108,15 @@ struct literal : raw_data<literal>
std::shared_ptr<char> buffer; std::shared_ptr<char> buffer;
shape m_shape; shape m_shape;
// Keeps the same data ordering as the given container
template <class Iterator> template <class Iterator>
void fill(Iterator start, Iterator end) void fill(Iterator start, Iterator end)
{ {
assert(std::distance(start, end) == m_shape.elements()); assert(std::distance(start, end) == m_shape.elements());
if(m_shape.standard())
{
m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.get())); });
}
else
{
auto it = start;
m_shape.visit_type([&](auto as) { m_shape.visit_type([&](auto as) {
auto output = make_view(m_shape, as.from(buffer.get())); auto output = make_view(m_shape, as.from(buffer.get()));
shape_for_each(output.get_shape(), [&](const auto& idx) { std::copy(start, end, output.begin());
output(idx.begin(), idx.end()) = *it; // NOLINT(bugprone-signed-char-misuse)
it++;
}); });
});
}
} }
}; };
......
...@@ -31,7 +31,7 @@ ...@@ -31,7 +31,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/int_divide.hpp> #include <migraphx/dyn_output.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -49,6 +49,9 @@ struct pooling ...@@ -49,6 +49,9 @@ struct pooling
bool ceil_mode = false; bool ceil_mode = false;
int lp_order = 2; int lp_order = 2;
// Global pooling with dynamic shape input
bool dyn_global = false;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
...@@ -57,7 +60,8 @@ struct pooling ...@@ -57,7 +60,8 @@ struct pooling
f(self.stride, "stride"), f(self.stride, "stride"),
f(self.lengths, "lengths"), f(self.lengths, "lengths"),
f(self.ceil_mode, "ceil_mode"), f(self.ceil_mode, "ceil_mode"),
f(self.lp_order, "lp_order")); f(self.lp_order, "lp_order"),
f(self.dyn_global, "dyn_global"));
} }
std::string name() const { return "pooling"; } std::string name() const { return "pooling"; }
...@@ -65,51 +69,111 @@ struct pooling ...@@ -65,51 +69,111 @@ struct pooling
void check_attribute_size() const void check_attribute_size() const
{ {
if((padding.size() != stride.size() and (padding.size() / 2) != stride.size()) or if((padding.size() != stride.size() and (padding.size() / 2) != stride.size()) or
stride.size() != lengths.size()) (not dyn_global and stride.size() != lengths.size()))
{ {
MIGRAPHX_THROW("POOLING: inconsistent attribute sizes"); MIGRAPHX_THROW("POOLING: inconsistent attribute sizes");
} }
} }
size_t kdims() const
{
check_attribute_size();
return stride.size();
}
value attributes() const { return {{"normalize_padding", "padding"}}; } value attributes() const { return {{"normalize_padding", "padding"}}; }
std::vector<std::size_t> calc_spatial_dim_out(const std::vector<std::size_t>& input_lens,
std::size_t kdims) const
{
std::vector<std::size_t> output_lens{};
for(size_t i = 0; i < kdims; ++i)
{
if(input_lens[i + 2] == 0)
{
// handle opt = 0
output_lens.push_back(0);
}
else
{
std::size_t padding_factor = 2 * padding[i];
if(padding.size() == 2 * kdims)
padding_factor = padding[i] + padding[i + kdims];
assert(input_lens[i + 2] + padding_factor >= lengths[i]);
std::size_t dim_size = input_lens[i + 2] + padding_factor - lengths[i];
std::size_t len =
(ceil_mode)
? dim_size / stride[i] + static_cast<std::size_t>((dim_size % stride[i] !=
0)) // ceil uint divide
: dim_size / stride[i]; // floor divide
output_lens.push_back(len + 1);
}
}
return output_lens;
}
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
check_attribute_size();
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
auto input_lens = input.lens();
size_t kdims = input_lens.size() - 2;
auto input_size = inputs[0].lens().size();
auto padding_size = padding.size(); auto padding_size = padding.size();
if(input_size != padding_size / 2 + 2 and input_size != padding_size + 2) size_t kdims = input.ndim() - 2;
if(input.ndim() != padding_size / 2 + 2 and input.ndim() != padding_size + 2)
{ {
MIGRAPHX_THROW("POOLING: input and attribute size mismatch!"); MIGRAPHX_THROW("POOLING: input and attribute size mismatch!");
} }
std::vector<std::size_t> output_lens(input_lens.begin(), input_lens.begin() + 2); if(input.dynamic())
for(size_t i = 0; i < kdims; i++)
{ {
std::ptrdiff_t dim_size; auto input_dyn_dims = input.dyn_dims();
auto padding_factor = 2 * padding[i]; std::vector<shape::dynamic_dimension> output_dyn_dims(input_dyn_dims.begin(),
if(padding_size == 2 * kdims) input_dyn_dims.begin() + 2);
padding_factor = padding[i] + padding[i + kdims]; if(dyn_global)
dim_size = input_lens[i + 2] + padding_factor - lengths[i]; {
assert(dim_size >= 0); for(size_t i = 0; i < kdims; ++i)
std::size_t len = (ceil_mode) ? ceil_divide<std::ptrdiff_t>(dim_size, stride[i]) {
: floor_divide<std::ptrdiff_t>(dim_size, stride[i]); output_dyn_dims.push_back(shape::dynamic_dimension{1, 1, 1});
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(1, len + 1)));
} }
return inputs[0].with_lens(output_lens); return {input.type(), output_dyn_dims};
}
else
{
auto min_spatial_dims = calc_spatial_dim_out(input.min_lens(), kdims);
auto max_spatial_dims = calc_spatial_dim_out(input.max_lens(), kdims);
auto opt_spatial_dims = calc_spatial_dim_out(input.opt_lens(), kdims);
for(size_t i = 0; i < kdims; ++i)
{
output_dyn_dims.push_back(shape::dynamic_dimension{
min_spatial_dims[i], max_spatial_dims[i], opt_spatial_dims[i]});
}
return {input.type(), output_dyn_dims};
} }
}
else
{
auto input_lens = input.lens();
size_t kdims() const std::vector<std::size_t> output_lens(input_lens.begin(), input_lens.begin() + 2);
// Used for when normalize_compute_shape() is called again at model eval time
// for an originally dynamic shape. Since kernel shape is not used with dyn_global.
if(dyn_global)
{ {
check_attribute_size(); for(size_t i = 0; i < kdims; ++i)
return stride.size(); {
output_lens.push_back(1);
}
return {input.type(), output_lens};
}
else
{
auto output_spatial_lens = calc_spatial_dim_out(input_lens, kdims);
output_lens.insert(
output_lens.end(), output_spatial_lens.begin(), output_spatial_lens.end());
return inputs[0].with_lens(output_lens);
}
}
} }
struct lpnorm_pool struct lpnorm_pool
...@@ -158,7 +222,11 @@ struct pooling ...@@ -158,7 +222,11 @@ struct pooling
}; };
template <class Type, class Out, class In, class Op> template <class Type, class Out, class In, class Op>
void calc_pooling(const shape& output_shape, Out& output, const In& input, Op op) const void calc_pooling(const shape& output_shape,
Out& output,
const In& input,
const std::vector<std::size_t>& kernel_dims,
Op op) const
{ {
auto in_s = input.get_shape(); auto in_s = input.get_shape();
auto in_lens = in_s.lens(); auto in_lens = in_s.lens();
...@@ -172,7 +240,7 @@ struct pooling ...@@ -172,7 +240,7 @@ struct pooling
auto d_2 = dim - 2; auto d_2 = dim - 2;
int start = int start =
static_cast<int>(idx_o[dim] * stride[d_2]) - static_cast<int>(padding[d_2]); static_cast<int>(idx_o[dim] * stride[d_2]) - static_cast<int>(padding[d_2]);
int end = std::min(start + lengths[d_2], in_lens[dim]); int end = std::min(start + kernel_dims[d_2], in_lens[dim]);
start = std::max(start, 0); start = std::max(start, 0);
win_start.push_back(start); win_start.push_back(start);
win_size.push_back(end - start); win_size.push_back(end - start);
...@@ -198,21 +266,32 @@ struct pooling ...@@ -198,21 +266,32 @@ struct pooling
}); });
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
argument result{dyn_out.computed_shape};
auto input_lens = args[0].get_shape().lens();
std::vector<std::size_t> kernel_dims;
if(dyn_global)
{ {
argument result{output_shape}; kernel_dims.insert(kernel_dims.end(), input_lens.begin() + 2, input_lens.end());
}
else
{
kernel_dims = this->lengths;
}
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
using type = typename decltype(output)::value_type; using type = typename decltype(output)::value_type;
switch(mode) switch(mode)
{ {
case migraphx::op::pooling_mode::average: case migraphx::op::pooling_mode::average:
calc_pooling<type>(output_shape, output, input, avg_pool{}); calc_pooling<type>(dyn_out.computed_shape, output, input, kernel_dims, avg_pool{});
break; break;
case migraphx::op::pooling_mode::max: case migraphx::op::pooling_mode::max:
calc_pooling<type>(output_shape, output, input, max_pool{}); calc_pooling<type>(dyn_out.computed_shape, output, input, kernel_dims, max_pool{});
break; break;
case migraphx::op::pooling_mode::lpnorm: case migraphx::op::pooling_mode::lpnorm:
calc_pooling<type>(output_shape, output, input, lpnorm_pool{lp_order}); calc_pooling<type>(
dyn_out.computed_shape, output, input, kernel_dims, lpnorm_pool{lp_order});
break; break;
} }
}); });
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -45,17 +46,15 @@ struct transpose ...@@ -45,17 +46,15 @@ struct transpose
} }
std::string name() const { return "transpose"; } std::string name() const { return "transpose"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
auto input = inputs.at(0); auto input = inputs.at(0);
auto input_lens = input.lens();
auto input_strides = input.strides();
auto t = input.type();
if(dims.size() != input_lens.size()) if(dims.size() != input.ndim())
{ {
MIGRAPHX_THROW("Permutation has wrong number of axes"); MIGRAPHX_THROW("TRANSPOSE: Permutation has wrong number of axes");
} }
std::vector<int64_t> axes(dims.size()); std::vector<int64_t> axes(dims.size());
std::iota(axes.begin(), axes.end(), 0); std::iota(axes.begin(), axes.end(), 0);
...@@ -63,19 +62,36 @@ struct transpose ...@@ -63,19 +62,36 @@ struct transpose
{ {
MIGRAPHX_THROW("TRANSPOSE: Invalid permutation"); MIGRAPHX_THROW("TRANSPOSE: Invalid permutation");
} }
std::vector<size_t> output_lens(input_lens.size());
std::vector<size_t> output_strides(input_lens.size()); if(input.dynamic())
for(std::size_t i = 0; i < output_lens.size(); i++) {
std::vector<shape::dynamic_dimension> output_dyn_dims(input.ndim());
std::transform(dims.cbegin(), dims.cend(), output_dyn_dims.begin(), [&](auto dim) {
return input.dyn_dims()[dim];
});
return {input.type(), output_dyn_dims};
}
else
{
auto input_lens = input.lens();
auto input_strides = input.strides();
std::vector<size_t> output_lens(input.ndim());
std::vector<size_t> output_strides(input.ndim());
for(std::size_t i = 0; i < input.ndim(); i++)
{ {
output_lens[i] = input_lens[dims[i]]; output_lens[i] = input_lens[dims[i]];
output_strides[i] = input_strides[dims[i]]; output_strides[i] = input_strides[dims[i]];
} }
return {t, output_lens, output_strides}; return {input.type(), output_lens, output_strides};
} }
argument compute(shape output_shape, std::vector<argument> args) const }
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(dyn_out.computed_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/reflect.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/requires.hpp> #include <migraphx/requires.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <type_traits> #include <type_traits>
...@@ -87,46 +88,55 @@ value to_value_impl(rank<3>, const T& x) ...@@ -87,46 +88,55 @@ value to_value_impl(rank<3>, const T& x)
return result; return result;
} }
template <class T>
auto to_value_impl(rank<4>, const optional<T>& x)
{
value result{};
if(x.has_value())
to_value(*x);
return result;
}
template <class T, MIGRAPHX_REQUIRES(std::is_signed<T>{})> template <class T, MIGRAPHX_REQUIRES(std::is_signed<T>{})>
value to_value_impl(rank<4>, const T& x) value to_value_impl(rank<5>, const T& x)
{ {
return std::int64_t{x}; return std::int64_t{x};
} }
template <class T, MIGRAPHX_REQUIRES(std::is_unsigned<T>{})> template <class T, MIGRAPHX_REQUIRES(std::is_unsigned<T>{})>
value to_value_impl(rank<5>, const T& x) value to_value_impl(rank<6>, const T& x)
{ {
return std::uint64_t{x}; return std::uint64_t{x};
} }
template <class T, MIGRAPHX_REQUIRES(std::is_floating_point<T>{})> template <class T, MIGRAPHX_REQUIRES(std::is_floating_point<T>{})>
value to_value_impl(rank<6>, const T& x) value to_value_impl(rank<7>, const T& x)
{ {
return double{x}; return double{x};
} }
template <class T, MIGRAPHX_REQUIRES(std::is_enum<T>{})> template <class T, MIGRAPHX_REQUIRES(std::is_enum<T>{})>
value to_value_impl(rank<7>, const T& x) value to_value_impl(rank<8>, const T& x)
{ {
return x; return x;
} }
inline value to_value_impl(rank<8>, const std::string& x) { return x; } inline value to_value_impl(rank<9>, const std::string& x) { return x; }
template <class T> template <class T>
auto to_value_impl(rank<9>, const T& x) -> decltype(migraphx_to_value(x)) auto to_value_impl(rank<10>, const T& x) -> decltype(migraphx_to_value(x))
{ {
return migraphx_to_value(x); return migraphx_to_value(x);
} }
template <class T> template <class T>
auto to_value_impl(rank<10>, const T& x) -> decltype(x.to_value()) auto to_value_impl(rank<11>, const T& x) -> decltype(x.to_value())
{ {
return x.to_value(); return x.to_value();
} }
template <class T> template <class T>
auto to_value_impl(rank<11>, const T& x) auto to_value_impl(rank<12>, const T& x)
-> decltype(migraphx_to_value(std::declval<value&>(), x), value{}) -> decltype(migraphx_to_value(std::declval<value&>(), x), value{})
{ {
value v; value v;
...@@ -195,28 +205,35 @@ void from_value_impl(rank<5>, const value& v, T& x) ...@@ -195,28 +205,35 @@ void from_value_impl(rank<5>, const value& v, T& x)
}); });
} }
template <class T>
void from_value_impl(rank<6>, const value& v, optional<T>& x)
{
if(not v.is_null())
x = from_value<T>(v);
}
template <class T, MIGRAPHX_REQUIRES(std::is_arithmetic<T>{})> template <class T, MIGRAPHX_REQUIRES(std::is_arithmetic<T>{})>
void from_value_impl(rank<6>, const value& v, T& x) void from_value_impl(rank<7>, const value& v, T& x)
{ {
x = v.to<T>(); x = v.to<T>();
} }
template <class T, MIGRAPHX_REQUIRES(std::is_enum<T>{})> template <class T, MIGRAPHX_REQUIRES(std::is_enum<T>{})>
void from_value_impl(rank<7>, const value& v, T& x) void from_value_impl(rank<8>, const value& v, T& x)
{ {
x = v.to<T>(); x = v.to<T>();
} }
inline void from_value_impl(rank<8>, const value& v, std::string& x) { x = v.to<std::string>(); } inline void from_value_impl(rank<9>, const value& v, std::string& x) { x = v.to<std::string>(); }
template <class T> template <class T>
auto from_value_impl(rank<9>, const value& v, T& x) -> decltype(x.from_value(v), void()) auto from_value_impl(rank<10>, const value& v, T& x) -> decltype(x.from_value(v), void())
{ {
x.from_value(v); x.from_value(v);
} }
template <class T> template <class T>
auto from_value_impl(rank<10>, const value& v, T& x) -> decltype(migraphx_from_value(v, x), void()) auto from_value_impl(rank<11>, const value& v, T& x) -> decltype(migraphx_from_value(v, x), void())
{ {
migraphx_from_value(v, x); migraphx_from_value(v, x);
} }
...@@ -226,13 +243,13 @@ auto from_value_impl(rank<10>, const value& v, T& x) -> decltype(migraphx_from_v ...@@ -226,13 +243,13 @@ auto from_value_impl(rank<10>, const value& v, T& x) -> decltype(migraphx_from_v
template <class T> template <class T>
value to_value(const T& x) value to_value(const T& x)
{ {
return detail::to_value_impl(rank<11>{}, x); return detail::to_value_impl(rank<12>{}, x);
} }
template <class T> template <class T>
void from_value(const value& v, T& x) void from_value(const value& v, T& x)
{ {
detail::from_value_impl(rank<10>{}, v, x); detail::from_value_impl(rank<11>{}, v, x);
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -31,6 +31,9 @@ ...@@ -31,6 +31,9 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
/**
* Iterates the given function over the indices from the shape in order.
*/
template <class F> template <class F>
void shape_for_each(const migraphx::shape& s, F f) void shape_for_each(const migraphx::shape& s, F f)
{ {
...@@ -51,7 +54,6 @@ void shape_for_each(const migraphx::shape& s, F f) ...@@ -51,7 +54,6 @@ void shape_for_each(const migraphx::shape& s, F f)
call(indices); call(indices);
} }
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/reflect.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/requires.hpp> #include <migraphx/requires.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <vector> #include <vector>
...@@ -99,12 +100,21 @@ void stream_write_value_impl(rank<0>, std::ostream& os, const T& x) ...@@ -99,12 +100,21 @@ void stream_write_value_impl(rank<0>, std::ostream& os, const T& x)
os << "}"; os << "}";
} }
template <class T>
void stream_write_value_impl(rank<0>, std::ostream& os, const optional<T>& x)
{
if(x.has_value())
stream_write_value_impl(rank<2>{}, os, *x);
else
os << "none";
}
} // namespace detail } // namespace detail
template <class T> template <class T>
void stream_write_value(std::ostream& os, const T& x) void stream_write_value(std::ostream& os, const T& x)
{ {
detail::stream_write_value_impl(rank<1>{}, os, x); detail::stream_write_value_impl(rank<2>{}, os, x);
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -77,14 +77,14 @@ static void update_pooling(const instruction_ref& input, const instruction_ref& ...@@ -77,14 +77,14 @@ static void update_pooling(const instruction_ref& input, const instruction_ref&
{ {
return; return;
} }
auto kdims = input->get_shape().lens().size() - 2; auto kdims = input->get_shape().ndim() - 2;
if(std::equal(op.padding.begin(), if(std::equal(op.padding.begin(),
op.padding.begin() + kdims, op.padding.begin() + kdims,
op.padding.begin() + kdims, op.padding.begin() + kdims,
op.padding.end())) op.padding.end()))
return; return;
std::vector<int64_t> padding(input->get_shape().lens().size() * 2, 0); std::vector<int64_t> padding(input->get_shape().ndim() * 2, 0);
std::vector<size_t> pads_l(op.padding.begin(), op.padding.begin() + kdims); std::vector<size_t> pads_l(op.padding.begin(), op.padding.begin() + kdims);
std::vector<size_t> pads_r(op.padding.begin() + kdims, op.padding.end()); std::vector<size_t> pads_r(op.padding.begin() + kdims, op.padding.end());
op.padding = std::vector<size_t>(kdims * 2, 0); op.padding = std::vector<size_t>(kdims * 2, 0);
......
...@@ -47,52 +47,42 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -47,52 +47,42 @@ struct parse_pooling : op_parser<parse_pooling>
{"GlobalLpPool", "lpnorm"}}; {"GlobalLpPool", "lpnorm"}};
} }
instruction_ref parse(const op_desc& opd, value handle_values(const op_desc& opd,
const onnx_parser& /*parser*/,
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const const shape& in_shape,
value values) const
{ {
const std::unordered_map<std::string, op::pooling_mode> mode_map = { auto kdims = in_shape.ndim() - 2;
{"max", op::pooling_mode::max}, if(starts_with(opd.onnx_name, "Global"))
{"average", op::pooling_mode::average},
{"lpnorm", op::pooling_mode::lpnorm}};
std::string mode = opd.op_name;
if(not contains(mode_map, mode))
{ {
MIGRAPHX_THROW("onnx pooling mode must be [\"max\", \"average\", \"lpnorm\"]"); // if spatial dimensions are dynamic use dyn_global flag
if(in_shape.dynamic() and std::any_of(in_shape.dyn_dims().cbegin() + 2,
in_shape.dyn_dims().cend(),
[](auto dd) { return not dd.is_fixed(); }))
{
values["dyn_global"] = true;
values["lengths"] = std::vector<size_t>();
} }
operation op = make_op("pooling", {{"mode", mode_map.at(mode)}}); else
value values = op.to_value();
auto l0 = args[0];
auto in_lens = l0->get_shape().lens();
assert(in_lens.size() > 2);
auto kdims = in_lens.size() - 2;
if(starts_with(opd.onnx_name, "Global"))
{ {
values["lengths"] = std::vector<size_t>(in_lens.begin() + 2, in_lens.end()); // works with static and fixed dynamic shape
auto m_lens = in_shape.max_lens();
values["lengths"] = std::vector<size_t>(m_lens.begin() + 2, m_lens.end());
}
} }
// does not support ceil_mode
if(contains(info.attributes, "ceil_mode")) if(contains(info.attributes, "ceil_mode"))
{ {
values["ceil_mode"] = static_cast<bool>(info.attributes.at("ceil_mode").i()); values["ceil_mode"] = static_cast<bool>(info.attributes.at("ceil_mode").i());
} }
// count include padding, if count include pad is 1, we always use
// explicit pad
int count_include_pad = 0;
if(contains(info.attributes, "count_include_pad"))
{
count_include_pad = info.attributes.at("count_include_pad").i();
}
if(contains(info.attributes, "strides")) if(contains(info.attributes, "strides"))
{ {
values["stride"].clear(); values["stride"].clear();
copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"])); copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"]));
check_attr_sizes(kdims, values["stride"].size(), "PARSE_POOLING: inconsistent strides"); check_attr_sizes(kdims, values["stride"].size(), "PARSE_POOLING: inconsistent strides");
} }
if(contains(info.attributes, "kernel_shape")) if(contains(info.attributes, "kernel_shape"))
{ {
values["lengths"].clear(); values["lengths"].clear();
...@@ -110,6 +100,46 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -110,6 +100,46 @@ struct parse_pooling : op_parser<parse_pooling>
// ensure pads availabe only when auto_pad is "NOT_SET" // ensure pads availabe only when auto_pad is "NOT_SET"
check_padding_mode(info, "POOLING"); check_padding_mode(info, "POOLING");
return values;
}
instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
std::string mode = opd.op_name;
const std::unordered_map<std::string, op::pooling_mode> mode_map = {
{"max", op::pooling_mode::max},
{"average", op::pooling_mode::average},
{"lpnorm", op::pooling_mode::lpnorm}};
if(not contains(mode_map, mode))
{
MIGRAPHX_THROW(
"PARSE_POOLING: onnx pooling mode must be [\"max\", \"average\", \"lpnorm\"]");
}
operation op = make_op("pooling", {{"mode", mode_map.at(mode)}});
value values = op.to_value();
auto l0 = args[0];
auto in_shape = l0->get_shape();
assert(in_shape.ndim() > 2);
auto kdims = in_shape.ndim() - 2;
values = handle_values(opd, info, in_shape, values);
// count include padding, if count include pad is 1, we always use
// explicit pad
int count_include_pad = 0;
if(contains(info.attributes, "count_include_pad"))
{
if(in_shape.dynamic())
{
MIGRAPHX_THROW("PARSE_POOLING: count_include_pad attribute is not supported for "
"dynamic input shape");
}
count_include_pad = info.attributes.at("count_include_pad").i();
}
std::vector<int64_t> paddings; std::vector<int64_t> paddings;
float pad_val = ((mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f); float pad_val = ((mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f);
...@@ -122,6 +152,13 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -122,6 +152,13 @@ struct parse_pooling : op_parser<parse_pooling>
} }
if(contains(info.attributes, "auto_pad")) if(contains(info.attributes, "auto_pad"))
{
if(in_shape.dynamic())
{
MIGRAPHX_THROW(
"PARSE_POOLING: Auto padding pooling with dynamic input shape not supported");
}
else
{ {
values["padding"].clear(); values["padding"].clear();
// return paddings could be empty, then setting to 0 for no padding // return paddings could be empty, then setting to 0 for no padding
...@@ -129,9 +166,10 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -129,9 +166,10 @@ struct parse_pooling : op_parser<parse_pooling>
values, values,
values["lengths"].to_vector<std::size_t>(), values["lengths"].to_vector<std::size_t>(),
{1, 1}, {1, 1},
in_lens, in_shape.lens(),
paddings); paddings);
} }
}
if(paddings.size() != 2 * kdims) if(paddings.size() != 2 * kdims)
{ {
...@@ -150,6 +188,7 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -150,6 +188,7 @@ struct parse_pooling : op_parser<parse_pooling>
values["stride"].resize(kdims); values["stride"].resize(kdims);
std::fill_n(values["stride"].begin(), kdims, 1); std::fill_n(values["stride"].begin(), kdims, 1);
} }
// used to calculate the supposed output shape // used to calculate the supposed output shape
std::vector<int64_t> orig_padding = paddings; std::vector<int64_t> orig_padding = paddings;
...@@ -159,6 +198,11 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -159,6 +198,11 @@ struct parse_pooling : op_parser<parse_pooling>
if(not slice_start.empty()) if(not slice_start.empty())
{ {
if(in_shape.dynamic())
{
MIGRAPHX_THROW(
"PARSE_POOLING: asymmetric padding not supported for dynamic input shape");
}
// calculate expected output shape // calculate expected output shape
orig_padding.insert(orig_padding.begin() + kdims, 2, 0); orig_padding.insert(orig_padding.begin() + kdims, 2, 0);
orig_padding.insert(orig_padding.begin(), 2, 0); orig_padding.insert(orig_padding.begin(), 2, 0);
......
...@@ -47,7 +47,7 @@ struct parse_transpose : op_parser<parse_transpose> ...@@ -47,7 +47,7 @@ struct parse_transpose : op_parser<parse_transpose>
} }
// if perm is empty, use the default value // if perm is empty, use the default value
auto n_dim = args.front()->get_shape().lens().size(); auto n_dim = args.front()->get_shape().ndim();
if(perm.empty()) if(perm.empty())
{ {
perm.resize(n_dim); perm.resize(n_dim);
......
...@@ -237,14 +237,17 @@ endif() ...@@ -237,14 +237,17 @@ endif()
include(CheckLibraryExists) include(CheckLibraryExists)
get_target_property(MIOPEN_LOCATION MIOpen LOCATION) get_target_property(MIOPEN_LOCATION MIOpen LOCATION)
check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API) check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API)
# check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API) check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API)
# if(HAS_FIND_2_API) # TODO: Set default to HAS_FIND_2_API
# target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API) set(MIGRAPHX_USE_FIND_2_API OFF CACHE BOOL "")
# message(STATUS "MIGraphx is using Find-2.0 API of MIOpen")
# else() if(MIGRAPHX_USE_FIND_2_API)
# message(STATUS "MIOpen does not have Find-2.0 API") target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API)
# endif() message(STATUS "MIGraphx is using Find-2.0 API of MIOpen")
else()
message(STATUS "MIGraphx is using legacy Find API in MIOpen")
endif()
if(HAS_FIND_MODE_API) if(HAS_FIND_MODE_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_MODE_API) target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_MODE_API)
......
...@@ -42,13 +42,15 @@ struct precompile_op ...@@ -42,13 +42,15 @@ struct precompile_op
operation op = op::identity{}; operation op = op::identity{};
std::size_t additional_args = 1; std::size_t additional_args = 1;
bool ignore_modules = false; bool ignore_modules = false;
optional<shape> output_shape = {};
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.op, "op"), return pack(f(self.op, "op"),
f(self.additional_args, "additional_args"), f(self.additional_args, "additional_args"),
f(self.ignore_modules, "ignore_modules")); f(self.ignore_modules, "ignore_modules"),
f(self.output_shape, "output_shape"));
} }
std::string name() const { return "gpu::precompile_op"; } std::string name() const { return "gpu::precompile_op"; }
...@@ -57,9 +59,14 @@ struct precompile_op ...@@ -57,9 +59,14 @@ struct precompile_op
{ {
// Pop off additional args // Pop off additional args
inputs.resize(inputs.size() - additional_args); inputs.resize(inputs.size() - additional_args);
shape r{};
if(ignore_modules) if(ignore_modules)
return op.compute_shape(inputs); r = op.compute_shape(inputs);
return op.compute_shape(inputs, mods); else
r = op.compute_shape(inputs, mods);
if(output_shape.has_value())
r = *output_shape;
return r;
} }
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
......
...@@ -44,7 +44,10 @@ struct ck_gemm ...@@ -44,7 +44,10 @@ struct ck_gemm
auto b = inputs[1]; auto b = inputs[1];
for(const auto& input : inputs) for(const auto& input : inputs)
check_gemm_shape(input); check_gemm_shape(input);
return op.compute_shape({a, b}); auto r = op.compute_shape({a, b});
if(mods.empty())
return r;
return r.with_type(mods.front()->get_output_shapes().front().type());
} }
}; };
MIGRAPHX_REGISTER_OP(ck_gemm); MIGRAPHX_REGISTER_OP(ck_gemm);
......
...@@ -675,6 +675,34 @@ struct find_gemm_pointwise ...@@ -675,6 +675,34 @@ struct find_gemm_pointwise
} }
}; };
struct find_contiguous_tranpose_precompile
{
auto matcher() const
{
return match::name("gpu::contiguous")(match::arg(0)(
match::name("transpose")(
match::arg(0)(match::name("gpu::precompile_op")(match::used_once()).bind("op")))
.bind("transpose")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto op_ins = r.instructions["op"];
auto alloc = op_ins->inputs().back();
auto transpose = r.instructions["transpose"];
auto perm = transpose->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto iperm = invert_permutation(perm);
auto s =
shape::from_permutation(op_ins->get_shape().type(), op_ins->get_shape().lens(), iperm);
auto v = op_ins->get_operator().to_value();
v["output_shape"] = to_value(s);
auto new_op = make_op("gpu::precompile_op", v);
m.replace_instruction(op_ins, new_op, op_ins->inputs(), op_ins->module_inputs());
m.replace_instruction(ins, transpose);
}
};
struct find_contiguous_tranpose_gemm struct find_contiguous_tranpose_gemm
{ {
auto matcher() const auto matcher() const
...@@ -850,6 +878,7 @@ void fuse_ops::apply(module& m) const ...@@ -850,6 +878,7 @@ void fuse_ops::apply(module& m) const
find_concat_pointwise{}, find_concat_pointwise{},
find_gemm_pointwise{}, find_gemm_pointwise{},
find_contiguous_tranpose_gemm{}, find_contiguous_tranpose_gemm{},
find_contiguous_tranpose_precompile{},
find_commutative_broadcast{}); find_commutative_broadcast{});
match::find_matches(m, find_contiguous{}); match::find_matches(m, find_contiguous{});
} }
......
...@@ -105,7 +105,7 @@ struct hip_copy_to_gpu ...@@ -105,7 +105,7 @@ struct hip_copy_to_gpu
std::string name() const { return "hip::copy_to_gpu"; } std::string name() const { return "hip::copy_to_gpu"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1, 2); check_shapes{inputs, *this}.has(1, 2).same_type();
return inputs.at(0); return inputs.at(0);
} }
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
...@@ -131,7 +131,7 @@ struct hip_copy_from_gpu ...@@ -131,7 +131,7 @@ struct hip_copy_from_gpu
std::string name() const { return "hip::copy_from_gpu"; } std::string name() const { return "hip::copy_from_gpu"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1, 2); check_shapes{inputs, *this}.has(1, 2).same_type();
return inputs.at(0); return inputs.at(0);
} }
argument argument
...@@ -159,7 +159,7 @@ struct hip_copy ...@@ -159,7 +159,7 @@ struct hip_copy
std::string name() const { return "hip::copy"; } std::string name() const { return "hip::copy"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this}.has(2).same_type();
return inputs.at(1); return inputs.at(1);
} }
argument compute(context& ctx, const shape&, std::vector<argument> args) const argument compute(context& ctx, const shape&, std::vector<argument> args) const
......
...@@ -50,6 +50,7 @@ using namespace migraphx::gpu::gen; // NOLINT ...@@ -50,6 +50,7 @@ using namespace migraphx::gpu::gen; // NOLINT
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_LOG_CK_GEMM); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_LOG_CK_GEMM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING_VALUE);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_DEBUG); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_DEBUG);
// NOLINTNEXTLINE // NOLINTNEXTLINE
...@@ -136,6 +137,8 @@ struct instance ...@@ -136,6 +137,8 @@ struct instance
std::string str() const { return join_strings(params, ","); } std::string str() const { return join_strings(params, ","); }
}; };
static bool transposed_matrix(const shape& s) { return s.strides().back() != 1; }
template <class F, class Action> template <class F, class Action>
auto action_decorate(F f, Action action) auto action_decorate(F f, Action action)
{ {
...@@ -153,6 +156,21 @@ static std::vector<tuning_entry> read_tuning(const std::string& s) ...@@ -153,6 +156,21 @@ static std::vector<tuning_entry> read_tuning(const std::string& s)
return from_value<std::vector<tuning_entry>>(from_json_string(read_string(s))); return from_value<std::vector<tuning_entry>>(from_json_string(read_string(s)));
} }
static float matrix_distance(const shape& x, const shape& y)
{
if(x.type() != y.type())
return std::numeric_limits<float>::max();
if(transposed_matrix(x) != transposed_matrix(y))
return std::numeric_limits<float>::max();
auto sum_squared = std::inner_product(x.lens().rbegin(),
x.lens().rbegin() + 2,
y.lens().rbegin(),
0,
std::plus<>{},
[](auto a, auto b) { return (a - b) * (a - b); });
return std::sqrt(sum_squared);
}
static std::size_t get_tuning_for(const std::vector<shape>& inputs) static std::size_t get_tuning_for(const std::vector<shape>& inputs)
{ {
static auto tuning = read_tuning(string_value_of(MIGRAPHX_CK_TUNING{}, "")); static auto tuning = read_tuning(string_value_of(MIGRAPHX_CK_TUNING{}, ""));
...@@ -163,7 +181,26 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs) ...@@ -163,7 +181,26 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs)
if(it == tuning.end()) if(it == tuning.end())
{ {
std::cout << "*********** Warning: CK tuning missing for config!" << std::endl; std::cout << "*********** Warning: CK tuning missing for config!" << std::endl;
return 4; std::vector<std::pair<float, std::size_t>> w;
std::transform(tuning.begin(), tuning.end(), std::back_inserter(w), [&](const auto& p) {
if(inputs.size() < 3 or p.first.size() < 3)
MIGRAPHX_THROW("Invalid CK config");
auto avg_distance = std::inner_product(
p.first.begin(),
p.first.begin() + 3,
inputs.begin(),
0.0f,
std::plus<>{},
[](const auto& x, const auto& y) { return matrix_distance(x, y) / 3.0f; });
return std::make_pair(avg_distance, p.second);
});
std::sort(w.begin(), w.end());
std::size_t default_value = 4;
if(not w.empty())
default_value = w.front().second;
auto tuning_val = value_of(MIGRAPHX_CK_TUNING_VALUE{}, default_value);
std::cout << "*********** Warning: CK try tuning: " << tuning_val << std::endl;
return tuning_val;
} }
return it->second; return it->second;
} }
...@@ -172,7 +209,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -172,7 +209,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{ {
static std::string get_layout(const shape& s) static std::string get_layout(const shape& s)
{ {
return s.transposed() ? "ck::tensor_layout::gemm::ColumnMajor" return transposed_matrix(s) ? "ck::tensor_layout::gemm::ColumnMajor"
: "ck::tensor_layout::gemm::RowMajor"; : "ck::tensor_layout::gemm::RowMajor";
} }
...@@ -191,6 +228,22 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -191,6 +228,22 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
return "ck::Tuple<" + join_strings(s, ",") + ">"; return "ck::Tuple<" + join_strings(s, ",") + ">";
} }
static std::vector<shape> adjust_inputs(std::vector<shape> inputs, bool& swap_inputs)
{
swap_inputs = false;
auto c_shape = inputs.back();
if(not transposed_matrix(c_shape))
return inputs;
std::vector<int64_t> perm(c_shape.lens().size());
std::iota(perm.begin(), perm.end(), 0);
std::swap(perm[perm.size() - 1], perm[perm.size() - 2]);
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](shape s) {
return reorder_shape(s, perm);
});
swap_inputs = true;
return inputs;
}
std::vector<std::string> names() const { return {"ck_gemm", "gpu::ck_gemm"}; } std::vector<std::string> names() const { return {"ck_gemm", "gpu::ck_gemm"}; }
operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP #define MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
#include <migraphx/kernels/reduce.hpp> #include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/ops.hpp> #include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/vec.hpp>
#include <migraphx/kernels/print.hpp> #include <migraphx/kernels/print.hpp>
namespace migraphx { namespace migraphx {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment