Commit f24e6384 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Update gather jit op

Pair programming with Paul
parent 20d8803c
...@@ -47,8 +47,7 @@ extern "C" { ...@@ -47,8 +47,7 @@ extern "C" {
__global__ void gather_kernel(void* in_data, void* in_indices, void* output) __global__ void gather_kernel(void* in_data, void* in_indices, void* output)
{ {
make_tensors()(in_data, in_indices, output)([](auto&&... xs) { make_tensors()(in_data, in_indices, output)([](auto&&... xs) {
auto settings = make_gather_settings(MIGRAPHX_MAKE_CONSTANT(int64_t{AXIS})); gather<${axis}>(xs...);
gather(xs..., settings);
}); });
} }
...@@ -58,7 +57,7 @@ __global__ void gather_kernel(void* in_data, void* in_indices, void* output) ...@@ -58,7 +57,7 @@ __global__ void gather_kernel(void* in_data, void* in_indices, void* output)
)__migraphx__"; )__migraphx__";
struct gathernd_compiler : compiler<gathernd_compiler> struct gather_compiler : compiler<gather_compiler>
{ {
std::vector<std::string> names() const { return {"gather"}; } std::vector<std::string> names() const { return {"gather"}; }
...@@ -72,12 +71,13 @@ struct gathernd_compiler : compiler<gathernd_compiler> ...@@ -72,12 +71,13 @@ struct gathernd_compiler : compiler<gathernd_compiler>
options.kernel_name = "gather_kernel"; options.kernel_name = "gather_kernel";
options.virtual_inputs = inputs; options.virtual_inputs = inputs;
// axis attribute
assert(v.contains("axis")); assert(v.contains("axis"));
auto axis = v.at("axis").to<int64_t>(); auto axis = v.at("axis").to<std::string>();
options.params += " -DAXIS=" + std::to_string(axis);
auto src = interpolate_string(gather_kernel, {{"axis", axis}});
return compile_hip_code_object(gather_kernel, options); // axis attribute
return compile_hip_code_object(src, options);
} }
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
......
...@@ -29,41 +29,19 @@ ...@@ -29,41 +29,19 @@
namespace migraphx { namespace migraphx {
template <class T> template <int axis, class T, class U, class V>
struct gather_settings __device__ void gather(const T& data_t, const U& indices_t, const V& output_t)
{
T axes{};
};
template <class... Ts>
constexpr gather_settings<Ts...> make_gather_settings(Ts... xs)
{
return {xs...};
}
template <class T, class U, class V, class Settings>
__device__ void gather(const T& data_t, const U& indices_t, const V& output_t, Settings s)
{ {
auto ind = make_index(); auto ind = make_index();
auto axis = s.axis;
auto output_shape = output_t.get_shape(); auto output_shape = output_t.get_shape();
auto indices_shape = indices_t.get_shape(); auto axis_dim_size = data_t.get_shape().lens[axis];
auto data_shape = data_t.get_shape();
auto axis_dim_size = data_shape.lens().at(axis);
auto indices_shape_lens = indices_shape.lens;
auto data_shape_lens = data_shape.lens;
const auto* indices_ptr = indices_t.data();
auto* output_ptr = output_t.data();
ind.global_stride(output_shape.elements(), [&](auto i) { ind.global_stride(output_shape.elements(), [&](auto i) {
auto idx = output_shape.multi(i); auto idx = output_shape.multi(i);
auto in_index = indices_ptr[idx[axis]]; auto in_index = indices_t[idx[axis]];
in_index = (in_index < 0) ? in_index + axis_dim_size : in_index; in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
idx[axis] = in_index; idx[axis] = in_index;
output_ptr[i] = indices_t[idx]; output_t[i] = indices_t[idx];
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment