Commit f1c8e6c9 authored by turneram's avatar turneram
Browse files

Merge remote-tracking branch 'origin/develop' into ck-integration-tuning

parents d09b7682 c1b8c975
......@@ -44,7 +44,7 @@ static const char* const pointwise_kernel = R"__migraphx__(
namespace migraphx {
extern "C" {
__global__ void pad_kernel(void* input_p, void* output_p)
MIGRAPHX_GLOBAL void pad_kernel(void* input_p, void* output_p)
{
auto offsets = index_ints<${offsets}>{};
auto idx = make_index();
......
......@@ -44,7 +44,7 @@ namespace migraphx {
${preamble}
extern "C" {
__global__ void ${kernel}(${params})
MIGRAPHX_GLOBAL void ${kernel}(${params})
{
auto idx = make_index();
pointwise(idx, ${transformers})(${lambda}, ${args});
......
......@@ -45,7 +45,7 @@ namespace migraphx {
${preamble}
extern "C" {
__global__ void reduce_kernel(void* input_p, void* output_p)
MIGRAPHX_GLOBAL void reduce_kernel(void* input_p, void* output_p)
{
transform_args(make_tensors(), ${transformers})(input_p, output_p)([](auto input, auto output) {
......
......@@ -41,7 +41,7 @@ namespace migraphx {
extern "C" {
__global__ void roialign_kernel(void* in_x, void* in_rois, void* in_ind, void* y)
MIGRAPHX_GLOBAL void roialign_kernel(void* in_x, void* in_rois, void* in_ind, void* y)
{
make_tensors()(in_x, in_rois, in_ind, y)([](auto&&... xs) {
auto settings = make_roalign_settings(MIGRAPHX_MAKE_CONSTANT(float{ROIS_OFFSET}),
......
......@@ -42,7 +42,7 @@ namespace migraphx {
extern "C" {
__global__ void scatternd_kernel(void* in_indices, void* in_updates, void* output)
MIGRAPHX_GLOBAL void scatternd_kernel(void* in_indices, void* in_updates, void* output)
{
make_tensors()(in_indices, in_updates, output)([](auto&&... xs) {
scatternd(xs..., ${reduction}{});
......
......@@ -45,7 +45,7 @@ static const char* const softmax_kernel = R"__migraphx__(
namespace migraphx {
extern "C" {
__global__ void softmax_kernel(void* input_p, void* output_p)
MIGRAPHX_GLOBAL void softmax_kernel(void* input_p, void* output_p)
{
transform_args(make_tensors(), ${transformers})(input_p, output_p)([](auto input, auto output) {
softmax<${axis}>(input, output);
......
......@@ -52,7 +52,7 @@ __device__ void ck_gemm_matrix(E e, A a, B b, Ds... ds)
ck::make_tuple(to_ck_tensor<Ds>()...),
to_ck_tensor<E>());
static_assert(desc.is_valid, "Invalid ck gemm.");
static_assert(desc.IsValid(), "Invalid ck gemm.");
G::Run(desc,
to_ck_const_pointer(a.data()),
......
......@@ -22,12 +22,19 @@
* THE SOFTWARE.
*/
#include <iterator>
#include <migraphx/gpu/lowering.hpp>
#include <utility>
#include <functional>
#include <algorithm>
#include <map>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/program.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/if_op.hpp>
......@@ -35,17 +42,12 @@
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/program.hpp>
#include <utility>
#include <functional>
#include <algorithm>
#include <map>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -53,8 +55,9 @@ namespace gpu {
struct miopen_apply
{
module* mod = nullptr;
const lowering* pass = nullptr;
module* mod = nullptr;
module_pass_manager* mpm = nullptr;
const lowering* pass = nullptr;
std::unordered_map<std::string, std::function<instruction_ref(instruction_ref)>> apply_map{};
instruction_ref last{};
bool offload_copy = false;
......@@ -83,8 +86,7 @@ struct miopen_apply
auto& ctx = get_context();
int8_x4_format = get_int8_x4_format(ctx);
compute_fp32 = get_compute_fp32_flag();
// TODO: Set Offload copy based on root modules' compile options
offload_copy = (mod->name() == "main") ? pass->offload_copy : false;
offload_copy = (mod == mpm->get_root_module()) ? pass->offload_copy : false;
add_generic_op("contiguous");
......@@ -376,7 +378,10 @@ struct miopen_apply
}
};
void lowering::apply(module& m) const { miopen_apply{&m, this}.apply(); }
void lowering::apply(module_pass_manager& mpm) const
{
miopen_apply{&mpm.get_module(), &mpm, this}.apply();
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -121,7 +121,10 @@ struct mlir_handle
#define MIGRAPHX_MANAGE_MLIR_HANDLE(T, F) migraphx::gpu::mlir_handle<T, decltype(&F), &F> // NOLINT
using mlir_context = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirContext, mlirContextDestroy);
using mlir_context = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirContext, mlirContextDestroy);
using mlir_thread_pool = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirLlvmThreadPool, mlirLlvmThreadPoolDestroy);
using mlir_dialect_registry = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirDialectRegistry,
mlirDialectRegistryDestroy);
using mlir_module = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirModule, mlirModuleDestroy);
using mlir_operation = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirOperation, mlirOperationDestroy);
using mlir_op_printing_flags = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirOpPrintingFlags,
......@@ -173,16 +176,38 @@ bool has_xdlops(const std::string& target_arch)
struct mlir_program
{
mlir_program()
: ctx(mlirContextCreate()),
: ctx(mlirContextCreateWithRegistry(get_dialect_registry().get(),
/*threadingEnable=*/false)),
location(mlirLocationUnknownGet(ctx.get())),
mmodule(mlirModuleCreateEmpty(location))
{
MlirDialectRegistry registry = mlirDialectRegistryCreate();
mlirRegisterRocMLIRDialects(registry);
mlirContextAppendDialectRegistry(ctx.get(), registry);
mlirContextSetThreadPool(ctx.get(), get_thread_pool().get());
mlirContextLoadAllAvailableDialects(ctx.get());
mlirDialectRegistryDestroy(registry);
mlirContextSetAllowUnregisteredDialects(ctx.get(), true /*allow*/);
}
static mlir_dialect_registry& get_dialect_registry()
{
static std::once_flag init_guard;
static mlir_dialect_registry the_registry;
// The MLIR registration functions (for dialects and passes) are not
// necessarily thread-safe and need to be executed exactly once
// (especially since they eventually call non-thread-safe LLVM
// initilizations).
std::call_once(init_guard, [&]() {
the_registry = mlirDialectRegistryCreate();
mlirRegisterRocMLIRDialects(the_registry.get());
mlirRegisterRocMLIRPasses();
});
return the_registry;
}
static mlir_thread_pool& get_thread_pool()
{
// To save on overhead, we create one LLVM thread pool and reuse it
// across all MLIR contexts as recommended by MLIR upstream.
// Note that this is thread-safe as of C++11.
static mlir_thread_pool the_pool = mlirLlvmThreadPoolCreate();
return the_pool;
}
MlirType make_type(shape::type_t t) const
......@@ -244,8 +269,6 @@ struct mlir_program
MlirAttribute attribute(std::int64_t i) const
{
if(i < 0)
MIGRAPHX_THROW("MLIR cant handle negative values since they are ambiguous");
return mlirIntegerAttrGet(mlirIntegerTypeGet(ctx.get(), 64), i);
}
MlirAttribute attribute(std::uint64_t i) const
......
......@@ -76,7 +76,6 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_FAST_GELU)
struct id_pass
{
......@@ -125,7 +124,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
inline_module{},
rewrite_pooling{},
dead_code_elimination{},
enable_pass(not enabled(MIGRAPHX_DISABLE_FAST_GELU{}), rewrite_gelu{}),
enable_pass(options.fast_math, rewrite_gelu{}),
optimize_module{},
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), layout_nhwc{}),
dead_code_elimination{},
......
......@@ -24,8 +24,6 @@
cmake_policy(SET CMP0057 NEW)
include(CTest)
find_package(Threads REQUIRED)
include(ProcessorCount)
ProcessorCount(N)
......
......@@ -30,7 +30,7 @@ void expect_equal(const char* x, const char* y)
abort();
}
int main()
int main(void)
{
char name[1024];
migraphx_operation_t op;
......
......@@ -41,7 +41,7 @@ TEST_CASE(simple_test)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(migraphx::make_op("add"), one, two);
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == count);
......@@ -57,7 +57,7 @@ TEST_CASE(simple_test_nop)
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(nop{});
mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(migraphx::make_op("add"), one, two);
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == count);
......@@ -73,7 +73,7 @@ TEST_CASE(simple_test_nop2)
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(nop{});
mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(nop{});
run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == 2);
......@@ -88,8 +88,8 @@ TEST_CASE(duplicate_test1)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(migraphx::make_op("add"), one, two);
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == (count - 1));
......@@ -104,9 +104,9 @@ TEST_CASE(duplicate_test2)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(minus_op{}, one, two);
mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(migraphx::make_op("sub"), one, two);
mm->add_instruction(migraphx::make_op("add"), one, two);
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == (count - 2));
......@@ -121,11 +121,11 @@ TEST_CASE(depth_test)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto x1 = mm->add_instruction(sum_op{}, one, two);
auto x2 = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(minus_op{}, x1, x2);
mm->add_instruction(minus_op{}, x1, x2);
mm->add_instruction(sum_op{}, one, two);
auto x1 = mm->add_instruction(migraphx::make_op("add"), one, two);
auto x2 = mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(migraphx::make_op("sub"), x1, x2);
mm->add_instruction(migraphx::make_op("sub"), x1, x2);
mm->add_instruction(migraphx::make_op("add"), one, two);
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == (count - 4));
......@@ -141,7 +141,7 @@ TEST_CASE(undefined_test)
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(migraphx::make_op("add"), one, two);
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == count - 1);
......@@ -232,7 +232,6 @@ TEST_CASE(reused_twice)
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
p.debug_print();
EXPECT(std::distance(mm->begin(), mm->end()) != count);
EXPECT(std::distance(mm->begin(), mm->end()) == 4);
}
......@@ -274,4 +273,17 @@ TEST_CASE(param_not_eliminated)
EXPECT(p == create_program());
}
TEST_CASE(tuple_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(tuple_op{}, one, two);
mm->add_return({one, two});
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == (count - 1));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -45,7 +45,7 @@ TEST_CASE(simple_test)
auto one_identity = mm->add_instruction(migraphx::make_op("identity"), one);
auto two = mm->add_literal(2);
auto two_identity = mm->add_instruction(migraphx::make_op("identity"), two);
mm->add_instruction(sum_op{}, one_identity, two_identity);
mm->add_instruction(migraphx::make_op("add"), one_identity, two_identity);
run_pass(p);
EXPECT(std::none_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity";
......@@ -62,7 +62,7 @@ TEST_CASE(simple_test_end)
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto ans = mm->add_instruction(sum_op{}, one, two);
auto ans = mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(migraphx::make_op("identity"), ans);
run_pass(p);
EXPECT(std::none_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) {
......@@ -81,8 +81,8 @@ TEST_CASE(simple_test_end_dependency)
auto one = mm->add_literal(1.0);
auto two = mm->add_literal(2.0);
auto three = mm->add_literal(3.0);
auto ans = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(sum_op{}, ans, three);
auto ans = mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(migraphx::make_op("add"), ans, three);
mm->add_instruction(migraphx::make_op("identity"), ans);
run_pass(p);
EXPECT(std::any_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) {
......
......@@ -27,6 +27,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/compile_options.hpp>
#include <migraphx/make_op.hpp>
#include <sstream>
#include "test.hpp"
#include <basic_ops.hpp>
......@@ -49,7 +50,7 @@ struct id_target
struct id_ctx_op
{
std::string name() const { return "id_ctx_op"; }
std::string name() const { return ""; }
migraphx::argument
compute(id_target::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
{
......@@ -156,7 +157,7 @@ TEST_CASE(literal_test1)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(migraphx::make_op("add"), one, two);
auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3});
EXPECT(result != migraphx::literal{4});
......@@ -168,8 +169,8 @@ TEST_CASE(literal_test2)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(sum_op{}, sum1, two);
auto sum1 = mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(migraphx::make_op("add"), sum1, two);
auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{5});
......@@ -182,7 +183,7 @@ TEST_CASE(print_test)
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::int32_type});
auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, x, two);
mm->add_instruction(migraphx::make_op("add"), x, two);
std::stringstream ss;
ss << p;
......@@ -197,7 +198,7 @@ TEST_CASE(param_test)
auto x = mm->add_parameter("x", {migraphx::shape::int32_type});
auto y = mm->add_parameter("y", {migraphx::shape::int32_type});
mm->add_instruction(sum_op{}, x, y);
mm->add_instruction(migraphx::make_op("add"), x, y);
auto result = p.eval({{"x", migraphx::literal{1}.get_argument()},
{"y", migraphx::literal{2}.get_argument()}})
.back();
......@@ -227,7 +228,7 @@ TEST_CASE(param_error_shape_test)
auto x = mm->add_parameter("x", {migraphx::shape::int32_type, {1, 1}});
auto y = mm->add_parameter("y", {migraphx::shape::int32_type, {1, 1}});
mm->add_instruction(sum_op{}, x, y);
mm->add_instruction(migraphx::make_op("add"), x, y);
EXPECT(test::throws<migraphx::exception>(
[&] {
p.eval({
......@@ -245,7 +246,7 @@ TEST_CASE(get_param1)
migraphx::shape s{migraphx::shape::int32_type, {1, 2}};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
mm->add_instruction(sum_op{}, x, y);
mm->add_instruction(migraphx::make_op("add"), x, y);
EXPECT(bool{p.get_parameter("x") == x});
EXPECT(bool{p.get_parameter("y") == y});
EXPECT(bool{p.get_parameter("nonexistent") == mm->end()});
......@@ -257,7 +258,7 @@ TEST_CASE(get_param2)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(migraphx::make_op("add"), one, two);
EXPECT(bool{p.get_parameter("nonexistent") == mm->end()});
}
......@@ -268,7 +269,7 @@ TEST_CASE(get_param_shapes)
migraphx::shape s{migraphx::shape::int32_type, {1, 2}};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
mm->add_instruction(sum_op{}, x, y);
mm->add_instruction(migraphx::make_op("add"), x, y);
auto m = p.get_parameter_shapes();
EXPECT(m.count("nonexistent") == 0);
EXPECT(m.at("x") == s);
......@@ -281,8 +282,8 @@ TEST_CASE(replace_test)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->replace_instruction(sum, minus_op{}, two, one);
auto sum = mm->add_instruction(migraphx::make_op("add"), one, two);
mm->replace_instruction(sum, migraphx::make_op("sub"), two, one);
EXPECT(bool{p.validate() == mm->end()});
auto result = p.eval({}).back();
......@@ -296,8 +297,8 @@ TEST_CASE(replace_ins_test)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
auto minus = mm->add_instruction(minus_op{}, two, one);
auto sum = mm->add_instruction(migraphx::make_op("add"), one, two);
auto minus = mm->add_instruction(migraphx::make_op("sub"), two, one);
mm->replace_instruction(sum, minus);
EXPECT(bool{p.validate() == mm->end()});
......@@ -312,8 +313,8 @@ TEST_CASE(replace_ins_test2)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
auto minus = mm->add_instruction(minus_op{}, two, one);
auto sum = mm->add_instruction(migraphx::make_op("add"), one, two);
auto minus = mm->add_instruction(migraphx::make_op("sub"), two, one);
mm->add_instruction(pass_op{}, minus);
mm->replace_instruction(two, sum);
EXPECT(bool{p.validate() == mm->end()});
......@@ -329,8 +330,8 @@ TEST_CASE(replace_op_test)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, two, one);
sum->replace(minus_op{});
auto sum = mm->add_instruction(migraphx::make_op("add"), two, one);
sum->replace(migraphx::make_op("sub"));
EXPECT(bool{p.validate() == mm->end()});
auto result = p.eval({}).back();
......@@ -344,7 +345,7 @@ TEST_CASE(replace_op_recompute_shape_throw)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
auto sum = mm->add_instruction(migraphx::make_op("add"), one, two);
EXPECT(test::throws<migraphx::exception>([&] { sum->replace(unary_pass_op{}); }));
}
......@@ -354,11 +355,11 @@ TEST_CASE(insert_replace_test)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(sum_op{}, sum1, two);
auto sum1 = mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(migraphx::make_op("add"), sum1, two);
auto sum0 = mm->insert_instruction(sum1, sum_op{}, two, two);
mm->replace_instruction(sum1, minus_op{}, sum0, two);
auto sum0 = mm->insert_instruction(sum1, migraphx::make_op("add"), two, two);
mm->replace_instruction(sum1, migraphx::make_op("sub"), sum0, two);
EXPECT(bool{p.validate() == mm->end()});
auto result = p.eval({}).back();
......@@ -372,8 +373,8 @@ TEST_CASE(remove_test1)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
auto removed = mm->add_instruction(minus_op{}, sum, one);
auto sum = mm->add_instruction(migraphx::make_op("add"), one, two);
auto removed = mm->add_instruction(migraphx::make_op("sub"), sum, one);
mm->remove_instruction(removed);
EXPECT(bool{p.validate() == mm->end()});
......@@ -388,8 +389,8 @@ TEST_CASE(remove_test2)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto removed = mm->add_instruction(minus_op{}, two, one);
mm->add_instruction(sum_op{}, one, two);
auto removed = mm->add_instruction(migraphx::make_op("sub"), two, one);
mm->add_instruction(migraphx::make_op("add"), one, two);
mm->remove_instruction(removed);
EXPECT(bool{p.validate() == mm->end()});
......@@ -404,7 +405,7 @@ TEST_CASE(target_test)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(migraphx::make_op("add"), one, two);
p.compile(id_target{});
auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3});
......@@ -460,7 +461,7 @@ TEST_CASE(eval_context1)
mm->add_instruction(sum_op{}, one, two);
p.compile(t);
EXPECT(is_shared(t.ctx, p.get_context()));
p.eval({}).back();
std::ignore = p.eval({}).back();
EXPECT(is_shared(t.ctx, p.get_context()));
}
......@@ -475,7 +476,7 @@ TEST_CASE(eval_context2)
mm->add_instruction(id_ctx_op{}, one, two);
p.compile(t);
EXPECT(is_shared(t.ctx, p.get_context()));
p.eval({}).back();
std::ignore = p.eval({}).back();
// id_ctx_op will modify the context
EXPECT(not is_shared(t.ctx, p.get_context()));
}
......@@ -492,8 +493,8 @@ TEST_CASE(eval_context3)
p.compile(t);
// Finalizer will modify the context
EXPECT(not is_shared(t.ctx, p.get_context()));
auto ctx = p.get_context();
p.eval({}).back();
auto ctx = p.get_context();
std::ignore = p.eval({}).back();
EXPECT(is_shared(ctx, p.get_context()));
EXPECT(not is_shared(t.ctx, p.get_context()));
}
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
#include <migraphx/make_op.hpp>
void run_prog(migraphx::program p,
const migraphx::target& t,
migraphx::parameter_map& m_in,
std::vector<float>& res)
{
p.compile(t);
migraphx::parameter_map m;
for(auto&& x : p.get_parameter_shapes())
{
if(m_in.count(x.first) > 0)
{
m[x.first] = t.copy_to(m_in[x.first]);
}
else
{
m[x.first] = t.allocate(x.second);
}
}
auto result = t.copy_from(p.eval(m).back());
result.visit([&](auto v) { res.assign(v.begin(), v.end()); });
}
// This test ensures that the codegen path doesn't round up literals,
// otherwise there are accuracy differences compared to ref.
// The values being passed in are 0.5 * (1/0.00787402),
// and after rounding must equal 63, not 64.
TEST_CASE(mul_literal_round_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::float_type, {1}};
auto l0 = mm->add_parameter("a", s0);
auto l1 = mm->add_literal(1 / 0.00787402f);
auto mul = mm->add_instruction(migraphx::make_op("mul"), l0, l1);
auto round = mm->add_instruction(migraphx::make_op("round"), mul);
mm->add_return({round});
migraphx::parameter_map m;
std::vector<float> a = {0.5f};
m["a"] = migraphx::argument{s0, a.data()};
std::vector<float> ref_result;
migraphx::target ref_t = migraphx::make_target("ref");
run_prog(p, ref_t, m, ref_result);
std::vector<float> gpu_result;
migraphx::target gpu_t = migraphx::make_target("gpu");
run_prog(p, gpu_t, m, gpu_result);
EXPECT(migraphx::verify_range(ref_result, gpu_result));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/program.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
#include <pointwise.hpp>
void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::gpu::fuse_mlir{}, migraphx::dead_code_elimination{}});
}
template <class F>
migraphx::instruction_ref add_mlir(migraphx::program& p,
const std::string& name,
std::vector<migraphx::instruction_ref> inputs,
std::vector<std::string> arg_names,
F f)
{
assert(inputs.size() == arg_names.size() && "One interior parameter name given per input.");
auto* mm = p.get_main_module();
auto* pm = p.create_module(name);
pm->set_bypass();
std::vector<migraphx::instruction_ref> params;
for(size_t i = 0, e = inputs.size(); i < e; ++i)
{
params.push_back(pm->add_parameter(arg_names[i], inputs[i]->get_shape()));
}
auto values = f(pm, params);
auto root = std::get<0>(values);
auto r = std::get<1>(values);
pm->add_return({r});
return mm->add_instruction(
migraphx::make_op("gpu::mlir_op", {{"op", migraphx::to_value(root->get_operator())}}),
inputs,
{pm});
}
TEST_CASE(dot_add)
{
migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto a = mm->add_parameter("a", s);
auto b = mm->add_parameter("b", s);
auto x = mm->add_parameter("x", s);
auto dot = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto add = add_pointwise(p1, "main:pointwise0", {dot, x}, single_pointwise("add"));
mm->add_return({add});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto a = mm->add_parameter("a", s);
auto b = mm->add_parameter("b", s);
auto x = mm->add_parameter("x", s);
auto fused =
add_mlir(p2,
"mlir_main:pointwise0",
{x, a, b},
{"x1", "y0", "y1"},
[=](auto* pm, const auto& inputs) {
auto dot =
pm->add_instruction(migraphx::make_op("dot"), inputs[1], inputs[2]);
auto add = pm->add_instruction(migraphx::make_op("add"), dot, inputs[0]);
return std::make_tuple(dot, add);
});
mm->add_return({fused});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(int_quant_dot_abs)
{
migraphx::shape s_a{migraphx::shape::int8_type, {5, 4}};
migraphx::shape s_b{migraphx::shape::int8_type, {4, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto a = mm->add_parameter("a", s_a);
auto b = mm->add_parameter("b", s_b);
auto dot = mm->add_instruction(migraphx::make_op("quant_dot"), a, b);
auto abs = add_pointwise(p1, "main:pointwise0", {dot}, single_pointwise("abs"));
mm->add_return({abs});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto a = mm->add_parameter("a", s_a);
auto b = mm->add_parameter("b", s_b);
auto fused = add_mlir(
p2, "mlir_main:pointwise0", {a, b}, {"y0", "y1"}, [=](auto* pm, const auto& inputs) {
auto dot =
pm->add_instruction(migraphx::make_op("quant_dot"), inputs[0], inputs[1]);
auto abs = pm->add_instruction(migraphx::make_op("abs"), dot);
return std::make_tuple(dot, abs);
});
mm->add_return({fused});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(int_quant_dot_tanh_fails)
{
migraphx::shape s_a{migraphx::shape::int8_type, {5, 4}};
migraphx::shape s_b{migraphx::shape::int8_type, {4, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto a = mm->add_parameter("a", s_a);
auto b = mm->add_parameter("b", s_b);
auto dot = mm->add_instruction(migraphx::make_op("quant_dot"), a, b);
auto tanh = add_pointwise(p1, "main:pointwise0", {dot}, single_pointwise("tanh"));
mm->add_return({tanh});
}
migraphx::program p2(p1);
// This pass should do nothing as int32_t tanh isn't supported.
run_pass(p1);
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[])
{
if(migraphx::gpu::mlir_enabled())
test::run(argc, argv);
return 0;
}
......@@ -187,12 +187,39 @@ module {
EXPECT(verify_mlir(m));
}
TEST_CASE(quant_dot_add)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @main(%arg0: tensor<1x5x4xi8>, %arg1: tensor<1x4x3xi8>, %arg2: tensor<1x5x3xi32>) -> tensor<1x5x3xi32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.quant_dot(%arg0, %arg1) : (tensor<1x5x4xi8>, tensor<1x4x3xi8>) -> tensor<1x5x3xi32>
%1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xi32>, tensor<1x5x3xi32>) -> tensor<1x5x3xi32>
return %1 : tensor<1x5x3xi32>
}
}
)__migraphx__";
migraphx::module m;
auto arg0 = m.add_parameter("arg0", {migraphx::shape::int8_type, {1, 5, 4}});
auto arg1 = m.add_parameter("arg1", {migraphx::shape::int8_type, {1, 4, 3}});
auto arg2 = m.add_parameter("arg2", {migraphx::shape::int32_type, {1, 5, 3}});
auto conv = m.add_instruction(migraphx::make_op("quant_dot"), arg0, arg1);
auto add = m.add_instruction(migraphx::make_op("add"), conv, arg2);
m.add_return({add});
auto s = migraphx::gpu::dump_mlir(m);
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
EXPECT(verify_mlir(m));
}
TEST_CASE(dot_add)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_dot(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.dot(%arg0, %arg1) : tensor<1x5x4xf32>, tensor<1x4x3xf32> -> tensor<1x5x3xf32>
%0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32>
%1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32>
return %1 : tensor<1x5x3xf32>
}
......@@ -246,4 +273,57 @@ module {
EXPECT(verify_mlir(m));
}
TEST_CASE(dot_convert)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_dot(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>) -> tensor<1x5x3xf16> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32>
%1 = migraphx.convert(%0) {target_type = 1 : i64} : (tensor<1x5x3xf32>) -> tensor<1x5x3xf16>
return %1 : tensor<1x5x3xf16>
}
}
)__migraphx__";
migraphx::module m;
auto arg0 = m.add_parameter("arg0", {migraphx::shape::float_type, {1, 5, 4}});
auto arg1 = m.add_parameter("arg1", {migraphx::shape::float_type, {1, 4, 3}});
auto dot = m.add_instruction(migraphx::make_op("dot"), arg0, arg1);
auto trunc = m.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), dot);
m.add_return({trunc});
auto s = migraphx::gpu::dump_mlir(m);
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
EXPECT(verify_mlir(m));
}
TEST_CASE(dot_where)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_dot(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xi8>, %arg3: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32>
%1 = migraphx.where(%arg2, %0, %arg3) : (tensor<1x5x3xi8>, tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32>
return %1 : tensor<1x5x3xf32>
}
}
)__migraphx__";
migraphx::module m;
auto arg0 = m.add_parameter("arg0", {migraphx::shape::float_type, {1, 5, 4}});
auto arg1 = m.add_parameter("arg1", {migraphx::shape::float_type, {1, 4, 3}});
auto arg2 = m.add_parameter("arg2", {migraphx::shape::bool_type, {1, 5, 3}});
auto arg3 = m.add_parameter("arg3", {migraphx::shape::float_type, {1, 5, 3}});
auto dot = m.add_instruction(migraphx::make_op("dot"), arg0, arg1);
auto where = m.add_instruction(migraphx::make_op("where"), arg2, dot, arg3);
m.add_return({where});
auto s = migraphx::gpu::dump_mlir(m);
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
EXPECT(verify_mlir(m));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -86,6 +86,25 @@ struct minus_op
};
struct pass_op
{
std::string name() const { return "pass"; }
migraphx::argument compute(const migraphx::shape&, std::vector<migraphx::argument> args) const
{
if(args.empty())
return {};
return args.front();
}
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{
if(inputs.empty())
return {};
return inputs.front();
}
int output_alias(const std::vector<migraphx::shape>& s) const { return s.empty() ? -1 : 0; }
};
struct non_const_pass_op
{
std::string name() const { return "pass"; }
migraphx::argument
......@@ -176,9 +195,7 @@ struct pass_standard_op
struct nop
{
std::string name() const { return "nop"; }
migraphx::argument compute(migraphx::context&,
const migraphx::shape&,
const std::vector<migraphx::argument>&) const
migraphx::argument compute(const migraphx::shape&, const std::vector<migraphx::argument>&) const
{
return {};
}
......@@ -186,6 +203,21 @@ struct nop
migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const { return {}; }
};
struct tuple_op
{
std::string name() const { return "tuple_op"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
{
return {inputs};
}
migraphx::argument compute(migraphx::context&,
const migraphx::shape&,
const std::vector<migraphx::argument>& input_args) const
{
return input_args;
}
};
inline migraphx::literal get_2x2(int base = 0)
{
return migraphx::literal{{migraphx::shape::float_type, {2, 2}},
......
......@@ -177,4 +177,10 @@ TEST_CASE(value_literal)
EXPECT(l4 == l2);
}
TEST_CASE(literal_to_string_float_precision)
{
migraphx::literal x{126.99993142003703f};
EXPECT(x.to_string() != "127");
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment