Commit 1e56489a authored by Khalique's avatar Khalique
Browse files

work on onehot implementation

parent 49e65e08
......@@ -173,6 +173,7 @@ struct tf_parser
add_mem_op("MatMul", &tf_parser::parse_matmul, false);
add_mem_op("MaxPool", &tf_parser::parse_pooling);
add_mem_op("Mean", &tf_parser::parse_mean);
add_mem_op("OneHot", &tf_parser::parse_onehot, false);
add_mem_op("Pack", &tf_parser::parse_pack, false);
add_mem_op("Pad", &tf_parser::parse_pad);
add_mem_op("Reshape", &tf_parser::parse_reshape, false);
......@@ -563,6 +564,32 @@ struct tf_parser
MIGRAPHX_THROW("MIGraphX does not support mean outside of GlobalAvgPool transformation");
}
instruction_ref
parse_onehot(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
auto indices = args[0]->eval().get<int64_t>().to_vector();
int depth = args[1]->eval().at<int32_t>();
int axis = -1;
size_t num_indices = indices.size();
float on_value = args[2]->eval().at<float>();
float off_value = args[3]->eval().at<float>();
if (contains(attributes, "axis"))
axis = attributes.at("axis").i();
if(axis == -1)
{
shape s{shape::float_type, {num_indices, static_cast<size_t>(depth)}};
std::vector<float> output(num_indices * depth);
std::fill(output.begin(), output.end(), off_value);
for (size_t i = 0; i < num_indices; i++)
{
if(indices[i] >= 0 and indices[i] < num_indices)
output[depth*i + indices[i]] = on_value;
}
return prog.add_literal(s, output);
}
MIGRAPHX_THROW("MIGraphX does not support axis != -1");
}
instruction_ref parse_pack(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment