Commit 33f53196 authored by Paul's avatar Paul
Browse files

Fix broadcast test

parent 9e284b7c
......@@ -427,33 +427,23 @@ struct broadcast
shape compute_shape(std::vector<shape> inputs) const
{
auto t = inputs.at(0).type();
auto shape0 = inputs.at(0);
auto shape1 = inputs.at(1);
auto shape0_lens = shape0.lens();
auto shape1_lens = shape1.lens();
auto shape1_strides = shape1.strides();
if(std::all_of(shape0_lens.cbegin(), shape1_lens.cend(), [&](auto x) { return x == 1; }))
auto result = inputs.at(0);
auto input = inputs.at(1);
std::vector<size_t> bcast_strides(result.lens().size(), 0);
if(std::all_of(result.lens().cbegin(), result.lens().cend(), [&](auto x) { return x == 1; }))
{
if(axis != 0)
RTG_THROW("when broadcasting tensor of size 1, axis should be 0");
const std::vector<size_t>& bcast_shape_lens = shape0_lens;
std::vector<size_t> bcast_shape_strides(bcast_shape_lens.size(), 0);
return {t, bcast_shape_lens, bcast_shape_strides};
return {t, result.lens(), std::move(bcast_strides)};
}
else
{
for(size_t i = 0; i < shape1_lens.size(); i++)
{
if(shape0_lens[i + axis] != shape1_lens[i])
RTG_THROW("when broadcasting success sizes must match");
}
std::vector<size_t> bcast_shape_lens = shape0_lens;
std::vector<size_t> bcast_shape_strides(bcast_shape_lens.size(), 0);
for(size_t i = 0; i < shape1_strides.size(); i++)
{
bcast_shape_strides[i + axis] = shape1_strides[i];
}
return {t, bcast_shape_lens, bcast_shape_strides};
assert(result.lens().size()-axis >= input.lens().size());
if(!std::equal(input.lens().begin(), input.lens().end(), result.lens().begin()+axis))
RTG_THROW("when broadcasting success sizes must match");
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin()+axis);
return {t, result.lens(), std::move(bcast_strides)};
}
}
argument compute(shape output_shape, std::vector<argument> args) const
......
......@@ -46,6 +46,8 @@ function(add_test_command NAME EXE)
file(MAKE_DIRECTORY ${TEST_DIR})
file(GENERATE OUTPUT "${TEST_DIR}/run.cmake"
CONTENT "
# Remove previous core dump
file(REMOVE ${TEST_DIR}/core)
execute_process(COMMAND $<TARGET_FILE:${EXE}> ${ARGN} WORKING_DIRECTORY ${TEST_DIR} RESULT_VARIABLE RESULT)
if(NOT RESULT EQUAL 0)
# TODO: check for core files based on pid when setting /proc/sys/kernel/core_uses_pid
......
......@@ -90,7 +90,11 @@ void broadcast_test()
p.add_instruction(rtg::broadcast{axis}, l1, l2);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<int32_t> results_vector(4);
auto output = result.get<int32_t>();
EXPECT(output(0,0) == -2);
EXPECT(output(0,1) == -2);
EXPECT(output(1,0) == -3);
EXPECT(output(1,1) == -3);
}
void add_broadcast_test()
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment