Commit 3d3ed155 authored by Scott Thornton's avatar Scott Thornton
Browse files

Refactored unary and binary operators, added reshape operator and test, refactored tests

parent 88bdd75a
...@@ -229,9 +229,6 @@ struct reshape ...@@ -229,9 +229,6 @@ struct reshape
struct gemm struct gemm
{ {
std::string name() const { return "gemm";} std::string name() const { return "gemm";}
std::size_t lda = 1;
std::size_t ldb = 1;
std::size_t ldc = 1;
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(2).same_type(); check_shapes{inputs}.has(2).same_type();
...@@ -254,9 +251,8 @@ struct gemm ...@@ -254,9 +251,8 @@ struct gemm
} }
}; };
struct identity struct unary
{ {
std::string name() const {return "identity"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
...@@ -265,136 +261,69 @@ struct identity ...@@ -265,136 +261,69 @@ struct identity
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct abs struct identity : unary
{
std::string name() const {return "identity"; }
};
struct abs : unary
{ {
std::string name() const {return "abs"; } std::string name() const {return "abs"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct exp struct exp : unary
{ {
std::string name() const { return "exp"; } std::string name() const { return "exp"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct sin struct sin : unary
{ {
std::string name() const {return "sin"; } std::string name() const {return "sin"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct cos struct cos : unary
{ {
std::string name() const {return "cos"; } std::string name() const {return "cos"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct tan struct tan : unary
{ {
std::string name() const {return "tan"; } std::string name() const {return "tan"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct asin struct asin : unary
{ {
std::string name() const {return "asin"; } std::string name() const {return "asin"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct acos struct acos : unary
{ {
std::string name() const {return "acos"; } std::string name() const {return "acos"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct atan struct atan : unary
{ {
std::string name() const {return "atan"; } std::string name() const {return "atan"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct softmax struct softmax : unary
{ {
std::string name() const {return "softmax"; } std::string name() const {return "softmax"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1).only_dims(4);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct tanh struct tanh : unary
{ {
std::string name() const {return "tanh"; } std::string name() const {return "tanh"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct sigmoid struct sigmoid : unary
{ {
std::string name() const {return "sigmoid"; } std::string name() const {return "sigmoid"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct neg struct neg : unary
{ {
std::string name() const {return "neg"; } std::string name() const {return "neg"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct flatten struct flatten
...@@ -402,9 +331,8 @@ struct flatten ...@@ -402,9 +331,8 @@ struct flatten
std::string name() const { return "flatten"; } std::string name() const { return "flatten"; }
}; };
struct add struct binary
{ {
std::string name() const { return "add"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
// TODO(wsttiger@gmail.com) Check this for numpy-style broadcasting operations // TODO(wsttiger@gmail.com) Check this for numpy-style broadcasting operations
...@@ -413,37 +341,24 @@ struct add ...@@ -413,37 +341,24 @@ struct add
} }
}; };
struct sub struct add : binary
{
std::string name() const { return "add"; }
};
struct sub : binary
{ {
std::string name() const { return "sub"; } std::string name() const { return "sub"; }
shape compute_shape(std::vector<shape> inputs) const
{
// TODO(wsttiger@gmail.com) Check this for numpy-style broadcasting operations
check_shapes{inputs}.has(2).same_type().same_dims();
return inputs.at(0);
}
}; };
struct mul struct mul : binary
{ {
std::string name() const { return "mul"; } std::string name() const { return "mul"; }
shape compute_shape(std::vector<shape> inputs) const
{
// TODO(wsttiger@gmail.com) Check this for numpy-style broadcasting operations
check_shapes{inputs}.has(2).same_type().same_dims();
return inputs.at(0);
}
}; };
struct div struct div : binary
{ {
std::string name() const { return "div"; } std::string name() const { return "div"; }
shape compute_shape(std::vector<shape> inputs) const
{
// TODO(wsttiger@gmail.com) Check this for numpy-style broadcasting operations
check_shapes{inputs}.has(2).same_type().same_dims();
return inputs.at(0);
}
}; };
struct outline struct outline
......
...@@ -51,6 +51,21 @@ struct cpu_convolution ...@@ -51,6 +51,21 @@ struct cpu_convolution
} }
}; };
struct cpu_reshape
{
reshape op;
std::string name() const { return "cpu::reshape"; }
shape compute_shape(std::vector<shape> inputs) const
{
return op.compute_shape(inputs);
}
argument compute(shape output_shape, std::vector<argument> args) const
{
return {output_shape, std::move(args.front().data)};
}
};
struct cpu_gemm struct cpu_gemm
{ {
gemm op; gemm op;
...@@ -277,6 +292,14 @@ struct cpu_apply ...@@ -277,6 +292,14 @@ struct cpu_apply
{ {
apply_convolution(it); apply_convolution(it);
} }
else if(it->op.name() == "gemm")
{
apply_gemm(it);
}
else if(it->op.name() == "reshape")
{
apply_reshape(it);
}
else if(it->op.name() == "activation") else if(it->op.name() == "activation")
{ {
apply_activation(it); apply_activation(it);
...@@ -317,10 +340,6 @@ struct cpu_apply ...@@ -317,10 +340,6 @@ struct cpu_apply
{ {
apply_tan(it); apply_tan(it);
} }
else if(it->op.name() == "gemm")
{
apply_gemm(it);
}
} }
} }
...@@ -336,6 +355,12 @@ struct cpu_apply ...@@ -336,6 +355,12 @@ struct cpu_apply
prog->replace_instruction(ins, cpu_gemm{op}, ins->arguments); prog->replace_instruction(ins, cpu_gemm{op}, ins->arguments);
} }
void apply_reshape(instruction_ref ins)
{
auto&& op = any_cast<reshape>(ins->op);
prog->replace_instruction(ins, cpu_reshape{op}, ins->arguments);
}
void apply_activation(instruction_ref ins) void apply_activation(instruction_ref ins)
{ {
auto&& op = any_cast<activation>(ins->op); auto&& op = any_cast<activation>(ins->op);
......
#include <cassert>
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <rtg/literal.hpp> #include <rtg/literal.hpp>
#include <rtg/operators.hpp> #include <rtg/operators.hpp>
#include <rtg/cpu/cpu_target.hpp> #include <rtg/cpu/cpu_target.hpp>
#include "test.hpp"
void exp_test() { void exp_test() {
rtg::program p; rtg::program p;
...@@ -13,14 +13,111 @@ void exp_test() { ...@@ -13,14 +13,111 @@ void exp_test() {
p.compile(rtg::cpu::cpu_target{}); p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
memcpy(results_vector.data(), result.data(), 3*sizeof(float)); result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.36787944f,1.f,2.71828183f}; std::vector<float> gold = {0.36787944f,1.f,2.71828183f};
float tol = 1e-6;
for (int i = 0; i < results_vector.size(); i++) {
EXPECT(std::abs(results_vector[i]-gold[i]) < tol);
}
}
void sin_test() {
rtg::program p;
rtg::shape s{rtg::shape::float_type, {3}};
auto l = p.add_literal(rtg::literal{s, {-1,0,1}});
p.add_instruction(rtg::sin{}, l);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-0.84147098f,0.f,0.84147098f};
float tol = 1e-6;
for (int i = 0; i < results_vector.size(); i++) {
EXPECT(std::abs(results_vector[i]-gold[i]) < tol);
}
}
void cos_test() {
rtg::program p;
rtg::shape s{rtg::shape::float_type, {3}};
auto l = p.add_literal(rtg::literal{s, {-1,0,1}});
p.add_instruction(rtg::cos{}, l);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.54030231f,1.f,0.54030231f};
float tol = 1e-6;
for (int i = 0; i < results_vector.size(); i++) {
EXPECT(std::abs(results_vector[i]-gold[i]) < tol);
}
}
void tan_test() {
rtg::program p;
rtg::shape s{rtg::shape::float_type, {3}};
auto l = p.add_literal(rtg::literal{s, {-1,0,1}});
p.add_instruction(rtg::tan{}, l);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-1.55740772f,0.0f,1.55740772f};
float tol = 1e-6;
for (int i = 0; i < results_vector.size(); i++) {
EXPECT(std::abs(results_vector[i]-gold[i]) < tol);
}
}
void reshape_test() {
rtg::shape a_shape{rtg::shape::float_type, {24,1,1,1}};
std::vector<float> data(24);
std::iota(data.begin(), data.end(), -3);
{
rtg::program p;
auto l = p.add_literal(rtg::literal{a_shape, data});
std::vector<int64_t> new_shape = {8,3,1,1};
p.add_instruction(rtg::reshape{new_shape}, l);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
float tol = 1e-8;
for (int i = 0; i < results_vector.size(); i++) {
EXPECT(std::abs(results_vector[i]-data[i]) < tol);
}
}
{
rtg::program p;
auto l = p.add_literal(rtg::literal{a_shape, data});
std::vector<int64_t> new_shape = {1,3,4,2};
p.add_instruction(rtg::reshape{new_shape}, l);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
float tol = 1e-8; float tol = 1e-8;
for (int i = 0; i < results_vector.size(); i++) { for (int i = 0; i < results_vector.size(); i++) {
assert(std::abs(results_vector[i]-gold[i]) < tol); EXPECT(std::abs(results_vector[i]-data[i]) < tol);
}
}
{
rtg::program p;
auto l = p.add_literal(rtg::literal{a_shape, data});
std::vector<int64_t> new_shape = {1,3,4,2};
p.add_instruction(rtg::reshape{new_shape}, l);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
float tol = 1e-8;
for (int i = 0; i < results_vector.size(); i++) {
EXPECT(std::abs(results_vector[i]-data[i]) < tol);
}
} }
} }
//std::cout << std::abs(results_vector[i]-gold[i]) << std::endl;
void gemm_test() { void gemm_test() {
rtg::program p; rtg::program p;
std::vector<float> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885 , std::vector<float> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885 ,
...@@ -44,10 +141,10 @@ void gemm_test() { ...@@ -44,10 +141,10 @@ void gemm_test() {
p.compile(rtg::cpu::cpu_target{}); p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(12); std::vector<float> results_vector(12);
memcpy(results_vector.data(), result.data(), 12*sizeof(float)); result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
float tol = 1e-6; float tol = 1e-6;
for (int i = 0; i < results_vector.size(); i++) { for (int i = 0; i < results_vector.size(); i++) {
assert(std::abs(results_vector[i]-c[i]) < tol); EXPECT(std::abs(results_vector[i]-c[i]) < tol);
} }
} }
...@@ -125,10 +222,10 @@ void softmax_test() { ...@@ -125,10 +222,10 @@ void softmax_test() {
p.compile(rtg::cpu::cpu_target{}); p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(120); std::vector<float> results_vector(120);
memcpy(results_vector.data(), result.data(), 120*sizeof(float)); result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
float tol = 1e-6; float tol = 1e-6;
for (int i = 0; i < results_vector.size(); i++) { for (int i = 0; i < results_vector.size(); i++) {
assert(std::abs(results_vector[i]-s[i]) < tol); EXPECT(std::abs(results_vector[i]-s[i]) < tol);
} }
} }
...@@ -190,10 +287,10 @@ void conv2d_test() { ...@@ -190,10 +287,10 @@ void conv2d_test() {
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(16); std::vector<float> results_vector(16);
memcpy(results_vector.data(), result.data(), 16*sizeof(float)); result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
float tol = 1e-6; float tol = 1e-6;
for (int i = 0; i < results_vector.size(); i++) { for (int i = 0; i < results_vector.size(); i++) {
assert(std::abs(results_vector[i]-s[i]) < tol); EXPECT(std::abs(results_vector[i]-s[i]) < tol);
} }
} }
...@@ -257,10 +354,10 @@ void conv2d_padding_test() { ...@@ -257,10 +354,10 @@ void conv2d_padding_test() {
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(64); std::vector<float> results_vector(64);
memcpy(results_vector.data(), result.data(), 64*sizeof(float)); result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
float tol = 1e-6; float tol = 1e-6;
for (int i = 0; i < results_vector.size(); i++) { for (int i = 0; i < results_vector.size(); i++) {
assert(std::abs(results_vector[i]-s[i]) < tol); EXPECT(std::abs(results_vector[i]-s[i]) < tol);
} }
} }
...@@ -315,16 +412,21 @@ void conv2d_padding_stride_test() { ...@@ -315,16 +412,21 @@ void conv2d_padding_stride_test() {
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(16); std::vector<float> results_vector(16);
memcpy(results_vector.data(), result.data(), 16*sizeof(float)); result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
float tol = 1e-6; float tol = 1e-6;
for (int i = 0; i < results_vector.size(); i++) { for (int i = 0; i < results_vector.size(); i++) {
assert(std::abs(results_vector[i]-s[i]) < tol); EXPECT(std::abs(results_vector[i]-s[i]) < tol);
} }
} }
int main() int main()
{ {
exp_test(); exp_test();
sin_test();
cos_test();
tan_test();
gemm_test(); gemm_test();
reshape_test();
softmax_test(); softmax_test();
conv2d_test(); conv2d_test();
conv2d_padding_test(); conv2d_padding_test();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment