"tests/pipelines/vscode:/vscode.git/clone" did not exist on "5915c2985db162278e09196160d796166c89ad12"
Commit bf6f82d8 authored by Paul's avatar Paul
Browse files

Merge from develop

parents 6a0797e2 b93f5320
......@@ -3,6 +3,8 @@ CheckOptions:
value: '::std::async;::std::launder;::std::remove;::std::remove_if;::std::unique;::std::unique_ptr::release;::std::basic_string::empty;::std::vector::empty;::std::find;::std::find_if;::std::find_if_not;::std::all_of;::std::any_of;::std::none_of;::std::count;::std::count_if;::std::mismatch;::std::find_end;::std::find_first_of;::std::adjacent_find;::std::search;::std::search_n;::std::nth_element;::std::lower_bound;::std::upper_bound;::std::binary_search;::std::equal_range;::std::max;::std::max_element;::std::min;::std::min_element;::std::minmax;::std::minmax_element;::std::equal;::std::lexicographical_compare;::std::accumulate;::std::inner_product'
- key: cppcoreguidelines-macro-usage.AllowedRegexp
value: 'DEBUG|FALLTHROUGH|_THROW|_REQUIRES|_DECLARE_|_VISIT_|_GENERATE_|_DETAIL_|_MANAGE_PTR|_MATCHER|DEVICE_SHARED'
- key: cppcoreguidelines-narrowing-conversions.WarnOnFloatingPointNarrowingConversion
value: 0
- key: modernize-loop-convert.MinConfidence
value: risky
- key: modernize-loop-convert.NamingStyle
......
......@@ -56,6 +56,7 @@ rocm_enable_clang_tidy(
-clang-diagnostic-extern-c-compat
-clang-diagnostic-disabled-macro-expansion
-clang-diagnostic-unused-command-line-argument
-cppcoreguidelines-explicit-virtual-functions
-cppcoreguidelines-pro-bounds-array-to-pointer-decay
-cppcoreguidelines-pro-bounds-constant-array-index
-cppcoreguidelines-pro-bounds-pointer-arithmetic
......@@ -83,6 +84,7 @@ rocm_enable_clang_tidy(
-modernize-pass-by-value
-modernize-use-default-member-init
-modernize-use-transparent-functors
-performance-type-promotion-in-math-fn
-readability-braces-around-statements
-readability-else-after-return
-readability-named-parameter
......
......@@ -44,13 +44,14 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
rm -rf /var/lib/apt/lists/*
# Install cget
RUN pip install cget
# RUN pip install cget
RUN pip install https://github.com/pfultz2/cget/archive/57b3289000fcdb3b7e424c60a35ea09bc44d8538.tar.gz
# Install rclone
RUN pip install https://github.com/pfultz2/rclone/archive/master.tar.gz
# Install hcc
RUN rclone -b roc-2.0.x -c 757fb492517b80e7c86338af5fc1a43d63cb25a9 https://github.com/RadeonOpenCompute/hcc.git /hcc
RUN rclone -b roc-2.3.x -c fd93baed7dcc4fe8019b5fdc90213bfe7c298245 https://github.com/RadeonOpenCompute/hcc.git /hcc
RUN cget -p $PREFIX install hcc,/hcc
# Use hcc
......@@ -73,3 +74,8 @@ ENV LD_LIBRARY_PATH=$PREFIX/lib
# Install doc requirements
ADD doc/requirements.txt /doc-requirements.txt
RUN pip install -r /doc-requirements.txt
# Setup ubsan environment to printstacktrace
RUN ln -s /usr/bin/llvm-symbolizer-5.0 /usr/local/bin/llvm-symbolizer
ENV UBSAN_OPTIONS=print_stacktrace=1
ENV ASAN_OPTIONS=detect_stack_use_after_return=1:check_initialization_order=1:strict_init_order=1
......@@ -90,6 +90,7 @@ else()
-Wno-double-promotion
-Wno-exit-time-destructors
-Wno-extra-semi
-Wno-extra-semi-stmt
-Wno-float-conversion
-Wno-gnu-anonymous-struct
-Wno-gnu-zero-variadic-macro-arguments
......@@ -104,7 +105,7 @@ else()
else()
list(APPEND CMAKE_COMPILER_WARNINGS
-Wno-missing-field-initializers
-Wno-deprecated-declarations
# -Wno-deprecated-declarations
)
endif()
add_definitions(${CMAKE_COMPILER_WARNINGS})
......
......@@ -89,49 +89,76 @@
<summary>Use manage pointer for resource management</summary>
</message>
</rule>
<rule>
<tokenlist>raw</tokenlist>
<pattern><![CDATA[hipLaunchKernelGGL \( (?!\( \w+ < \w+ > \))]]></pattern>
<message>
<id>UseDeviceLaunch</id>
<severity>style</severity>
<summary>Use device::launch instead</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern>! !</pattern>
<pattern><![CDATA[if (\([^()]*(?-1)*[^()]*\)) { [^{}]* (return|throw|break|continue) [^;]* ; } else {]]></pattern>
<message>
<id>doubleNegative</id>
<id>UnnecessaryElseStatement</id>
<severity>style</severity>
<summary>Double negative is always positive</summary>
<summary>Else statement is not necessary.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[if \( \w+ (\||&) \w+ \)]]></pattern>
<pattern><![CDATA[\? (true|false) : (true|false)]]></pattern>
<message>
<id>BitwiseOperatorInConditional</id>
<id>RedundantConditionalOperator</id>
<severity>style</severity>
<summary>Bitwise operator found in if statement.</summary>
<summary>Conditional operator is redundant.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[if \( [^)]+ \) { if \( [^)]+ \) ({[^{}]*(?1)*[^{}]*}) }]]></pattern>
<pattern><![CDATA[switch (\([^()]*(?-1)*[^()]*\)) { }]]></pattern>
<message>
<id>CollapsibleIfStatements</id>
<id>EmptySwitchStatement</id>
<severity>style</severity>
<summary>These two if statements can be collapsed into one.</summary>
<summary>Empty switch statement.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[catch \( [^())]+ \) { }]]></pattern>
<pattern><![CDATA[(?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)* (\w) ; \1 = [^;]+ ; return \1 ;]]></pattern>
<message>
<id>EmptyCatchStatement</id>
<id>RedundantLocalVariable</id>
<severity>style</severity>
<summary>An empty catch statement.</summary>
<summary>Variable is returned immediately after its declaration, can be simplified to just return expression.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[do { } while \(]]></pattern>
<pattern><![CDATA[for \( ; [^;]+ ; \)]]></pattern>
<message>
<id>EmptyDoWhileStatement</id>
<id>ForLoopShouldBeWhileLoop</id>
<severity>style</severity>
<summary>Empty do-while.</summary>
<summary>For loop should be written as a while loop.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[while (\([^()]*(?-1)*[^()]*\)) { }]]></pattern>
<message>
<id>EmptyWhileStatement</id>
<severity>style</severity>
<summary>Empty while statement.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[if \( \w+ (\||&) \w+ \)]]></pattern>
<message>
<id>BitwiseOperatorInConditional</id>
<severity>style</severity>
<summary>Bitwise operator found in if statement.</summary>
</message>
</rule>
<rule>
......@@ -145,7 +172,7 @@
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[for \( [^()]+ \) { }]]></pattern>
<pattern><![CDATA[for (\([^()]*(?-1)*[^()]*\)) { }]]></pattern>
<message>
<id>EmptyForStatement</id>
<severity>style</severity>
......@@ -154,7 +181,7 @@
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[if \( [^()]+ \) { }]]></pattern>
<pattern><![CDATA[if (\([^()]*(?-1)*[^()]*\)) { }]]></pattern>
<message>
<id>EmptyIfStatement</id>
<severity>style</severity>
......@@ -163,43 +190,52 @@
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[switch \( [^()]+ \) { }]]></pattern>
<pattern><![CDATA[if (\([^()]*(?-1)*[^()]*\)) { return (true|false) ; } else { return (true|false) ; }]]></pattern>
<message>
<id>EmptySwitchStatement</id>
<id>RedundantIfStatement</id>
<severity>style</severity>
<summary>Empty switch statement.</summary>
<summary>The if statement is redundant.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[while \( [^()]+ \) { }]]></pattern>
<pattern><![CDATA[! !]]></pattern>
<message>
<id>EmptyWhileStatement</id>
<id>DoubleNegative</id>
<severity>style</severity>
<summary>Empty while statement.</summary>
<summary>Double negative is always positive.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[ for \( ; [^;]+ ; \)]]></pattern>
<pattern><![CDATA[~ ~]]></pattern>
<message>
<id>ForLoopShouldBeWhileLoop</id>
<id>DoubleNegative</id>
<severity>style</severity>
<summary>For loop should be written as a while loop.</summary>
<summary>Double negative is always positive.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern>goto</pattern>
<pattern><![CDATA[! \( !]]></pattern>
<message>
<id>GotoStatement</id>
<id>DoubleNegative</id>
<severity>style</severity>
<summary>Goto considered harmful.</summary>
<summary>Double negative is always positive.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[~ \( ~]]></pattern>
<message>
<id>DoubleNegative</id>
<severity>style</severity>
<summary>Double negative is always positive.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[if \( \w+ != \w+ \) ({[^{}]*(?1)*[^{}]*}) else { (?!if)]]></pattern>
<pattern><![CDATA[if \( \w+ != \w+ \) ({[^{}]*(?-1)*[^{}]*}) else { (?!if)]]></pattern>
<message>
<id>InvertedLogic</id>
<severity>style</severity>
......@@ -208,7 +244,7 @@
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[if \( ! \w+ \) ({[^{}]*(?1)*[^{}]*}) else { (?!if)]]></pattern>
<pattern><![CDATA[if \( ! \w+ \) ({[^{}]*(?-1)*[^{}]*}) else { (?!if)]]></pattern>
<message>
<id>InvertedLogic</id>
<severity>style</severity>
......@@ -235,34 +271,43 @@
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[\? (true|false) : (true|false)]]></pattern>
<pattern><![CDATA[catch (\([^()]*(?-1)*[^()]*\)) { }]]></pattern>
<message>
<id>RedundantConditionalOperator</id>
<id>EmptyCatchStatement</id>
<severity>style</severity>
<summary>Conditional operator is redundant.</summary>
<summary>An empty catch statement.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[if \( [^()]+ \) { return (true|false) ; } else { return (true|false) ; }]]></pattern>
<pattern><![CDATA[if (\([^()]*(?-1)*[^()]*\)) { assert (\([^()]*(?-1)*[^()]*\)) ; }]]></pattern>
<message>
<id>RedundantIfStatement</id>
<id>ConditionalAssert</id>
<severity>style</severity>
<summary>The if statement is redundant.</summary>
<summary>The if condition should be included in assert.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[if \( [^()]+ \) { [^{}]* (return|throw|break|continue) [^;]* ; } else {]]></pattern>
<pattern><![CDATA[if \( (\w) . empty \( \) \) { for \( (?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)* \w : \1 \) ({[^{}]*(?-1)*[^{}]*}) }]]></pattern>
<message>
<id>UnnecessaryElseStatement</id>
<id>UnnecessaryEmptyCondition</id>
<severity>style</severity>
<summary>Else statement is not necessary.</summary>
<summary>Unnecessary check for empty before for range loop.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[for \( \w+ (\w+) = \w+ ; \1 < \w+ ; (\1 \+\+|\+\+ \1|\1 \-\-|\-\- \1) \) { \w+ \[ \1 \] = \w+ \[ \1 \] ; }]]></pattern>
<pattern><![CDATA[if \( ! (\w) . empty \( \) \) { for \( (?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)* \w : \1 \) ({[^{}]*(?-1)*[^{}]*}) }]]></pattern>
<message>
<id>UnnecessaryEmptyCondition</id>
<severity>style</severity>
<summary>Unnecessary check for empty before for range loop.</summary>
</message>
</rule>
<rule>
<tokenlist>simple</tokenlist>
<pattern><![CDATA[for \( (?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)* (\w+) = \w+ ; \1 < \w+ ; (\1 \+\+|\+\+ \1|\1 \-\-|\-\- \1) \) { \w+ \[ \1 \] = \w+ \[ \1 \] ; }]]></pattern>
<message>
<id>useStlAlgorithm</id>
<severity>style</severity>
......@@ -270,8 +315,8 @@
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[for \( \w+ (\w+) = \w+ ; \1 < \w+ ; (\1 \+\+|\+\+ \1|\1 \-\-|\-\- \1) \) { \w+ \[ \1 \] = \w+ ; }]]></pattern>
<tokenlist>simple</tokenlist>
<pattern><![CDATA[for \( (?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)* (\w+) = \w+ ; \1 < \w+ ; (\1 \+\+|\+\+ \1|\1 \-\-|\-\- \1) \) { \w+ \[ \1 \] = \w+ ; }]]></pattern>
<message>
<id>useStlAlgorithm</id>
<severity>style</severity>
......@@ -279,8 +324,8 @@
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[for \( \w+ (\w+) = \w+ ; \1 < \w+ ; (\1 \+\+|\+\+ \1|\1 \-\-|\-\- \1) \) { \w+ \[ \1 \] = (\w+ :: )*\w+ \( \) ; }]]></pattern>
<tokenlist>simple</tokenlist>
<pattern><![CDATA[for \( (?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)* (\w+) = \w+ ; \1 < \w+ ; (\1 \+\+|\+\+ \1|\1 \-\-|\-\- \1) \) { \w+ \[ \1 \] = (?:\w+ :: )*\w+ \( \) ; }]]></pattern>
<message>
<id>useStlAlgorithm</id>
<severity>style</severity>
......@@ -288,8 +333,8 @@
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[for \( \w+ (\w+) = \w+ ; \1 < \w+ ; (\1 \+\+|\+\+ \1|\1 \-\-|\-\- \1) \) { \w+ \[ \1 \] = (\w+ :: )*\w+ \( \w+ \[ \1 \] \) ; }]]></pattern>
<tokenlist>simple</tokenlist>
<pattern><![CDATA[for \( (?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)* (\w+) = \w+ ; \1 < \w+ ; (\1 \+\+|\+\+ \1|\1 \-\-|\-\- \1) \) { \w+ \[ \1 \] = (?:\w+ :: )*\w+ \( \w+ \[ \1 \] \) ; }]]></pattern>
<message>
<id>useStlAlgorithm</id>
<severity>style</severity>
......@@ -297,11 +342,65 @@
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[for \( \w+ (\w+) = \w+ ; \1 < \w+ ; (\1 \+\+|\+\+ \1|\1 \-\-|\-\- \1) \) { \w+ \[ \1 \] = (\w+ :: )*\w+ \( \w+ \[ \1 \] , \w+ \[ \1 \] \) ; }]]></pattern>
<tokenlist>simple</tokenlist>
<pattern><![CDATA[for \( (?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)* (\w+) = \w+ ; \1 < \w+ ; (\1 \+\+|\+\+ \1|\1 \-\-|\-\- \1) \) { \w+ \[ \1 \] = (?:\w+ :: )*\w+ \( \w+ \[ \1 \] , \w+ \[ \1 \] \) ; }]]></pattern>
<message>
<id>useStlAlgorithm</id>
<severity>style</severity>
<summary>Considering using std::transform instead.</summary>
</message>
</rule>
<rule>
<tokenlist>simple</tokenlist>
<pattern><![CDATA[for \( (?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)* (\w+) : (?:[^()]*(\([^()]*(?-1)*[^()]*\)))*[^)]*\) { (?:(?<idx1>\w+) \+\+|\+\+ (?<idx2>\w+)) ; if (\([^()]*(?-1)*[^()]*\)) { \w+ = \g{idx1}|\g{idx2} ; (?:break ; )?(?:return [^;]*; )?} }]]></pattern>
<message>
<id>useStlAlgorithm</id>
<severity>style</severity>
<summary>Considering using std::find or std::find_if instead.</summary>
</message>
</rule>
<rule>
<tokenlist>simple</tokenlist>
<pattern><![CDATA[for \( (?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)* (\w+) : (?:[^()]*(\([^()]*(?-1)*[^()]*\)))*[^)]*\) { if (\([^()]*(?-1)*[^()]*\)) { \w+ = (?<idx>\w) ; (?:break ; )?(?:return [^;]*; )?} (?:(\g{idx}) \+\+|\+\+ (\g{idx})) ; }]]></pattern>
<message>
<id>useStlAlgorithm</id>
<severity>style</severity>
<summary>Considering using std::find or std::find_if instead.</summary>
</message>
</rule>
<rule>
<tokenlist>simple</tokenlist>
<pattern><![CDATA[for \( (?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)* (\w+) : (?:[^()]*(\([^()]*(?-1)*[^()]*\)))*[^)]*\) { (?:(?<idx1>\w+) \+\+|\+\+ (?<idx2>\w+)) ; if (\([^()]*(?-1)*[^()]*\)) { return \g{idx1}|\g{idx2} ; } }]]></pattern>
<message>
<id>useStlAlgorithm</id>
<severity>style</severity>
<summary>Considering using std::find or std::find_if instead.</summary>
</message>
</rule>
<rule>
<tokenlist>simple</tokenlist>
<pattern><![CDATA[for \( (?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)* (\w+) : (?:[^()]*(\([^()]*(?-1)*[^()]*\)))*[^)]*\) { if (\([^()]*(?-1)*[^()]*\)) { return (?<idx>\w+) ; } (?:(\g{idx}) \+\+|\+\+ (\g{idx})) ; }]]></pattern>
<message>
<id>useStlAlgorithm</id>
<severity>style</severity>
<summary>Considering using std::find or std::find_if instead.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[do { } while \(]]></pattern>
<message>
<id>EmptyDoWhileStatement</id>
<severity>style</severity>
<summary>Empty do-while.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern>goto</pattern>
<message>
<id>GotoStatement</id>
<severity>style</severity>
<summary>Goto considered harmful.</summary>
</message>
</rule>
......@@ -9,19 +9,60 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
bool try_compute_shape(const operation& op, const std::vector<instruction_ref>& args)
static bool try_compute_shape(instruction_ref ins, const std::vector<shape>& inputs)
{
try
{
compute_shape(op, args);
shape new_shape = ins->get_operator().compute_shape(inputs);
// If the output shape is a standard shape, no need to try its output
if(new_shape.standard())
{
return true;
}
// if no changes for the shape, the contiguous can also be removed
if(new_shape == ins->get_shape())
{
return true;
}
auto outputs = ins->outputs();
// If the current instruction has no output, it means it is the last
// instruction and generates a non-standard output shape, and the last
// output shape is different from the case with the contiguous operator
if(outputs.empty())
{
return false;
}
for(auto output : outputs)
{
auto args = output->inputs();
std::vector<shape> input_shapes(args.size());
std::transform(args.begin(), args.end(), input_shapes.begin(), [&](auto& arg) {
return (arg == ins) ? new_shape : arg->get_shape();
});
if(!try_compute_shape(output, input_shapes))
{
return false;
}
}
}
catch(...)
{
return false;
}
return true;
}
static bool try_compute_shape(instruction_ref ins, const std::vector<instruction_ref>& args)
{
auto inputs = to_shapes(args);
return try_compute_shape(ins, inputs);
}
void eliminate_contiguous::apply(program& p) const
{
for(auto ins : iterator_for(p))
......@@ -37,7 +78,7 @@ void eliminate_contiguous::apply(program& p) const
auto new_args = args;
auto prev = arg->inputs().front();
replace(new_args, arg, prev);
if(try_compute_shape(ins->get_operator(), new_args))
if(try_compute_shape(ins, new_args))
{
instruction::replace_argument(ins, arg, prev);
}
......
......@@ -103,6 +103,13 @@ struct check_shapes
return *this;
}
const check_shapes& standard_or_scalar() const
{
if(!this->all_of([](const shape& s) { return s.standard() or s.scalar(); }))
MIGRAPHX_THROW(prefix() + "Shapes are not a scalar or in standard layout");
return *this;
}
const check_shapes& packed() const
{
if(!this->all_of([](const shape& s) { return s.packed(); }))
......
......@@ -17,7 +17,7 @@ constexpr T normalize(unsigned long z)
return T(0);
const auto max = 32;
const double range = max / 2; // NOLINT
double result = (z % max) / range;
double result = double(z % max) / range;
result -= 1;
return T(result);
}
......
......@@ -24,7 +24,7 @@ struct instruction
instruction(literal l);
void replace(const shape& r);
void replace(operation o);
void recompute_shape();
......@@ -72,7 +72,9 @@ struct instruction
static void
replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args);
argument eval() const;
bool can_eval() const;
argument eval(bool check_eval = true) const;
void finalize(context& ctx);
......@@ -88,7 +90,8 @@ struct instruction
// internal
void replace_argument(instruction_ref old, instruction_ref new_ins);
private:
void replace(const shape& r);
operation op;
shape result;
std::vector<instruction_ref> output;
......
#ifndef MIGRAPHX_GUARD_RTGLIB_INT_DIVIDE_HPP
#define MIGRAPHX_GUARD_RTGLIB_INT_DIVIDE_HPP
#include <migraphx/config.hpp>
#include <cmath>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class R, class T, class U>
R floor_divide(T x, U y)
{
return R(std::floor(double(x) / double(y)));
}
template <class R, class T, class U>
R ceil_divide(T x, U y)
{
return R(std::ceil(double(x) / double(y)));
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -39,6 +39,11 @@ struct undefined
struct unknown
{
std::string op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"));
}
std::string name() const { return "unknown:" + op; }
shape compute_shape(std::vector<shape> input) const
{
......
......@@ -36,7 +36,7 @@ struct as_shape
{
return {std::move(output_shape), std::move(args.front().data)};
}
int output_alias(const std::vector<shape>&) const { return 0; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
} // namespace op
......
......@@ -13,10 +13,16 @@ struct binary : op_name<Derived>
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(2).same_type().same_dims();
const auto& s = inputs.front();
if(s.scalar() and s.elements() == 1)
return {s.type()};
return {s.type(), s.lens()};
auto s0 = inputs.at(0);
auto s1 = inputs.at(1);
if(s0 == s1 and s0.packed())
{
return s0;
}
else
{
return {s0.type(), s0.lens()};
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
......
......@@ -63,7 +63,7 @@ struct broadcast
{
return {std::move(output_shape), std::move(args.at(0).data)};
}
int output_alias(const std::vector<shape>&) const { return 0; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
} // namespace op
......
#ifndef MIGRAPHX_GUARD_OPERATORS_CLIP_HPP
#define MIGRAPHX_GUARD_OPERATORS_CLIP_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
#include <limits>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct clip : unary<clip>
{
float max_val = std::numeric_limits<float>::max();
float min_val = std::numeric_limits<float>::min();
clip() {}
clip(float max, float min) : max_val(max), min_val(min) {}
auto apply() const
{
auto max = max_val;
auto min = min_val;
return [max, min](auto x) {
using type = decltype(x);
return std::min(std::max(type(min), x), type(max));
};
}
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.max_val, "max"), f(self.min_val, "min"));
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -19,6 +19,13 @@ namespace op {
struct concat
{
std::size_t axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
std::string name() const { return "concat"; }
std::vector<std::size_t> compute_offsets(const shape& output_shape,
const std::vector<argument>& args) const
......
......@@ -46,7 +46,7 @@ struct flatten
{
return {std::move(output_shape), std::move(args.front().data)};
}
int output_alias(const std::vector<shape>&) const { return 0; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
} // namespace op
......
......@@ -30,7 +30,7 @@ struct gather
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
check_shapes{inputs, *this}.has(2).standard();
auto lens = inputs[0].lens();
int n_dim = static_cast<int>(lens.size());
if(axis >= n_dim || axis < -n_dim)
......
......@@ -24,7 +24,7 @@ struct identity
{
return {std::move(output_shape), std::move(args.at(0).data)};
}
int output_alias(const std::vector<shape>&) const { return 0; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
} // namespace op
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment