Commit 5ec8f913 authored by Ted Themistokleous's avatar Ted Themistokleous Committed by Ted Themistokleous
Browse files

Merge branch 'develop' into simplify_1_mul_div_ops

parents 32d69e8e d78bcdfb
...@@ -53,6 +53,7 @@ jobs: ...@@ -53,6 +53,7 @@ jobs:
CXX=/opt/rocm/llvm/bin/clang++ CC=/opt/rocm/llvm/bin/clang cmake \ CXX=/opt/rocm/llvm/bin/clang++ CC=/opt/rocm/llvm/bin/clang cmake \
-DMIGRAPHX_ENABLE_GPU=On \ -DMIGRAPHX_ENABLE_GPU=On \
-DMIGRAPHX_ENABLE_CPU=On \ -DMIGRAPHX_ENABLE_CPU=On \
-DMIGRAPHX_ENABLE_FPGA=On \
-DROCM_ENABLE_GH_ANNOTATIONS=On \ -DROCM_ENABLE_GH_ANNOTATIONS=On \
-DCLANG_TIDY_DEPEND_ON_TARGET=Off \ -DCLANG_TIDY_DEPEND_ON_TARGET=Off \
-DCLANG_TIDY_CACHE=/data/tidy-cache \ -DCLANG_TIDY_CACHE=/data/tidy-cache \
......
name: MIGraphX Performance Tests name: MIGraphX Performance Tests
on: on:
push:
branches: [develop]
pull_request: pull_request:
branches: [develop] branches: [develop]
types: [opened, synchronize, closed]
schedule: schedule:
- cron: "0 5 * * 1-6" - cron: "0 5 * * 1-6"
......
...@@ -63,7 +63,7 @@ set(CMAKE_EXTRA_INCLUDE_FILES) ...@@ -63,7 +63,7 @@ set(CMAKE_EXTRA_INCLUDE_FILES)
include(ROCMSetupVersion) include(ROCMSetupVersion)
rocm_setup_version(VERSION 2.3) rocm_setup_version(VERSION 2.4)
set(MIGRAPHX_SO_VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR}) set(MIGRAPHX_SO_VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR})
option( BUILD_SHARED_LIBS "Build as a shared library" ON ) option( BUILD_SHARED_LIBS "Build as a shared library" ON )
......
...@@ -107,7 +107,7 @@ ...@@ -107,7 +107,7 @@
<summary>Use make_shared or make_unique instead of new</summary> <summary>Use make_shared or make_unique instead of new</summary>
</message> </message>
</rule> </rule>
<!-- <rule> <rule>
<tokenlist>raw</tokenlist> <tokenlist>raw</tokenlist>
<pattern><![CDATA[ \|\| ]]></pattern> <pattern><![CDATA[ \|\| ]]></pattern>
<message> <message>
...@@ -124,7 +124,7 @@ ...@@ -124,7 +124,7 @@
<severity>style</severity> <severity>style</severity>
<summary>Use 'not' instead of !</summary> <summary>Use 'not' instead of !</summary>
</message> </message>
</rule> --> </rule>
<!-- <rule> <!-- <rule>
<tokenlist>raw</tokenlist> <tokenlist>raw</tokenlist>
<pattern><![CDATA[] (__device__ |__host__ )+(\(|{)]]></pattern> <pattern><![CDATA[] (__device__ |__host__ )+(\(|{)]]></pattern>
......
...@@ -53,8 +53,8 @@ int main(int argc, char** argv) ...@@ -53,8 +53,8 @@ int main(int argc, char** argv)
migraphx::program p; migraphx::program p;
if(cmdOptionExists(argv + 2, argv + argc, "--parse") || if(cmdOptionExists(argv + 2, argv + argc, "--parse") or
!cmdOptionExists(argv + 2, argv + argc, "--load")) not cmdOptionExists(argv + 2, argv + argc, "--load"))
{ {
std::cout << "Parsing ONNX File" << std::endl; std::cout << "Parsing ONNX File" << std::endl;
migraphx::onnx_options options; migraphx::onnx_options options;
......
...@@ -72,7 +72,7 @@ struct sscal_custom_op final : migraphx::experimental_custom_op_base ...@@ -72,7 +72,7 @@ struct sscal_custom_op final : migraphx::experimental_custom_op_base
{ {
throw std::runtime_error("sscal_custom_op must have 2 input arguments"); throw std::runtime_error("sscal_custom_op must have 2 input arguments");
} }
if(inputs[0].lengths().size() != 1 || inputs[0].lengths()[0] != 1) if(inputs[0].lengths().size() != 1 or inputs[0].lengths()[0] != 1)
{ {
throw std::runtime_error("first input argument to sscal_custom_op must be a scalar"); throw std::runtime_error("first input argument to sscal_custom_op must be a scalar");
} }
......
...@@ -51,16 +51,16 @@ int main(int argc, char** argv) ...@@ -51,16 +51,16 @@ int main(int argc, char** argv)
char** begin = argv + 1; char** begin = argv + 1;
char** end = argv + argc; char** end = argv + argc;
const bool CPU = (std::find(begin, end, std::string("-c")) != end) || const bool CPU = (std::find(begin, end, std::string("-c")) != end) or
std::find(begin, end, std::string("--cpu")) != end; std::find(begin, end, std::string("--cpu")) != end;
const bool GPU = std::find(begin, end, std::string("-g")) != end || const bool GPU = std::find(begin, end, std::string("-g")) != end or
std::find(begin, end, std::string("--gpu")) != end; std::find(begin, end, std::string("--gpu")) != end;
const bool FP16 = std::find(begin, end, std::string("-f")) != end || const bool FP16 = std::find(begin, end, std::string("-f")) != end or
std::find(begin, end, std::string("--fp16")) != end; std::find(begin, end, std::string("--fp16")) != end;
const bool INT8 = std::find(begin, end, std::string("-i")) != end || const bool INT8 = std::find(begin, end, std::string("-i")) != end or
std::find(begin, end, std::string("--int8")) != end; std::find(begin, end, std::string("--int8")) != end;
const bool CALIB = std::find(begin, end, std::string("--cal")) != end; const bool CALIB = std::find(begin, end, std::string("--cal")) != end;
const bool PRINT = std::find(begin, end, std::string("-p")) != end || const bool PRINT = std::find(begin, end, std::string("-p")) != end or
std::find(begin, end, std::string("--print")) != end; std::find(begin, end, std::string("--print")) != end;
migraphx::program prog; migraphx::program prog;
...@@ -182,7 +182,7 @@ void read_nth_digit(const int n, std::vector<float>& digit) ...@@ -182,7 +182,7 @@ void read_nth_digit(const int n, std::vector<float>& digit)
const int HEIGHT = 28; const int HEIGHT = 28;
const int WIDTH = 28; const int WIDTH = 28;
if(!file.is_open()) if(not file.is_open())
{ {
return; return;
} }
......
...@@ -82,6 +82,7 @@ add_library(migraphx ...@@ -82,6 +82,7 @@ add_library(migraphx
simplify_qdq.cpp simplify_qdq.cpp
sqlite.cpp sqlite.cpp
rewrite_batchnorm.cpp rewrite_batchnorm.cpp
rewrite_gelu.cpp
rewrite_pooling.cpp rewrite_pooling.cpp
rewrite_quantization.cpp rewrite_quantization.cpp
rewrite_rnn.cpp rewrite_rnn.cpp
...@@ -90,7 +91,6 @@ add_library(migraphx ...@@ -90,7 +91,6 @@ add_library(migraphx
shape.cpp shape.cpp
simplify_algebra.cpp simplify_algebra.cpp
simplify_reshapes.cpp simplify_reshapes.cpp
target_assignments.cpp
tmp_dir.cpp tmp_dir.cpp
value.cpp value.cpp
verify_args.cpp verify_args.cpp
......
...@@ -517,7 +517,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -517,7 +517,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.") MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
shape(const migraphx_shape* p) { this->set_handle(p, borrow{}); } shape(const migraphx_shape* p) { this->set_handle(p, borrow{}); }
MIGRAPHX_HANDLE_CONSTRUCTOR(shape); MIGRAPHX_HANDLE_CONSTRUCTOR(shape)
/// Construct a scalar shape /// Construct a scalar shape
shape(migraphx_shape_datatype_t type) shape(migraphx_shape_datatype_t type)
...@@ -588,7 +588,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -588,7 +588,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return pout; return pout;
} }
friend bool operator!=(const shape& px, const shape& py) { return !(px == py); } friend bool operator!=(const shape& px, const shape& py) { return not(px == py); }
}; };
/** /**
...@@ -601,7 +601,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument) ...@@ -601,7 +601,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{ {
argument() {} argument() {}
MIGRAPHX_HANDLE_CONSTRUCTOR(argument); MIGRAPHX_HANDLE_CONSTRUCTOR(argument)
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.") MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
argument(const migraphx_argument* p) { this->set_handle(p, borrow{}); } argument(const migraphx_argument* p) { this->set_handle(p, borrow{}); }
...@@ -647,7 +647,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument) ...@@ -647,7 +647,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
return pout; return pout;
} }
friend bool operator!=(const argument& px, const argument& py) { return !(px == py); } friend bool operator!=(const argument& px, const argument& py) { return not(px == py); }
}; };
/// A target for compilation /// A target for compilation
...@@ -655,7 +655,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target) ...@@ -655,7 +655,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target)
{ {
target() {} target() {}
MIGRAPHX_HANDLE_CONSTRUCTOR(target); MIGRAPHX_HANDLE_CONSTRUCTOR(target)
/// Construct a target from its name /// Construct a target from its name
target(const char* name) { this->make_handle(&migraphx_target_create, name); } target(const char* name) { this->make_handle(&migraphx_target_create, name); }
...@@ -665,7 +665,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes) ...@@ -665,7 +665,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{ {
program_parameter_shapes() {} program_parameter_shapes() {}
MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameter_shapes); MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameter_shapes)
size_t size() const size_t size() const
{ {
...@@ -684,7 +684,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes) ...@@ -684,7 +684,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
std::vector<const char*> names() const std::vector<const char*> names() const
{ {
std::vector<const char*> result(this->size()); std::vector<const char*> result(this->size());
if(!result.empty()) if(not result.empty())
{ {
call(&migraphx_program_parameter_shapes_names, result.data(), this->get_handle_ptr()); call(&migraphx_program_parameter_shapes_names, result.data(), this->get_handle_ptr());
} }
...@@ -695,7 +695,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes) ...@@ -695,7 +695,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
/// A class to construct the inputs parameters for a program /// A class to construct the inputs parameters for a program
struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters) struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameters); MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameters)
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.") MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
program_parameters(migraphx_program_parameters* p) { this->set_handle(p, borrow{}); } program_parameters(migraphx_program_parameters* p) { this->set_handle(p, borrow{}); }
...@@ -722,7 +722,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters) ...@@ -722,7 +722,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments> struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(arguments); MIGRAPHX_HANDLE_CONSTRUCTOR(arguments)
size_t size() const size_t size() const
{ {
...@@ -741,7 +741,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments> ...@@ -741,7 +741,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes> struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(shapes); MIGRAPHX_HANDLE_CONSTRUCTOR(shapes)
size_t size() const size_t size() const
{ {
...@@ -760,7 +760,7 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes> ...@@ -760,7 +760,7 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
struct operation : MIGRAPHX_HANDLE_BASE(operation) struct operation : MIGRAPHX_HANDLE_BASE(operation)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(operation); MIGRAPHX_HANDLE_CONSTRUCTOR(operation)
template <class... Ts> template <class... Ts>
operation(const char* name, const char* attributes = nullptr, Ts... xs) operation(const char* name, const char* attributes = nullptr, Ts... xs)
...@@ -778,12 +778,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation) ...@@ -778,12 +778,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
struct instruction : MIGRAPHX_CONST_HANDLE_BASE(instruction) struct instruction : MIGRAPHX_CONST_HANDLE_BASE(instruction)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(instruction); MIGRAPHX_HANDLE_CONSTRUCTOR(instruction)
}; };
struct instructions : MIGRAPHX_HANDLE_BASE(instructions) struct instructions : MIGRAPHX_HANDLE_BASE(instructions)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(instructions); MIGRAPHX_HANDLE_CONSTRUCTOR(instructions)
template <class... Ts> template <class... Ts>
instructions(Ts... xs) instructions(Ts... xs)
...@@ -797,7 +797,7 @@ struct module; ...@@ -797,7 +797,7 @@ struct module;
struct modules : MIGRAPHX_HANDLE_BASE(modules) struct modules : MIGRAPHX_HANDLE_BASE(modules)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(modules); MIGRAPHX_HANDLE_CONSTRUCTOR(modules)
template <class... Ts> template <class... Ts>
modules(Ts... xs) modules(Ts... xs)
...@@ -911,7 +911,7 @@ struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options) ...@@ -911,7 +911,7 @@ struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
{ {
compile_options() { this->make_handle(&migraphx_compile_options_create); } compile_options() { this->make_handle(&migraphx_compile_options_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(compile_options); MIGRAPHX_HANDLE_CONSTRUCTOR(compile_options)
/// For targets with offloaded memory(such as the gpu), this will insert /// For targets with offloaded memory(such as the gpu), this will insert
/// instructions during compilation to copy the input parameters to the /// instructions during compilation to copy the input parameters to the
...@@ -935,7 +935,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program) ...@@ -935,7 +935,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
{ {
program() { this->make_handle(&migraphx_program_create); } program() { this->make_handle(&migraphx_program_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(program); MIGRAPHX_HANDLE_CONSTRUCTOR(program)
/// Compile the program for a specific target to be ran on /// Compile the program for a specific target to be ran on
void compile(const target& ptarget, const compile_options& poptions) const void compile(const target& ptarget, const compile_options& poptions) const
...@@ -1015,13 +1015,13 @@ struct program : MIGRAPHX_HANDLE_BASE(program) ...@@ -1015,13 +1015,13 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
return module{p_modu, this->share_handle()}; return module{p_modu, this->share_handle()};
} }
friend bool operator!=(const program& px, const program& py) { return !(px == py); } friend bool operator!=(const program& px, const program& py) { return not(px == py); }
}; };
// options for migraphx file format options // options for migraphx file format options
struct file_options : MIGRAPHX_HANDLE_BASE(file_options) struct file_options : MIGRAPHX_HANDLE_BASE(file_options)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(file_options); MIGRAPHX_HANDLE_CONSTRUCTOR(file_options)
file_options() { this->make_handle(&migraphx_file_options_create); } file_options() { this->make_handle(&migraphx_file_options_create); }
// set file format // set file format
...@@ -1063,7 +1063,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options) ...@@ -1063,7 +1063,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
{ {
onnx_options() { this->make_handle(&migraphx_onnx_options_create); } onnx_options() { this->make_handle(&migraphx_onnx_options_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(onnx_options); MIGRAPHX_HANDLE_CONSTRUCTOR(onnx_options)
/// Make onnx parser treat an inputs with a certain dimensions /// Make onnx parser treat an inputs with a certain dimensions
void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim) void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim)
...@@ -1145,7 +1145,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options) ...@@ -1145,7 +1145,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options)
{ {
tf_options() { this->make_handle(&migraphx_tf_options_create); } tf_options() { this->make_handle(&migraphx_tf_options_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(tf_options); MIGRAPHX_HANDLE_CONSTRUCTOR(tf_options)
/// Make tf parser treat an inputs with a certain dimensions /// Make tf parser treat an inputs with a certain dimensions
void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim) void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim)
...@@ -1198,7 +1198,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names) ...@@ -1198,7 +1198,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names)
{ {
quantize_op_names() { this->make_handle(&migraphx_quantize_op_names_create); } quantize_op_names() { this->make_handle(&migraphx_quantize_op_names_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_op_names); MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_op_names)
void add(const std::string& name) void add(const std::string& name)
{ {
...@@ -1223,7 +1223,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options) ...@@ -1223,7 +1223,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options)
{ {
quantize_int8_options() { this->make_handle(&migraphx_quantize_int8_options_create); } quantize_int8_options() { this->make_handle(&migraphx_quantize_int8_options_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_int8_options); MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_int8_options)
/// Add an operator that should be quantized /// Add an operator that should be quantized
void add_op_name(const std::string& name) void add_op_name(const std::string& name)
......
...@@ -39,7 +39,7 @@ instruction_ref insert_apply_alpha_beta(module& m, ...@@ -39,7 +39,7 @@ instruction_ref insert_apply_alpha_beta(module& m,
auto a = args[0]; auto a = args[0];
auto b = args[1]; auto b = args[1];
auto input_type = a->get_shape().type(); auto input_type = a->get_shape().type();
if(!float_equal(alpha.at<float>(0), 1.0)) if(not float_equal(alpha.at<float>(0), 1.0))
{ {
auto alpha_literal = m.add_literal(alpha); auto alpha_literal = m.add_literal(alpha);
a = insert_common_op(m, pos, migraphx::make_op("mul"), {alpha_literal, a}); a = insert_common_op(m, pos, migraphx::make_op("mul"), {alpha_literal, a});
......
...@@ -25,13 +25,10 @@ ...@@ -25,13 +25,10 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/json.hpp>
#include "models.hpp" #include "models.hpp"
namespace migraphx { namespace migraphx {
namespace driver { namespace driver {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
{ {
migraphx::program p; migraphx::program p;
...@@ -42,179 +39,153 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -42,179 +39,153 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1}}, 1))); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1}}, 1)));
auto x_main_module_2 = mmain->add_literal(migraphx::abs( auto x_main_module_2 = mmain->add_literal(migraphx::abs(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1}}, 2))); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1}}, 2)));
auto x_input_1 = mmain->add_parameter( auto x_0 = mmain->add_parameter(
"input.1", migraphx::shape{migraphx::shape::float_type, {batch, 3, 224, 224}}); "0", migraphx::shape{migraphx::shape::float_type, {batch, 3, 224, 224}});
auto x_main_module_4 = mmain->add_literal( auto x_main_module_4 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 4096}}, 3)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000}}, 3));
auto x_main_module_5 = mmain->add_literal( auto x_main_module_5 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 4)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000, 4096}}, 4));
auto x_main_module_6 = mmain->add_literal( auto x_main_module_6 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 9216}}, 5)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 5));
auto x_main_module_7 = mmain->add_literal( auto x_main_module_7 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 6)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 4096}}, 6));
auto x_main_module_8 = mmain->add_literal( auto x_main_module_8 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000, 4096}}, 7)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 7));
auto x_main_module_9 = mmain->add_literal( auto x_main_module_9 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000}}, 8)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 9216}}, 8));
auto x_main_module_10 = mmain->add_literal(migraphx::generate_literal( auto x_main_module_10 = mmain->add_literal(
migraphx::shape{migraphx::shape::float_type, {256, 384, 3, 3}}, 9)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 9));
auto x_main_module_11 = mmain->add_literal( auto x_main_module_11 = mmain->add_literal(migraphx::generate_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 10)); migraphx::shape{migraphx::shape::float_type, {256, 256, 3, 3}}, 10));
auto x_main_module_12 = mmain->add_literal(migraphx::generate_literal( auto x_main_module_12 = mmain->add_literal(
migraphx::shape{migraphx::shape::float_type, {384, 192, 3, 3}}, 11)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 11));
auto x_main_module_13 = mmain->add_literal( auto x_main_module_13 = mmain->add_literal(migraphx::generate_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 12)); migraphx::shape{migraphx::shape::float_type, {256, 384, 3, 3}}, 12));
auto x_main_module_14 = mmain->add_literal(migraphx::generate_literal( auto x_main_module_14 = mmain->add_literal(
migraphx::shape{migraphx::shape::float_type, {192, 64, 5, 5}}, 13)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 13));
auto x_main_module_15 = mmain->add_literal( auto x_main_module_15 = mmain->add_literal(migraphx::generate_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 14)); migraphx::shape{migraphx::shape::float_type, {384, 192, 3, 3}}, 14));
auto x_main_module_16 = mmain->add_literal(migraphx::generate_literal( auto x_main_module_16 = mmain->add_literal(
migraphx::shape{migraphx::shape::float_type, {256, 256, 3, 3}}, 15)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 15));
auto x_main_module_17 = mmain->add_literal( auto x_main_module_17 = mmain->add_literal(migraphx::generate_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 16)); migraphx::shape{migraphx::shape::float_type, {192, 64, 5, 5}}, 16));
auto x_main_module_18 = mmain->add_literal(migraphx::generate_literal( auto x_main_module_18 = mmain->add_literal(
migraphx::shape{migraphx::shape::float_type, {64, 3, 11, 11}}, 17)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 17));
auto x_main_module_19 = mmain->add_literal( auto x_main_module_19 = mmain->add_literal(migraphx::generate_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 18)); migraphx::shape{migraphx::shape::float_type, {64, 3, 11, 11}}, 18));
auto x_main_module_20 = mmain->add_instruction( auto x_main_module_20 = mmain->add_instruction(
migraphx::make_op( migraphx::make_json_op("convolution",
"convolution", "{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[4,"
migraphx::from_json_string( "4],use_dynamic_same_auto_pad:0}"),
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[4,4]}")), x_0,
x_input_1,
x_main_module_18);
auto x_main_module_21 = mmain->add_instruction(
migraphx::make_op("broadcast",
migraphx::from_json_string("{axis:1,out_lens:[1,64,55,55]}")),
x_main_module_19); x_main_module_19);
auto x_main_module_21 = mmain->add_instruction(
migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,64,55,55]}"), x_main_module_18);
auto x_main_module_22 = auto x_main_module_22 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_20, x_main_module_21); mmain->add_instruction(migraphx::make_op("add"), x_main_module_20, x_main_module_21);
auto x_main_module_23 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_22); auto x_main_module_23 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_22);
auto x_main_module_24 = mmain->add_instruction( auto x_main_module_24 = mmain->add_instruction(
migraphx::make_op( migraphx::make_json_op(
"pooling", "pooling",
migraphx::from_json_string( "{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"),
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}")),
x_main_module_23); x_main_module_23);
auto x_main_module_25 = mmain->add_instruction( auto x_main_module_25 = mmain->add_instruction(
migraphx::make_op( migraphx::make_json_op("convolution",
"convolution", "{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[1,"
migraphx::from_json_string( "1],use_dynamic_same_auto_pad:0}"),
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[1,1]}")),
x_main_module_24, x_main_module_24,
x_main_module_14); x_main_module_17);
auto x_main_module_26 = mmain->add_instruction( auto x_main_module_26 = mmain->add_instruction(
migraphx::make_op("broadcast", migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,192,27,27]}"), x_main_module_16);
migraphx::from_json_string("{axis:1,out_lens:[1,192,27,27]}")),
x_main_module_15);
auto x_main_module_27 = auto x_main_module_27 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_25, x_main_module_26); mmain->add_instruction(migraphx::make_op("add"), x_main_module_25, x_main_module_26);
auto x_main_module_28 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_27); auto x_main_module_28 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_27);
auto x_main_module_29 = mmain->add_instruction( auto x_main_module_29 = mmain->add_instruction(
migraphx::make_op( migraphx::make_json_op(
"pooling", "pooling",
migraphx::from_json_string( "{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"),
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}")),
x_main_module_28); x_main_module_28);
auto x_main_module_30 = mmain->add_instruction( auto x_main_module_30 = mmain->add_instruction(
migraphx::make_op( migraphx::make_json_op("convolution",
"convolution", "{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,"
migraphx::from_json_string( "1],use_dynamic_same_auto_pad:0}"),
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}")),
x_main_module_29, x_main_module_29,
x_main_module_12); x_main_module_15);
auto x_main_module_31 = mmain->add_instruction( auto x_main_module_31 = mmain->add_instruction(
migraphx::make_op("broadcast", migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,384,13,13]}"), x_main_module_14);
migraphx::from_json_string("{axis:1,out_lens:[1,384,13,13]}")),
x_main_module_13);
auto x_main_module_32 = auto x_main_module_32 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_30, x_main_module_31); mmain->add_instruction(migraphx::make_op("add"), x_main_module_30, x_main_module_31);
auto x_main_module_33 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_32); auto x_main_module_33 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_32);
auto x_main_module_34 = mmain->add_instruction( auto x_main_module_34 = mmain->add_instruction(
migraphx::make_op( migraphx::make_json_op("convolution",
"convolution", "{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,"
migraphx::from_json_string( "1],use_dynamic_same_auto_pad:0}"),
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}")),
x_main_module_33, x_main_module_33,
x_main_module_10); x_main_module_13);
auto x_main_module_35 = mmain->add_instruction( auto x_main_module_35 = mmain->add_instruction(
migraphx::make_op("broadcast", migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,256,13,13]}"), x_main_module_12);
migraphx::from_json_string("{axis:1,out_lens:[1,256,13,13]}")),
x_main_module_11);
auto x_main_module_36 = auto x_main_module_36 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_34, x_main_module_35); mmain->add_instruction(migraphx::make_op("add"), x_main_module_34, x_main_module_35);
auto x_main_module_37 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_36); auto x_main_module_37 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_36);
auto x_main_module_38 = mmain->add_instruction( auto x_main_module_38 = mmain->add_instruction(
migraphx::make_op( migraphx::make_json_op("convolution",
"convolution", "{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,"
migraphx::from_json_string( "1],use_dynamic_same_auto_pad:0}"),
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}")),
x_main_module_37, x_main_module_37,
x_main_module_16); x_main_module_11);
auto x_main_module_39 = mmain->add_instruction( auto x_main_module_39 = mmain->add_instruction(
migraphx::make_op("broadcast", migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,256,13,13]}"), x_main_module_10);
migraphx::from_json_string("{axis:1,out_lens:[1,256,13,13]}")),
x_main_module_17);
auto x_main_module_40 = auto x_main_module_40 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_38, x_main_module_39); mmain->add_instruction(migraphx::make_op("add"), x_main_module_38, x_main_module_39);
auto x_main_module_41 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_40); auto x_main_module_41 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_40);
auto x_main_module_42 = mmain->add_instruction( auto x_main_module_42 = mmain->add_instruction(
migraphx::make_op( migraphx::make_json_op(
"pooling", "pooling",
migraphx::from_json_string( "{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"),
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}")),
x_main_module_41); x_main_module_41);
auto x_main_module_43 = mmain->add_instruction( auto x_main_module_43 =
migraphx::make_op("reshape", migraphx::from_json_string("{dims:[1,9216]}")), mmain->add_instruction(migraphx::make_json_op("flatten", "{axis:1}"), x_main_module_42);
x_main_module_42); auto x_main_module_44 = mmain->add_instruction(migraphx::make_op("identity"), x_main_module_43);
auto x_main_module_44 = mmain->add_instruction( auto x_main_module_45 = mmain->add_instruction(
migraphx::make_op("transpose", migraphx::from_json_string("{permutation:[1,0]}")), migraphx::make_json_op("transpose", "{permutation:[1,0]}"), x_main_module_9);
x_main_module_6); auto x_main_module_46 =
auto x_main_module_45 = mmain->add_instruction(migraphx::make_op("dot"), x_main_module_44, x_main_module_45);
mmain->add_instruction(migraphx::make_op("dot"), x_main_module_43, x_main_module_44);
auto x_main_module_46 = mmain->add_instruction(
migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,4096]}")),
x_main_module_7);
auto x_main_module_47 = mmain->add_instruction( auto x_main_module_47 = mmain->add_instruction(
migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,4096]}")), migraphx::make_json_op("multibroadcast", "{out_lens:[1,4096]}"), x_main_module_8);
x_main_module_2); auto x_main_module_48 = mmain->add_instruction(
auto x_main_module_48 = migraphx::make_json_op("multibroadcast", "{out_lens:[1,4096]}"), x_main_module_2);
mmain->add_instruction(migraphx::make_op("mul"), x_main_module_46, x_main_module_47);
auto x_main_module_49 = auto x_main_module_49 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_45, x_main_module_48); mmain->add_instruction(migraphx::make_op("mul"), x_main_module_47, x_main_module_48);
auto x_main_module_50 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_49); auto x_main_module_50 =
auto x_main_module_51 = mmain->add_instruction( mmain->add_instruction(migraphx::make_op("add"), x_main_module_46, x_main_module_49);
migraphx::make_op("transpose", migraphx::from_json_string("{permutation:[1,0]}")), auto x_main_module_51 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_50);
x_main_module_4); auto x_main_module_52 = mmain->add_instruction(migraphx::make_op("identity"), x_main_module_51);
auto x_main_module_52 =
mmain->add_instruction(migraphx::make_op("dot"), x_main_module_50, x_main_module_51);
auto x_main_module_53 = mmain->add_instruction( auto x_main_module_53 = mmain->add_instruction(
migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,4096]}")), migraphx::make_json_op("transpose", "{permutation:[1,0]}"), x_main_module_7);
x_main_module_5); auto x_main_module_54 =
auto x_main_module_54 = mmain->add_instruction( mmain->add_instruction(migraphx::make_op("dot"), x_main_module_52, x_main_module_53);
migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,4096]}")), auto x_main_module_55 = mmain->add_instruction(
x_main_module_1); migraphx::make_json_op("multibroadcast", "{out_lens:[1,4096]}"), x_main_module_6);
auto x_main_module_55 = auto x_main_module_56 = mmain->add_instruction(
mmain->add_instruction(migraphx::make_op("mul"), x_main_module_53, x_main_module_54); migraphx::make_json_op("multibroadcast", "{out_lens:[1,4096]}"), x_main_module_1);
auto x_main_module_56 = auto x_main_module_57 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_52, x_main_module_55); mmain->add_instruction(migraphx::make_op("mul"), x_main_module_55, x_main_module_56);
auto x_main_module_57 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_56); auto x_main_module_58 =
auto x_main_module_58 = mmain->add_instruction( mmain->add_instruction(migraphx::make_op("add"), x_main_module_54, x_main_module_57);
migraphx::make_op("transpose", migraphx::from_json_string("{permutation:[1,0]}")), auto x_main_module_59 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_58);
x_main_module_8);
auto x_main_module_59 =
mmain->add_instruction(migraphx::make_op("dot"), x_main_module_57, x_main_module_58);
auto x_main_module_60 = mmain->add_instruction( auto x_main_module_60 = mmain->add_instruction(
migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,1000]}")), migraphx::make_json_op("transpose", "{permutation:[1,0]}"), x_main_module_5);
x_main_module_9); auto x_main_module_61 =
auto x_main_module_61 = mmain->add_instruction( mmain->add_instruction(migraphx::make_op("dot"), x_main_module_59, x_main_module_60);
migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,1000]}")), auto x_main_module_62 = mmain->add_instruction(
x_main_module_0); migraphx::make_json_op("multibroadcast", "{out_lens:[1,1000]}"), x_main_module_4);
auto x_main_module_62 = auto x_main_module_63 = mmain->add_instruction(
mmain->add_instruction(migraphx::make_op("mul"), x_main_module_60, x_main_module_61); migraphx::make_json_op("multibroadcast", "{out_lens:[1,1000]}"), x_main_module_0);
auto x_main_module_63 = auto x_main_module_64 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_59, x_main_module_62); mmain->add_instruction(migraphx::make_op("mul"), x_main_module_62, x_main_module_63);
mmain->add_return({x_main_module_63}); auto x_main_module_65 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_61, x_main_module_64);
mmain->add_return({x_main_module_65});
return p; return p;
} }
......
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
...@@ -60,7 +60,7 @@ void eliminate_concat::apply(module& m) const ...@@ -60,7 +60,7 @@ void eliminate_concat::apply(module& m) const
auto lens = ins->inputs().front()->get_shape().lens(); auto lens = ins->inputs().front()->get_shape().lens();
auto concat_op = concat_opt.get_concat(ins->get_operator()); auto concat_op = concat_opt.get_concat(ins->get_operator());
std::size_t axis_index = tune_axis(lens.size(), concat_op.axis, concat_op.name()); std::size_t axis_index = tune_axis(lens.size(), concat_op.axis, concat_op.name());
if(axis_index == 0 || if(axis_index == 0 or
std::all_of(lens.begin(), lens.begin() + axis_index, [](auto x) { return x == 1; })) std::all_of(lens.begin(), lens.begin() + axis_index, [](auto x) { return x == 1; }))
{ {
// Last input should be an allocation // Last input should be an allocation
......
...@@ -71,7 +71,7 @@ static bool try_compute_shape(instruction_ref ins, ...@@ -71,7 +71,7 @@ static bool try_compute_shape(instruction_ref ins,
return (arg == ins) ? new_shape : arg->get_shape(); return (arg == ins) ? new_shape : arg->get_shape();
}); });
if(!try_compute_shape(output, input_shapes, mods)) if(not try_compute_shape(output, input_shapes, mods))
{ {
return false; return false;
} }
......
...@@ -39,7 +39,7 @@ T generic_read_file(const std::string& filename) ...@@ -39,7 +39,7 @@ T generic_read_file(const std::string& filename)
is.seekg(0, std::ios::beg); is.seekg(0, std::ios::beg);
T buffer(size, 0); T buffer(size, 0);
if(!is.read(&buffer[0], size)) if(not is.read(&buffer[0], size))
MIGRAPHX_THROW("Error reading file: " + filename); MIGRAPHX_THROW("Error reading file: " + filename);
return buffer; return buffer;
} }
......
...@@ -205,7 +205,7 @@ struct allocation_model ...@@ -205,7 +205,7 @@ struct allocation_model
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type( private_detail_te_handle_type(
PrivateDetailTypeErasedT value, PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value, typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value)) : private_detail_te_value(std::move(value))
{ {
...@@ -267,7 +267,7 @@ struct allocation_model ...@@ -267,7 +267,7 @@ struct allocation_model
private_detail_te_handle_base_type& private_detail_te_get_handle() private_detail_te_handle_base_type& private_detail_te_get_handle()
{ {
assert(private_detail_te_handle_mem_var != nullptr); assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique()) if(not private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
} }
......
...@@ -101,7 +101,7 @@ struct check_shapes ...@@ -101,7 +101,7 @@ struct check_shapes
const check_shapes& nelements(std::size_t n) const const check_shapes& nelements(std::size_t n) const
{ {
if(!this->all_of([&](const shape& s) { return s.elements() == n; })) if(not this->all_of([&](const shape& s) { return s.elements() == n; }))
MIGRAPHX_THROW(prefix() + "Shapes must have only " + std::to_string(n) + " elements"); MIGRAPHX_THROW(prefix() + "Shapes must have only " + std::to_string(n) + " elements");
return *this; return *this;
} }
...@@ -164,7 +164,7 @@ struct check_shapes ...@@ -164,7 +164,7 @@ struct check_shapes
*/ */
const check_shapes& same_shape() const const check_shapes& same_shape() const
{ {
if(!this->same([](const shape& s) { return s; })) if(not this->same([](const shape& s) { return s; }))
MIGRAPHX_THROW(prefix() + "Shapes do not match"); MIGRAPHX_THROW(prefix() + "Shapes do not match");
return *this; return *this;
} }
...@@ -174,7 +174,7 @@ struct check_shapes ...@@ -174,7 +174,7 @@ struct check_shapes
*/ */
const check_shapes& same_type() const const check_shapes& same_type() const
{ {
if(!this->same([](const shape& s) { return s.type(); })) if(not this->same([](const shape& s) { return s.type(); }))
MIGRAPHX_THROW(prefix() + "Types do not match"); MIGRAPHX_THROW(prefix() + "Types do not match");
return *this; return *this;
} }
...@@ -184,10 +184,10 @@ struct check_shapes ...@@ -184,10 +184,10 @@ struct check_shapes
*/ */
const check_shapes& same_dims() const const check_shapes& same_dims() const
{ {
if(!this->same([](const shape& s) { return s.max_lens(); })) if(not this->same([](const shape& s) { return s.max_lens(); }))
MIGRAPHX_THROW(prefix() + "Dimensions do not match"); MIGRAPHX_THROW(prefix() + "Dimensions do not match");
if(this->any_of([&](const shape& s) { return s.dynamic(); })) if(this->any_of([&](const shape& s) { return s.dynamic(); }))
if(!this->same([](const shape& s) { return s.min_lens(); })) if(not this->same([](const shape& s) { return s.min_lens(); }))
MIGRAPHX_THROW(prefix() + "Min dynamic dimensions do not match"); MIGRAPHX_THROW(prefix() + "Min dynamic dimensions do not match");
return *this; return *this;
} }
...@@ -197,7 +197,7 @@ struct check_shapes ...@@ -197,7 +197,7 @@ struct check_shapes
*/ */
const check_shapes& same_ndims() const const check_shapes& same_ndims() const
{ {
if(!this->same([](const shape& s) { return s.max_lens().size(); })) if(not this->same([](const shape& s) { return s.max_lens().size(); }))
MIGRAPHX_THROW(prefix() + "Number of dimensions do not match"); MIGRAPHX_THROW(prefix() + "Number of dimensions do not match");
return *this; return *this;
} }
...@@ -207,7 +207,7 @@ struct check_shapes ...@@ -207,7 +207,7 @@ struct check_shapes
*/ */
const check_shapes& standard() const const check_shapes& standard() const
{ {
if(!this->all_of([](const shape& s) { return s.standard(); })) if(not this->all_of([](const shape& s) { return s.standard(); }))
MIGRAPHX_THROW(prefix() + "Shapes are not in standard layout"); MIGRAPHX_THROW(prefix() + "Shapes are not in standard layout");
return *this; return *this;
} }
...@@ -217,7 +217,7 @@ struct check_shapes ...@@ -217,7 +217,7 @@ struct check_shapes
*/ */
const check_shapes& standard_or_scalar() const const check_shapes& standard_or_scalar() const
{ {
if(!this->all_of([](const shape& s) { return s.standard() or s.scalar(); })) if(not this->all_of([](const shape& s) { return s.standard() or s.scalar(); }))
MIGRAPHX_THROW(prefix() + "Shapes are not a scalar or in standard layout"); MIGRAPHX_THROW(prefix() + "Shapes are not a scalar or in standard layout");
return *this; return *this;
} }
...@@ -227,7 +227,7 @@ struct check_shapes ...@@ -227,7 +227,7 @@ struct check_shapes
*/ */
const check_shapes& packed() const const check_shapes& packed() const
{ {
if(!this->all_of([](const shape& s) { return s.packed(); })) if(not this->all_of([](const shape& s) { return s.packed(); }))
MIGRAPHX_THROW(prefix() + "Shapes are not packed"); MIGRAPHX_THROW(prefix() + "Shapes are not packed");
return *this; return *this;
} }
...@@ -237,7 +237,7 @@ struct check_shapes ...@@ -237,7 +237,7 @@ struct check_shapes
*/ */
const check_shapes& packed_or_broadcasted() const const check_shapes& packed_or_broadcasted() const
{ {
if(!this->all_of([](const shape& s) { return s.packed() or s.broadcasted(); })) if(not this->all_of([](const shape& s) { return s.packed() or s.broadcasted(); }))
MIGRAPHX_THROW(prefix() + "Shapes are not packed nor broadcasted"); MIGRAPHX_THROW(prefix() + "Shapes are not packed nor broadcasted");
return *this; return *this;
} }
...@@ -247,7 +247,7 @@ struct check_shapes ...@@ -247,7 +247,7 @@ struct check_shapes
*/ */
const check_shapes& tuple_type() const const check_shapes& tuple_type() const
{ {
if(!this->all_of([](const shape& s) { return s.type() == shape::tuple_type; })) if(not this->all_of([](const shape& s) { return s.type() == shape::tuple_type; }))
MIGRAPHX_THROW(prefix() + "Shapes are not tuple!"); MIGRAPHX_THROW(prefix() + "Shapes are not tuple!");
return *this; return *this;
} }
...@@ -257,7 +257,7 @@ struct check_shapes ...@@ -257,7 +257,7 @@ struct check_shapes
*/ */
const check_shapes& not_transposed() const const check_shapes& not_transposed() const
{ {
if(!this->all_of([](const shape& s) { return not s.transposed(); })) if(not this->all_of([](const shape& s) { return not s.transposed(); }))
MIGRAPHX_THROW(prefix() + "Shapes are transposed"); MIGRAPHX_THROW(prefix() + "Shapes are transposed");
return *this; return *this;
} }
...@@ -267,7 +267,7 @@ struct check_shapes ...@@ -267,7 +267,7 @@ struct check_shapes
*/ */
const check_shapes& not_broadcasted() const const check_shapes& not_broadcasted() const
{ {
if(!this->all_of([](const shape& s) { return not s.broadcasted(); })) if(not this->all_of([](const shape& s) { return not s.broadcasted(); }))
MIGRAPHX_THROW(prefix() + "Shapes are broadcasted"); MIGRAPHX_THROW(prefix() + "Shapes are broadcasted");
return *this; return *this;
} }
...@@ -278,7 +278,7 @@ struct check_shapes ...@@ -278,7 +278,7 @@ struct check_shapes
*/ */
const check_shapes& elements(std::size_t n) const const check_shapes& elements(std::size_t n) const
{ {
if(!this->all_of([&](const shape& s) { return s.elements() == n; })) if(not this->all_of([&](const shape& s) { return s.elements() == n; }))
MIGRAPHX_THROW(prefix() + "Wrong number of elements"); MIGRAPHX_THROW(prefix() + "Wrong number of elements");
return *this; return *this;
} }
...@@ -288,7 +288,8 @@ struct check_shapes ...@@ -288,7 +288,8 @@ struct check_shapes
*/ */
const check_shapes& batch_not_transposed() const const check_shapes& batch_not_transposed() const
{ {
if(!this->all_of([&](const shape& s) { return batch_not_transposed_strides(s.strides()); })) if(not this->all_of(
[&](const shape& s) { return batch_not_transposed_strides(s.strides()); }))
MIGRAPHX_THROW(prefix() + "Batch size is transposed"); MIGRAPHX_THROW(prefix() + "Batch size is transposed");
return *this; return *this;
} }
......
...@@ -183,7 +183,7 @@ struct concat_optimization ...@@ -183,7 +183,7 @@ struct concat_optimization
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type( private_detail_te_handle_type(
PrivateDetailTypeErasedT value, PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value, typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value)) : private_detail_te_value(std::move(value))
{ {
...@@ -233,7 +233,7 @@ struct concat_optimization ...@@ -233,7 +233,7 @@ struct concat_optimization
private_detail_te_handle_base_type& private_detail_te_get_handle() private_detail_te_handle_base_type& private_detail_te_get_handle()
{ {
assert(private_detail_te_handle_mem_var != nullptr); assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique()) if(not private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
} }
......
...@@ -246,7 +246,7 @@ struct context ...@@ -246,7 +246,7 @@ struct context
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type( private_detail_te_handle_type(
PrivateDetailTypeErasedT value, PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value, typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value)) : private_detail_te_value(std::move(value))
{ {
...@@ -306,7 +306,7 @@ struct context ...@@ -306,7 +306,7 @@ struct context
private_detail_te_handle_base_type& private_detail_te_get_handle() private_detail_te_handle_base_type& private_detail_te_get_handle()
{ {
assert(private_detail_te_handle_mem_var != nullptr); assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique()) if(not private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment