Commit 3d3ed155 authored by Scott Thornton's avatar Scott Thornton
Browse files

Refactored unary and binary operators, added reshape operator and test, refactored tests

parent 88bdd75a
......@@ -229,9 +229,6 @@ struct reshape
struct gemm
{
std::string name() const { return "gemm";}
std::size_t lda = 1;
std::size_t ldb = 1;
std::size_t ldc = 1;
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(2).same_type();
......@@ -254,9 +251,8 @@ struct gemm
}
};
struct identity
struct unary
{
std::string name() const {return "identity"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
......@@ -265,136 +261,69 @@ struct identity
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
};
struct abs
struct identity : unary
{
std::string name() const {return "identity"; }
};
struct abs : unary
{
std::string name() const {return "abs"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
};
struct exp
struct exp : unary
{
std::string name() const { return "exp"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
};
struct sin
struct sin : unary
{
std::string name() const {return "sin"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
};
struct cos
struct cos : unary
{
std::string name() const {return "cos"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
};
struct tan
struct tan : unary
{
std::string name() const {return "tan"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
};
struct asin
struct asin : unary
{
std::string name() const {return "asin"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
};
struct acos
struct acos : unary
{
std::string name() const {return "acos"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
};
struct atan
struct atan : unary
{
std::string name() const {return "atan"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
};
struct softmax
struct softmax : unary
{
std::string name() const {return "softmax"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1).only_dims(4);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
};
struct tanh
struct tanh : unary
{
std::string name() const {return "tanh"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
};
struct sigmoid
struct sigmoid : unary
{
std::string name() const {return "sigmoid"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
};
struct neg
struct neg : unary
{
std::string name() const {return "neg"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
};
struct flatten
......@@ -402,9 +331,8 @@ struct flatten
std::string name() const { return "flatten"; }
};
struct add
struct binary
{
std::string name() const { return "add"; }
shape compute_shape(std::vector<shape> inputs) const
{
// TODO(wsttiger@gmail.com) Check this for numpy-style broadcasting operations
......@@ -413,37 +341,24 @@ struct add
}
};
struct sub
struct add : binary
{
std::string name() const { return "add"; }
};
struct sub : binary
{
std::string name() const { return "sub"; }
shape compute_shape(std::vector<shape> inputs) const
{
// TODO(wsttiger@gmail.com) Check this for numpy-style broadcasting operations
check_shapes{inputs}.has(2).same_type().same_dims();
return inputs.at(0);
}
};
struct mul
struct mul : binary
{
std::string name() const { return "mul"; }
shape compute_shape(std::vector<shape> inputs) const
{
// TODO(wsttiger@gmail.com) Check this for numpy-style broadcasting operations
check_shapes{inputs}.has(2).same_type().same_dims();
return inputs.at(0);
}
};
struct div
struct div : binary
{
std::string name() const { return "div"; }
shape compute_shape(std::vector<shape> inputs) const
{
// TODO(wsttiger@gmail.com) Check this for numpy-style broadcasting operations
check_shapes{inputs}.has(2).same_type().same_dims();
return inputs.at(0);
}
};
struct outline
......
......@@ -51,6 +51,21 @@ struct cpu_convolution
}
};
struct cpu_reshape
{
reshape op;
std::string name() const { return "cpu::reshape"; }
shape compute_shape(std::vector<shape> inputs) const
{
return op.compute_shape(inputs);
}
argument compute(shape output_shape, std::vector<argument> args) const
{
return {output_shape, std::move(args.front().data)};
}
};
struct cpu_gemm
{
gemm op;
......@@ -277,6 +292,14 @@ struct cpu_apply
{
apply_convolution(it);
}
else if(it->op.name() == "gemm")
{
apply_gemm(it);
}
else if(it->op.name() == "reshape")
{
apply_reshape(it);
}
else if(it->op.name() == "activation")
{
apply_activation(it);
......@@ -317,10 +340,6 @@ struct cpu_apply
{
apply_tan(it);
}
else if(it->op.name() == "gemm")
{
apply_gemm(it);
}
}
}
......@@ -336,6 +355,12 @@ struct cpu_apply
prog->replace_instruction(ins, cpu_gemm{op}, ins->arguments);
}
void apply_reshape(instruction_ref ins)
{
auto&& op = any_cast<reshape>(ins->op);
prog->replace_instruction(ins, cpu_reshape{op}, ins->arguments);
}
void apply_activation(instruction_ref ins)
{
auto&& op = any_cast<activation>(ins->op);
......
#include <cassert>
#include <iostream>
#include <vector>
#include <rtg/literal.hpp>
#include <rtg/operators.hpp>
#include <rtg/cpu/cpu_target.hpp>
#include "test.hpp"
void exp_test() {
rtg::program p;
......@@ -13,14 +13,111 @@ void exp_test() {
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
memcpy(results_vector.data(), result.data(), 3*sizeof(float));
result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.36787944f,1.f,2.71828183f};
float tol = 1e-6;
for (int i = 0; i < results_vector.size(); i++) {
EXPECT(std::abs(results_vector[i]-gold[i]) < tol);
}
}
void sin_test() {
rtg::program p;
rtg::shape s{rtg::shape::float_type, {3}};
auto l = p.add_literal(rtg::literal{s, {-1,0,1}});
p.add_instruction(rtg::sin{}, l);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-0.84147098f,0.f,0.84147098f};
float tol = 1e-6;
for (int i = 0; i < results_vector.size(); i++) {
EXPECT(std::abs(results_vector[i]-gold[i]) < tol);
}
}
void cos_test() {
rtg::program p;
rtg::shape s{rtg::shape::float_type, {3}};
auto l = p.add_literal(rtg::literal{s, {-1,0,1}});
p.add_instruction(rtg::cos{}, l);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.54030231f,1.f,0.54030231f};
float tol = 1e-6;
for (int i = 0; i < results_vector.size(); i++) {
EXPECT(std::abs(results_vector[i]-gold[i]) < tol);
}
}
void tan_test() {
rtg::program p;
rtg::shape s{rtg::shape::float_type, {3}};
auto l = p.add_literal(rtg::literal{s, {-1,0,1}});
p.add_instruction(rtg::tan{}, l);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-1.55740772f,0.0f,1.55740772f};
float tol = 1e-6;
for (int i = 0; i < results_vector.size(); i++) {
EXPECT(std::abs(results_vector[i]-gold[i]) < tol);
}
}
void reshape_test() {
rtg::shape a_shape{rtg::shape::float_type, {24,1,1,1}};
std::vector<float> data(24);
std::iota(data.begin(), data.end(), -3);
{
rtg::program p;
auto l = p.add_literal(rtg::literal{a_shape, data});
std::vector<int64_t> new_shape = {8,3,1,1};
p.add_instruction(rtg::reshape{new_shape}, l);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
float tol = 1e-8;
for (int i = 0; i < results_vector.size(); i++) {
EXPECT(std::abs(results_vector[i]-data[i]) < tol);
}
}
{
rtg::program p;
auto l = p.add_literal(rtg::literal{a_shape, data});
std::vector<int64_t> new_shape = {1,3,4,2};
p.add_instruction(rtg::reshape{new_shape}, l);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
float tol = 1e-8;
for (int i = 0; i < results_vector.size(); i++) {
assert(std::abs(results_vector[i]-gold[i]) < tol);
EXPECT(std::abs(results_vector[i]-data[i]) < tol);
}
}
{
rtg::program p;
auto l = p.add_literal(rtg::literal{a_shape, data});
std::vector<int64_t> new_shape = {1,3,4,2};
p.add_instruction(rtg::reshape{new_shape}, l);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
float tol = 1e-8;
for (int i = 0; i < results_vector.size(); i++) {
EXPECT(std::abs(results_vector[i]-data[i]) < tol);
}
}
}
//std::cout << std::abs(results_vector[i]-gold[i]) << std::endl;
void gemm_test() {
rtg::program p;
std::vector<float> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885 ,
......@@ -44,10 +141,10 @@ void gemm_test() {
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(12);
memcpy(results_vector.data(), result.data(), 12*sizeof(float));
result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
float tol = 1e-6;
for (int i = 0; i < results_vector.size(); i++) {
assert(std::abs(results_vector[i]-c[i]) < tol);
EXPECT(std::abs(results_vector[i]-c[i]) < tol);
}
}
......@@ -125,10 +222,10 @@ void softmax_test() {
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(120);
memcpy(results_vector.data(), result.data(), 120*sizeof(float));
result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
float tol = 1e-6;
for (int i = 0; i < results_vector.size(); i++) {
assert(std::abs(results_vector[i]-s[i]) < tol);
EXPECT(std::abs(results_vector[i]-s[i]) < tol);
}
}
......@@ -190,10 +287,10 @@ void conv2d_test() {
auto result = p.eval({});
std::vector<float> results_vector(16);
memcpy(results_vector.data(), result.data(), 16*sizeof(float));
result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
float tol = 1e-6;
for (int i = 0; i < results_vector.size(); i++) {
assert(std::abs(results_vector[i]-s[i]) < tol);
EXPECT(std::abs(results_vector[i]-s[i]) < tol);
}
}
......@@ -257,10 +354,10 @@ void conv2d_padding_test() {
auto result = p.eval({});
std::vector<float> results_vector(64);
memcpy(results_vector.data(), result.data(), 64*sizeof(float));
result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
float tol = 1e-6;
for (int i = 0; i < results_vector.size(); i++) {
assert(std::abs(results_vector[i]-s[i]) < tol);
EXPECT(std::abs(results_vector[i]-s[i]) < tol);
}
}
......@@ -315,16 +412,21 @@ void conv2d_padding_stride_test() {
auto result = p.eval({});
std::vector<float> results_vector(16);
memcpy(results_vector.data(), result.data(), 16*sizeof(float));
result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
float tol = 1e-6;
for (int i = 0; i < results_vector.size(); i++) {
assert(std::abs(results_vector[i]-s[i]) < tol);
EXPECT(std::abs(results_vector[i]-s[i]) < tol);
}
}
int main()
{
exp_test();
sin_test();
cos_test();
tan_test();
gemm_test();
reshape_test();
softmax_test();
conv2d_test();
conv2d_padding_test();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment