Commit ed6bf6d1 authored by Scott Thornton's avatar Scott Thornton
Browse files

Added transpose and contigous operator ... need tests

parent e13947ec
...@@ -235,38 +235,44 @@ struct transpose ...@@ -235,38 +235,44 @@ struct transpose
{ {
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
auto input = inputs.at(0); auto input = inputs.at(0);
auto input_size = input.lens(); auto input_lens = input.lens();
auto input_strides = input.strides(); auto input_strides = input.strides();
auto t = input.type(); auto t = input.type();
if (dims.size() != input_size.size()) { if (dims.size() != input_lens.size()) {
RTG_THROW("Permutation has wrong number of axes"); RTG_THROW("Permutation has wrong number of axes");
} }
// DEBUG
for (int i = 0; i < dims.size(); i++) {
std::cout << dims[i] << std::endl;
}
std::cout << std::endl;
std::vector<int64_t> axes(dims.size()); std::vector<int64_t> axes(dims.size());
std::iota(axes.begin(), axes.end(), 0); std::iota(axes.begin(), axes.end(), 0);
if (!std::is_permutation(axes.begin(), axes.end(), dims.begin())) { if (!std::is_permutation(axes.begin(), axes.end(), dims.begin())) {
RTG_THROW("Invalid permutation"); RTG_THROW("Invalid permutation");
} }
std::vector<size_t> output_size(input_size.size()); std::vector<size_t> output_lens(input_lens.size());
std::vector<size_t> output_strides(input_size.size()); std::vector<size_t> output_strides(input_lens.size());
for (int i = 0; i < output_size.size(); i++) { for (int i = 0; i < output_lens.size(); i++) {
output_size[i] = input_size[dims[i]]; output_lens[i] = input_lens[dims[i]];
output_strides[i] = input_strides[dims[i]]; output_strides[i] = input_strides[dims[i]];
//std::cout << input_size[i] << " " << output_size[i] << std::endl;
//std::cout << input_strides[i] << " " << output_strides[i] << std::endl;
} }
std::cout << std::endl; return {t, output_lens, output_strides};
return {t, output_size, output_strides};
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct contiguous
{
std::vector<int64_t> dims;
std::string name() const { return "contiguous"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
auto lens = inputs.at(0).lens();
auto t = inputs.at(0).type();
if (lens.size() < 2) {
RTG_THROW("Number of dimensions should exceed 1");
}
return {t, lens};
}
};
struct reshape struct reshape
{ {
std::vector<int64_t> dims; std::vector<int64_t> dims;
......
...@@ -61,30 +61,81 @@ struct cpu_transpose ...@@ -61,30 +61,81 @@ struct cpu_transpose
std::string name() const { return "cpu::transpose"; } std::string name() const { return "cpu::transpose"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); } shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{
return {output_shape, std::move(args.front().data)};
}
};
struct cpu_contiguous
{
contiguous op;
std::string name() const { return "cpu::contiguous"; }
shape compute_shape(std::vector<shape> inputs) const
{
return op.compute_shape(inputs);
}
argument compute(shape output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
auto input_shape = args[0].get_shape();
auto ndim = output_shape.lens().size();
using value_type = typename decltype(input)::value_type; using value_type = typename decltype(input)::value_type;
value_type* ptr = static_cast<value_type*>(output.data()); value_type* ptr = static_cast<value_type*>(output.data());
auto nb = output_shape.lens()[0]; if (ndim == 2) {
auto nc = output_shape.lens()[1]; dfor(input_shape.lens()[0],
auto nh = output_shape.lens()[2]; input_shape.lens()[1])(
auto nw = output_shape.lens()[3]; [&](std::size_t i0, std::size_t i1) {
for (int kk = 0; kk < 4; kk++) { *ptr++ = input(i0,i1);
std::cout << "cpu_transpose: " << output_shape.lens()[kk] << " " << output_shape.strides()[kk] << std::endl; });
} }
else if (ndim == 3) {
for (int b = 0; b < nb; b++) { dfor(input_shape.lens()[0],
for (int c = 0; c < nc; c++) { input_shape.lens()[1],
for (int i = 0; i < nh; i++) { input_shape.lens()[2])(
for (int j = 0; j < nw; j++) { [&](std::size_t i0, std::size_t i1, std::size_t i2) {
*ptr++ = input(b,c,i,j); *ptr++ = input(i0,i1,i2);
std::cout << input(b,c,i,j) << " "; });
} }
else if (ndim == 4) {
dfor(input_shape.lens()[0],
input_shape.lens()[1],
input_shape.lens()[2],
input_shape.lens()[3])(
[&](std::size_t i0, std::size_t i1, std::size_t i2, std::size_t i3) {
*ptr++ = input(i0,i1,i2,i3);
});
} }
else if (ndim == 5) {
dfor(input_shape.lens()[0],
input_shape.lens()[1],
input_shape.lens()[2],
input_shape.lens()[3],
input_shape.lens()[4])(
[&](std::size_t i0,
std::size_t i1,
std::size_t i2,
std::size_t i3,
std::size_t i4) {
*ptr++ = input(i0,i1,i2,i3,i4);
});
} }
else if (ndim == 6) {
dfor(input_shape.lens()[0],
input_shape.lens()[1],
input_shape.lens()[2],
input_shape.lens()[3],
input_shape.lens()[4],
input_shape.lens()[5])(
[&](std::size_t i0,
std::size_t i1,
std::size_t i2,
std::size_t i3,
std::size_t i4,
std::size_t i5) {
*ptr++ = input(i0,i1,i2,i3,i4,i5);
});
} }
std::cout << std::endl;
}); });
return result; return result;
} }
......
...@@ -404,12 +404,14 @@ void transpose_test() ...@@ -404,12 +404,14 @@ void transpose_test()
p.add_instruction(rtg::transpose{perm}, l); p.add_instruction(rtg::transpose{perm}, l);
p.compile(rtg::cpu::cpu_target{}); p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(12); std::vector<float> results_vector(12);
result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); }); result.visit([&] (auto output){
float tol = 1e-6; std::vector<size_t> new_lens = {1,3,2,2};
for (int i = 0; i < results_vector.size(); i++) { std::vector<size_t> new_strides = {12,1,6,3};
std::cout << results_vector[i] << std::endl; EXPECT(bool{output.get_shape().lens() == new_lens});
} EXPECT(bool{output.get_shape().strides() == new_strides});
});
} }
int main() int main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment