Commit 41c67dac authored by Khalique's avatar Khalique
Browse files

Merge branch 'one_hot_op' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into bert_ops

parents 4c7f3cba b9e0e0fc
...@@ -599,7 +599,7 @@ struct tf_parser ...@@ -599,7 +599,7 @@ struct tf_parser
{ {
auto indices = args[0]->eval().get<int32_t>().to_vector(); auto indices = args[0]->eval().get<int32_t>().to_vector();
int depth = args[1]->eval().at<int32_t>(); int depth = args[1]->eval().at<int32_t>();
int axis = -1; int64_t axis = -1;
size_t num_indices = indices.size(); size_t num_indices = indices.size();
float on_value = args[2]->eval().at<float>(); float on_value = args[2]->eval().at<float>();
float off_value = args[3]->eval().at<float>(); float off_value = args[3]->eval().at<float>();
...@@ -612,8 +612,8 @@ struct tf_parser ...@@ -612,8 +612,8 @@ struct tf_parser
std::fill(output.begin(), output.end(), off_value); std::fill(output.begin(), output.end(), off_value);
for(size_t i = 0; i < num_indices; i++) for(size_t i = 0; i < num_indices; i++)
{ {
if(indices[i] >= 0 and indices[i] < num_indices) if(indices[i] >= 0 and indices[i] < depth)
output[depth * i + indices[i]] = on_value; output.at(depth * i + indices[i]) = on_value;
} }
return prog.add_literal(s, output); return prog.add_literal(s, output);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment