Commit d92520de authored by Paul's avatar Paul
Browse files

Fix bug in pad operator from dimension reduction

parent b0798343
...@@ -65,10 +65,32 @@ struct pad_compiler : compiler<pad_compiler> ...@@ -65,10 +65,32 @@ struct pad_compiler : compiler<pad_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
auto padding = v.at("pads").to_vector<int64_t>();
auto input_lens = inputs.front().lens();
std::vector<size_t> offsets(input_lens.size());
std::copy(padding.begin(), padding.begin() + offsets.size(), offsets.begin());
auto offset_lens = input_lens;
std::transform(input_lens.begin(), input_lens.end(), offsets.begin(), offset_lens.begin(), [&](auto input, auto offset) {
return input+offset;
});
auto vinputs = inputs;
vinputs.push_back(inputs.front().with_lens(offset_lens));
auto rinputs = reduce_dims(vinputs);
auto rinput_lens = rinputs.front().lens();
auto roffset_lens = rinputs.back().lens();
std::vector<size_t> roffsets(roffset_lens.size());
std::transform(rinput_lens.begin(), rinput_lens.end(), roffset_lens.begin(), roffsets.begin(), [](auto input, auto offset_dim) {
return offset_dim - input;
});
rinputs.pop_back();
hip_compile_options options; hip_compile_options options;
options.inputs = inputs; options.inputs = inputs;
options.output = inputs.back(); options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs); options.virtual_inputs = rinputs;
options.kernel_name = "pad_kernel"; options.kernel_name = "pad_kernel";
options.set_launch_params(v, compute_global_for(ctx, inputs.at(1).elements())); options.set_launch_params(v, compute_global_for(ctx, inputs.at(1).elements()));
...@@ -79,14 +101,9 @@ struct pad_compiler : compiler<pad_compiler> ...@@ -79,14 +101,9 @@ struct pad_compiler : compiler<pad_compiler>
if(float_equal(pad_val, std::numeric_limits<float>::max())) if(float_equal(pad_val, std::numeric_limits<float>::max()))
pad_val_string = "highest{}"; pad_val_string = "highest{}";
auto padding = v.at("pads").to_vector<int64_t>();
auto input_lens = inputs.front().lens();
std::vector<size_t> offsets(input_lens.size());
std::copy(padding.begin(), padding.begin() + offsets.size(), offsets.begin());
auto src = interpolate_string( auto src = interpolate_string(
pointwise_kernel, pointwise_kernel,
{{"pad_val", to_string(pad_val_string)}, {"offsets", to_string_range(offsets)}}); {{"pad_val", to_string(pad_val_string)}, {"offsets", to_string_range(roffsets)}});
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment