Unverified Commit 9c91c08d authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into enable_navi_32_ci

parents a56bb11d c1b8c975
...@@ -130,6 +130,8 @@ struct index ...@@ -130,6 +130,8 @@ struct index
return blockDim.x; return blockDim.x;
} }
#endif #endif
constexpr auto ngroup() const { return nglobal() / max_nlocal(); }
template <class N, class Stride> template <class N, class Stride>
static constexpr auto max_stride_iterations(N n, Stride stride) static constexpr auto max_stride_iterations(N n, Stride stride)
{ {
...@@ -231,6 +233,12 @@ struct index ...@@ -231,6 +233,12 @@ struct index
{ {
for_stride<true>(local, n, nlocal(), f); for_stride<true>(local, n, nlocal(), f);
} }
template <class F, class N>
__device__ void group_stride(N n, F f) const
{
for_stride<false>(group, n, ngroup(), f);
}
}; };
#ifdef MIGRAPHX_NLOCAL #ifdef MIGRAPHX_NLOCAL
......
...@@ -188,10 +188,8 @@ MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, max, ::max) ...@@ -188,10 +188,8 @@ MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, max, ::max)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, min, ::min) MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, min, ::min)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, max, ::max) MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, max, ::max)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, min, ::min) MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, min, ::min)
// Add overloads for half that calls the float version, this should use "hmax" and "hmin" once MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, max, ::__hmax)
// perf CI docker is upgraded to rocm-5.5 MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, min, ::__hmin)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, max, ::fmaxf)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, min, ::fminf)
template <class T, MIGRAPHX_REQUIRES(not is_any_vec<T>())> template <class T, MIGRAPHX_REQUIRES(not is_any_vec<T>())>
constexpr auto max(const T& a, const T& b) constexpr auto max(const T& a, const T& b)
......
...@@ -22,12 +22,19 @@ ...@@ -22,12 +22,19 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <iterator> #include <iterator>
#include <migraphx/gpu/lowering.hpp> #include <utility>
#include <functional>
#include <algorithm>
#include <map>
#include <migraphx/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/program.hpp>
#include <migraphx/op/dot.hpp> #include <migraphx/op/dot.hpp>
#include <migraphx/op/if_op.hpp> #include <migraphx/op/if_op.hpp>
...@@ -35,17 +42,12 @@ ...@@ -35,17 +42,12 @@
#include <migraphx/op/quant_dot.hpp> #include <migraphx/op/quant_dot.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/device_name.hpp> #include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/gemm.hpp> #include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/rocblas.hpp> #include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/compiler.hpp> #include <migraphx/gpu/compiler.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/program.hpp>
#include <utility>
#include <functional>
#include <algorithm>
#include <map>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -54,6 +56,7 @@ namespace gpu { ...@@ -54,6 +56,7 @@ namespace gpu {
struct miopen_apply struct miopen_apply
{ {
module* mod = nullptr; module* mod = nullptr;
module_pass_manager* mpm = nullptr;
const lowering* pass = nullptr; const lowering* pass = nullptr;
std::unordered_map<std::string, std::function<instruction_ref(instruction_ref)>> apply_map{}; std::unordered_map<std::string, std::function<instruction_ref(instruction_ref)>> apply_map{};
instruction_ref last{}; instruction_ref last{};
...@@ -83,8 +86,7 @@ struct miopen_apply ...@@ -83,8 +86,7 @@ struct miopen_apply
auto& ctx = get_context(); auto& ctx = get_context();
int8_x4_format = get_int8_x4_format(ctx); int8_x4_format = get_int8_x4_format(ctx);
compute_fp32 = get_compute_fp32_flag(); compute_fp32 = get_compute_fp32_flag();
// TODO: Set Offload copy based on root modules' compile options offload_copy = (mod == mpm->get_root_module()) ? pass->offload_copy : false;
offload_copy = (mod->name() == "main") ? pass->offload_copy : false;
add_generic_op("contiguous"); add_generic_op("contiguous");
...@@ -376,7 +378,10 @@ struct miopen_apply ...@@ -376,7 +378,10 @@ struct miopen_apply
} }
}; };
void lowering::apply(module& m) const { miopen_apply{&m, this}.apply(); } void lowering::apply(module_pass_manager& mpm) const
{
miopen_apply{&mpm.get_module(), &mpm, this}.apply();
}
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -122,6 +122,9 @@ struct mlir_handle ...@@ -122,6 +122,9 @@ struct mlir_handle
#define MIGRAPHX_MANAGE_MLIR_HANDLE(T, F) migraphx::gpu::mlir_handle<T, decltype(&F), &F> // NOLINT #define MIGRAPHX_MANAGE_MLIR_HANDLE(T, F) migraphx::gpu::mlir_handle<T, decltype(&F), &F> // NOLINT
using mlir_context = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirContext, mlirContextDestroy); using mlir_context = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirContext, mlirContextDestroy);
using mlir_thread_pool = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirLlvmThreadPool, mlirLlvmThreadPoolDestroy);
using mlir_dialect_registry = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirDialectRegistry,
mlirDialectRegistryDestroy);
using mlir_module = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirModule, mlirModuleDestroy); using mlir_module = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirModule, mlirModuleDestroy);
using mlir_operation = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirOperation, mlirOperationDestroy); using mlir_operation = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirOperation, mlirOperationDestroy);
using mlir_op_printing_flags = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirOpPrintingFlags, using mlir_op_printing_flags = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirOpPrintingFlags,
...@@ -173,16 +176,38 @@ bool has_xdlops(const std::string& target_arch) ...@@ -173,16 +176,38 @@ bool has_xdlops(const std::string& target_arch)
struct mlir_program struct mlir_program
{ {
mlir_program() mlir_program()
: ctx(mlirContextCreate()), : ctx(mlirContextCreateWithRegistry(get_dialect_registry().get(),
/*threadingEnable=*/false)),
location(mlirLocationUnknownGet(ctx.get())), location(mlirLocationUnknownGet(ctx.get())),
mmodule(mlirModuleCreateEmpty(location)) mmodule(mlirModuleCreateEmpty(location))
{ {
MlirDialectRegistry registry = mlirDialectRegistryCreate(); mlirContextSetThreadPool(ctx.get(), get_thread_pool().get());
mlirRegisterRocMLIRDialects(registry);
mlirContextAppendDialectRegistry(ctx.get(), registry);
mlirContextLoadAllAvailableDialects(ctx.get()); mlirContextLoadAllAvailableDialects(ctx.get());
mlirDialectRegistryDestroy(registry); }
mlirContextSetAllowUnregisteredDialects(ctx.get(), true /*allow*/);
static mlir_dialect_registry& get_dialect_registry()
{
static std::once_flag init_guard;
static mlir_dialect_registry the_registry;
// The MLIR registration functions (for dialects and passes) are not
// necessarily thread-safe and need to be executed exactly once
// (especially since they eventually call non-thread-safe LLVM
// initilizations).
std::call_once(init_guard, [&]() {
the_registry = mlirDialectRegistryCreate();
mlirRegisterRocMLIRDialects(the_registry.get());
mlirRegisterRocMLIRPasses();
});
return the_registry;
}
static mlir_thread_pool& get_thread_pool()
{
// To save on overhead, we create one LLVM thread pool and reuse it
// across all MLIR contexts as recommended by MLIR upstream.
// Note that this is thread-safe as of C++11.
static mlir_thread_pool the_pool = mlirLlvmThreadPoolCreate();
return the_pool;
} }
MlirType make_type(shape::type_t t) const MlirType make_type(shape::type_t t) const
...@@ -244,8 +269,6 @@ struct mlir_program ...@@ -244,8 +269,6 @@ struct mlir_program
MlirAttribute attribute(std::int64_t i) const MlirAttribute attribute(std::int64_t i) const
{ {
if(i < 0)
MIGRAPHX_THROW("MLIR cant handle negative values since they are ambiguous");
return mlirIntegerAttrGet(mlirIntegerTypeGet(ctx.get(), 64), i); return mlirIntegerAttrGet(mlirIntegerTypeGet(ctx.get(), 64), i);
} }
MlirAttribute attribute(std::uint64_t i) const MlirAttribute attribute(std::uint64_t i) const
......
...@@ -57,6 +57,7 @@ ...@@ -57,6 +57,7 @@
#include <migraphx/gpu/concat_gpu_opt.hpp> #include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device_name.hpp> #include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/gpu/fuse_mlir.hpp> #include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/fuse_ops.hpp> #include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp> #include <migraphx/gpu/prefuse_ops.hpp>
...@@ -74,6 +75,8 @@ namespace gpu { ...@@ -74,6 +75,8 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK)
struct id_pass struct id_pass
{ {
std::string name() const { return "id"; } std::string name() const { return "id"; }
...@@ -121,7 +124,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -121,7 +124,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
inline_module{}, inline_module{},
rewrite_pooling{}, rewrite_pooling{},
dead_code_elimination{}, dead_code_elimination{},
rewrite_gelu{}, enable_pass(options.fast_math, rewrite_gelu{}),
optimize_module{}, optimize_module{},
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), layout_nhwc{}), enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), layout_nhwc{}),
dead_code_elimination{}, dead_code_elimination{},
...@@ -133,6 +136,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -133,6 +136,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
enable_pass(not enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{}), fuse_reduce{}), enable_pass(not enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{}), fuse_reduce{}),
dead_code_elimination{}, dead_code_elimination{},
enable_pass(enabled(MIGRAPHX_ENABLE_CK{}), fuse_ck{}),
dead_code_elimination{},
enable_pass(mlir_enabled(), fuse_mlir{&ctx}), enable_pass(mlir_enabled(), fuse_mlir{&ctx}),
dead_code_elimination{}, dead_code_elimination{},
lowering{&ctx, options.offload_copy}, lowering{&ctx, options.offload_copy},
...@@ -150,7 +155,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -150,7 +155,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
adjust_allocation{gpu_allocation_model{}}, adjust_allocation{gpu_allocation_model{}},
dead_code_elimination{}, dead_code_elimination{},
compile_ops{&ctx}, compile_ops{&ctx, options.exhaustive_tune},
dead_code_elimination{}, dead_code_elimination{},
promote_literals{}, promote_literals{},
dead_code_elimination{}, dead_code_elimination{},
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/gpu/driver/perf.hpp> #include <migraphx/gpu/time_op.hpp>
#include <migraphx/context.hpp> #include <migraphx/context.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/time.hpp> #include <migraphx/time.hpp>
...@@ -30,7 +30,6 @@ ...@@ -30,7 +30,6 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace driver {
std::vector<argument> generate_arguments(const std::vector<shape>& shapes, unsigned long seed = 0) std::vector<argument> generate_arguments(const std::vector<shape>& shapes, unsigned long seed = 0)
{ {
...@@ -69,7 +68,6 @@ time_op(context& ictx, operation op, const std::vector<shape>& inputs, int n) ...@@ -69,7 +68,6 @@ time_op(context& ictx, operation op, const std::vector<shape>& inputs, int n)
return std::make_pair(host_time / n, device_time / n); return std::make_pair(host_time / n, device_time / n);
} }
} // namespace driver
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/optional.hpp> #include <migraphx/optional.hpp>
#include <migraphx/hash.hpp>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
...@@ -519,6 +520,38 @@ std::ostream& operator<<(std::ostream& os, const value& d) ...@@ -519,6 +520,38 @@ std::ostream& operator<<(std::ostream& os, const value& d)
return os; return os;
} }
template <class T>
std::size_t value_hash(const std::string& key, const T& x)
{
std::size_t h = hash_value(key);
hash_combine(h, x);
return h;
}
std::size_t value_hash(const std::string& key, std::nullptr_t) { return hash_value(key); }
std::size_t value_hash(const std::string& key, const std::vector<value>& x)
{
std::size_t h = hash_value(key);
for(const auto& v : x)
hash_combine(h, v);
return h;
}
std::size_t value_hash(const std::string& key, const value::binary& x)
{
std::size_t h = hash_value(key);
for(const auto& v : x)
hash_combine(h, v);
return h;
}
std::size_t value::hash() const
{
std::size_t h = 0;
this->visit_value([&](const auto& a) { h = value_hash(this->get_key(), a); });
return h;
}
void value::debug_print(bool show_type) const void value::debug_print(bool show_type) const
{ {
if(show_type) if(show_type)
......
...@@ -24,8 +24,6 @@ ...@@ -24,8 +24,6 @@
cmake_policy(SET CMP0057 NEW) cmake_policy(SET CMP0057 NEW)
include(CTest)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
include(ProcessorCount) include(ProcessorCount)
ProcessorCount(N) ProcessorCount(N)
......
...@@ -30,7 +30,7 @@ void expect_equal(const char* x, const char* y) ...@@ -30,7 +30,7 @@ void expect_equal(const char* x, const char* y)
abort(); abort();
} }
int main() int main(void)
{ {
char name[1024]; char name[1024];
migraphx_operation_t op; migraphx_operation_t op;
......
...@@ -151,7 +151,7 @@ TEST_CASE(dynamic_batch_load_and_run_offload) ...@@ -151,7 +151,7 @@ TEST_CASE(dynamic_batch_load_and_run_offload)
c_options.set_offload_copy(); c_options.set_offload_copy();
p.compile(migraphx::target("gpu"), c_options); p.compile(migraphx::target("gpu"), c_options);
auto out_shapes = p.get_output_shapes(); auto out_shapes = p.get_output_shapes();
CHECK(out_shapes.size() == 1); EXPECT(out_shapes.size() == 1);
EXPECT(out_shapes[0].dynamic()); EXPECT(out_shapes[0].dynamic());
// batch size = 2 // batch size = 2
...@@ -165,9 +165,9 @@ TEST_CASE(dynamic_batch_load_and_run_offload) ...@@ -165,9 +165,9 @@ TEST_CASE(dynamic_batch_load_and_run_offload)
migraphx::argument(migraphx::shape(migraphx_shape_float_type, {2, 3, 3, 3}), c.data())); migraphx::argument(migraphx::shape(migraphx_shape_float_type, {2, 3, 3, 3}), c.data()));
auto outputs = p.eval(pp); auto outputs = p.eval(pp);
CHECK(shapes_before.size() == outputs.size()); EXPECT(shapes_before.size() == outputs.size());
CHECK(bool{outputs.front().get_shape() == EXPECT(bool{outputs.front().get_shape() ==
migraphx::shape(migraphx_shape_float_type, {2, 1, 3, 3})}); migraphx::shape(migraphx_shape_float_type, {2, 2, 2, 2})});
} }
TEST_CASE(load_and_run_async) TEST_CASE(load_and_run_async)
......
...@@ -41,7 +41,7 @@ TEST_CASE(simple_test) ...@@ -41,7 +41,7 @@ TEST_CASE(simple_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
auto count = std::distance(mm->begin(), mm->end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == count); EXPECT(std::distance(mm->begin(), mm->end()) == count);
...@@ -57,7 +57,7 @@ TEST_CASE(simple_test_nop) ...@@ -57,7 +57,7 @@ TEST_CASE(simple_test_nop)
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(nop{}); mm->add_instruction(nop{});
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
auto count = std::distance(mm->begin(), mm->end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == count); EXPECT(std::distance(mm->begin(), mm->end()) == count);
...@@ -73,7 +73,7 @@ TEST_CASE(simple_test_nop2) ...@@ -73,7 +73,7 @@ TEST_CASE(simple_test_nop2)
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(nop{}); mm->add_instruction(nop{});
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(nop{}); mm->add_instruction(nop{});
run_pass(p); run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == 2); EXPECT(std::distance(mm->begin(), mm->end()) == 2);
...@@ -88,8 +88,8 @@ TEST_CASE(duplicate_test1) ...@@ -88,8 +88,8 @@ TEST_CASE(duplicate_test1)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
auto count = std::distance(mm->begin(), mm->end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == (count - 1)); EXPECT(std::distance(mm->begin(), mm->end()) == (count - 1));
...@@ -104,9 +104,9 @@ TEST_CASE(duplicate_test2) ...@@ -104,9 +104,9 @@ TEST_CASE(duplicate_test2)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(minus_op{}, one, two); mm->add_instruction(migraphx::make_op("sub"), one, two);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
auto count = std::distance(mm->begin(), mm->end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == (count - 2)); EXPECT(std::distance(mm->begin(), mm->end()) == (count - 2));
...@@ -121,11 +121,11 @@ TEST_CASE(depth_test) ...@@ -121,11 +121,11 @@ TEST_CASE(depth_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto x1 = mm->add_instruction(sum_op{}, one, two); auto x1 = mm->add_instruction(migraphx::make_op("add"), one, two);
auto x2 = mm->add_instruction(sum_op{}, one, two); auto x2 = mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(minus_op{}, x1, x2); mm->add_instruction(migraphx::make_op("sub"), x1, x2);
mm->add_instruction(minus_op{}, x1, x2); mm->add_instruction(migraphx::make_op("sub"), x1, x2);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
auto count = std::distance(mm->begin(), mm->end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == (count - 4)); EXPECT(std::distance(mm->begin(), mm->end()) == (count - 4));
...@@ -141,7 +141,7 @@ TEST_CASE(undefined_test) ...@@ -141,7 +141,7 @@ TEST_CASE(undefined_test)
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(migraphx::make_op("undefined")); mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
auto count = std::distance(mm->begin(), mm->end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == count - 1); EXPECT(std::distance(mm->begin(), mm->end()) == count - 1);
...@@ -232,7 +232,6 @@ TEST_CASE(reused_twice) ...@@ -232,7 +232,6 @@ TEST_CASE(reused_twice)
auto count = std::distance(mm->begin(), mm->end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
p.debug_print();
EXPECT(std::distance(mm->begin(), mm->end()) != count); EXPECT(std::distance(mm->begin(), mm->end()) != count);
EXPECT(std::distance(mm->begin(), mm->end()) == 4); EXPECT(std::distance(mm->begin(), mm->end()) == 4);
} }
...@@ -274,4 +273,17 @@ TEST_CASE(param_not_eliminated) ...@@ -274,4 +273,17 @@ TEST_CASE(param_not_eliminated)
EXPECT(p == create_program()); EXPECT(p == create_program());
} }
TEST_CASE(tuple_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(tuple_op{}, one, two);
mm->add_return({one, two});
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == (count - 1));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -45,7 +45,7 @@ TEST_CASE(simple_test) ...@@ -45,7 +45,7 @@ TEST_CASE(simple_test)
auto one_identity = mm->add_instruction(migraphx::make_op("identity"), one); auto one_identity = mm->add_instruction(migraphx::make_op("identity"), one);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto two_identity = mm->add_instruction(migraphx::make_op("identity"), two); auto two_identity = mm->add_instruction(migraphx::make_op("identity"), two);
mm->add_instruction(sum_op{}, one_identity, two_identity); mm->add_instruction(migraphx::make_op("add"), one_identity, two_identity);
run_pass(p); run_pass(p);
EXPECT(std::none_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) { EXPECT(std::none_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity"; return ins.name() == "identity";
...@@ -62,7 +62,7 @@ TEST_CASE(simple_test_end) ...@@ -62,7 +62,7 @@ TEST_CASE(simple_test_end)
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto ans = mm->add_instruction(sum_op{}, one, two); auto ans = mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(migraphx::make_op("identity"), ans); mm->add_instruction(migraphx::make_op("identity"), ans);
run_pass(p); run_pass(p);
EXPECT(std::none_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) { EXPECT(std::none_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) {
...@@ -81,8 +81,8 @@ TEST_CASE(simple_test_end_dependency) ...@@ -81,8 +81,8 @@ TEST_CASE(simple_test_end_dependency)
auto one = mm->add_literal(1.0); auto one = mm->add_literal(1.0);
auto two = mm->add_literal(2.0); auto two = mm->add_literal(2.0);
auto three = mm->add_literal(3.0); auto three = mm->add_literal(3.0);
auto ans = mm->add_instruction(sum_op{}, one, two); auto ans = mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(sum_op{}, ans, three); mm->add_instruction(migraphx::make_op("add"), ans, three);
mm->add_instruction(migraphx::make_op("identity"), ans); mm->add_instruction(migraphx::make_op("identity"), ans);
run_pass(p); run_pass(p);
EXPECT(std::any_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) { EXPECT(std::any_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) {
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/compile_options.hpp> #include <migraphx/compile_options.hpp>
#include <migraphx/make_op.hpp>
#include <sstream> #include <sstream>
#include "test.hpp" #include "test.hpp"
#include <basic_ops.hpp> #include <basic_ops.hpp>
...@@ -49,7 +50,7 @@ struct id_target ...@@ -49,7 +50,7 @@ struct id_target
struct id_ctx_op struct id_ctx_op
{ {
std::string name() const { return "id_ctx_op"; } std::string name() const { return ""; }
migraphx::argument migraphx::argument
compute(id_target::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const compute(id_target::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
{ {
...@@ -156,7 +157,7 @@ TEST_CASE(literal_test1) ...@@ -156,7 +157,7 @@ TEST_CASE(literal_test1)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3}); EXPECT(result == migraphx::literal{3});
EXPECT(result != migraphx::literal{4}); EXPECT(result != migraphx::literal{4});
...@@ -168,8 +169,8 @@ TEST_CASE(literal_test2) ...@@ -168,8 +169,8 @@ TEST_CASE(literal_test2)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two); auto sum1 = mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(sum_op{}, sum1, two); mm->add_instruction(migraphx::make_op("add"), sum1, two);
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{5}); EXPECT(result == migraphx::literal{5});
...@@ -182,7 +183,7 @@ TEST_CASE(print_test) ...@@ -182,7 +183,7 @@ TEST_CASE(print_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::int32_type}); auto x = mm->add_parameter("x", {migraphx::shape::int32_type});
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, x, two); mm->add_instruction(migraphx::make_op("add"), x, two);
std::stringstream ss; std::stringstream ss;
ss << p; ss << p;
...@@ -197,7 +198,7 @@ TEST_CASE(param_test) ...@@ -197,7 +198,7 @@ TEST_CASE(param_test)
auto x = mm->add_parameter("x", {migraphx::shape::int32_type}); auto x = mm->add_parameter("x", {migraphx::shape::int32_type});
auto y = mm->add_parameter("y", {migraphx::shape::int32_type}); auto y = mm->add_parameter("y", {migraphx::shape::int32_type});
mm->add_instruction(sum_op{}, x, y); mm->add_instruction(migraphx::make_op("add"), x, y);
auto result = p.eval({{"x", migraphx::literal{1}.get_argument()}, auto result = p.eval({{"x", migraphx::literal{1}.get_argument()},
{"y", migraphx::literal{2}.get_argument()}}) {"y", migraphx::literal{2}.get_argument()}})
.back(); .back();
...@@ -227,7 +228,7 @@ TEST_CASE(param_error_shape_test) ...@@ -227,7 +228,7 @@ TEST_CASE(param_error_shape_test)
auto x = mm->add_parameter("x", {migraphx::shape::int32_type, {1, 1}}); auto x = mm->add_parameter("x", {migraphx::shape::int32_type, {1, 1}});
auto y = mm->add_parameter("y", {migraphx::shape::int32_type, {1, 1}}); auto y = mm->add_parameter("y", {migraphx::shape::int32_type, {1, 1}});
mm->add_instruction(sum_op{}, x, y); mm->add_instruction(migraphx::make_op("add"), x, y);
EXPECT(test::throws<migraphx::exception>( EXPECT(test::throws<migraphx::exception>(
[&] { [&] {
p.eval({ p.eval({
...@@ -245,7 +246,7 @@ TEST_CASE(get_param1) ...@@ -245,7 +246,7 @@ TEST_CASE(get_param1)
migraphx::shape s{migraphx::shape::int32_type, {1, 2}}; migraphx::shape s{migraphx::shape::int32_type, {1, 2}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
mm->add_instruction(sum_op{}, x, y); mm->add_instruction(migraphx::make_op("add"), x, y);
EXPECT(bool{p.get_parameter("x") == x}); EXPECT(bool{p.get_parameter("x") == x});
EXPECT(bool{p.get_parameter("y") == y}); EXPECT(bool{p.get_parameter("y") == y});
EXPECT(bool{p.get_parameter("nonexistent") == mm->end()}); EXPECT(bool{p.get_parameter("nonexistent") == mm->end()});
...@@ -257,7 +258,7 @@ TEST_CASE(get_param2) ...@@ -257,7 +258,7 @@ TEST_CASE(get_param2)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
EXPECT(bool{p.get_parameter("nonexistent") == mm->end()}); EXPECT(bool{p.get_parameter("nonexistent") == mm->end()});
} }
...@@ -268,7 +269,7 @@ TEST_CASE(get_param_shapes) ...@@ -268,7 +269,7 @@ TEST_CASE(get_param_shapes)
migraphx::shape s{migraphx::shape::int32_type, {1, 2}}; migraphx::shape s{migraphx::shape::int32_type, {1, 2}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
mm->add_instruction(sum_op{}, x, y); mm->add_instruction(migraphx::make_op("add"), x, y);
auto m = p.get_parameter_shapes(); auto m = p.get_parameter_shapes();
EXPECT(m.count("nonexistent") == 0); EXPECT(m.count("nonexistent") == 0);
EXPECT(m.at("x") == s); EXPECT(m.at("x") == s);
...@@ -281,8 +282,8 @@ TEST_CASE(replace_test) ...@@ -281,8 +282,8 @@ TEST_CASE(replace_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two); auto sum = mm->add_instruction(migraphx::make_op("add"), one, two);
mm->replace_instruction(sum, minus_op{}, two, one); mm->replace_instruction(sum, migraphx::make_op("sub"), two, one);
EXPECT(bool{p.validate() == mm->end()}); EXPECT(bool{p.validate() == mm->end()});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -296,8 +297,8 @@ TEST_CASE(replace_ins_test) ...@@ -296,8 +297,8 @@ TEST_CASE(replace_ins_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two); auto sum = mm->add_instruction(migraphx::make_op("add"), one, two);
auto minus = mm->add_instruction(minus_op{}, two, one); auto minus = mm->add_instruction(migraphx::make_op("sub"), two, one);
mm->replace_instruction(sum, minus); mm->replace_instruction(sum, minus);
EXPECT(bool{p.validate() == mm->end()}); EXPECT(bool{p.validate() == mm->end()});
...@@ -312,8 +313,8 @@ TEST_CASE(replace_ins_test2) ...@@ -312,8 +313,8 @@ TEST_CASE(replace_ins_test2)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two); auto sum = mm->add_instruction(migraphx::make_op("add"), one, two);
auto minus = mm->add_instruction(minus_op{}, two, one); auto minus = mm->add_instruction(migraphx::make_op("sub"), two, one);
mm->add_instruction(pass_op{}, minus); mm->add_instruction(pass_op{}, minus);
mm->replace_instruction(two, sum); mm->replace_instruction(two, sum);
EXPECT(bool{p.validate() == mm->end()}); EXPECT(bool{p.validate() == mm->end()});
...@@ -329,8 +330,8 @@ TEST_CASE(replace_op_test) ...@@ -329,8 +330,8 @@ TEST_CASE(replace_op_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, two, one); auto sum = mm->add_instruction(migraphx::make_op("add"), two, one);
sum->replace(minus_op{}); sum->replace(migraphx::make_op("sub"));
EXPECT(bool{p.validate() == mm->end()}); EXPECT(bool{p.validate() == mm->end()});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -344,7 +345,7 @@ TEST_CASE(replace_op_recompute_shape_throw) ...@@ -344,7 +345,7 @@ TEST_CASE(replace_op_recompute_shape_throw)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two); auto sum = mm->add_instruction(migraphx::make_op("add"), one, two);
EXPECT(test::throws<migraphx::exception>([&] { sum->replace(unary_pass_op{}); })); EXPECT(test::throws<migraphx::exception>([&] { sum->replace(unary_pass_op{}); }));
} }
...@@ -354,11 +355,11 @@ TEST_CASE(insert_replace_test) ...@@ -354,11 +355,11 @@ TEST_CASE(insert_replace_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two); auto sum1 = mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(sum_op{}, sum1, two); mm->add_instruction(migraphx::make_op("add"), sum1, two);
auto sum0 = mm->insert_instruction(sum1, sum_op{}, two, two); auto sum0 = mm->insert_instruction(sum1, migraphx::make_op("add"), two, two);
mm->replace_instruction(sum1, minus_op{}, sum0, two); mm->replace_instruction(sum1, migraphx::make_op("sub"), sum0, two);
EXPECT(bool{p.validate() == mm->end()}); EXPECT(bool{p.validate() == mm->end()});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -372,8 +373,8 @@ TEST_CASE(remove_test1) ...@@ -372,8 +373,8 @@ TEST_CASE(remove_test1)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two); auto sum = mm->add_instruction(migraphx::make_op("add"), one, two);
auto removed = mm->add_instruction(minus_op{}, sum, one); auto removed = mm->add_instruction(migraphx::make_op("sub"), sum, one);
mm->remove_instruction(removed); mm->remove_instruction(removed);
EXPECT(bool{p.validate() == mm->end()}); EXPECT(bool{p.validate() == mm->end()});
...@@ -388,8 +389,8 @@ TEST_CASE(remove_test2) ...@@ -388,8 +389,8 @@ TEST_CASE(remove_test2)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto removed = mm->add_instruction(minus_op{}, two, one); auto removed = mm->add_instruction(migraphx::make_op("sub"), two, one);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
mm->remove_instruction(removed); mm->remove_instruction(removed);
EXPECT(bool{p.validate() == mm->end()}); EXPECT(bool{p.validate() == mm->end()});
...@@ -404,7 +405,7 @@ TEST_CASE(target_test) ...@@ -404,7 +405,7 @@ TEST_CASE(target_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
p.compile(id_target{}); p.compile(id_target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3}); EXPECT(result == migraphx::literal{3});
...@@ -460,7 +461,7 @@ TEST_CASE(eval_context1) ...@@ -460,7 +461,7 @@ TEST_CASE(eval_context1)
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(sum_op{}, one, two);
p.compile(t); p.compile(t);
EXPECT(is_shared(t.ctx, p.get_context())); EXPECT(is_shared(t.ctx, p.get_context()));
p.eval({}).back(); std::ignore = p.eval({}).back();
EXPECT(is_shared(t.ctx, p.get_context())); EXPECT(is_shared(t.ctx, p.get_context()));
} }
...@@ -475,7 +476,7 @@ TEST_CASE(eval_context2) ...@@ -475,7 +476,7 @@ TEST_CASE(eval_context2)
mm->add_instruction(id_ctx_op{}, one, two); mm->add_instruction(id_ctx_op{}, one, two);
p.compile(t); p.compile(t);
EXPECT(is_shared(t.ctx, p.get_context())); EXPECT(is_shared(t.ctx, p.get_context()));
p.eval({}).back(); std::ignore = p.eval({}).back();
// id_ctx_op will modify the context // id_ctx_op will modify the context
EXPECT(not is_shared(t.ctx, p.get_context())); EXPECT(not is_shared(t.ctx, p.get_context()));
} }
...@@ -493,7 +494,7 @@ TEST_CASE(eval_context3) ...@@ -493,7 +494,7 @@ TEST_CASE(eval_context3)
// Finalizer will modify the context // Finalizer will modify the context
EXPECT(not is_shared(t.ctx, p.get_context())); EXPECT(not is_shared(t.ctx, p.get_context()));
auto ctx = p.get_context(); auto ctx = p.get_context();
p.eval({}).back(); std::ignore = p.eval({}).back();
EXPECT(is_shared(ctx, p.get_context())); EXPECT(is_shared(ctx, p.get_context()));
EXPECT(not is_shared(t.ctx, p.get_context())); EXPECT(not is_shared(t.ctx, p.get_context()));
} }
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
#include <migraphx/make_op.hpp>
void run_prog(migraphx::program p,
const migraphx::target& t,
migraphx::parameter_map& m_in,
std::vector<float>& res)
{
p.compile(t);
migraphx::parameter_map m;
for(auto&& x : p.get_parameter_shapes())
{
if(m_in.count(x.first) > 0)
{
m[x.first] = t.copy_to(m_in[x.first]);
}
else
{
m[x.first] = t.allocate(x.second);
}
}
auto result = t.copy_from(p.eval(m).back());
result.visit([&](auto v) { res.assign(v.begin(), v.end()); });
}
// This test ensures that the codegen path doesn't round up literals,
// otherwise there are accuracy differences compared to ref.
// The values being passed in are 0.5 * (1/0.00787402),
// and after rounding must equal 63, not 64.
TEST_CASE(mul_literal_round_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::float_type, {1}};
auto l0 = mm->add_parameter("a", s0);
auto l1 = mm->add_literal(1 / 0.00787402f);
auto mul = mm->add_instruction(migraphx::make_op("mul"), l0, l1);
auto round = mm->add_instruction(migraphx::make_op("round"), mul);
mm->add_return({round});
migraphx::parameter_map m;
std::vector<float> a = {0.5f};
m["a"] = migraphx::argument{s0, a.data()};
std::vector<float> ref_result;
migraphx::target ref_t = migraphx::make_target("ref");
run_prog(p, ref_t, m, ref_result);
std::vector<float> gpu_result;
migraphx::target gpu_t = migraphx::make_target("gpu");
run_prog(p, gpu_t, m, gpu_result);
EXPECT(migraphx::verify_range(ref_result, gpu_result));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/program.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
#include <pointwise.hpp>
void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::gpu::fuse_mlir{}, migraphx::dead_code_elimination{}});
}
template <class F>
migraphx::instruction_ref add_mlir(migraphx::program& p,
const std::string& name,
std::vector<migraphx::instruction_ref> inputs,
std::vector<std::string> arg_names,
F f)
{
assert(inputs.size() == arg_names.size() && "One interior parameter name given per input.");
auto* mm = p.get_main_module();
auto* pm = p.create_module(name);
pm->set_bypass();
std::vector<migraphx::instruction_ref> params;
for(size_t i = 0, e = inputs.size(); i < e; ++i)
{
params.push_back(pm->add_parameter(arg_names[i], inputs[i]->get_shape()));
}
auto values = f(pm, params);
auto root = std::get<0>(values);
auto r = std::get<1>(values);
pm->add_return({r});
return mm->add_instruction(
migraphx::make_op("gpu::mlir_op", {{"op", migraphx::to_value(root->get_operator())}}),
inputs,
{pm});
}
TEST_CASE(dot_add)
{
migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto a = mm->add_parameter("a", s);
auto b = mm->add_parameter("b", s);
auto x = mm->add_parameter("x", s);
auto dot = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto add = add_pointwise(p1, "main:pointwise0", {dot, x}, single_pointwise("add"));
mm->add_return({add});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto a = mm->add_parameter("a", s);
auto b = mm->add_parameter("b", s);
auto x = mm->add_parameter("x", s);
auto fused =
add_mlir(p2,
"mlir_main:pointwise0",
{x, a, b},
{"x1", "y0", "y1"},
[=](auto* pm, const auto& inputs) {
auto dot =
pm->add_instruction(migraphx::make_op("dot"), inputs[1], inputs[2]);
auto add = pm->add_instruction(migraphx::make_op("add"), dot, inputs[0]);
return std::make_tuple(dot, add);
});
mm->add_return({fused});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(int_quant_dot_abs)
{
migraphx::shape s_a{migraphx::shape::int8_type, {5, 4}};
migraphx::shape s_b{migraphx::shape::int8_type, {4, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto a = mm->add_parameter("a", s_a);
auto b = mm->add_parameter("b", s_b);
auto dot = mm->add_instruction(migraphx::make_op("quant_dot"), a, b);
auto abs = add_pointwise(p1, "main:pointwise0", {dot}, single_pointwise("abs"));
mm->add_return({abs});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto a = mm->add_parameter("a", s_a);
auto b = mm->add_parameter("b", s_b);
auto fused = add_mlir(
p2, "mlir_main:pointwise0", {a, b}, {"y0", "y1"}, [=](auto* pm, const auto& inputs) {
auto dot =
pm->add_instruction(migraphx::make_op("quant_dot"), inputs[0], inputs[1]);
auto abs = pm->add_instruction(migraphx::make_op("abs"), dot);
return std::make_tuple(dot, abs);
});
mm->add_return({fused});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(int_quant_dot_tanh_fails)
{
migraphx::shape s_a{migraphx::shape::int8_type, {5, 4}};
migraphx::shape s_b{migraphx::shape::int8_type, {4, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto a = mm->add_parameter("a", s_a);
auto b = mm->add_parameter("b", s_b);
auto dot = mm->add_instruction(migraphx::make_op("quant_dot"), a, b);
auto tanh = add_pointwise(p1, "main:pointwise0", {dot}, single_pointwise("tanh"));
mm->add_return({tanh});
}
migraphx::program p2(p1);
// This pass should do nothing as int32_t tanh isn't supported.
run_pass(p1);
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[])
{
if(migraphx::gpu::mlir_enabled())
test::run(argc, argv);
return 0;
}
...@@ -187,12 +187,39 @@ module { ...@@ -187,12 +187,39 @@ module {
EXPECT(verify_mlir(m)); EXPECT(verify_mlir(m));
} }
TEST_CASE(quant_dot_add)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @main(%arg0: tensor<1x5x4xi8>, %arg1: tensor<1x4x3xi8>, %arg2: tensor<1x5x3xi32>) -> tensor<1x5x3xi32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.quant_dot(%arg0, %arg1) : (tensor<1x5x4xi8>, tensor<1x4x3xi8>) -> tensor<1x5x3xi32>
%1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xi32>, tensor<1x5x3xi32>) -> tensor<1x5x3xi32>
return %1 : tensor<1x5x3xi32>
}
}
)__migraphx__";
migraphx::module m;
auto arg0 = m.add_parameter("arg0", {migraphx::shape::int8_type, {1, 5, 4}});
auto arg1 = m.add_parameter("arg1", {migraphx::shape::int8_type, {1, 4, 3}});
auto arg2 = m.add_parameter("arg2", {migraphx::shape::int32_type, {1, 5, 3}});
auto conv = m.add_instruction(migraphx::make_op("quant_dot"), arg0, arg1);
auto add = m.add_instruction(migraphx::make_op("add"), conv, arg2);
m.add_return({add});
auto s = migraphx::gpu::dump_mlir(m);
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
EXPECT(verify_mlir(m));
}
TEST_CASE(dot_add) TEST_CASE(dot_add)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_dot(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr"} { func.func @mlir_dot(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.dot(%arg0, %arg1) : tensor<1x5x4xf32>, tensor<1x4x3xf32> -> tensor<1x5x3xf32> %0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32>
%1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32> %1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32>
return %1 : tensor<1x5x3xf32> return %1 : tensor<1x5x3xf32>
} }
...@@ -246,4 +273,57 @@ module { ...@@ -246,4 +273,57 @@ module {
EXPECT(verify_mlir(m)); EXPECT(verify_mlir(m));
} }
TEST_CASE(dot_convert)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_dot(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>) -> tensor<1x5x3xf16> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32>
%1 = migraphx.convert(%0) {target_type = 1 : i64} : (tensor<1x5x3xf32>) -> tensor<1x5x3xf16>
return %1 : tensor<1x5x3xf16>
}
}
)__migraphx__";
migraphx::module m;
auto arg0 = m.add_parameter("arg0", {migraphx::shape::float_type, {1, 5, 4}});
auto arg1 = m.add_parameter("arg1", {migraphx::shape::float_type, {1, 4, 3}});
auto dot = m.add_instruction(migraphx::make_op("dot"), arg0, arg1);
auto trunc = m.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), dot);
m.add_return({trunc});
auto s = migraphx::gpu::dump_mlir(m);
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
EXPECT(verify_mlir(m));
}
TEST_CASE(dot_where)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_dot(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xi8>, %arg3: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32>
%1 = migraphx.where(%arg2, %0, %arg3) : (tensor<1x5x3xi8>, tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32>
return %1 : tensor<1x5x3xf32>
}
}
)__migraphx__";
migraphx::module m;
auto arg0 = m.add_parameter("arg0", {migraphx::shape::float_type, {1, 5, 4}});
auto arg1 = m.add_parameter("arg1", {migraphx::shape::float_type, {1, 4, 3}});
auto arg2 = m.add_parameter("arg2", {migraphx::shape::bool_type, {1, 5, 3}});
auto arg3 = m.add_parameter("arg3", {migraphx::shape::float_type, {1, 5, 3}});
auto dot = m.add_instruction(migraphx::make_op("dot"), arg0, arg1);
auto where = m.add_instruction(migraphx::make_op("where"), arg2, dot, arg3);
m.add_return({where});
auto s = migraphx::gpu::dump_mlir(m);
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
EXPECT(verify_mlir(m));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
This diff is collapsed.
...@@ -177,4 +177,10 @@ TEST_CASE(value_literal) ...@@ -177,4 +177,10 @@ TEST_CASE(value_literal)
EXPECT(l4 == l2); EXPECT(l4 == l2);
} }
TEST_CASE(literal_to_string_float_precision)
{
migraphx::literal x{126.99993142003703f};
EXPECT(x.to_string() != "127");
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
5a43828b3d73028bfd33b3856f82698d9ab02cb1 fbf08c4b4dce5da245189203d9f6cfc41f6663a2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment