Commit c00100f9 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into nhwc_workaround

parents cd4ab535 ad618aa9
...@@ -48,7 +48,9 @@ endif() ...@@ -48,7 +48,9 @@ endif()
set(CMAKE_BUILD_RPATH "${CMAKE_BINARY_DIR}/lib") set(CMAKE_BUILD_RPATH "${CMAKE_BINARY_DIR}/lib")
project(migraphx) project(migraphx LANGUAGES C CXX)
include(CTest)
find_package(ROCM REQUIRED) find_package(ROCM REQUIRED)
find_path(HALF_INCLUDE_DIR half.hpp PATH_SUFFIXES half) find_path(HALF_INCLUDE_DIR half.hpp PATH_SUFFIXES half)
...@@ -269,7 +271,9 @@ set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib) ...@@ -269,7 +271,9 @@ set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/bin) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/bin)
add_subdirectory(src) add_subdirectory(src)
add_subdirectory(docs) add_subdirectory(docs)
add_subdirectory(test) if(BUILD_TESTING)
add_subdirectory(test)
endif()
add_subdirectory(tools) add_subdirectory(tools)
set(DEST_DIR ${CMAKE_BINARY_DIR}) set(DEST_DIR ${CMAKE_BINARY_DIR})
......
...@@ -12,6 +12,9 @@ RUN apt-get update && apt-get install -y gnupg2 --no-install-recommends curl && ...@@ -12,6 +12,9 @@ RUN apt-get update && apt-get install -y gnupg2 --no-install-recommends curl &&
# Add rocm repository # Add rocm repository
RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/5.5/ focal main > /etc/apt/sources.list.d/rocm.list' RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/5.5/ focal main > /etc/apt/sources.list.d/rocm.list'
# From docs.amd.com for installing rocm. Needed to install properly
RUN sh -c "echo 'Package: *\nPin: release o=repo.radeon.com\nPin-priority: 600' > /etc/apt/preferences.d/rocm-pin-600"
# Install dependencies # Install dependencies
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
apt-utils \ apt-utils \
...@@ -110,7 +113,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR ...@@ -110,7 +113,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
RUN cget -p /usr/local install ROCmSoftwarePlatform/rocMLIR@a997d5f51314b45d7a4c04f1599966dcf53f9b4d -DBUILD_MIXR_TARGET=On -DLLVM_ENABLE_ZSTD=Off -DLLVM_ENABLE_THREADS=Off RUN cget -p /usr/local install ROCmSoftwarePlatform/rocMLIR@8d25af3b3721c159bb41cc6388e9453b1018c126 -DBUILD_MIXR_TARGET=On -DLLVM_ENABLE_ZSTD=Off -DLLVM_ENABLE_THREADS=Off
ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include <migraphx/optional.hpp> #include <migraphx/optional.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/type_name.hpp> #include <migraphx/type_name.hpp>
#include <migraphx/source_location.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
...@@ -370,31 +371,30 @@ match::matcher_result find_match(module& modl, M&& m) ...@@ -370,31 +371,30 @@ match::matcher_result find_match(module& modl, M&& m)
} }
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES_FOR)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_VALIDATE_MATCHES) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_VALIDATE_MATCHES)
/// Find matches for an instruction in the module for per section of matchers /// Find matches for an instruction in the module for per section of matchers
template <class Mod, class... Ms> template <class Mod, class... Ms>
void find_matches(size_t trace_pass, Mod& mod, instruction_ref ins, Ms&&... ms) void find_matches_for(source_location location, Mod& mod, instruction_ref ins, Ms&&... ms)
{ {
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5 const int trace = value_of(MIGRAPHX_TRACE_MATCHES{});
const const bool validate = enabled(MIGRAPHX_VALIDATE_MATCHES{});
#endif const auto trace_filter = string_value_of(MIGRAPHX_TRACE_MATCHES_FOR{});
int trace = value_of(MIGRAPHX_TRACE_MATCHES{}); const bool trace_for = not trace_filter.empty() and
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5 (contains(std::string{location.file_name()}, trace_filter) or
const contains(std::string{location.function_name()}, trace_filter));
#endif bool match = false;
bool validate = enabled(MIGRAPHX_VALIDATE_MATCHES{});
bool match = false;
each_args( each_args(
[&](auto&& m) { [&](auto&& m) {
if(match) if(match)
return; return;
if(trace > 1 or trace_pass > 1) if(trace > 1 or trace_for)
std::cout << "Match: " << get_type_name(m) << std::endl; std::cout << "Match: " << get_type_name(m) << std::endl;
auto r = match_instruction(get_module(mod), ins, m.matcher()); auto r = match_instruction(get_module(mod), ins, m.matcher());
if(r.result == get_module(mod).end()) if(r.result == get_module(mod).end())
return; return;
if(trace > 0 or trace_pass > 0) if(trace > 0 or trace_for)
{ {
std::cout << "Matched by " << get_type_name(m) << std::endl; std::cout << "Matched by " << get_type_name(m) << std::endl;
get_module(mod).debug_print(ins); get_module(mod).debug_print(ins);
...@@ -420,23 +420,19 @@ void find_matches(size_t trace_pass, Mod& mod, instruction_ref ins, Ms&&... ms) ...@@ -420,23 +420,19 @@ void find_matches(size_t trace_pass, Mod& mod, instruction_ref ins, Ms&&... ms)
/// Find matches in a module /// Find matches in a module
template <class Mod, class... Ms> template <class Mod, class... Ms>
void find_matches(Mod& mod, Ms&&... ms) struct find_matches
{ {
for(auto ins : iterator_for(get_module(mod))) find_matches(Mod& mod, Ms&&... ms, source_location location = source_location::current())
{ {
find_matches(0, mod, ins, ms...); for(auto ins : iterator_for(get_module(mod)))
{
find_matches_for(location, mod, ins, ms...);
}
} }
} };
/// Find matches in a pass
template <class Mod, class... Ms> template <class Mod, class... Ms>
void find_matches(size_t trace_pass, Mod& mod, Ms&&... ms) find_matches(Mod& mod, Ms&&... ms) -> find_matches<Mod, Ms...>;
{
for(auto ins : iterator_for(get_module(mod)))
{
find_matches(trace_pass, mod, ins, ms...);
}
}
template <class M, class F> template <class M, class F>
struct find_generic_match struct find_generic_match
......
...@@ -66,17 +66,7 @@ struct convert : unary<convert> ...@@ -66,17 +66,7 @@ struct convert : unary<convert>
auto type = target_type; auto type = target_type;
return [type](auto x) { return [type](auto x) {
auto y = x; auto y = x;
shape::visit(type, [&](auto as) { shape::visit(type, [&](auto as) { y = as(x); });
// clamping value between target_type's max and min doesn't work for NaNs,
if(std::isnan(x))
{
y = as.nan();
}
else
{
y = std::min(std::max(as(x), as.min()), as.max());
}
});
return y; return y;
}; };
} }
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_SOURCE_LOCATION_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_SOURCE_LOCATION_HPP
#include <migraphx/config.hpp>
#if defined(CPPCHECK)
#define MIGRAPHX_HAS_SOURCE_LOCATION 1
#define MIGRAPHX_HAS_SOURCE_LOCATION_TS 1
#elif defined(__has_include)
#if __has_include(<source_location>) && __cplusplus >= 202003L
#define MIGRAPHX_HAS_SOURCE_LOCATION 1
#else
#define MIGRAPHX_HAS_SOURCE_LOCATION 0
#endif
#if __has_include(<experimental/source_location>) && __cplusplus >= 201103L
#define MIGRAPHX_HAS_SOURCE_LOCATION_TS 1
#else
#define MIGRAPHX_HAS_SOURCE_LOCATION_TS 0
#endif
#else
#define MIGRAPHX_HAS_SOURCE_LOCATION 0
#define MIGRAPHX_HAS_SOURCE_LOCATION_TS 0
#endif
#if MIGRAPHX_HAS_SOURCE_LOCATION
#include <source_location>
#elif MIGRAPHX_HAS_SOURCE_LOCATION_TS
#include <experimental/source_location>
#endif
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#if MIGRAPHX_HAS_SOURCE_LOCATION
using source_location = std::source_location;
#elif MIGRAPHX_HAS_SOURCE_LOCATION_TS
using source_location = std::experimental::source_location;
#else
struct source_location
{
static constexpr source_location current() noexcept { return source_location{}; }
constexpr std::uint_least32_t line() const noexcept { return 0; }
constexpr std::uint_least32_t column() const noexcept { return 0; }
constexpr const char* file_name() const noexcept { return ""; }
constexpr const char* function_name() const noexcept { return ""; }
};
#endif
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_SOURCE_LOCATION_HPP
...@@ -31,8 +31,6 @@ namespace migraphx { ...@@ -31,8 +31,6 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace onnx { namespace onnx {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK)
struct parse_where : op_parser<parse_where> struct parse_where : op_parser<parse_where>
{ {
std::vector<op_desc> operators() const { return {{"Where"}}; } std::vector<op_desc> operators() const { return {{"Where"}}; }
...@@ -59,13 +57,6 @@ struct parse_where : op_parser<parse_where> ...@@ -59,13 +57,6 @@ struct parse_where : op_parser<parse_where>
compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens()); compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens());
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens()); lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens());
if(enabled(MIGRAPHX_ENABLE_CK{}))
{
// Convert condition tensor to int32 to work around CK not supporting bool type
args[0] = info.add_instruction(
make_op("convert", {{"target_type", shape::int32_type}}), args[0]);
}
if(args[0]->get_shape().lens() != lens) if(args[0]->get_shape().lens() != lens)
{ {
args[0] = args[0] =
......
...@@ -359,6 +359,31 @@ std::string classify(T x) ...@@ -359,6 +359,31 @@ std::string classify(T x)
} }
} }
void print_statistics(std::ostream& os, const argument& a)
{
a.visit(
[&](auto t) {
os << "Min value: " << *std::min_element(t.begin(), t.end()) << ", ";
os << "Max value: " << *std::max_element(t.begin(), t.end()) << ", ";
double num_elements = t.size();
auto mean = std::reduce(t.begin(), t.end(), 0.0) / num_elements;
auto stddev = std::sqrt(
std::accumulate(t.begin(),
t.end(),
0.0,
[&](auto r, auto v) { return r + std::pow((v - mean), 2.0); }) /
num_elements);
os << "Mean: " << mean << ", ";
os << "StdDev: " << stddev << "\n";
},
[&](const auto& xs) {
for(const auto& x : xs)
{
print_statistics(os, x);
}
});
}
std::unordered_set<std::string> classify_argument(const argument& a) std::unordered_set<std::string> classify_argument(const argument& a)
{ {
std::unordered_set<std::string> result; std::unordered_set<std::string> result;
...@@ -578,6 +603,7 @@ std::vector<argument> program::eval(parameter_map params, execution_environment ...@@ -578,6 +603,7 @@ std::vector<argument> program::eval(parameter_map params, execution_environment
std::cout << "Output: "; std::cout << "Output: ";
preview_argument(std::cout, buffer); preview_argument(std::cout, buffer);
std::cout << std::endl; std::cout << std::endl;
print_statistics(std::cout, buffer);
} }
else else
{ {
......
...@@ -39,8 +39,6 @@ ...@@ -39,8 +39,6 @@
#include <migraphx/algorithm.hpp> #include <migraphx/algorithm.hpp>
#include <unordered_set> #include <unordered_set>
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_SIMPLIFY_ALGEBRA_MATCHES)
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -1487,13 +1485,10 @@ struct find_split_transpose ...@@ -1487,13 +1485,10 @@ struct find_split_transpose
void simplify_algebra::apply(module& m) const void simplify_algebra::apply(module& m) const
{ {
size_t trace = value_of(MIGRAPHX_TRACE_SIMPLIFY_ALGEBRA_MATCHES{});
// Run simplifications multiple times // Run simplifications multiple times
for(int i = 0; i < 8; i++) for(int i = 0; i < 8; i++)
{ {
match::find_matches(trace, match::find_matches(m,
m,
find_inner_broadcast{}, find_inner_broadcast{},
find_dot_broadcast{}, find_dot_broadcast{},
find_double_add_lit_broadcast{}, find_double_add_lit_broadcast{},
......
...@@ -88,8 +88,6 @@ foreach(LIBRARY ${OpenMP_CXX_LIBRARIES}) ...@@ -88,8 +88,6 @@ foreach(LIBRARY ${OpenMP_CXX_LIBRARIES})
endif() endif()
endforeach() endforeach()
target_link_libraries(migraphx_all_targets INTERFACE migraphx_cpu)
rocm_install_targets( rocm_install_targets(
TARGETS migraphx_cpu TARGETS migraphx_cpu
INCLUDE INCLUDE
......
...@@ -139,7 +139,8 @@ struct find_mlir_op ...@@ -139,7 +139,8 @@ struct find_mlir_op
auto matcher() const auto matcher() const
{ {
auto dot_or_conv = match::skip(match::name("contiguous"))( auto dot_or_conv = match::skip(match::name("contiguous"))(
match::any_of(match::name("dot"), is_mlir_conv()).bind("gemm_based_op")); match::any_of(match::name("dot"), match::name("quant_dot"), is_mlir_conv())
.bind("gemm_based_op"));
return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x"))); return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x")));
} }
...@@ -190,6 +191,68 @@ struct find_mlir_op ...@@ -190,6 +191,68 @@ struct find_mlir_op
return {new_gemm_based_op, top_inputs}; return {new_gemm_based_op, top_inputs};
} }
// Whitelist supported fusion options, including imposing type constraints
// for cases where MLIR only supports an operation (usually a pointwise function)
// on particular types.
bool is_pointwise_op_supported_by_mlir(const instruction& i) const
{
using type_t = shape::type_t;
const auto& name = i.name();
const auto result_type = i.get_shape().type();
const std::initializer_list<type_t> allowed_types = {type_t::float_type,
type_t::half_type,
type_t::int8_type,
type_t::int32_type,
type_t::bool_type};
// Preliminary type check.
if(not contains(allowed_types, result_type))
{
return false;
}
const std::initializer_list<std::string> any_type_ops = {"@literal", "@param", "@return"};
const std::initializer_list<std::string> no_bool_ops = {"convolution",
"quant_convolution",
"dot",
"quant_dot",
"add",
"clip",
"sub",
"mul",
"div",
"pow",
"where",
"quantizelinear",
"dequantizelinear",
"abs",
"neg"};
const std::initializer_list<std::string> fp_only_ops = {"ceil",
"erf",
"exp",
"floor",
"log",
"recip",
"rsqrt",
"sigmoid"
"softmax",
"tanh"};
bool is_float = contains({type_t::float_type, type_t::half_type}, result_type);
if(contains(any_type_ops, name))
return true;
if(result_type != type_t::bool_type && contains(no_bool_ops, name))
return true;
if(is_float && contains(fp_only_ops, name))
return true;
// Only conversions between floating types are known to be unambigiously
// supported.
if(is_float && name == "convert")
{
return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) {
return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type());
});
}
return false;
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
...@@ -197,31 +260,12 @@ struct find_mlir_op ...@@ -197,31 +260,12 @@ struct find_mlir_op
auto x_ins = r.instructions["x"]; // input after contiguous auto x_ins = r.instructions["x"]; // input after contiguous
auto* pm = ins->module_inputs().front(); auto* pm = ins->module_inputs().front();
auto names = pm->get_parameter_names(); auto names = pm->get_parameter_names();
// Whitelist pointwise operators // Whitelist pointwise operators.
if(std::any_of(pm->begin(), pm->end(), [](const auto& i) { if(std::any_of(pm->begin(), pm->end(), [&](const auto& i) {
return not contains({"@literal", return not is_pointwise_op_supported_by_mlir(i);
"@param",
"@return",
"convolution",
"quant_convolution",
"dot",
"add",
"relu",
"dequantizelinear",
"quantizelinear",
"mul"},
i.name());
}))
return;
// Only fuse with fp32/fp16/int8/int32
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return not contains({shape::type_t::float_type,
shape::type_t::half_type,
shape::type_t::int8_type,
shape::type_t::int32_type},
i->get_shape().type());
})) }))
return; return;
std::sort(names.begin(), names.end()); std::sort(names.begin(), names.end());
module_ref mm = mpm.create_module("mlir_" + pm->name()); module_ref mm = mpm.create_module("mlir_" + pm->name());
mm->set_bypass(); mm->set_bypass();
......
...@@ -121,7 +121,10 @@ struct mlir_handle ...@@ -121,7 +121,10 @@ struct mlir_handle
#define MIGRAPHX_MANAGE_MLIR_HANDLE(T, F) migraphx::gpu::mlir_handle<T, decltype(&F), &F> // NOLINT #define MIGRAPHX_MANAGE_MLIR_HANDLE(T, F) migraphx::gpu::mlir_handle<T, decltype(&F), &F> // NOLINT
using mlir_context = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirContext, mlirContextDestroy); using mlir_context = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirContext, mlirContextDestroy);
using mlir_thread_pool = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirLlvmThreadPool, mlirLlvmThreadPoolDestroy);
using mlir_dialect_registry = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirDialectRegistry,
mlirDialectRegistryDestroy);
using mlir_module = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirModule, mlirModuleDestroy); using mlir_module = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirModule, mlirModuleDestroy);
using mlir_operation = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirOperation, mlirOperationDestroy); using mlir_operation = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirOperation, mlirOperationDestroy);
using mlir_op_printing_flags = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirOpPrintingFlags, using mlir_op_printing_flags = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirOpPrintingFlags,
...@@ -173,16 +176,38 @@ bool has_xdlops(const std::string& target_arch) ...@@ -173,16 +176,38 @@ bool has_xdlops(const std::string& target_arch)
struct mlir_program struct mlir_program
{ {
mlir_program() mlir_program()
: ctx(mlirContextCreate()), : ctx(mlirContextCreateWithRegistry(get_dialect_registry().get(),
/*threadingEnable=*/false)),
location(mlirLocationUnknownGet(ctx.get())), location(mlirLocationUnknownGet(ctx.get())),
mmodule(mlirModuleCreateEmpty(location)) mmodule(mlirModuleCreateEmpty(location))
{ {
MlirDialectRegistry registry = mlirDialectRegistryCreate(); mlirContextSetThreadPool(ctx.get(), get_thread_pool().get());
mlirRegisterRocMLIRDialects(registry);
mlirContextAppendDialectRegistry(ctx.get(), registry);
mlirContextLoadAllAvailableDialects(ctx.get()); mlirContextLoadAllAvailableDialects(ctx.get());
mlirDialectRegistryDestroy(registry); }
mlirContextSetAllowUnregisteredDialects(ctx.get(), true /*allow*/);
static mlir_dialect_registry& get_dialect_registry()
{
static std::once_flag init_guard;
static mlir_dialect_registry the_registry;
// The MLIR registration functions (for dialects and passes) are not
// necessarily thread-safe and need to be executed exactly once
// (especially since they eventually call non-thread-safe LLVM
// initilizations).
std::call_once(init_guard, [&]() {
the_registry = mlirDialectRegistryCreate();
mlirRegisterRocMLIRDialects(the_registry.get());
mlirRegisterRocMLIRPasses();
});
return the_registry;
}
static mlir_thread_pool& get_thread_pool()
{
// To save on overhead, we create one LLVM thread pool and reuse it
// across all MLIR contexts as recommended by MLIR upstream.
// Note that this is thread-safe as of C++11.
static mlir_thread_pool the_pool = mlirLlvmThreadPoolCreate();
return the_pool;
} }
MlirType make_type(shape::type_t t) const MlirType make_type(shape::type_t t) const
...@@ -244,8 +269,6 @@ struct mlir_program ...@@ -244,8 +269,6 @@ struct mlir_program
MlirAttribute attribute(std::int64_t i) const MlirAttribute attribute(std::int64_t i) const
{ {
if(i < 0)
MIGRAPHX_THROW("MLIR cant handle negative values since they are ambiguous");
return mlirIntegerAttrGet(mlirIntegerTypeGet(ctx.get(), 64), i); return mlirIntegerAttrGet(mlirIntegerTypeGet(ctx.get(), 64), i);
} }
MlirAttribute attribute(std::uint64_t i) const MlirAttribute attribute(std::uint64_t i) const
......
...@@ -77,7 +77,6 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS) ...@@ -77,7 +77,6 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_FAST_GELU)
struct id_pass struct id_pass
{ {
...@@ -126,7 +125,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -126,7 +125,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
inline_module{}, inline_module{},
rewrite_pooling{}, rewrite_pooling{},
dead_code_elimination{}, dead_code_elimination{},
enable_pass(not enabled(MIGRAPHX_DISABLE_FAST_GELU{}), rewrite_gelu{}), enable_pass(options.fast_math, rewrite_gelu{}),
optimize_module{}, optimize_module{},
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), layout_nhwc{}), enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), layout_nhwc{}),
dead_code_elimination{}, dead_code_elimination{},
......
...@@ -24,8 +24,6 @@ ...@@ -24,8 +24,6 @@
cmake_policy(SET CMP0057 NEW) cmake_policy(SET CMP0057 NEW)
include(CTest)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
include(ProcessorCount) include(ProcessorCount)
ProcessorCount(N) ProcessorCount(N)
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/program.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
#include <pointwise.hpp>
void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::gpu::fuse_mlir{}, migraphx::dead_code_elimination{}});
}
template <class F>
migraphx::instruction_ref add_mlir(migraphx::program& p,
const std::string& name,
std::vector<migraphx::instruction_ref> inputs,
std::vector<std::string> arg_names,
F f)
{
assert(inputs.size() == arg_names.size() && "One interior parameter name given per input.");
auto* mm = p.get_main_module();
auto* pm = p.create_module(name);
pm->set_bypass();
std::vector<migraphx::instruction_ref> params;
for(size_t i = 0, e = inputs.size(); i < e; ++i)
{
params.push_back(pm->add_parameter(arg_names[i], inputs[i]->get_shape()));
}
auto values = f(pm, params);
auto root = std::get<0>(values);
auto r = std::get<1>(values);
pm->add_return({r});
return mm->add_instruction(
migraphx::make_op("gpu::mlir_op", {{"op", migraphx::to_value(root->get_operator())}}),
inputs,
{pm});
}
TEST_CASE(dot_add)
{
migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto a = mm->add_parameter("a", s);
auto b = mm->add_parameter("b", s);
auto x = mm->add_parameter("x", s);
auto dot = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto add = add_pointwise(p1, "main:pointwise0", {dot, x}, single_pointwise("add"));
mm->add_return({add});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto a = mm->add_parameter("a", s);
auto b = mm->add_parameter("b", s);
auto x = mm->add_parameter("x", s);
auto fused =
add_mlir(p2,
"mlir_main:pointwise0",
{x, a, b},
{"x1", "y0", "y1"},
[=](auto* pm, const auto& inputs) {
auto dot =
pm->add_instruction(migraphx::make_op("dot"), inputs[1], inputs[2]);
auto add = pm->add_instruction(migraphx::make_op("add"), dot, inputs[0]);
return std::make_tuple(dot, add);
});
mm->add_return({fused});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(int_quant_dot_abs)
{
migraphx::shape s_a{migraphx::shape::int8_type, {5, 4}};
migraphx::shape s_b{migraphx::shape::int8_type, {4, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto a = mm->add_parameter("a", s_a);
auto b = mm->add_parameter("b", s_b);
auto dot = mm->add_instruction(migraphx::make_op("quant_dot"), a, b);
auto abs = add_pointwise(p1, "main:pointwise0", {dot}, single_pointwise("abs"));
mm->add_return({abs});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto a = mm->add_parameter("a", s_a);
auto b = mm->add_parameter("b", s_b);
auto fused = add_mlir(
p2, "mlir_main:pointwise0", {a, b}, {"y0", "y1"}, [=](auto* pm, const auto& inputs) {
auto dot =
pm->add_instruction(migraphx::make_op("quant_dot"), inputs[0], inputs[1]);
auto abs = pm->add_instruction(migraphx::make_op("abs"), dot);
return std::make_tuple(dot, abs);
});
mm->add_return({fused});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(int_quant_dot_tanh_fails)
{
migraphx::shape s_a{migraphx::shape::int8_type, {5, 4}};
migraphx::shape s_b{migraphx::shape::int8_type, {4, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto a = mm->add_parameter("a", s_a);
auto b = mm->add_parameter("b", s_b);
auto dot = mm->add_instruction(migraphx::make_op("quant_dot"), a, b);
auto tanh = add_pointwise(p1, "main:pointwise0", {dot}, single_pointwise("tanh"));
mm->add_return({tanh});
}
migraphx::program p2(p1);
// This pass should do nothing as int32_t tanh isn't supported.
run_pass(p1);
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[])
{
if(migraphx::gpu::mlir_enabled())
test::run(argc, argv);
return 0;
}
...@@ -187,12 +187,39 @@ module { ...@@ -187,12 +187,39 @@ module {
EXPECT(verify_mlir(m)); EXPECT(verify_mlir(m));
} }
TEST_CASE(quant_dot_add)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @main(%arg0: tensor<1x5x4xi8>, %arg1: tensor<1x4x3xi8>, %arg2: tensor<1x5x3xi32>) -> tensor<1x5x3xi32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.quant_dot(%arg0, %arg1) : (tensor<1x5x4xi8>, tensor<1x4x3xi8>) -> tensor<1x5x3xi32>
%1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xi32>, tensor<1x5x3xi32>) -> tensor<1x5x3xi32>
return %1 : tensor<1x5x3xi32>
}
}
)__migraphx__";
migraphx::module m;
auto arg0 = m.add_parameter("arg0", {migraphx::shape::int8_type, {1, 5, 4}});
auto arg1 = m.add_parameter("arg1", {migraphx::shape::int8_type, {1, 4, 3}});
auto arg2 = m.add_parameter("arg2", {migraphx::shape::int32_type, {1, 5, 3}});
auto conv = m.add_instruction(migraphx::make_op("quant_dot"), arg0, arg1);
auto add = m.add_instruction(migraphx::make_op("add"), conv, arg2);
m.add_return({add});
auto s = migraphx::gpu::dump_mlir(m);
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
EXPECT(verify_mlir(m));
}
TEST_CASE(dot_add) TEST_CASE(dot_add)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_dot(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr"} { func.func @mlir_dot(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.dot(%arg0, %arg1) : tensor<1x5x4xf32>, tensor<1x4x3xf32> -> tensor<1x5x3xf32> %0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32>
%1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32> %1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32>
return %1 : tensor<1x5x3xf32> return %1 : tensor<1x5x3xf32>
} }
...@@ -246,4 +273,57 @@ module { ...@@ -246,4 +273,57 @@ module {
EXPECT(verify_mlir(m)); EXPECT(verify_mlir(m));
} }
TEST_CASE(dot_convert)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_dot(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>) -> tensor<1x5x3xf16> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32>
%1 = migraphx.convert(%0) {target_type = 1 : i64} : (tensor<1x5x3xf32>) -> tensor<1x5x3xf16>
return %1 : tensor<1x5x3xf16>
}
}
)__migraphx__";
migraphx::module m;
auto arg0 = m.add_parameter("arg0", {migraphx::shape::float_type, {1, 5, 4}});
auto arg1 = m.add_parameter("arg1", {migraphx::shape::float_type, {1, 4, 3}});
auto dot = m.add_instruction(migraphx::make_op("dot"), arg0, arg1);
auto trunc = m.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), dot);
m.add_return({trunc});
auto s = migraphx::gpu::dump_mlir(m);
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
EXPECT(verify_mlir(m));
}
TEST_CASE(dot_where)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_dot(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xi8>, %arg3: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32>
%1 = migraphx.where(%arg2, %0, %arg3) : (tensor<1x5x3xi8>, tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32>
return %1 : tensor<1x5x3xf32>
}
}
)__migraphx__";
migraphx::module m;
auto arg0 = m.add_parameter("arg0", {migraphx::shape::float_type, {1, 5, 4}});
auto arg1 = m.add_parameter("arg1", {migraphx::shape::float_type, {1, 4, 3}});
auto arg2 = m.add_parameter("arg2", {migraphx::shape::bool_type, {1, 5, 3}});
auto arg3 = m.add_parameter("arg3", {migraphx::shape::float_type, {1, 5, 3}});
auto dot = m.add_instruction(migraphx::make_op("dot"), arg0, arg1);
auto where = m.add_instruction(migraphx::make_op("where"), arg2, dot, arg3);
m.add_return({where});
auto s = migraphx::gpu::dump_mlir(m);
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
EXPECT(verify_mlir(m));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
5a43828b3d73028bfd33b3856f82698d9ab02cb1 fbf08c4b4dce5da245189203d9f6cfc41f6663a2
...@@ -38,7 +38,7 @@ struct test_pad_highest : verify_program<test_pad_highest> ...@@ -38,7 +38,7 @@ struct test_pad_highest : verify_program<test_pad_highest>
migraphx::shape s0{migraphx::shape::half_type, {2, 2}}; migraphx::shape s0{migraphx::shape::half_type, {2, 2}};
auto l0 = mm->add_literal(migraphx::literal{s0, data0}); auto l0 = mm->add_literal(migraphx::literal{s0, data0});
migraphx::op::pad op{}; migraphx::op::pad op{};
op.value = std::numeric_limits<float>::max(); op.value = std::numeric_limits<migraphx::half>::max();
op.pads = {0, 0, 1, 1}; op.pads = {0, 0, 1, 1};
mm->add_instruction(op, l0); mm->add_instruction(op, l0);
return p; return p;
......
...@@ -38,7 +38,7 @@ struct test_pad_lowest : verify_program<test_pad_lowest> ...@@ -38,7 +38,7 @@ struct test_pad_lowest : verify_program<test_pad_lowest>
migraphx::shape s0{migraphx::shape::half_type, {2, 2}}; migraphx::shape s0{migraphx::shape::half_type, {2, 2}};
auto l0 = mm->add_literal(migraphx::literal{s0, data0}); auto l0 = mm->add_literal(migraphx::literal{s0, data0});
migraphx::op::pad op{}; migraphx::op::pad op{};
op.value = std::numeric_limits<float>::lowest(); op.value = std::numeric_limits<migraphx::half>::lowest();
op.pads = {0, 0, 1, 1}; op.pads = {0, 0, 1, 1};
mm->add_instruction(op, l0); mm->add_instruction(op, l0);
return p; return p;
......
FROM ubuntu:22.04
ARG PREFIX=/usr/local
# Support multiarch
RUN dpkg --add-architecture i386
# Install rocm key
RUN apt-get update && apt-get install -y gnupg2 --no-install-recommends curl && \
curl -fsSL http://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg
# Add rocm repository
RUN sh -c "echo 'deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] http://repo.radeon.com/rocm/apt/5.5 jammy main' > /etc/apt/sources.list.d/rocm.list"
# From docs.amd.com for installing rocm. Needed to install properly
RUN sh -c "echo 'Package: *\nPin: release o=repo.radeon.com\nPin-priority: 600' > /etc/apt/preferences.d/rocm-pin-600"
# Install dependencies
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
apt-utils \
build-essential \
#clang-format-10 \
cmake \
curl \
doxygen \
#g++-7 \
gdb \
git \
lcov \
locales \
pkg-config \
python3 \
python3-dev \
python3-pip \
software-properties-common \
wget \
rocm-device-libs \
hip-base \
libnuma-dev \
miopen-hip \
rocblas \
hipfft \
rocthrust \
rocrand \
hipsparse \
rccl \
rccl-dev \
rocm-smi-lib \
rocm-dev \
roctracer-dev \
hipcub \
hipblas \
hipify-clang \
half \
libssl-dev \
zlib1g-dev && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
# add this for roctracer dependancies
RUN pip3 install CppHeaderParser
# Workaround broken rocm packages
RUN ln -s /opt/rocm-* /opt/rocm
RUN echo "/opt/rocm/lib" > /etc/ld.so.conf.d/rocm.conf
RUN echo "/opt/rocm/llvm/lib" > /etc/ld.so.conf.d/rocm-llvm.conf
RUN ldconfig
RUN locale-gen en_US.UTF-8
RUN update-locale LANG=en_US.UTF-8
ENV LC_ALL=C.UTF-8
ENV LANG=C.UTF-8
# Install dependencies
ADD dev-requirements.txt /dev-requirements.txt
ADD requirements.txt /requirements.txt
ADD rbuild.ini /rbuild.ini
COPY ./tools/install_prereqs.sh /
RUN /install_prereqs.sh /usr/local / && rm /install_prereqs.sh
RUN test -f /usr/local/hash || exit 1
# Install yapf
RUN pip3 install yapf==0.28.0
# Install doc requirements
ADD doc/requirements.txt /doc-requirements.txt
RUN pip3 install -r /doc-requirements.txt
# Download real models to run onnx unit tests
ENV ONNX_HOME=/.onnx
COPY ./tools/download_models.sh /
RUN /download_models.sh && rm /download_models.sh
# Install latest ccache version
RUN cget -p $PREFIX install facebook/zstd@v1.4.5 -X subdir -DCMAKE_DIR=build/cmake
RUN cget -p $PREFIX install ccache@v4.1 -DENABLE_TESTING=OFF
RUN cget -p /opt/cmake install kitware/cmake@v3.24.3
COPY ./test/onnx/.onnxrt-commit /
ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime
ARG ONNXRUNTIME_BRANCH=main
ARG ONNXRUNTIME_COMMIT
RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime && \
cd onnxruntime && \
if [ -z "$ONNXRUNTIME_COMMIT" ] ; then git checkout $(cat /.onnxrt-commit) ; else git checkout ${ONNXRUNTIME_COMMIT} ; fi && \
/bin/sh /onnxruntime/dockerfiles/scripts/install_common_deps.sh
ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
RUN cget -p /usr/local install ROCmSoftwarePlatform/rocMLIR@a997d5f51314b45d7a4c04f1599966dcf53f9b4d -DBUILD_MIXR_TARGET=On -DLLVM_ENABLE_ZSTD=Off -DLLVM_ENABLE_THREADS=Off
ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
ENV LD_LIBRARY_PATH=$PREFIX/lib
# Setup ubsan environment to printstacktrace
ENV UBSAN_OPTIONS=print_stacktrace=1
ENV ASAN_OPTIONS=detect_stack_use_after_return=1:check_initialization_order=1:strict_init_order=1
RUN ln -s /opt/rocm/llvm/bin/llvm-symbolizer /usr/bin/llvm-symbolizer
...@@ -56,8 +56,9 @@ echo "Dependencies are installed at $PREFIX" ...@@ -56,8 +56,9 @@ echo "Dependencies are installed at $PREFIX"
# Install deps with rbuild # Install deps with rbuild
rbuild prepare -d $PREFIX -s develop rbuild prepare -d $PREFIX -s develop
# install onnx package for unit tests export CMAKE_ARGS="-DONNX_USE_PROTOBUF_SHARED_LIBS=ON"
pip3 install onnx==1.10.2 numpy==1.21.6 typing==3.7.4 pytest==6.0.1 packaging==23.0 pip3 install onnx==1.10.2 numpy==1.21.6 typing==3.7.4 pytest==6.0.1 packaging==23.0
# pin version of protobuf in Python for onnx runtime unit tests # pin version of protobuf in Python for onnx runtime unit tests between dist versions
pip3 install protobuf==3.20.0 pip3 install protobuf==3.20.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment