Commit ed6bf6d1 authored by Scott Thornton's avatar Scott Thornton
Browse files

Added transpose and contigous operator ... need tests

parent e13947ec
......@@ -235,38 +235,44 @@ struct transpose
{
check_shapes{inputs}.has(1);
auto input = inputs.at(0);
auto input_size = input.lens();
auto input_lens = input.lens();
auto input_strides = input.strides();
auto t = input.type();
if (dims.size() != input_size.size()) {
if (dims.size() != input_lens.size()) {
RTG_THROW("Permutation has wrong number of axes");
}
// DEBUG
for (int i = 0; i < dims.size(); i++) {
std::cout << dims[i] << std::endl;
}
std::cout << std::endl;
std::vector<int64_t> axes(dims.size());
std::iota(axes.begin(), axes.end(), 0);
if (!std::is_permutation(axes.begin(), axes.end(), dims.begin())) {
RTG_THROW("Invalid permutation");
}
std::vector<size_t> output_size(input_size.size());
std::vector<size_t> output_strides(input_size.size());
for (int i = 0; i < output_size.size(); i++) {
output_size[i] = input_size[dims[i]];
std::vector<size_t> output_lens(input_lens.size());
std::vector<size_t> output_strides(input_lens.size());
for (int i = 0; i < output_lens.size(); i++) {
output_lens[i] = input_lens[dims[i]];
output_strides[i] = input_strides[dims[i]];
//std::cout << input_size[i] << " " << output_size[i] << std::endl;
//std::cout << input_strides[i] << " " << output_strides[i] << std::endl;
}
std::cout << std::endl;
return {t, output_size, output_strides};
return {t, output_lens, output_strides};
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
};
struct contiguous
{
std::vector<int64_t> dims;
std::string name() const { return "contiguous"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
auto lens = inputs.at(0).lens();
auto t = inputs.at(0).type();
if (lens.size() < 2) {
RTG_THROW("Number of dimensions should exceed 1");
}
return {t, lens};
}
};
struct reshape
{
std::vector<int64_t> dims;
......
......@@ -61,30 +61,81 @@ struct cpu_transpose
std::string name() const { return "cpu::transpose"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(shape output_shape, std::vector<argument> args) const
{
return {output_shape, std::move(args.front().data)};
}
};
struct cpu_contiguous
{
contiguous op;
std::string name() const { return "cpu::contiguous"; }
shape compute_shape(std::vector<shape> inputs) const
{
return op.compute_shape(inputs);
}
argument compute(shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
auto input_shape = args[0].get_shape();
auto ndim = output_shape.lens().size();
using value_type = typename decltype(input)::value_type;
value_type* ptr = static_cast<value_type*>(output.data());
auto nb = output_shape.lens()[0];
auto nc = output_shape.lens()[1];
auto nh = output_shape.lens()[2];
auto nw = output_shape.lens()[3];
for (int kk = 0; kk < 4; kk++) {
std::cout << "cpu_transpose: " << output_shape.lens()[kk] << " " << output_shape.strides()[kk] << std::endl;
if (ndim == 2) {
dfor(input_shape.lens()[0],
input_shape.lens()[1])(
[&](std::size_t i0, std::size_t i1) {
*ptr++ = input(i0,i1);
});
}
for (int b = 0; b < nb; b++) {
for (int c = 0; c < nc; c++) {
for (int i = 0; i < nh; i++) {
for (int j = 0; j < nw; j++) {
*ptr++ = input(b,c,i,j);
std::cout << input(b,c,i,j) << " ";
}
}
}
else if (ndim == 3) {
dfor(input_shape.lens()[0],
input_shape.lens()[1],
input_shape.lens()[2])(
[&](std::size_t i0, std::size_t i1, std::size_t i2) {
*ptr++ = input(i0,i1,i2);
});
}
else if (ndim == 4) {
dfor(input_shape.lens()[0],
input_shape.lens()[1],
input_shape.lens()[2],
input_shape.lens()[3])(
[&](std::size_t i0, std::size_t i1, std::size_t i2, std::size_t i3) {
*ptr++ = input(i0,i1,i2,i3);
});
}
else if (ndim == 5) {
dfor(input_shape.lens()[0],
input_shape.lens()[1],
input_shape.lens()[2],
input_shape.lens()[3],
input_shape.lens()[4])(
[&](std::size_t i0,
std::size_t i1,
std::size_t i2,
std::size_t i3,
std::size_t i4) {
*ptr++ = input(i0,i1,i2,i3,i4);
});
}
else if (ndim == 6) {
dfor(input_shape.lens()[0],
input_shape.lens()[1],
input_shape.lens()[2],
input_shape.lens()[3],
input_shape.lens()[4],
input_shape.lens()[5])(
[&](std::size_t i0,
std::size_t i1,
std::size_t i2,
std::size_t i3,
std::size_t i4,
std::size_t i5) {
*ptr++ = input(i0,i1,i2,i3,i4,i5);
});
}
std::cout << std::endl;
});
return result;
}
......
......@@ -404,12 +404,14 @@ void transpose_test()
p.add_instruction(rtg::transpose{perm}, l);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(12);
result.visit([&] (auto output){ results_vector.assign(output.begin(), output.end()); });
float tol = 1e-6;
for (int i = 0; i < results_vector.size(); i++) {
std::cout << results_vector[i] << std::endl;
}
result.visit([&] (auto output){
std::vector<size_t> new_lens = {1,3,2,2};
std::vector<size_t> new_strides = {12,1,6,3};
EXPECT(bool{output.get_shape().lens() == new_lens});
EXPECT(bool{output.get_shape().strides() == new_strides});
});
}
int main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment