Commit 33f53196 authored by Paul's avatar Paul
Browse files

Fix broadcast test

parent 9e284b7c
...@@ -427,33 +427,23 @@ struct broadcast ...@@ -427,33 +427,23 @@ struct broadcast
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
auto t = inputs.at(0).type(); auto t = inputs.at(0).type();
auto shape0 = inputs.at(0); auto result = inputs.at(0);
auto shape1 = inputs.at(1); auto input = inputs.at(1);
auto shape0_lens = shape0.lens();
auto shape1_lens = shape1.lens(); std::vector<size_t> bcast_strides(result.lens().size(), 0);
auto shape1_strides = shape1.strides(); if(std::all_of(result.lens().cbegin(), result.lens().cend(), [&](auto x) { return x == 1; }))
if(std::all_of(shape0_lens.cbegin(), shape1_lens.cend(), [&](auto x) { return x == 1; }))
{ {
if(axis != 0) if(axis != 0)
RTG_THROW("when broadcasting tensor of size 1, axis should be 0"); RTG_THROW("when broadcasting tensor of size 1, axis should be 0");
const std::vector<size_t>& bcast_shape_lens = shape0_lens; return {t, result.lens(), std::move(bcast_strides)};
std::vector<size_t> bcast_shape_strides(bcast_shape_lens.size(), 0);
return {t, bcast_shape_lens, bcast_shape_strides};
} }
else else
{ {
for(size_t i = 0; i < shape1_lens.size(); i++) assert(result.lens().size()-axis >= input.lens().size());
{ if(!std::equal(input.lens().begin(), input.lens().end(), result.lens().begin()+axis))
if(shape0_lens[i + axis] != shape1_lens[i]) RTG_THROW("when broadcasting success sizes must match");
RTG_THROW("when broadcasting success sizes must match"); std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin()+axis);
} return {t, result.lens(), std::move(bcast_strides)};
std::vector<size_t> bcast_shape_lens = shape0_lens;
std::vector<size_t> bcast_shape_strides(bcast_shape_lens.size(), 0);
for(size_t i = 0; i < shape1_strides.size(); i++)
{
bcast_shape_strides[i + axis] = shape1_strides[i];
}
return {t, bcast_shape_lens, bcast_shape_strides};
} }
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
......
...@@ -46,6 +46,8 @@ function(add_test_command NAME EXE) ...@@ -46,6 +46,8 @@ function(add_test_command NAME EXE)
file(MAKE_DIRECTORY ${TEST_DIR}) file(MAKE_DIRECTORY ${TEST_DIR})
file(GENERATE OUTPUT "${TEST_DIR}/run.cmake" file(GENERATE OUTPUT "${TEST_DIR}/run.cmake"
CONTENT " CONTENT "
# Remove previous core dump
file(REMOVE ${TEST_DIR}/core)
execute_process(COMMAND $<TARGET_FILE:${EXE}> ${ARGN} WORKING_DIRECTORY ${TEST_DIR} RESULT_VARIABLE RESULT) execute_process(COMMAND $<TARGET_FILE:${EXE}> ${ARGN} WORKING_DIRECTORY ${TEST_DIR} RESULT_VARIABLE RESULT)
if(NOT RESULT EQUAL 0) if(NOT RESULT EQUAL 0)
# TODO: check for core files based on pid when setting /proc/sys/kernel/core_uses_pid # TODO: check for core files based on pid when setting /proc/sys/kernel/core_uses_pid
......
...@@ -90,7 +90,11 @@ void broadcast_test() ...@@ -90,7 +90,11 @@ void broadcast_test()
p.add_instruction(rtg::broadcast{axis}, l1, l2); p.add_instruction(rtg::broadcast{axis}, l1, l2);
p.compile(rtg::cpu::cpu_target{}); p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<int32_t> results_vector(4); auto output = result.get<int32_t>();
EXPECT(output(0,0) == -2);
EXPECT(output(0,1) == -2);
EXPECT(output(1,0) == -3);
EXPECT(output(1,1) == -3);
} }
void add_broadcast_test() void add_broadcast_test()
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment