"vscode:/vscode.git/clone" did not exist on "78ed423d73e2982b81b61154cc18fa36002b2b1e"
Commit 796d1802 authored by Scott Thornton's avatar Scott Thornton
Browse files

Added contiguous test and fixed up operator

parent ed6bf6d1
...@@ -259,7 +259,6 @@ struct transpose ...@@ -259,7 +259,6 @@ struct transpose
struct contiguous struct contiguous
{ {
std::vector<int64_t> dims;
std::string name() const { return "contiguous"; } std::string name() const { return "contiguous"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
...@@ -271,6 +270,7 @@ struct contiguous ...@@ -271,6 +270,7 @@ struct contiguous
} }
return {t, lens}; return {t, lens};
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct reshape struct reshape
......
...@@ -440,6 +440,10 @@ struct cpu_apply ...@@ -440,6 +440,10 @@ struct cpu_apply
{ {
apply_reshape(it); apply_reshape(it);
} }
else if(it->op.name() == "contiguous")
{
apply_contiguous(it);
}
else if(it->op.name() == "transpose") else if(it->op.name() == "transpose")
{ {
apply_transpose(it); apply_transpose(it);
...@@ -505,6 +509,12 @@ struct cpu_apply ...@@ -505,6 +509,12 @@ struct cpu_apply
prog->replace_instruction(ins, cpu_reshape{op}, ins->arguments); prog->replace_instruction(ins, cpu_reshape{op}, ins->arguments);
} }
void apply_contiguous(instruction_ref ins)
{
auto&& op = any_cast<contiguous>(ins->op);
prog->replace_instruction(ins, cpu_contiguous{op}, ins->arguments);
}
void apply_transpose(instruction_ref ins) void apply_transpose(instruction_ref ins)
{ {
auto&& op = any_cast<transpose>(ins->op); auto&& op = any_cast<transpose>(ins->op);
......
...@@ -414,6 +414,25 @@ void transpose_test() ...@@ -414,6 +414,25 @@ void transpose_test()
}); });
} }
void contiguous_test() {
rtg::shape a_shape{rtg::shape::float_type, {1,3,2,2}, {12,1,6,3}};
std::vector<float> data(12);
std::iota(data.begin(), data.end(), 0);
rtg::program p;
auto l = p.add_literal(rtg::literal{a_shape, data});
p.add_instruction(rtg::contiguous{}, l);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(12);
result.visit([&] (auto output){
std::vector<size_t> new_lens = {1, 3, 2, 2};
std::vector<size_t> new_strides = {12, 1, 6, 3};
std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
});
}
int main() int main()
{ {
exp_test(); exp_test();
...@@ -423,6 +442,7 @@ int main() ...@@ -423,6 +442,7 @@ int main()
gemm_test(); gemm_test();
reshape_test(); reshape_test();
transpose_test(); transpose_test();
contiguous_test();
softmax_test(); softmax_test();
conv2d_test(); conv2d_test();
conv2d_padding_test(); conv2d_padding_test();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment