Commit 867e588b authored by Paul's avatar Paul
Browse files

Merge branch 'jit-concat' into jit-concat-pointwise

parents a33f9cca 2483cfa2
...@@ -71,13 +71,13 @@ struct concat_compiler : compiler<concat_compiler> ...@@ -71,13 +71,13 @@ struct concat_compiler : compiler<concat_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
// TODO: Use reduce_dims
hip_compile_options options; hip_compile_options options;
options.inputs = inputs; options.inputs = inputs;
options.output = inputs.back(); options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs);
options.params = "-Wno-float-equal"; options.params = "-Wno-float-equal";
auto axis = find_fast_axis(options.virtual_inputs); auto axis = find_fast_axis(options.inputs);
auto vec = vectorize::elements(axis, options.virtual_inputs); auto vec = vectorize::elements(axis, options.inputs);
options.kernel_name = v.get("kernel", "concat_kernel"); options.kernel_name = v.get("kernel", "concat_kernel");
options.set_launch_params( options.set_launch_params(
v, compute_global_for(ctx, get_concat_elements(options.inputs) / vec.size, 256)); v, compute_global_for(ctx, get_concat_elements(options.inputs) / vec.size, 256));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment