Commit 92324d57 authored by Paul's avatar Paul
Browse files

Format

parent 4d66f031
......@@ -48,8 +48,8 @@ struct softmax_compiler : compiler<softmax_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
// TODO: Use reduce_dims
auto axis = v.at("axis").to<int64_t>();
auto faxis = find_fast_axis({inputs.front()});
auto axis = v.at("axis").to<int64_t>();
auto faxis = find_fast_axis({inputs.front()});
vectorize vec{};
// Vectorize if the axis is a reduction axis
if(inputs.back().lens()[faxis] == 1)
......@@ -66,7 +66,9 @@ struct softmax_compiler : compiler<softmax_compiler>
options.inputs = inputs;
options.kernel_name = "softmax_kernel";
auto src = interpolate_string(softmax_kernel, {{"transformers", make_transformer_args(vec)}, {"axis", to_string(axis)}});
auto src = interpolate_string(
softmax_kernel,
{{"transformers", make_transformer_args(vec)}, {"axis", to_string(axis)}});
return compile_hip_code_object(src, options);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment