"vscode:/vscode.git/clone" did not exist on "624b21e742f2bfc493b30ca17e7c86ca9255e1e6"
Commit 3d3ed155 authored by Scott Thornton's avatar Scott Thornton
Browse files

Refactored unary and binary operators, added reshape operator and test, refactored tests

parent 88bdd75a
...@@ -229,9 +229,6 @@ struct reshape ...@@ -229,9 +229,6 @@ struct reshape
struct gemm struct gemm
{ {
std::string name() const { return "gemm";} std::string name() const { return "gemm";}
std::size_t lda = 1;
std::size_t ldb = 1;
std::size_t ldc = 1;
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(2).same_type(); check_shapes{inputs}.has(2).same_type();
...@@ -254,9 +251,8 @@ struct gemm ...@@ -254,9 +251,8 @@ struct gemm
} }
}; };
struct identity struct unary
{ {
std::string name() const {return "identity"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
...@@ -265,136 +261,69 @@ struct identity ...@@ -265,136 +261,69 @@ struct identity
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct abs struct identity : unary
{
std::string name() const {return "identity"; }
};
struct abs : unary
{ {
std::string name() const {return "abs"; } std::string name() const {return "abs"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct exp struct exp : unary
{ {
std::string name() const { return "exp"; } std::string name() const { return "exp"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct sin struct sin : unary
{ {
std::string name() const {return "sin"; } std::string name() const {return "sin"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct cos struct cos : unary
{ {
std::string name() const {return "cos"; } std::string name() const {return "cos"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct tan struct tan : unary
{ {
std::string name() const {return "tan"; } std::string name() const {return "tan"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct asin struct asin : unary
{ {
std::string name() const {return "asin"; } std::string name() const {return "asin"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct acos struct acos : unary
{ {
std::string name() const {return "acos"; } std::string name() const {return "acos"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct atan struct atan : unary
{ {
std::string name() const {return "atan"; } std::string name() const {return "atan"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct softmax struct softmax : unary
{ {
std::string name() const {return "softmax"; } std::string name() const {return "softmax"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1).only_dims(4);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct tanh struct tanh : unary
{ {
std::string name() const {return "tanh"; } std::string name() const {return "tanh"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct sigmoid struct sigmoid : unary
{ {
std::string name() const {return "sigmoid"; } std::string name() const {return "sigmoid"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct neg struct neg : unary
{ {
std::string name() const {return "neg"; } std::string name() const {return "neg"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct flatten struct flatten
...@@ -402,9 +331,8 @@ struct flatten ...@@ -402,9 +331,8 @@ struct flatten
std::string name() const { return "flatten"; } std::string name() const { return "flatten"; }
}; };
struct add struct binary
{ {
std::string name() const { return "add"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
// TODO(wsttiger@gmail.com) Check this for numpy-style broadcasting operations // TODO(wsttiger@gmail.com) Check this for numpy-style broadcasting operations
...@@ -413,37 +341,24 @@ struct add ...@@ -413,37 +341,24 @@ struct add
} }
}; };
struct sub struct add : binary
{
std::string name() const { return "add"; }
};
struct sub : binary
{ {
std::string name() const { return "sub"; } std::string name() const { return "sub"; }
shape compute_shape(std::vector<shape> inputs) const
{
// TODO(wsttiger@gmail.com) Check this for numpy-style broadcasting operations
check_shapes{inputs}.has(2).same_type().same_dims();
return inputs.at(0);
}
}; };
struct mul struct mul : binary
{ {
std::string name() const { return "mul"; } std::string name() const { return "mul"; }
shape compute_shape(std::vector<shape> inputs) const
{
// TODO(wsttiger@gmail.com) Check this for numpy-style broadcasting operations
check_shapes{inputs}.has(2).same_type().same_dims();
return inputs.at(0);
}
}; };
struct div struct div : binary
{ {
std::string name() const { return "div"; } std::string name() const { return "div"; }
shape compute_shape(std::vector<shape> inputs) const
{
// TODO(wsttiger@gmail.com) Check this for numpy-style broadcasting operations
check_shapes{inputs}.has(2).same_type().same_dims();
return inputs.at(0);
}
}; };
struct outline struct outline
......
...@@ -51,6 +51,21 @@ struct cpu_convolution ...@@ -51,6 +51,21 @@ struct cpu_convolution
} }
}; };
struct cpu_reshape
{
reshape op;
std::string name() const { return "cpu::reshape"; }
shape compute_shape(std::vector<shape> inputs) const
{
return op.compute_shape(inputs);
}
argument compute(shape output_shape, std::vector<argument> args) const
{
return {output_shape, std::move(args.front().data)};
}
};
struct cpu_gemm struct cpu_gemm
{ {
gemm op; gemm op;
...@@ -277,6 +292,14 @@ struct cpu_apply ...@@ -277,6 +292,14 @@ struct cpu_apply
{ {
apply_convolution(it); apply_convolution(it);
} }
else if(it->op.name() == "gemm")
{
apply_gemm(it);
}
else if(it->op.name() == "reshape")
{
apply_reshape(it);
}
else if(it->op.name() == "activation") else if(it->op.name() == "activation")
{ {
apply_activation(it); apply_activation(it);
...@@ -317,10 +340,6 @@ struct cpu_apply ...@@ -317,10 +340,6 @@ struct cpu_apply
{ {
apply_tan(it); apply_tan(it);
} }
else if(it->op.name() == "gemm")
{
apply_gemm(it);
}
} }
} }
...@@ -336,6 +355,12 @@ struct cpu_apply ...@@ -336,6 +355,12 @@ struct cpu_apply
prog->replace_instruction(ins, cpu_gemm{op}, ins->arguments); prog->replace_instruction(ins, cpu_gemm{op}, ins->arguments);
} }
void apply_reshape(instruction_ref ins)
{
auto&& op = any_cast<reshape>(ins->op);
prog->replace_instruction(ins, cpu_reshape{op}, ins->arguments);
}
void apply_activation(instruction_ref ins) void apply_activation(instruction_ref ins)
{ {
auto&& op = any_cast<activation>(ins->op); auto&& op = any_cast<activation>(ins->op);
......
#include <cassert>
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <rtg/literal.hpp> #include <rtg/literal.hpp>
#include <rtg/operators.hpp> #include <rtg/operators.hpp>
#include <rtg/cpu/cpu_target.hpp> #include <rtg/cpu/cpu_target.hpp>
#include "test.hpp"
void exp_test() { void exp_test() {
rtg::program p; rtg::program p;
...@@ -13,14 +13,111 @@ void exp_test() { ...@@ -13,14 +13,111 @@ void exp_test() {
p.compile(rtg::cpu::cpu_target{}); p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
memcpy(results_vector.data(), result.data(), 3*sizeof(float)); result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.36787944f,1.f,2.71828183f}; std::vector<float> gold = {0.36787944f,1.f,2.71828183f};
float tol = 1e-8; float tol = 1e-6;
for (int i = 0; i < results_vector.size(); i++) {
EXPECT(std::abs(results_vector[i]-gold[i]) < tol);
}
}
void sin_test() {
rtg::program p;
rtg::shape s{rtg::shape::float_type, {3}};
auto l = p.add_literal(rtg::literal{s, {-1,0,1}});
p.add_instruction(rtg::sin{}, l);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-0.84147098f,0.f,0.84147098f};
float tol = 1e-6;
for (int i = 0; i < results_vector.size(); i++) {
EXPECT(std::abs(results_vector[i]-gold[i]) < tol);
}
}
void cos_test() {
rtg::program p;
rtg::shape s{rtg::shape::float_type, {3}};
auto l = p.add_literal(rtg::literal{s, {-1,0,1}});
p.add_instruction(rtg::cos{}, l);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.54030231f,1.f,0.54030231f};
float tol = 1e-6;
for (int i = 0; i < results_vector.size(); i++) {
EXPECT(std::abs(results_vector[i]-gold[i]) < tol);
}
}
void tan_test() {
rtg::program p;
rtg::shape s{rtg::shape::float_type, {3}};
auto l = p.add_literal(rtg::literal{s, {-1,0,1}});
p.add_instruction(rtg::tan{}, l);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-1.55740772f,0.0f,1.55740772f};
float tol = 1e-6;
for (int i = 0; i < results_vector.size(); i++) { for (int i = 0; i < results_vector.size(); i++) {
assert(std::abs(results_vector[i]-gold[i]) < tol); EXPECT(std::abs(results_vector[i]-gold[i]) < tol);
}
}
void reshape_test() {
rtg::shape a_shape{rtg::shape::float_type, {24,1,1,1}};
std::vector<float> data(24);
std::iota(data.begin(), data.end(), -3);
{
rtg::program p;
auto l = p.add_literal(rtg::literal{a_shape, data});
std::vector<int64_t> new_shape = {8,3,1,1};
p.add_instruction(rtg::reshape{new_shape}, l);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
float tol = 1e-8;
for (int i = 0; i < results_vector.size(); i++) {
EXPECT(std::abs(results_vector[i]-data[i]) < tol);
}
}
{
rtg::program p;
auto l = p.add_literal(rtg::literal{a_shape, data});
std::vector<int64_t> new_shape = {1,3,4,2};
p.add_instruction(rtg::reshape{new_shape}, l);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
float tol = 1e-8;
for (int i = 0; i < results_vector.size(); i++) {
EXPECT(std::abs(results_vector[i]-data[i]) < tol);
}
}
{
rtg::program p;
auto l = p.add_literal(rtg::literal{a_shape, data});
std::vector<int64_t> new_shape = {1,3,4,2};
p.add_instruction(rtg::reshape{new_shape}, l);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
float tol = 1e-8;
for (int i = 0; i < results_vector.size(); i++) {
EXPECT(std::abs(results_vector[i]-data[i]) < tol);
}
} }
} }
//std::cout << std::abs(results_vector[i]-gold[i]) << std::endl;
void gemm_test() { void gemm_test() {
rtg::program p; rtg::program p;
std::vector<float> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885 , std::vector<float> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885 ,
...@@ -44,10 +141,10 @@ void gemm_test() { ...@@ -44,10 +141,10 @@ void gemm_test() {
p.compile(rtg::cpu::cpu_target{}); p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(12); std::vector<float> results_vector(12);
memcpy(results_vector.data(), result.data(), 12*sizeof(float)); result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
float tol = 1e-6; float tol = 1e-6;
for (int i = 0; i < results_vector.size(); i++) { for (int i = 0; i < results_vector.size(); i++) {
assert(std::abs(results_vector[i]-c[i]) < tol); EXPECT(std::abs(results_vector[i]-c[i]) < tol);
} }
} }
...@@ -125,10 +222,10 @@ void softmax_test() { ...@@ -125,10 +222,10 @@ void softmax_test() {
p.compile(rtg::cpu::cpu_target{}); p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(120); std::vector<float> results_vector(120);
memcpy(results_vector.data(), result.data(), 120*sizeof(float)); result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
float tol = 1e-6; float tol = 1e-6;
for (int i = 0; i < results_vector.size(); i++) { for (int i = 0; i < results_vector.size(); i++) {
assert(std::abs(results_vector[i]-s[i]) < tol); EXPECT(std::abs(results_vector[i]-s[i]) < tol);
} }
} }
...@@ -190,10 +287,10 @@ void conv2d_test() { ...@@ -190,10 +287,10 @@ void conv2d_test() {
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(16); std::vector<float> results_vector(16);
memcpy(results_vector.data(), result.data(), 16*sizeof(float)); result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
float tol = 1e-6; float tol = 1e-6;
for (int i = 0; i < results_vector.size(); i++) { for (int i = 0; i < results_vector.size(); i++) {
assert(std::abs(results_vector[i]-s[i]) < tol); EXPECT(std::abs(results_vector[i]-s[i]) < tol);
} }
} }
...@@ -257,10 +354,10 @@ void conv2d_padding_test() { ...@@ -257,10 +354,10 @@ void conv2d_padding_test() {
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(64); std::vector<float> results_vector(64);
memcpy(results_vector.data(), result.data(), 64*sizeof(float)); result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
float tol = 1e-6; float tol = 1e-6;
for (int i = 0; i < results_vector.size(); i++) { for (int i = 0; i < results_vector.size(); i++) {
assert(std::abs(results_vector[i]-s[i]) < tol); EXPECT(std::abs(results_vector[i]-s[i]) < tol);
} }
} }
...@@ -315,16 +412,21 @@ void conv2d_padding_stride_test() { ...@@ -315,16 +412,21 @@ void conv2d_padding_stride_test() {
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(16); std::vector<float> results_vector(16);
memcpy(results_vector.data(), result.data(), 16*sizeof(float)); result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
float tol = 1e-6; float tol = 1e-6;
for (int i = 0; i < results_vector.size(); i++) { for (int i = 0; i < results_vector.size(); i++) {
assert(std::abs(results_vector[i]-s[i]) < tol); EXPECT(std::abs(results_vector[i]-s[i]) < tol);
} }
} }
int main() int main()
{ {
exp_test(); exp_test();
sin_test();
cos_test();
tan_test();
gemm_test(); gemm_test();
reshape_test();
softmax_test(); softmax_test();
conv2d_test(); conv2d_test();
conv2d_padding_test(); conv2d_padding_test();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment