"...composable_kernel_rocm.git" did not exist on "9fb4914aa4ef3c6eb4d74c1713b7e6186b6f959c"
Commit dc2b0abf authored by Scott Thornton's avatar Scott Thornton
Browse files
parents ac230464 f2e18b73
...@@ -3,6 +3,12 @@ cmake_minimum_required(VERSION 3.5) ...@@ -3,6 +3,12 @@ cmake_minimum_required(VERSION 3.5)
project(rtglib) project(rtglib)
find_package(ROCM REQUIRED) find_package(ROCM REQUIRED)
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "5.4")
message(FATAL_ERROR "RTGLib requires at least gcc 5.4")
endif()
endif()
add_compile_options(-std=c++14) add_compile_options(-std=c++14)
list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
...@@ -49,6 +55,7 @@ rocm_enable_clang_tidy( ...@@ -49,6 +55,7 @@ rocm_enable_clang_tidy(
-llvm-header-guard -llvm-header-guard
-llvm-include-order -llvm-include-order
-misc-macro-parentheses -misc-macro-parentheses
-modernize-use-auto
-modernize-pass-by-value -modernize-pass-by-value
-modernize-use-default-member-init -modernize-use-default-member-init
-modernize-use-transparent-functors -modernize-use-transparent-functors
......
FROM ubuntu:16.04 FROM ubuntu:xenial-20180417
ARG PREFIX=/usr/local ARG PREFIX=/usr/local
......
...@@ -320,6 +320,10 @@ function(add_doxygen_doc) ...@@ -320,6 +320,10 @@ function(add_doxygen_doc)
file(WRITE ${DOXYGEN_CONFIG_FILE} "# Auto-generated doxygen configuration file\n") file(WRITE ${DOXYGEN_CONFIG_FILE} "# Auto-generated doxygen configuration file\n")
if(NOT PARSE_STRIP_FROM_PATH)
set(PARSE_STRIP_FROM_PATH ${CMAKE_SOURCE_DIR})
endif()
foreach(ARG ${DOXYGEN_ARGS}) foreach(ARG ${DOXYGEN_ARGS})
if(PARSE_${ARG}) if(PARSE_${ARG})
string(REPLACE ";" " " ARG_VALUE "${PARSE_${ARG}}") string(REPLACE ";" " " ARG_VALUE "${PARSE_${ARG}}")
...@@ -342,7 +346,7 @@ function(add_doxygen_doc) ...@@ -342,7 +346,7 @@ function(add_doxygen_doc)
add_custom_target(doxygen add_custom_target(doxygen
${DOXYGEN_EXECUTABLE} ${DOXYGEN_CONFIG_FILE} ${DOXYGEN_EXECUTABLE} ${DOXYGEN_CONFIG_FILE}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
COMMENT "Building documentation with doxygen" COMMENT "Building documentation with doxygen"
) )
if(PARSE_OUTPUT_DIRECTORY) if(PARSE_OUTPUT_DIRECTORY)
......
...@@ -16,6 +16,7 @@ add_doxygen_doc( ...@@ -16,6 +16,7 @@ add_doxygen_doc(
CALL_GRAPH YES CALL_GRAPH YES
CALLER_GRAPH YES CALLER_GRAPH YES
BUILTIN_STL_SUPPORT YES BUILTIN_STL_SUPPORT YES
PROJECT_NAME RTGLib
SORT_MEMBERS_CTORS_1ST YES SORT_MEMBERS_CTORS_1ST YES
SOURCE_BROWSER YES SOURCE_BROWSER YES
GENERATE_TREEVIEW YES GENERATE_TREEVIEW YES
...@@ -24,6 +25,7 @@ add_doxygen_doc( ...@@ -24,6 +25,7 @@ add_doxygen_doc(
REFERENCES_LINK_SOURCE YES REFERENCES_LINK_SOURCE YES
EXTRACT_ALL YES EXTRACT_ALL YES
ENUM_VALUES_PER_LINE 1 ENUM_VALUES_PER_LINE 1
FULL_PATH_NAMES YES
) )
# include(SphinxDoc) # include(SphinxDoc)
......
...@@ -12,7 +12,7 @@ struct literal ...@@ -12,7 +12,7 @@ struct literal
{ {
std::string name() const { return "@literal"; } std::string name() const { return "@literal"; }
shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); } shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); }
argument compute(std::vector<argument>) const { RTG_THROW("builtin"); } argument compute(shape, std::vector<argument>) const { RTG_THROW("builtin"); }
}; };
struct param struct param
...@@ -20,7 +20,7 @@ struct param ...@@ -20,7 +20,7 @@ struct param
std::string parameter; std::string parameter;
std::string name() const { return "@param"; } std::string name() const { return "@param"; }
shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); } shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); }
argument compute(std::vector<argument>) const { RTG_THROW("builtin"); } argument compute(shape, std::vector<argument>) const { RTG_THROW("builtin"); }
friend std::ostream& operator<<(std::ostream& os, const param& op) friend std::ostream& operator<<(std::ostream& os, const param& op)
{ {
os << op.name() << ":" << op.parameter; os << op.name() << ":" << op.parameter;
......
#ifndef RTG_GUARD_RTGLIB_DFOR_HPP
#define RTG_GUARD_RTGLIB_DFOR_HPP
namespace rtg {
// Multidimensional for loop
inline auto dfor()
{
return [](auto f) { f(); };
}
template <class T, class... Ts>
auto dfor(T x, Ts... xs)
{
return [=](auto f) {
for(T i = 0; i < x; i++)
{
dfor(xs...)([&](Ts... is) { f(i, is...); });
}
};
}
} // namespace rtg
#endif
#ifndef RTG_GUARD_RTG_MANAGE_PTR_HPP
#define RTG_GUARD_RTG_MANAGE_PTR_HPP
#include <memory>
#include <type_traits>
namespace rtg {
template <class F, F f>
struct manage_deleter
{
template <class T>
void operator()(T* x) const
{
if(x != nullptr)
{
f(x);
}
}
};
struct null_deleter
{
template <class T>
void operator()(T*) const
{
}
};
template <class T, class F, F f>
using manage_ptr = std::unique_ptr<T, manage_deleter<F, f>>;
template <class T>
struct element_type
{
using type = typename T::element_type;
};
template <class T>
using remove_ptr = typename std::
conditional_t<std::is_pointer<T>{}, std::remove_pointer<T>, element_type<T>>::type;
template <class T>
using shared = std::shared_ptr<remove_ptr<T>>;
} // namespace rtg
#define RTG_MANAGE_PTR(T, F) rtg::manage_ptr<std::remove_pointer_t<T>, decltype(&F), &F> // NOLINT
#endif
...@@ -28,7 +28,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name()) ...@@ -28,7 +28,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
* { * {
* std::string name() const; * std::string name() const;
* shape compute_shape(std::vector<shape> input) const; * shape compute_shape(std::vector<shape> input) const;
* argument compute(std::vector<argument> input) const; * argument compute(shape output,std::vector<argument> input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ; * friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* }; * };
* *
...@@ -95,10 +95,10 @@ struct operation ...@@ -95,10 +95,10 @@ struct operation
return (*this).private_detail_te_get_handle().compute_shape(std::move(input)); return (*this).private_detail_te_get_handle().compute_shape(std::move(input));
} }
argument compute(std::vector<argument> input) const argument compute(shape output, std::vector<argument> input) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute(std::move(input)); return (*this).private_detail_te_get_handle().compute(std::move(output), std::move(input));
} }
friend std::ostream& operator<<(std::ostream& os, const operation& op) friend std::ostream& operator<<(std::ostream& os, const operation& op)
...@@ -114,10 +114,10 @@ struct operation ...@@ -114,10 +114,10 @@ struct operation
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0; virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0; virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual shape compute_shape(std::vector<shape> input) const = 0; virtual shape compute_shape(std::vector<shape> input) const = 0;
virtual argument compute(std::vector<argument> input) const = 0; virtual argument compute(shape output, std::vector<argument> input) const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0; virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
}; };
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
...@@ -156,10 +156,10 @@ struct operation ...@@ -156,10 +156,10 @@ struct operation
return private_detail_te_value.compute_shape(std::move(input)); return private_detail_te_value.compute_shape(std::move(input));
} }
argument compute(std::vector<argument> input) const override argument compute(shape output, std::vector<argument> input) const override
{ {
return private_detail_te_value.compute(std::move(input)); return private_detail_te_value.compute(std::move(output), std::move(input));
} }
std::ostream& operator_shift_left(std::ostream& os) const override std::ostream& operator_shift_left(std::ostream& os) const override
......
...@@ -8,9 +8,74 @@ ...@@ -8,9 +8,74 @@
namespace rtg { namespace rtg {
struct check_shapes
{
const std::vector<shape>* shapes;
check_shapes(const std::vector<shape>& s) : shapes(&s) {}
const check_shapes& has(std::size_t n) const
{
assert(shapes != nullptr);
if(shapes->size() != n)
RTG_THROW("Wrong number of arguments: expected " + std::to_string(n) + " but given " +
std::to_string(shapes->size()));
return *this;
}
const check_shapes& only_dims(std::size_t n) const
{
assert(shapes != nullptr);
if(!shapes->empty())
{
if(shapes->front().lens().size() != n)
RTG_THROW("Only " + std::to_string(n) + "d supported");
}
return *this;
}
const check_shapes& same_shape() const
{
if(!this->same([](const shape& s) { return s; }))
RTG_THROW("Shapes do not match");
return *this;
}
const check_shapes& same_type() const
{
if(!this->same([](const shape& s) { return s.type(); }))
RTG_THROW("Types do not match");
return *this;
}
const check_shapes& same_dims() const
{
if(!this->same([](const shape& s) { return s.lens(); }))
RTG_THROW("Dimensions do not match");
return *this;
}
template <class F>
bool same(F f) const
{
assert(shapes != nullptr);
if(shapes->empty())
return true;
auto&& key = f(shapes->front());
return this->all_of([&](const shape& s) { return f(s) == key; });
}
template <class Predicate>
bool all_of(Predicate p) const
{
assert(shapes != nullptr);
return std::all_of(shapes->begin(), shapes->end(), p);
}
};
struct not_computable struct not_computable
{ {
argument compute(std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct convolution struct convolution
...@@ -21,18 +86,11 @@ struct convolution ...@@ -21,18 +86,11 @@ struct convolution
std::string name() const { return "convolution"; } std::string name() const { return "convolution"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
if(inputs.size() != 2) check_shapes{inputs}.has(2).same_type().same_dims().only_dims(4);
RTG_THROW("Wrong number of arguments");
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
const shape& weights = inputs.at(1); const shape& weights = inputs.at(1);
if(input.type() != weights.type()) auto t = input.type();
RTG_THROW("Type doesn't match");
if(input.lens().size() != weights.lens().size())
RTG_THROW("Dimensions don't match");
if(input.lens().size() != 4)
RTG_THROW("Only 4d convolution supported");
auto t = input.type();
return {t, return {t,
{ {
input.lens()[0], input.lens()[0],
...@@ -52,7 +110,7 @@ struct convolution ...@@ -52,7 +110,7 @@ struct convolution
}}; }};
} }
argument compute(std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const convolution& op) friend std::ostream& operator<<(std::ostream& os, const convolution& op)
{ {
...@@ -74,13 +132,10 @@ struct pooling ...@@ -74,13 +132,10 @@ struct pooling
std::string name() const { return "pooling"; } std::string name() const { return "pooling"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
if(inputs.empty()) check_shapes{inputs}.has(1).only_dims(4);
RTG_THROW("Wrong number of arguments");
const shape& input = inputs.at(0);
if(input.lens().size() != 4)
RTG_THROW("Only 4d pooling supported");
auto t = input.type(); const shape& input = inputs.at(0);
auto t = input.type();
return {t, return {t,
{ {
input.lens()[0], input.lens()[0],
...@@ -98,7 +153,7 @@ struct pooling ...@@ -98,7 +153,7 @@ struct pooling
}}; }};
} }
argument compute(std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const pooling& op) friend std::ostream& operator<<(std::ostream& os, const pooling& op)
{ {
...@@ -117,12 +172,11 @@ struct activation ...@@ -117,12 +172,11 @@ struct activation
std::string name() const { return "activation"; } std::string name() const { return "activation"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
if(inputs.empty()) check_shapes{inputs}.has(1);
RTG_THROW("Wrong number of arguments");
return inputs.front(); return inputs.front();
} }
argument compute(std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const activation& op) friend std::ostream& operator<<(std::ostream& os, const activation& op)
{ {
os << op.name() << ":" << op.mode; os << op.name() << ":" << op.mode;
...@@ -153,7 +207,7 @@ struct reshape ...@@ -153,7 +207,7 @@ struct reshape
return {inputs.front().type(), rdims}; return {inputs.front().type(), rdims};
} }
argument compute(std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const reshape& op) friend std::ostream& operator<<(std::ostream& os, const reshape& op)
{ {
......
...@@ -82,7 +82,6 @@ struct raw_data ...@@ -82,7 +82,6 @@ struct raw_data
/** /**
* @brief Retrieves a single element of data * @brief Retrieves a single element of data
* @details [long description]
* *
* @param n The index to retrieve the data from * @param n The index to retrieve the data from
* @tparam T The type of data to be retrieved * @tparam T The type of data to be retrieved
...@@ -97,6 +96,38 @@ struct raw_data ...@@ -97,6 +96,38 @@ struct raw_data
} }
}; };
namespace detail {
template <class V, class... Ts>
void visit_all_impl(const shape& s, V&& v, Ts&&... xs)
{
s.visit_type([&](auto as) { v(make_view(xs.get_shape(), as.from(xs.data()))...); });
}
} // namespace detail
/**
* @brief Visits every object together
* @details This will visit every object, but assumes each object is the same type. This can reduce
* the deeply nested visit calls. This will return a function that will take the visitor callback.
* So it will be called with `visit_all(xs...)([](auto... ys) {})` where `xs...` and `ys...` are the
* same number of parameters.
*
* @param x A raw data object
* @param xs Many raw data objects
* @return A function to be called with the visitor
*/
template <class T, class... Ts>
auto visit_all(T&& x, Ts&&... xs)
{
auto&& s = x.get_shape();
std::initializer_list<shape::type_t> types = {xs.get_shape().type()...};
if(!std::all_of(types.begin(), types.end(), [&](shape::type_t t) { return t == s.type(); }))
RTG_THROW("Types must be the same");
return [&](auto v) {
// Workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70100
detail::visit_all_impl(s, v, x, xs...);
};
}
} // namespace rtg } // namespace rtg
#endif #endif
...@@ -33,7 +33,7 @@ struct tensor_view ...@@ -33,7 +33,7 @@ struct tensor_view
template <class... Ts> template <class... Ts>
T& operator()(Ts... xs) T& operator()(Ts... xs)
{ {
return m_data[m_shape.index({xs...})]; return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
} }
T& operator[](std::size_t i) T& operator[](std::size_t i)
......
...@@ -23,7 +23,10 @@ struct unknown ...@@ -23,7 +23,10 @@ struct unknown
else else
return input.front(); return input.front();
} }
rtg::argument compute(std::vector<rtg::argument>) const { RTG_THROW("not computable"); } rtg::argument compute(rtg::shape, std::vector<rtg::argument>) const
{
RTG_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const unknown& x) friend std::ostream& operator<<(std::ostream& os, const unknown& x)
{ {
os << x.name(); os << x.name();
......
...@@ -109,7 +109,7 @@ literal program::eval(std::unordered_map<std::string, argument> params) const ...@@ -109,7 +109,7 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
ins.arguments.end(), ins.arguments.end(),
values.begin(), values.begin(),
[&](instruction_ref i) { return results.at(std::addressof(*i)); }); [&](instruction_ref i) { return results.at(std::addressof(*i)); });
result = ins.op.compute(values); result = ins.op.compute(ins.result, values);
} }
results.emplace(std::addressof(ins), result); results.emplace(std::addressof(ins), result);
} }
......
add_library(rtg_cpu
cpu_target.cpp
)
rocm_clang_tidy_check(rtg_cpu)
target_link_libraries(rtg_cpu rtg)
target_include_directories(rtg_cpu PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
#include <rtg/cpu/cpu_target.hpp>
#include <rtg/instruction.hpp>
#include <rtg/dfor.hpp>
#include <rtg/operators.hpp>
namespace rtg {
namespace cpu {
struct cpu_convolution
{
convolution op;
std::string name() const { return "cpu::convolution"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) {
auto in_n = input.get_shape().lens()[0];
auto in_c = input.get_shape().lens()[1];
auto in_h = input.get_shape().lens()[2];
auto in_w = input.get_shape().lens()[3];
auto wei_c = weights.get_shape().lens()[1];
auto wei_h = weights.get_shape().lens()[2];
auto wei_w = weights.get_shape().lens()[3];
dfor(in_n, in_c, in_h, in_w)(
[&](std::size_t o, std::size_t w, std::size_t i, std::size_t j) {
const int start_x = i * op.stride[0] - op.padding[0];
const int start_y = j * op.stride[1] - op.padding[1];
double acc = 0;
dfor(wei_c, wei_h, wei_w)([&](std::size_t k, std::size_t x, std::size_t y) {
const int in_x = start_x + x;
const int in_y = start_y + y;
if(in_x >= 0 && in_x < in_h && in_y >= 0 && in_y < in_w)
{
acc += input(o, k, in_x, in_y) * weights(w, k, x, y);
}
});
output(o, w, i, j) = acc;
});
});
return result;
}
};
struct relu
{
std::string name() const { return "cpu::relu"; }
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
result.visit([&](auto output) {
args[0].visit([&](auto input) {
std::transform(input.begin(), input.end(), output.begin(), [](auto x) {
return x > 0 ? x : 0;
});
});
});
return result;
}
};
struct cpu_apply
{
program* prog;
void apply()
{
for(auto it = prog->begin(); it != prog->end(); it++)
{
if(it->op.name() == "convolution")
{
apply_convolution(it);
}
else if(it->op.name() == "activation")
{
apply_activation(it);
}
}
}
void apply_convolution(instruction_ref ins)
{
auto&& op = any_cast<convolution>(ins->op);
prog->replace_instruction(ins, cpu_convolution{op}, ins->arguments);
}
void apply_activation(instruction_ref ins)
{
auto&& op = any_cast<activation>(ins->op);
if(op.mode == "relu")
prog->replace_instruction(ins, relu{}, ins->arguments);
}
};
std::string cpu_target::name() const { return "cpu"; }
void cpu_target::apply(program& p) const { cpu_apply{&p}.apply(); }
} // namespace cpu
} // namespace rtg
#ifndef RTG_GUARD_RTGLIB_CPU_TARGET_HPP
#define RTG_GUARD_RTGLIB_CPU_TARGET_HPP
#include <rtg/program.hpp>
namespace rtg {
namespace cpu {
struct cpu_target
{
std::string name() const;
void apply(program& p) const;
};
} // namespace cpu
} // namespace rtg
#endif
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
struct sum_op struct sum_op
{ {
std::string name() const { return "sum"; } std::string name() const { return "sum"; }
rtg::argument compute(std::vector<rtg::argument> args) const rtg::argument compute(rtg::shape, std::vector<rtg::argument> args) const
{ {
rtg::argument result; rtg::argument result;
if(args.size() != 2) if(args.size() != 2)
...@@ -37,7 +37,7 @@ struct sum_op ...@@ -37,7 +37,7 @@ struct sum_op
struct minus_op struct minus_op
{ {
std::string name() const { return "minus"; } std::string name() const { return "minus"; }
rtg::argument compute(std::vector<rtg::argument> args) const rtg::argument compute(rtg::shape, std::vector<rtg::argument> args) const
{ {
rtg::argument result; rtg::argument result;
if(args.size() != 2) if(args.size() != 2)
......
...@@ -9,7 +9,10 @@ struct simple_operation ...@@ -9,7 +9,10 @@ struct simple_operation
int data = 1; int data = 1;
std::string name() const { return "simple"; } std::string name() const { return "simple"; }
rtg::shape compute_shape(std::vector<rtg::shape>) const { RTG_THROW("not computable"); } rtg::shape compute_shape(std::vector<rtg::shape>) const { RTG_THROW("not computable"); }
rtg::argument compute(std::vector<rtg::argument>) const { RTG_THROW("not computable"); } rtg::argument compute(rtg::shape, std::vector<rtg::argument>) const
{
RTG_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const simple_operation& op) friend std::ostream& operator<<(std::ostream& os, const simple_operation& op)
{ {
os << "[" << op.name() << "]"; os << "[" << op.name() << "]";
...@@ -21,7 +24,10 @@ struct simple_operation_no_print ...@@ -21,7 +24,10 @@ struct simple_operation_no_print
{ {
std::string name() const { return "simple"; } std::string name() const { return "simple"; }
rtg::shape compute_shape(std::vector<rtg::shape>) const { RTG_THROW("not computable"); } rtg::shape compute_shape(std::vector<rtg::shape>) const { RTG_THROW("not computable"); }
rtg::argument compute(std::vector<rtg::argument>) const { RTG_THROW("not computable"); } rtg::argument compute(rtg::shape, std::vector<rtg::argument>) const
{
RTG_THROW("not computable");
}
}; };
void operation_copy_test() void operation_copy_test()
......
...@@ -25,7 +25,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name()) ...@@ -25,7 +25,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
interface('operation', interface('operation',
virtual('name', returns='std::string', const=True), virtual('name', returns='std::string', const=True),
virtual('compute_shape', returns='shape', input='std::vector<shape>', const=True), virtual('compute_shape', returns='shape', input='std::vector<shape>', const=True),
virtual('compute', returns='argument', input='std::vector<argument>', const=True), virtual('compute', returns='argument', output='shape', input='std::vector<argument>', const=True),
friend('operator<<', returns='std::ostream &', os='std::ostream &', op='const operation &', using='rtg::operation_stream::operator<<') friend('operator<<', returns='std::ostream &', os='std::ostream &', op='const operation &', using='rtg::operation_stream::operator<<')
) )
%> %>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment