"llm/vscode:/vscode.git/clone" did not exist on "9220b4fa916e3a5434083caa3d9920fb9e2149ba"
Unverified Commit 6aa6c954 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Use parallel STL for parallel execution (#2165)

parent d3e5a5c0
...@@ -22,6 +22,8 @@ def rocmtestnode(Map conf) { ...@@ -22,6 +22,8 @@ def rocmtestnode(Map conf) {
def cmd = """ def cmd = """
ulimit -c unlimited ulimit -c unlimited
echo "leak:dnnl::impl::malloc" > suppressions.txt echo "leak:dnnl::impl::malloc" > suppressions.txt
echo "leak:libtbb.so" >> suppressions.txt
cat suppressions.txt
export LSAN_OPTIONS="suppressions=\$(pwd)/suppressions.txt" export LSAN_OPTIONS="suppressions=\$(pwd)/suppressions.txt"
export MIGRAPHX_GPU_DEBUG=${gpu_debug} export MIGRAPHX_GPU_DEBUG=${gpu_debug}
export CXX=${compiler} export CXX=${compiler}
......
...@@ -28,6 +28,7 @@ include(ROCMInstallTargets) ...@@ -28,6 +28,7 @@ include(ROCMInstallTargets)
include(ROCMPackageConfigHelpers) include(ROCMPackageConfigHelpers)
include(RegisterOp) include(RegisterOp)
include(CheckCXXLinkerFlag) include(CheckCXXLinkerFlag)
include(CheckCXXSourceCompiles)
add_library(migraphx add_library(migraphx
adjust_allocation.cpp adjust_allocation.cpp
...@@ -263,6 +264,50 @@ endif() ...@@ -263,6 +264,50 @@ endif()
target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>) target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
target_link_libraries(migraphx PUBLIC Threads::Threads) target_link_libraries(migraphx PUBLIC Threads::Threads)
function(check_execution_par RESULT)
set(CMAKE_REQUIRED_LIBRARIES ${ARGN})
set(CMAKE_REQUIRED_FLAGS)
if(NOT MSVC)
set(CMAKE_REQUIRED_FLAGS "-std=c++17")
endif()
string(MD5 _flags_hash "${CMAKE_REQUIRED_FLAGS} ${CMAKE_REQUIRED_LIBRARIES}")
set(_source "
#include <execution>
int main() {
int* i = nullptr;
std::sort(std::execution::par, i, i);
}
")
check_cxx_source_compiles("${_source}" _has_execution_${_flags_hash})
set(${RESULT} ${_has_execution_${_flags_hash}} PARENT_SCOPE)
endfunction()
set(MIGRAPHX_HAS_EXECUTORS_DEFAULT Off)
find_package(TBB)
if(TBB_FOUND)
check_execution_par(TBB_HAS_EXECUTION_PAR TBB::tbb)
if(TBB_HAS_EXECUTION_PAR)
target_link_libraries(migraphx PUBLIC TBB::tbb)
set(MIGRAPHX_HAS_EXECUTORS_DEFAULT On)
message(STATUS "Using TBB for parallel execution")
endif()
else()
check_execution_par(HAS_EXECUTION_PAR)
if(HAS_EXECUTION_PAR)
set(MIGRAPHX_HAS_EXECUTORS_DEFAULT On)
endif()
endif()
option(MIGRAPHX_HAS_EXECUTORS "C++ supports parallel executors" ${MIGRAPHX_HAS_EXECUTORS_DEFAULT})
if(MIGRAPHX_HAS_EXECUTORS)
message("Parallel STL enabled")
target_compile_definitions(migraphx PUBLIC MIGRAPHX_HAS_EXECUTORS=1)
else()
message("Parallel STL disabled")
target_compile_definitions(migraphx PUBLIC MIGRAPHX_HAS_EXECUTORS=0)
endif()
find_package(nlohmann_json 3.8.0 REQUIRED) find_package(nlohmann_json 3.8.0 REQUIRED)
target_link_libraries(migraphx PRIVATE nlohmann_json::nlohmann_json) target_link_libraries(migraphx PRIVATE nlohmann_json::nlohmann_json)
migraphx_generate_export_header(migraphx) migraphx_generate_export_header(migraphx)
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp> #include <migraphx/dyn_output.hpp>
#include <migraphx/par.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -95,11 +96,11 @@ struct binary : op_name<Derived> ...@@ -95,11 +96,11 @@ struct binary : op_name<Derived>
{ {
argument result{dyn_out.computed_shape}; argument result{dyn_out.computed_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) { visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
std::transform(input1.begin(), par_transform(input1.begin(),
input1.end(), input1.end(),
input2.begin(), input2.begin(),
output.begin(), output.begin(),
static_cast<const Derived&>(*this).apply()); static_cast<const Derived&>(*this).apply());
}); });
return result; return result;
} }
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp> #include <migraphx/dyn_output.hpp>
#include <migraphx/par.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -84,10 +85,10 @@ struct unary : op_name<Derived> ...@@ -84,10 +85,10 @@ struct unary : op_name<Derived>
argument result{dyn_out.computed_shape}; argument result{dyn_out.computed_shape};
result.visit([&](auto output) { result.visit([&](auto output) {
args[0].visit([&](auto input) { args[0].visit([&](auto input) {
std::transform(input.begin(), par_transform(input.begin(),
input.end(), input.end(),
output.begin(), output.begin(),
static_cast<const Derived&>(*this).apply()); static_cast<const Derived&>(*this).apply());
}); });
}); });
return result; return result;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_PAR_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_PAR_HPP
#include <migraphx/config.hpp>
#if MIGRAPHX_HAS_EXECUTORS
#include <execution>
#else
#include <migraphx/simple_par_for.hpp>
#endif
#include <algorithm>
#include <mutex>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace detail {
struct exception_list
{
std::vector<std::exception_ptr> exceptions;
std::mutex m;
void add_exception()
{
std::lock_guard<std::mutex> guard(m);
exceptions.push_back(std::current_exception());
}
template <class F>
auto collect(F f)
{
return [f, this](auto&&... xs) {
try
{
f(std::forward<decltype(xs)>(xs)...);
}
catch(...)
{
this->add_exception();
}
};
}
void throw_if_exception() const
{
if(not exceptions.empty())
std::rethrow_exception(exceptions.front());
}
};
} // namespace detail
template <class InputIt, class OutputIt, class UnaryOperation>
OutputIt par_transform(InputIt first1, InputIt last1, OutputIt d_first, UnaryOperation unary_op)
{
#if MIGRAPHX_HAS_EXECUTORS
return std::transform(std::execution::par, first1, last1, d_first, std::move(unary_op));
#else
simple_par_for(last1 - first1, [&](auto i) { d_first[i] = unary_op(first1[i]); });
return d_first + (last1 - first1);
#endif
}
template <class InputIt1, class InputIt2, class OutputIt, class BinaryOperation>
OutputIt par_transform(
InputIt1 first1, InputIt1 last1, InputIt2 first2, OutputIt d_first, BinaryOperation binary_op)
{
#if MIGRAPHX_HAS_EXECUTORS
return std::transform(
std::execution::par, first1, last1, first2, d_first, std::move(binary_op));
#else
simple_par_for(last1 - first1, [&](auto i) { d_first[i] = binary_op(first1[i], first2[i]); });
return d_first + (last1 - first1);
#endif
}
template <class InputIt, class UnaryFunction>
void par_for_each(InputIt first, InputIt last, UnaryFunction f)
{
#if MIGRAPHX_HAS_EXECUTORS
// Propagate the exception
detail::exception_list ex;
std::for_each(std::execution::par, first, last, ex.collect(std::move(f)));
ex.throw_if_exception();
#else
simple_par_for(last - first, [&](auto i) { f(first[i]); });
#endif
}
template <class... Ts>
auto par_copy_if(Ts&&... xs)
{
#if MIGRAPHX_HAS_EXECUTORS
return std::copy_if(std::execution::par, std::forward<Ts>(xs)...);
#else
return std::copy_if(std::forward<Ts>(xs)...);
#endif
}
template <class... Ts>
auto par_sort(Ts&&... xs)
{
#if MIGRAPHX_HAS_EXECUTORS
return std::sort(std::execution::par, std::forward<Ts>(xs)...);
#else
return std::sort(std::forward<Ts>(xs)...);
#endif
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_PAR_HPP
...@@ -24,93 +24,23 @@ ...@@ -24,93 +24,23 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP #define MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP
#include <thread> #include <migraphx/par.hpp>
#include <cmath> #include <migraphx/ranges.hpp>
#include <algorithm>
#include <vector>
#include <cassert>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct joinable_thread : std::thread
{
template <class... Xs>
joinable_thread(Xs&&... xs) : std::thread(std::forward<Xs>(xs)...) // NOLINT
{
}
joinable_thread& operator=(joinable_thread&& other) = default;
joinable_thread(joinable_thread&& other) = default;
~joinable_thread()
{
if(this->joinable())
this->join();
}
};
template <class F>
auto thread_invoke(std::size_t i, std::size_t tid, F f) -> decltype(f(i, tid))
{
f(i, tid);
}
template <class F>
auto thread_invoke(std::size_t i, std::size_t, F f) -> decltype(f(i))
{
f(i);
}
template <class F>
void par_for_impl(std::size_t n, std::size_t threadsize, F f)
{
if(threadsize <= 1)
{
for(std::size_t i = 0; i < n; i++)
thread_invoke(i, 0, f);
}
else
{
std::vector<joinable_thread> threads(threadsize);
// Using const here causes gcc 5 to ICE
#if(!defined(__GNUC__) || __GNUC__ != 5)
const
#endif
std::size_t grainsize = std::ceil(static_cast<double>(n) / threads.size());
std::size_t work = 0;
std::size_t tid = 0;
std::generate(threads.begin(), threads.end(), [=, &work, &tid] {
auto result = joinable_thread([=] {
std::size_t start = work;
std::size_t last = std::min(n, work + grainsize);
for(std::size_t i = start; i < last; i++)
{
thread_invoke(i, tid, f);
}
});
work += grainsize;
++tid;
return result;
});
assert(work >= n);
}
}
template <class F> template <class F>
void par_for(std::size_t n, std::size_t min_grain, F f) void par_for(std::size_t n, F f)
{ {
const auto threadsize = std::min<std::size_t>(std::thread::hardware_concurrency(), using iterator = basic_iota_iterator<id, std::size_t>;
n / std::max<std::size_t>(1, min_grain)); par_for_each(iterator{0, {}}, iterator{n, {}}, f);
par_for_impl(n, threadsize, f);
} }
template <class F> template <class F>
void par_for(std::size_t n, F f) void par_for(std::size_t n, std::size_t, F f)
{ {
const int min_grain = 8; par_for(n, f);
par_for(n, min_grain, f);
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_SIMPLE_PAR_FOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_SIMPLE_PAR_FOR_HPP
#include <thread>
#include <cmath>
#include <algorithm>
#include <vector>
#include <cassert>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct joinable_thread : std::thread
{
template <class... Xs>
joinable_thread(Xs&&... xs) : std::thread(std::forward<Xs>(xs)...) // NOLINT
{
}
joinable_thread& operator=(joinable_thread&& other) = default;
joinable_thread(joinable_thread&& other) = default;
~joinable_thread()
{
if(this->joinable())
this->join();
}
};
template <class F>
auto thread_invoke(std::size_t i, std::size_t tid, F f) -> decltype(f(i, tid))
{
f(i, tid);
}
template <class F>
auto thread_invoke(std::size_t i, std::size_t, F f) -> decltype(f(i))
{
f(i);
}
template <class F>
void simple_par_for_impl(std::size_t n, std::size_t threadsize, F f)
{
if(threadsize <= 1)
{
for(std::size_t i = 0; i < n; i++)
thread_invoke(i, 0, f);
}
else
{
std::vector<joinable_thread> threads(threadsize);
// Using const here causes gcc 5 to ICE
#if(!defined(__GNUC__) || __GNUC__ != 5)
const
#endif
std::size_t grainsize = std::ceil(static_cast<double>(n) / threads.size());
std::size_t work = 0;
std::size_t tid = 0;
std::generate(threads.begin(), threads.end(), [=, &work, &tid] {
auto result = joinable_thread([=] {
std::size_t start = work;
std::size_t last = std::min(n, work + grainsize);
for(std::size_t i = start; i < last; i++)
{
thread_invoke(i, tid, f);
}
});
work += grainsize;
++tid;
return result;
});
assert(work >= n);
}
}
template <class F>
void simple_par_for(std::size_t n, std::size_t min_grain, F f)
{
const auto threadsize = std::min<std::size_t>(std::thread::hardware_concurrency(),
n / std::max<std::size_t>(1, min_grain));
simple_par_for_impl(n, threadsize, f);
}
template <class F>
void simple_par_for(std::size_t n, F f)
{
const int min_grain = 8;
simple_par_for(n, min_grain, f);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/iterator.hpp> #include <migraphx/iterator.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/simple_par_for.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/dom_info.hpp> #include <migraphx/dom_info.hpp>
...@@ -461,7 +461,7 @@ struct stream_info ...@@ -461,7 +461,7 @@ struct stream_info
std::back_inserter(index_to_ins), std::back_inserter(index_to_ins),
[](auto&& it) { return it.first; }); [](auto&& it) { return it.first; });
par_for(concur_ins.size(), [&](auto ins_index, auto tid) { simple_par_for(concur_ins.size(), [&](auto ins_index, auto tid) {
auto merge_first = index_to_ins[ins_index]; auto merge_first = index_to_ins[ins_index];
assert(concur_ins.count(merge_first) > 0); assert(concur_ins.count(merge_first) > 0);
auto& merge_second = concur_ins.at(merge_first); auto& merge_second = concur_ins.at(merge_first);
......
...@@ -53,7 +53,8 @@ else ...@@ -53,7 +53,8 @@ else
python3-pip \ python3-pip \
python3-venv \ python3-venv \
rocblas-dev \ rocblas-dev \
rocm-cmake rocm-cmake \
libtbb-dev
fi fi
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment