Commit 52df0ca8 authored by Khalique's avatar Khalique
Browse files

Merge branch 'one_hot_op' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into bert_ops

parents 41c67dac cb6deeb2
...@@ -597,25 +597,27 @@ struct tf_parser ...@@ -597,25 +597,27 @@ struct tf_parser
instruction_ref instruction_ref
parse_onehot(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_onehot(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
auto indices = args[0]->eval().get<int32_t>().to_vector(); // auto indices = args[0]->eval().get<int32_t>().to_vector();
int depth = args[1]->eval().at<int32_t>(); size_t depth = static_cast<size_t>(args[1]->eval().at<int32_t>());
int64_t axis = -1;
size_t num_indices = indices.size(); int64_t axis = -1;
float on_value = args[2]->eval().at<float>(); // size_t num_indices = indices.size();
float off_value = args[3]->eval().at<float>(); float on_value = args[2]->eval().at<float>();
float off_value = args[3]->eval().at<float>();
std::vector<float> depth_input(depth * depth, off_value);
for(int i = 0; i < depth; i++)
{
depth_input[depth * i + i] = on_value;
}
if(contains(attributes, "axis")) if(contains(attributes, "axis"))
axis = attributes.at("axis").i(); axis = attributes.at("axis").i();
if(axis == -1) if(axis == -1)
{ {
shape s{shape::float_type, {num_indices, static_cast<size_t>(depth)}}; shape s{shape::float_type, {depth, depth}};
std::vector<float> output(num_indices * depth); auto l0 = prog.add_literal({s, depth_input});
std::fill(output.begin(), output.end(), off_value); return prog.add_instruction(op::gather{0}, {l0, args[0]});
for(size_t i = 0; i < num_indices; i++)
{
if(indices[i] >= 0 and indices[i] < depth)
output.at(depth * i + indices[i]) = on_value;
}
return prog.add_literal(s, output);
} }
MIGRAPHX_THROW("MIGraphX does not support axis != -1"); MIGRAPHX_THROW("MIGraphX does not support axis != -1");
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment