Commit 7e316254 authored by turneram's avatar turneram
Browse files

Merge remote-tracking branch 'origin/develop' into bert-attention-no-transpose-ops

parents a80f5b19 ebdddf58
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_IOTA_ITERATOR_HPP #ifndef MIGRAPHX_GUARD_KERNELS_IOTA_ITERATOR_HPP
#define MIGRAPHX_GUARD_KERNELS_IOTA_ITERATOR_HPP #define MIGRAPHX_GUARD_KERNELS_IOTA_ITERATOR_HPP
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_MATH_HPP #ifndef MIGRAPHX_GUARD_KERNELS_MATH_HPP
#define MIGRAPHX_GUARD_KERNELS_MATH_HPP #define MIGRAPHX_GUARD_KERNELS_MATH_HPP
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_OPS_HPP #ifndef MIGRAPHX_GUARD_KERNELS_OPS_HPP
#define MIGRAPHX_GUARD_KERNELS_OPS_HPP #define MIGRAPHX_GUARD_KERNELS_OPS_HPP
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_POINTWISE_HPP #ifndef MIGRAPHX_GUARD_KERNELS_POINTWISE_HPP
#define MIGRAPHX_GUARD_KERNELS_POINTWISE_HPP #define MIGRAPHX_GUARD_KERNELS_POINTWISE_HPP
...@@ -18,8 +41,15 @@ struct implicit_conversion_op ...@@ -18,8 +41,15 @@ struct implicit_conversion_op
template <index_int N, class U> template <index_int N, class U>
constexpr operator vec<U, N>() const constexpr operator vec<U, N>() const
{ {
static_assert(vec_size<T>() == N, "Vector mismatch size"); if constexpr(vec_size<T>() == 0)
return __builtin_convertvector(x, vec<U, N>); {
return x;
}
else
{
static_assert(vec_size<T>() == N, "Vector mismatch size");
return __builtin_convertvector(x, vec<U, N>);
}
} }
template <class U> template <class U>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_PRELOAD_HPP #ifndef MIGRAPHX_GUARD_KERNELS_PRELOAD_HPP
#define MIGRAPHX_GUARD_KERNELS_PRELOAD_HPP #define MIGRAPHX_GUARD_KERNELS_PRELOAD_HPP
...@@ -163,7 +186,8 @@ __device__ auto auto_preload(index idx) ...@@ -163,7 +186,8 @@ __device__ auto auto_preload(index idx)
{ {
return make_transform([=](auto f, auto... xs) { return make_transform([=](auto f, auto... xs) {
auto invoke = [=](auto... ys) { auto invoke = [=](auto... ys) {
__syncthreads(); if constexpr((Bs or ...))
__syncthreads();
f(ys...); f(ys...);
}; };
join(invoke, preload_copy<Bs>(idx, xs)...); join(invoke, preload_copy<Bs>(idx, xs)...);
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_PRINT_HPP #ifndef MIGRAPHX_GUARD_KERNELS_PRINT_HPP
#define MIGRAPHX_GUARD_KERNELS_PRINT_HPP #define MIGRAPHX_GUARD_KERNELS_PRINT_HPP
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_REDUCE_HPP #ifndef MIGRAPHX_GUARD_KERNELS_REDUCE_HPP
#define MIGRAPHX_GUARD_KERNELS_REDUCE_HPP #define MIGRAPHX_GUARD_KERNELS_REDUCE_HPP
...@@ -152,6 +175,21 @@ constexpr auto sliced(Slicer slicer, F f) ...@@ -152,6 +175,21 @@ constexpr auto sliced(Slicer slicer, F f)
}; };
} }
template <class Input, index_int Axis>
constexpr auto compute_reduce_axis()
{
constexpr auto lens =
transform_i(get_shape_c<Input>{}.lens, [](index_int x, index_int i) -> index_int {
if(i == Axis)
return 1;
return x;
});
return make_shape(lens, get_shape_c<Input>{}.strides);
}
template <class Input, index_int Axis>
using with_axis = decltype(compute_reduce_axis<Input, Axis>());
struct block struct block
{ {
template <class Slicer> template <class Slicer>
...@@ -163,9 +201,12 @@ struct block ...@@ -163,9 +201,12 @@ struct block
__device__ auto reduce(Op op, T init, Read read) const __device__ auto reduce(Op op, T init, Read read) const
{ {
return sliced(slicer, [=](auto x, auto... xs) { return sliced(slicer, [=](auto x, auto... xs) {
return block_reduce(idx, op, init, x.get_shape().elements(), [&](auto j) { return vec_reduce(block_reduce(idx,
return read(x[j], xs[j]...); op,
}); init,
x.get_shape().elements(),
[&](auto j) { return read(x[j], xs[j]...); }),
op);
}); });
} }
...@@ -175,6 +216,14 @@ struct block ...@@ -175,6 +216,14 @@ struct block
if(idx.local == 0) if(idx.local == 0)
f(); f();
} }
template <class F>
__device__ auto inner(F f) const
{
return sliced(slicer, [=](auto x, auto... xs) {
idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); });
});
}
}; };
template <class Slicer> template <class Slicer>
...@@ -221,6 +270,17 @@ struct lane ...@@ -221,6 +270,17 @@ struct lane
{ {
f(); f();
} }
template <class F>
__device__ auto inner(F f) const
{
return sliced(slicer, [=](auto x, auto... xs) {
for(index_int j = 0; j < x.get_shape().elements(); j++)
{
f(x[j], xs[j]...);
}
});
}
}; };
template <class Slicer> template <class Slicer>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_ROIALIGN_HPP #ifndef MIGRAPHX_GUARD_KERNELS_ROIALIGN_HPP
#define MIGRAPHX_GUARD_KERNELS_ROIALIGN_HPP #define MIGRAPHX_GUARD_KERNELS_ROIALIGN_HPP
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_SCATTERND_HPP #ifndef MIGRAPHX_GUARD_KERNELS_SCATTERND_HPP
#define MIGRAPHX_GUARD_KERNELS_SCATTERND_HPP #define MIGRAPHX_GUARD_KERNELS_SCATTERND_HPP
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_SHAPE_HPP #ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_SHAPE_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_SHAPE_HPP #define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_SHAPE_HPP
...@@ -9,6 +32,7 @@ namespace migraphx { ...@@ -9,6 +32,7 @@ namespace migraphx {
template <class Lens, class Strides> template <class Lens, class Strides>
struct shape struct shape
{ {
using shape_type = shape;
using index_array = typename Lens::base_array; using index_array = typename Lens::base_array;
Lens lens = {}; Lens lens = {};
Strides strides = {}; Strides strides = {};
...@@ -21,7 +45,7 @@ struct shape ...@@ -21,7 +45,7 @@ struct shape
constexpr auto element_space() const { return _c<Strides{}.dot(Lens{} - 1) + 1>; } constexpr auto element_space() const { return _c<Strides{}.dot(Lens{} - 1) + 1>; }
constexpr auto packed() const { return elements() == element_space(); } constexpr auto packed() const { return not skips() and elements() == element_space(); }
constexpr auto broadcasted() const { return _c<Strides{}.product() == 0>; } constexpr auto broadcasted() const { return _c<Strides{}.product() == 0>; }
constexpr auto transposed() const constexpr auto transposed() const
{ {
...@@ -30,16 +54,9 @@ struct shape ...@@ -30,16 +54,9 @@ struct shape
if(shape{}.broadcasted()) if(shape{}.broadcasted())
{ {
index_array s{}; index_array s{};
index_int j = 0; auto out = copy_if(
for(index_int i = 0; i < s.size(); i++) lstrides.begin(), lstrides.end(), s.begin(), [](auto x) { return x != 0; });
{ return not is_sorted(s.begin(), out, greater{});
if(lstrides[i] != 0)
{
s[j] = lstrides[i];
j++;
}
}
return not is_sorted(s.begin(), s.begin() + j, greater{});
} }
else else
{ {
...@@ -47,6 +64,13 @@ struct shape ...@@ -47,6 +64,13 @@ struct shape
} }
}); });
} }
constexpr auto skips() const
{
return return_c([] {
auto lstrides = Strides{};
return none_of(lstrides.begin(), lstrides.end(), [](auto x) { return x == 1; });
});
}
constexpr auto standard() const { return packed() and not transposed(); } constexpr auto standard() const { return packed() and not transposed(); }
...@@ -63,26 +87,34 @@ struct shape ...@@ -63,26 +87,34 @@ struct shape
constexpr index_int index(index_int i) const constexpr index_int index(index_int i) const
{ {
if(this->standard()) if(this->standard())
{
MIGRAPHX_ASSERT(i == compute_index(i));
return i; return i;
}
else else
{ {
const auto rank = this->lens.size(); return compute_index(i);
index_int s = 1;
index_int result = 0;
for(index_int j = 0; j < rank; j++)
{
const index_int k = rank - j - 1;
const index_int stride = this->strides[k];
const index_int len = this->lens[k];
const index_int slen = s * len;
const index_int idx = (i % slen) / s;
result += stride * idx;
s = slen;
}
return result;
} }
} }
constexpr index_int compute_index(index_int i) const
{
const auto rank = this->lens.size();
index_int s = 1;
index_int result = 0;
for(index_int j = 0; j < rank; j++)
{
const index_int k = rank - j - 1;
const index_int stride = this->strides[k];
const index_int len = this->lens[k];
const index_int slen = s * len;
const index_int idx = (i % slen) / s;
result += stride * idx;
s = slen;
}
return result;
}
/// Convert single index into a multi-index /// Convert single index into a multi-index
constexpr index_array multi(index_int idx) const constexpr index_array multi(index_int idx) const
{ {
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_SOFTMAX_HPP
#define MIGRAPHX_GUARD_KERNELS_SOFTMAX_HPP
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/ops.hpp>
namespace migraphx {
template <index_int Axis, class Input, class Output>
__device__ void softmax(Input input, Output output)
{
reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) {
auto batch_max = r.reduce(op::max{}, lowest{}, op::id{})(input);
auto batch_sum =
r.reduce(op::sum{}, 0, [&](auto x) { return migraphx::exp(x - batch_max); })(input);
r.inner([&](auto& y, auto x) { y = migraphx::exp(x - batch_max) / batch_sum; })(output,
input);
});
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_SOFTMAX_HPP
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_TENSOR_VIEW_HPP #ifndef MIGRAPHX_GUARD_KERNELS_TENSOR_VIEW_HPP
#define MIGRAPHX_GUARD_KERNELS_TENSOR_VIEW_HPP #define MIGRAPHX_GUARD_KERNELS_TENSOR_VIEW_HPP
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP #ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP #define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP #ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP #define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_VEC_HPP #ifndef MIGRAPHX_GUARD_KERNELS_VEC_HPP
#define MIGRAPHX_GUARD_KERNELS_VEC_HPP #define MIGRAPHX_GUARD_KERNELS_VEC_HPP
#include <migraphx/kernels/types.hpp> #include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/functional.hpp> #include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/debug.hpp>
namespace migraphx { namespace migraphx {
...@@ -146,5 +171,19 @@ constexpr auto vec_packed_transform(Ts... xs) ...@@ -146,5 +171,19 @@ constexpr auto vec_packed_transform(Ts... xs)
}; };
} }
template <class T, class Op>
constexpr auto vec_reduce(T x, Op op)
{
if constexpr(vec_size<T>() < 2)
return x;
else
{
vec_type<T> result = x[0];
for(int i = 1; i < vec_size<T>(); i++)
result = op(result, x[i]);
return result;
}
}
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP #endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP #ifndef MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP
#define MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP #define MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP
...@@ -213,7 +236,9 @@ template <index_int N, index_int Axis, class T> ...@@ -213,7 +236,9 @@ template <index_int N, index_int Axis, class T>
__device__ __host__ auto vectorize_tensor(T x) __device__ __host__ auto vectorize_tensor(T x)
{ {
constexpr auto shape = get_shape_c<T>{}; constexpr auto shape = get_shape_c<T>{};
if constexpr(shape.strides[Axis] == 0) if constexpr(shape.lens[Axis] == 1)
return x;
else if constexpr(shape.strides[Axis] == 0)
return tensor_step<N>(x, _c<Axis>); return tensor_step<N>(x, _c<Axis>);
else else
return as_vec<N>(x, _c<Axis>); return as_vec<N>(x, _c<Axis>);
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/leaky_relu.hpp> #include <migraphx/gpu/leaky_relu.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/miopen.hpp>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/logsoftmax.hpp> #include <migraphx/gpu/logsoftmax.hpp>
#include <migraphx/gpu/device/logsoftmax.hpp> #include <migraphx/gpu/device/logsoftmax.hpp>
#include <migraphx/op/logsoftmax.hpp> #include <migraphx/op/logsoftmax.hpp>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/run_loop.hpp> #include <migraphx/run_loop.hpp>
#include <migraphx/gpu/loop.hpp> #include <migraphx/gpu/loop.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <iterator> #include <iterator>
#include <migraphx/gpu/lowering.hpp> #include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
...@@ -58,7 +81,6 @@ struct miopen_apply ...@@ -58,7 +81,6 @@ struct miopen_apply
const lowering* pass = nullptr; const lowering* pass = nullptr;
std::unordered_map<std::string, std::function<instruction_ref(instruction_ref)>> apply_map{}; std::unordered_map<std::string, std::function<instruction_ref(instruction_ref)>> apply_map{};
instruction_ref last{}; instruction_ref last{};
std::unordered_map<instruction_ref, std::string> prog_output_names{};
bool offload_copy = false; bool offload_copy = false;
bool int8_x4_format = true; bool int8_x4_format = true;
bool compute_fp32 = false; bool compute_fp32 = false;
...@@ -77,27 +99,6 @@ struct miopen_apply ...@@ -77,27 +99,6 @@ struct miopen_apply
(void)i; (void)i;
} }
void create_output_names()
{
this->last = instruction::get_output_alias(std::prev(mod->end()));
if(this->last->name() == "@return")
{
const auto& prog_outputs = last->inputs();
std::vector<instruction_ref> outputs_alias(prog_outputs.size());
std::transform(prog_outputs.begin(),
prog_outputs.end(),
outputs_alias.begin(),
[](const auto& i) { return instruction::get_output_alias(i); });
std::size_t index = 0;
for(auto ins : outputs_alias)
{
prog_output_names[ins] = mod->name() + ":#output_" + std::to_string(index++);
}
}
}
const std::unordered_set<std::string>& get_rocblas_fp32_archs() const std::unordered_set<std::string>& get_rocblas_fp32_archs()
{ {
static std::unordered_set<std::string> supported_archs{"gfx908", "gfx90a"}; static std::unordered_set<std::string> supported_archs{"gfx908", "gfx90a"};
...@@ -120,7 +121,6 @@ struct miopen_apply ...@@ -120,7 +121,6 @@ struct miopen_apply
#endif #endif
offload_copy = (mod->name() == "main") ? pass->offload_copy : false; offload_copy = (mod->name() == "main") ? pass->offload_copy : false;
create_output_names();
add_generic_op("acos"); add_generic_op("acos");
add_generic_op("acosh"); add_generic_op("acosh");
...@@ -186,7 +186,6 @@ struct miopen_apply ...@@ -186,7 +186,6 @@ struct miopen_apply
add_extend_op("rnn_var_sl_shift_output"); add_extend_op("rnn_var_sl_shift_output");
add_extend_op("rnn_var_sl_shift_sequence"); add_extend_op("rnn_var_sl_shift_sequence");
add_extend_op("scatter_none"); add_extend_op("scatter_none");
add_extend_op("softmax");
add_extend_op("topk"); add_extend_op("topk");
add_batch_norm_inference_op(); add_batch_norm_inference_op();
...@@ -201,7 +200,7 @@ struct miopen_apply ...@@ -201,7 +200,7 @@ struct miopen_apply
add_quant_convolution_op(); add_quant_convolution_op();
} }
void copy_params() void copy_params() const
{ {
if(not offload_copy) if(not offload_copy)
return; return;
...@@ -261,7 +260,7 @@ struct miopen_apply ...@@ -261,7 +260,7 @@ struct miopen_apply
copy_params(); copy_params();
} }
instruction_ref insert_precompile_op(instruction_ref ins) instruction_ref insert_precompile_op(instruction_ref ins) const
{ {
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
std::vector<instruction_ref> refs = ins->inputs(); std::vector<instruction_ref> refs = ins->inputs();
...@@ -274,28 +273,9 @@ struct miopen_apply ...@@ -274,28 +273,9 @@ struct miopen_apply
ins->module_inputs()); ins->module_inputs());
} }
instruction_ref insert_allocation(instruction_ref ins, const shape& s, std::string tag = "") instruction_ref insert_allocation(instruction_ref ins, const shape& s) const
{ {
// Instruction's output is an input of the ret instruction return mod->insert_instruction(ins, make_op("allocate", {{"shape", to_value(s)}}));
if(offload_copy)
{
auto result = mod->insert_instruction(
ins, make_op("hip::allocate", {{"shape", to_value(s)}, {"tag", std::move(tag)}}));
return result;
}
auto ins_alias = instruction::get_output_alias(ins);
if(last->name() == "@return" and tag.empty() and prog_output_names.count(ins_alias) > 0)
{
return mod->add_parameter(prog_output_names[ins_alias], s);
}
else if(ins == last and tag.empty())
{
return mod->add_parameter("output", s);
}
return mod->insert_instruction(
ins, make_op("hip::allocate", {{"shape", to_value(s)}, {"tag", std::move(tag)}}));
} }
void add_convolution_op() void add_convolution_op()
...@@ -306,7 +286,7 @@ struct miopen_apply ...@@ -306,7 +286,7 @@ struct miopen_apply
auto conv = miopen_convolution{op, make_conv(op)}; auto conv = miopen_convolution{op, make_conv(op)};
auto ws = conv.find(get_context(), ins->get_shape(), to_shapes(ins->inputs())); auto ws = conv.find(get_context(), ins->get_shape(), to_shapes(ins->inputs()));
auto workspace = insert_allocation(ins, ws, "workspace"); auto workspace = insert_allocation(ins, ws);
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
return mod->replace_instruction( return mod->replace_instruction(
...@@ -320,9 +300,9 @@ struct miopen_apply ...@@ -320,9 +300,9 @@ struct miopen_apply
auto&& op = any_cast<op::deconvolution>(ins->get_operator()); auto&& op = any_cast<op::deconvolution>(ins->get_operator());
auto conv = miopen_deconvolution{op, make_deconv(op)}; auto conv = miopen_deconvolution{op, make_deconv(op)};
auto ws = conv.compile(get_context(), ins->get_shape(), to_shapes(ins->inputs())); auto ws = conv.find(get_context(), ins->get_shape(), to_shapes(ins->inputs()));
auto workspace = insert_allocation(ins, ws, "workspace"); auto workspace = insert_allocation(ins, ws);
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
return mod->replace_instruction( return mod->replace_instruction(
...@@ -335,27 +315,9 @@ struct miopen_apply ...@@ -335,27 +315,9 @@ struct miopen_apply
{ {
apply_map.emplace(name, [=](instruction_ref ins) { apply_map.emplace(name, [=](instruction_ref ins) {
std::vector<instruction_ref> refs = ins->inputs(); std::vector<instruction_ref> refs = ins->inputs();
if(refs.size() == 2) assert(refs.size() == 2);
{ auto output = insert_allocation(ins, ins->get_shape());
auto output = insert_allocation(ins, ins->get_shape()); refs.push_back(output);
refs.push_back(output);
}
else
{
auto c_alias = instruction::get_output_alias(refs.back());
if(ins == last or refs.back()->outputs().size() > 1 or c_alias->inputs().empty())
{
auto output = insert_allocation(ins, ins->get_shape());
auto copy_out =
mod->insert_instruction(ins, make_op("hip::copy"), refs.back(), output);
refs.back() = copy_out;
refs.push_back(copy_out);
}
else
{
refs.push_back(refs.back());
}
}
return mod->replace_instruction( return mod->replace_instruction(
ins, rocblas_gemm<Op>{Op{}, 1, 0, int8_x4_format, compute_fp32}, refs); ins, rocblas_gemm<Op>{Op{}, 1, 0, int8_x4_format, compute_fp32}, refs);
}); });
...@@ -365,11 +327,25 @@ struct miopen_apply ...@@ -365,11 +327,25 @@ struct miopen_apply
{ {
apply_map.emplace("quant_convolution", [=](instruction_ref ins) { apply_map.emplace("quant_convolution", [=](instruction_ref ins) {
auto&& op = any_cast<op::quant_convolution>(ins->get_operator()); auto&& op = any_cast<op::quant_convolution>(ins->get_operator());
auto conv = miopen_quant_convolution{op, make_conv(op)}; shape ws;
auto ws = conv.compile(get_context(), ins->get_shape(), to_shapes(ins->inputs())); miopen_quant_convolution conv;
auto compile_quant_conv_with_format = [&](bool format) {
conv = miopen_quant_convolution{op, format, make_conv(op)};
ws = conv.find(get_context(), ins->get_shape(), to_shapes(ins->inputs()));
};
try
{
compile_quant_conv_with_format(int8_x4_format);
}
catch(migraphx::exception&)
{
// In case no solver supports the default format, retry using the other format.
compile_quant_conv_with_format(!int8_x4_format);
}
auto args = ins->inputs(); auto args = ins->inputs();
auto workspace = insert_allocation(ins, ws, "workspace"); auto workspace = insert_allocation(ins, ws);
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
return mod->replace_instruction(ins, conv, args[0], args[1], workspace, output); return mod->replace_instruction(ins, conv, args[0], args[1], workspace, output);
...@@ -466,33 +442,7 @@ struct miopen_apply ...@@ -466,33 +442,7 @@ struct miopen_apply
auto sync_cond = mod->insert_instruction(ins, make_op("hip::sync_stream"), cpu_cond); auto sync_cond = mod->insert_instruction(ins, make_op("hip::sync_stream"), cpu_cond);
inputs.front() = sync_cond; inputs.front() = sync_cond;
std::vector<module_ref> mod_args = ins->module_inputs(); return mod->replace_instruction(ins, ins->get_operator(), inputs, ins->module_inputs());
std::map<std::string, shape> name_shapes;
for(const auto& smod : mod_args)
{
auto ps = smod->get_parameter_shapes();
name_shapes.insert(ps.begin(), ps.end());
}
bool ins_output_allocated = false;
for(auto& pn : name_shapes)
{
const auto& s = pn.second;
instruction_ref output{};
if(s == ins->get_shape() and not ins_output_allocated)
{
output = insert_allocation(ins, s);
ins_output_allocated = true;
}
else
{
output = mod->insert_instruction(
ins, make_op("hip::allocate", {{"shape", to_value(s)}}));
}
inputs.push_back(output);
}
return mod->replace_instruction(ins, ins->get_operator(), inputs, mod_args);
}); });
} }
...@@ -511,20 +461,17 @@ struct miopen_apply ...@@ -511,20 +461,17 @@ struct miopen_apply
inputs.at(0) = synced_max_iter; inputs.at(0) = synced_max_iter;
inputs.at(1) = cpu_cond; inputs.at(1) = cpu_cond;
auto copy_inputs = inputs; auto copy_inputs = inputs;
std::transform( std::transform(copy_inputs.begin(),
copy_inputs.begin(), copy_inputs.end(), std::back_inserter(inputs), [&](auto in) { copy_inputs.end(),
return mod->insert_instruction( std::back_inserter(inputs),
ins, make_op("hip::allocate", {{"shape", to_value(in->get_shape())}})); [&](auto in) { return insert_allocation(ins, in->get_shape()); });
});
auto mod_args = ins->module_inputs(); auto mod_args = ins->module_inputs();
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
const auto* sub_mod = mod_args.front(); const auto* sub_mod = mod_args.front();
auto cond_out = mod->insert_instruction( auto cond_out = insert_allocation(ins, sub_mod->get_output_shapes().front());
ins,
make_op("hip::allocate",
{{"shape", to_value(sub_mod->get_output_shapes().front())}}));
// add cond and mod outputs to the argument list // add cond and mod outputs to the argument list
inputs.push_back(cond_out); inputs.push_back(cond_out);
inputs.push_back(output); inputs.push_back(output);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment