Commit c8a91e20 authored by Khalique's avatar Khalique
Browse files

initial tf progress

parent 547fd938
......@@ -35,6 +35,7 @@ target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLU
set(PACKAGE_DEPENDS)
add_subdirectory(onnx)
add_subdirectory(tf)
add_subdirectory(targets/cpu)
if(MIGRAPHX_ENABLE_GPU)
list(APPEND PACKAGE_DEPENDS MIOpen rocblas)
......
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_TF_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_TF_HPP
#include <migraphx/program.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct unknown
{
std::string op;
std::string name() const { return "unknown:" + op; }
shape compute_shape(std::vector<shape> input) const
{
if(input.empty())
return {};
else
return input.front();
}
friend std::ostream& operator<<(std::ostream& os, const unknown& x)
{
os << x.name();
return os;
}
};
/// Create a program from an onnx file
program parse_tf(const std::string& name, bool is_nhwc);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
find_package(Protobuf REQUIRED)
protobuf_generate_cpp(
PROTO_SRCS PROTO_HDRS
graph.proto
node_def.proto
attr_value.proto
tensor.proto
tensor_shape.proto
resource_handle.proto
types.proto
function.proto
op_def.proto
versions.proto
)
add_library(tf-proto STATIC ${PROTO_SRCS})
target_include_directories(tf-proto SYSTEM PUBLIC ${CMAKE_CURRENT_BINARY_DIR} ${PROTOBUF_INCLUDE_DIR})
target_compile_options(tf-proto PRIVATE -w)
target_link_libraries(tf-proto PRIVATE ${PROTOBUF_LIBRARY})
set_target_properties(tf-proto PROPERTIES POSITION_INDEPENDENT_CODE On)
add_library(migraphx_tf tf.cpp)
set_target_properties(migraphx_tf PROPERTIES EXPORT_NAME tf)
rocm_clang_tidy_check(migraphx_tf)
target_link_libraries(migraphx_tf PRIVATE tf-proto)
target_link_libraries(migraphx_tf PUBLIC migraphx)
rocm_install_targets(
TARGETS migraphx_tf
)
add_executable(read_tf read_tf.cpp)
rocm_clang_tidy_check(read_tf)
target_link_libraries(read_tf migraphx_tf)
# add_executable(verify_onnx verify_onnx.cpp)
# rocm_clang_tidy_check(verify_onnx)
# target_link_libraries(verify_onnx migraphx_onnx migraphx_cpu migraphx_gpu)
# add_executable(perf_onnx perf_onnx.cpp)
# rocm_clang_tidy_check(perf_onnx)
# target_link_libraries(perf_onnx migraphx_onnx migraphx_cpu migraphx_gpu)
# endif()
syntax = "proto3";
package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "AttrValueProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework";
import "tensor.proto";
import "tensor_shape.proto";
import "types.proto";
// Protocol buffer representing the value for an attr used to configure an Op.
// Comment indicates the corresponding attr type. Only the field matching the
// attr type may be filled.
message AttrValue {
// LINT.IfChange
message ListValue {
repeated bytes s = 2; // "list(string)"
repeated int64 i = 3 [packed = true]; // "list(int)"
repeated float f = 4 [packed = true]; // "list(float)"
repeated bool b = 5 [packed = true]; // "list(bool)"
repeated DataType type = 6 [packed = true]; // "list(type)"
repeated TensorShapeProto shape = 7; // "list(shape)"
repeated TensorProto tensor = 8; // "list(tensor)"
repeated NameAttrList func = 9; // "list(attr)"
}
// LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc)
oneof value {
bytes s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
DataType type = 6; // "type"
TensorShapeProto shape = 7; // "shape"
TensorProto tensor = 8; // "tensor"
ListValue list = 1; // any "list(...)"
// "func" represents a function. func.name is a function's name or
// a primitive op's name. func.attr.first is the name of an attr
// defined for that function. func.attr.second is the value for
// that attr in the instantiation.
NameAttrList func = 10;
// This is a placeholder only used in nodes defined inside a
// function. It indicates the attr value will be supplied when
// the function is instantiated. For example, let us suppose a
// node "N" in function "FN". "N" has an attr "A" with value
// placeholder = "foo". When FN is instantiated with attr "foo"
// set to "bar", the instantiated node N's attr A will have been
// given the value "bar".
string placeholder = 9;
}
}
// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NameAttrList {
string name = 1;
map<string, AttrValue> attr = 2;
}
syntax = "proto3";
package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "FunctionProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework";
import "attr_value.proto";
import "node_def.proto";
import "op_def.proto";
// A library is a set of named functions.
message FunctionDefLibrary {
repeated FunctionDef function = 1;
repeated GradientDef gradient = 2;
}
// A function can be instantiated when the runtime can bind every attr
// with a value. When a GraphDef has a call to a function, it must
// have binding for every attr defined in the signature.
//
// TODO(zhifengc):
// * device spec, etc.
message FunctionDef {
// The definition of the function's name, arguments, return values,
// attrs etc.
OpDef signature = 1;
// Attributes specific to this function definition.
map<string, AttrValue> attr = 5;
// NOTE: field id 2 deleted on Jan 11, 2017, GraphDef version 21.
reserved 2;
// In both of the following fields, there is the need to specify an
// output that is used as either the input to another node (in
// `node_def`) or as a return value of the function (in `ret`).
// Unlike the NodeDefs in GraphDef, we need to be able to specify a
// list in some cases (instead of just single outputs). Also, we
// need to be able to deal with lists of unknown length (so the
// output index may not be known at function definition time). So
// we use the following format instead:
// * "fun_in" where "fun_in" is the name of a function input arg in
// the `signature` field above. This represents that input, whether
// it is a single tensor or a list.
// * "fun_in:0" gives the first element of a function input arg (a
// non-list input is considered a list of length 1 for these
// purposes).
// * "node:out" where "node" is the name of a node in `node_def` and
// "out" is the name one of its op's output arguments (the name
// comes from the OpDef of the node's op). This represents that
// node's output, whether it is a single tensor or a list.
// Note: We enforce that an op's output arguments are never
// renamed in the backwards-compatibility test.
// * "node:out:0" gives the first element of a node output arg (a
// non-list output is considered a list of length 1 for these
// purposes).
//
// NOT CURRENTLY SUPPORTED (but may be in the future):
// * "node:out:-1" gives last element in a node output list
// * "node:out:1:" gives a list with all but the first element in a
// node output list
// * "node:out::-1" gives a list with all but the last element in a
// node output list
// The body of the function. Unlike the NodeDefs in a GraphDef, attrs
// may have values of type `placeholder` and the `input` field uses
// the "output" format above.
// By convention, "op" in node_def is resolved by consulting with a
// user-defined library first. If not resolved, "func" is assumed to
// be a builtin op.
repeated NodeDef node_def = 3;
// A mapping from the output arg names from `signature` to the
// outputs from `node_def` that should be returned by the function.
map<string, string> ret = 4;
}
// GradientDef defines the gradient function of a function defined in
// a function library.
//
// A gradient function g (specified by gradient_func) for a function f
// (specified by function_name) must follow the following:
//
// The function 'f' must be a numerical function which takes N inputs
// and produces M outputs. Its gradient function 'g', which is a
// function taking N + M inputs and produces N outputs.
//
// I.e. if we have
// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N),
// then, g is
// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N,
// dL/dy1, dL/dy2, ..., dL/dy_M),
// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the
// loss function). dL/dx_i is the partial derivative of L with respect
// to x_i.
message GradientDef {
string function_name = 1; // The function name.
string gradient_func = 2; // The gradient function's name.
}
syntax = "proto3";
package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "GraphProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework";
import "node_def.proto";
import "function.proto";
import "versions.proto";
// Represents the graph of operations
message GraphDef {
repeated NodeDef node = 1;
// Compatibility versions of the graph. See core/public/version.h for version
// history. The GraphDef version is distinct from the TensorFlow version, and
// each release of TensorFlow will support a range of GraphDef versions.
VersionDef versions = 4;
// Deprecated single version field; use versions above instead. Since all
// GraphDef changes before "versions" was introduced were forward
// compatible, this field is entirely ignored.
int32 version = 3 [deprecated = true];
// EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET.
//
// "library" provides user-defined functions.
//
// Naming:
// * library.function.name are in a flat namespace.
// NOTE: We may need to change it to be hierarchical to support
// different orgs. E.g.,
// { "/google/nn", { ... }},
// { "/google/vision", { ... }}
// { "/org_foo/module_bar", { ... }}
// map<string, FunctionDefLib> named_lib;
// * If node[i].op is the name of one function in "library",
// node[i] is deemed as a function call. Otherwise, node[i].op
// must be a primitive operation supported by the runtime.
//
//
// Function call semantics:
//
// * The callee may start execution as soon as some of its inputs
// are ready. The caller may want to use Tuple() mechanism to
// ensure all inputs are ready in the same time.
//
// * The consumer of return values may start executing as soon as
// the return values the consumer depends on are ready. The
// consumer may want to use Tuple() mechanism to ensure the
// consumer does not start until all return values of the callee
// function are ready.
FunctionDefLibrary library = 2;
};
syntax = "proto3";
package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "NodeProto";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework";
import "attr_value.proto";
message NodeDef {
// The name given to this operator. Used for naming inputs,
// logging, visualization, etc. Unique within a single GraphDef.
// Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*".
string name = 1;
// The operation name. There may be custom parameters in attrs.
// Op names starting with an underscore are reserved for internal use.
string op = 2;
// Each input is "node:src_output" with "node" being a string name and
// "src_output" indicating which output tensor to use from "node". If
// "src_output" is 0 the ":0" suffix can be omitted. Regular inputs
// may optionally be followed by control inputs that have the format
// "^node".
repeated string input = 3;
// A (possibly partial) specification for the device on which this
// node should be placed.
// The expected syntax for this string is as follows:
//
// DEVICE_SPEC ::= PARTIAL_SPEC
//
// PARTIAL_SPEC ::= ("/" CONSTRAINT) *
// CONSTRAINT ::= ("job:" JOB_NAME)
// | ("replica:" [1-9][0-9]*)
// | ("task:" [1-9][0-9]*)
// | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") )
//
// Valid values for this string include:
// * "/job:worker/replica:0/task:1/device:GPU:3" (full specification)
// * "/job:worker/device:GPU:3" (partial specification)
// * "" (no specification)
//
// If the constraints do not resolve to a single device (or if this
// field is empty or not present), the runtime will attempt to
// choose a device automatically.
string device = 4;
// Operation-specific graph-construction-time configuration.
// Note that this should include all attrs defined in the
// corresponding OpDef, including those with a value matching
// the default -- this allows the default to change and makes
// NodeDefs easier to interpret on their own. However, if
// an attr with a default is not specified in this list, the
// default will be used.
// The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and
// one of the names from the corresponding OpDef's attr field).
// The values must have a type matching the corresponding OpDef
// attr's type field.
// TODO(josh11b): Add some examples here showing best practices.
map<string, AttrValue> attr = 5;
};
syntax = "proto3";
package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "OpDefProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework";
import "attr_value.proto";
import "types.proto";
// Defines an operation. A NodeDef in a GraphDef specifies an Op by
// using the "op" field which should match the name of a OpDef.
// LINT.IfChange
message OpDef {
// Op names starting with an underscore are reserved for internal use.
// Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*".
string name = 1;
// For describing inputs and outputs.
message ArgDef {
// Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*".
string name = 1;
// Human readable description.
string description = 2;
// Describes the type of one or more tensors that are accepted/produced
// by this input/output arg. The only legal combinations are:
// * For a single tensor: either the "type" field is set or the
// "type_attr" field is set to the name of an attr with type "type".
// * For a sequence of tensors with the same type: the "number_attr"
// field will be set to the name of an attr with type "int", and
// either the "type" or "type_attr" field will be set as for
// single tensors.
// * For a sequence of tensors, the "type_list_attr" field will be set
// to the name of an attr with type "list(type)".
DataType type = 3;
string type_attr = 4; // if specified, attr must have type "type"
string number_attr = 5; // if specified, attr must have type "int"
// If specified, attr must have type "list(type)", and none of
// type, type_attr, and number_attr may be specified.
string type_list_attr = 6;
// For inputs: if true, the inputs are required to be refs.
// By default, inputs can be either refs or non-refs.
// For outputs: if true, outputs are refs, otherwise they are not.
bool is_ref = 16;
};
// Description of the input(s).
repeated ArgDef input_arg = 2;
// Description of the output(s).
repeated ArgDef output_arg = 3;
// Description of the graph-construction-time configuration of this
// Op. That is to say, this describes the attr fields that will
// be specified in the NodeDef.
message AttrDef {
// A descriptive name for the argument. May be used, e.g. by the
// Python client, as a keyword argument name, and so should match
// the regexp "[a-z][a-z0-9_]+".
string name = 1;
// One of the type names from attr_value.proto ("string", "list(string)",
// "int", etc.).
string type = 2;
// A reasonable default for this attribute if the user does not supply
// a value. If not specified, the user must supply a value.
AttrValue default_value = 3;
// Human-readable description.
string description = 4;
// TODO(josh11b): bool is_optional?
// --- Constraints ---
// These constraints are only in effect if specified. Default is no
// constraints.
// For type == "int", this is a minimum value. For "list(___)"
// types, this is the minimum length.
bool has_minimum = 5;
int64 minimum = 6;
// The set of allowed values. Has type that is the "list" version
// of the "type" field above (uses the "list" field of AttrValue).
// If type == "type" or "list(type)" above, then the "type" field
// of "allowed_values.list" has the set of allowed DataTypes.
// If type == "string" or "list(string)", then the "s" field of
// "allowed_values.list" has the set of allowed strings.
AttrValue allowed_values = 7;
}
repeated AttrDef attr = 4;
// Optional deprecation based on GraphDef versions.
OpDeprecation deprecation = 8;
// One-line human-readable description of what the Op does.
string summary = 5;
// Additional, longer human-readable description of what the Op does.
string description = 6;
// -------------------------------------------------------------------------
// Which optimizations this operation can participate in.
// True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs)
bool is_commutative = 18;
// If is_aggregate is true, then this operation accepts N >= 2
// inputs and produces 1 output all of the same type. Should be
// associative and commutative, and produce output with the same
// shape as the input. The optimizer may replace an aggregate op
// taking input from multiple devices with a tree of aggregate ops
// that aggregate locally within each device (and possibly within
// groups of nearby devices) before communicating.
// TODO(josh11b): Implement that optimization.
bool is_aggregate = 16; // for things like add
// Other optimizations go here, like
// can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc.
// -------------------------------------------------------------------------
// Optimization constraints.
// Ops are marked as stateful if their behavior depends on some state beyond
// their input tensors (e.g. variable reading op) or if they have
// a side-effect (e.g. printing or asserting ops). Equivalently, stateless ops
// must always produce the same output for the same input and have
// no side-effects.
//
// By default Ops may be moved between devices. Stateful ops should
// either not be moved, or should only be moved if that state can also
// be moved (e.g. via some sort of save / restore).
// Stateful ops are guaranteed to never be optimized away by Common
// Subexpression Elimination (CSE).
bool is_stateful = 17; // for things like variables, queue
// -------------------------------------------------------------------------
// Non-standard options.
// By default, all inputs to an Op must be initialized Tensors. Ops
// that may initialize tensors for the first time should set this
// field to true, to allow the Op to take an uninitialized Tensor as
// input.
bool allows_uninitialized_input = 19; // for Assign, etc.
};
// LINT.ThenChange(
// https://www.tensorflow.org/code/tensorflow/core/framework/op_def_util.cc)
// Information about version-dependent deprecation of an op
message OpDeprecation {
// First GraphDef version at which the op is disallowed.
int32 version = 1;
// Explanation of why it was deprecated and what to use instead.
string explanation = 2;
};
// A collection of OpDefs
message OpList {
repeated OpDef op = 1;
};
#include <migraphx/tf.hpp>
int main(int argc, char const* argv[])
{
if(argc > 1)
{
bool is_nhwc = true;
if(argc > 2)
{
if(argv[2] == "nchw")
is_nhwc = false;
}
std::string file = argv[1];
auto prog = migraphx::parse_tf(file, is_nhwc);
std::cout << prog << std::endl;
}
}
syntax = "proto3";
package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "ResourceHandle";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework";
// Protocol buffer representing a handle to a tensorflow resource. Handles are
// not valid across executions, but can be serialized back and forth from within
// a single run.
message ResourceHandleProto {
// Unique name for the device containing the resource.
string device = 1;
// Container in which this resource is placed.
string container = 2;
// Unique name of this resource.
string name = 3;
// Hash code for the type of the resource. Is only valid in the same device
// and in the same execution.
uint64 hash_code = 4;
// For debug-only, the name of the type pointed to by this handle, if
// available.
string maybe_type_name = 5;
};
syntax = "proto3";
package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "TensorProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework";
import "resource_handle.proto";
import "tensor_shape.proto";
import "types.proto";
// Protocol buffer representing a tensor.
message TensorProto {
DataType dtype = 1;
// Shape of the tensor. TODO(touts): sort out the 0-rank issues.
TensorShapeProto tensor_shape = 2;
// Only one of the representations below is set, one of "tensor_contents" and
// the "xxx_val" attributes. We are not using oneof because as oneofs cannot
// contain repeated fields it would require another extra set of messages.
// Version number.
//
// In version 0, if the "repeated xxx" representations contain only one
// element, that element is repeated to fill the shape. This makes it easy
// to represent a constant Tensor with a single value.
int32 version_number = 3;
// Serialized raw tensor content from either Tensor::AsProtoTensorContent or
// memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation
// can be used for all tensor types. The purpose of this representation is to
// reduce serialization overhead during RPC call by avoiding serialization of
// many repeated small items.
bytes tensor_content = 4;
// Type specific representations that make it easy to create tensor protos in
// all languages. Only the representation corresponding to "dtype" can
// be set. The values hold the flattened representation of the tensor in
// row major order.
// DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll
// have some pointless zero padding for each value here.
repeated int32 half_val = 13 [packed = true];
// DT_FLOAT.
repeated float float_val = 5 [packed = true];
// DT_DOUBLE.
repeated double double_val = 6 [packed = true];
// DT_INT32, DT_INT16, DT_INT8, DT_UINT8.
repeated int32 int_val = 7 [packed = true];
// DT_STRING
repeated bytes string_val = 8;
// DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real
// and imaginary parts of i-th single precision complex.
repeated float scomplex_val = 9 [packed = true];
// DT_INT64
repeated int64 int64_val = 10 [packed = true];
// DT_BOOL
repeated bool bool_val = 11 [packed = true];
// DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real
// and imaginary parts of i-th double precision complex.
repeated double dcomplex_val = 12 [packed = true];
// DT_RESOURCE
repeated ResourceHandleProto resource_handle_val = 14;
// DT_VARIANT
repeated VariantTensorDataProto variant_val = 15;
// DT_UINT32
repeated uint32 uint32_val = 16 [packed = true];
// DT_UINT64
repeated uint64 uint64_val = 17 [packed = true];
};
// Protocol buffer representing the serialization format of DT_VARIANT tensors.
message VariantTensorDataProto {
// Name of the type of objects being serialized.
string type_name = 1;
// Portions of the object that are not Tensors.
bytes metadata = 2;
// Tensors contained within objects being serialized.
repeated TensorProto tensors = 3;
}
// Protocol buffer representing the shape of tensors.
syntax = "proto3";
option cc_enable_arenas = true;
option java_outer_classname = "TensorShapeProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework";
package tensorflow;
// Dimensions of a tensor.
message TensorShapeProto {
// One dimension of the tensor.
message Dim {
// Size of the tensor in that dimension.
// This value must be >= -1, but values of -1 are reserved for "unknown"
// shapes (values of -1 mean "unknown" dimension). Certain wrappers
// that work with TensorShapeProto may fail at runtime when deserializing
// a TensorShapeProto containing a dim value of -1.
int64 size = 1;
// Optional name of the tensor dimension.
string name = 2;
};
// Dimensions of the tensor, such as {"input", 30}, {"output", 40}
// for a 30 x 40 2D tensor. If an entry has size -1, this
// corresponds to a dimension of unknown size. The names are
// optional.
//
// The order of entries in "dim" matters: It indicates the layout of the
// values in the tensor in-memory representation.
//
// The first entry in "dim" is the outermost dimension used to layout the
// values, the last entry is the innermost dimension. This matches the
// in-memory layout of RowMajor Eigen tensors.
//
// If "dim.size()" > 0, "unknown_rank" must be false.
repeated Dim dim = 2;
// If true, the number of dimensions in the shape is unknown.
//
// If true, "dim.size()" must be 0.
bool unknown_rank = 3;
};
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <graph.pb.h>
#include <iostream>
#include <fstream>
#include <unordered_map>
#include <unordered_set>
#include <functional>
#include <array>
#include <utility>
#include <vector>
#include <migraphx/fallthrough.hpp>
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
#include <migraphx/tf.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct tf_parser
{
using attribute_map = std::unordered_map<std::string, tensorflow::AttrValue>;
using node_map = std::unordered_map<std::string, tensorflow::NodeDef>;
// using input_node_map = std::unordered_map<std::string, std::unordered_set<std::string>>;
using op_func = std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>;
node_map nodes;
std::vector<tensorflow::NodeDef> input_nodes;
std::unordered_map<std::string, instruction_ref> instructions;
program prog = program();
bool is_nhwc = true;
std::unordered_map<std::string, op_func> ops;
tf_parser()
{
add_generic_op("Identity", op::identity{});
add_generic_op("Relu", op::relu{});
add_binary_op("BiasAdd", op::add{});
// add_mem_op("AvgPool", &tf_parser::parse_pooling);
// add_mem_op("ConcatV2", &tf_parser::parse_concat);
add_mem_op("Const", &tf_parser::parse_constant);
add_mem_op("Conv2D", &tf_parser::parse_conv);
add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
}
template <class F>
void add_mem_op(std::string name, F f)
{
ops.emplace(name, [=](auto&&... xs) {
return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
});
}
template <class T>
void add_binary_op(std::string name, T x)
{
ops.emplace(name, [this, x](attribute_map, std::vector<instruction_ref> args) {
if(args.size() != 2)
MIGRAPHX_THROW("binary operators should have 2 operands");
return add_broadcastable_binary_op(args[0], args[1], x);
});
}
template <class T>
instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
{
if(arg0->get_shape() != arg1->get_shape())
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
//
// Get lengths for both arguments
const std::vector<std::size_t>* s0 = &arg0->get_shape().lens();
const std::vector<std::size_t>* s1 = &arg1->get_shape().lens();
// Make sure s0 is the smaller size
if(s0->size() > s1->size())
std::swap(s0, s1);
std::vector<std::size_t> output_lens(*s1);
auto offset = s1->size() - s0->size();
std::transform(s0->begin(),
s0->end(),
s1->begin() + offset,
output_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); });
auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, arg0);
auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, arg1);
return prog.add_instruction(x, l0, l1);
}
else
{
return prog.add_instruction(x, {arg0, arg1});
}
}
template <class T>
void add_generic_op(std::string name, T x)
{
ops.emplace(name, [this, x](attribute_map, std::vector<instruction_ref> args) {
return prog.add_instruction(x, args);
});
}
instruction_ref
parse_batchnorm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
float epsilon = 1e-4f;
float momentum = 1.f;
op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
if(contains(attributes, "epsilon"))
{
epsilon = attributes.at("epsilon").f();
}
op::batch_norm_inference op{epsilon, momentum, bn_mode};
return prog.add_instruction(op, std::move(args));
}
instruction_ref
parse_concat(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
// get index for axis within args
std::size_t axis_idx = attributes.at("N").i();
std::size_t axis = args[axis_idx]->eval().at<int64_t>();
op::concat op{axis};
return prog.add_instruction(op, std::move(args));
}
instruction_ref parse_constant(const std::string&,
attribute_map attributes,
const std::vector<instruction_ref>&)
{
literal v = parse_tensor(attributes.at("value").tensor());
return prog.add_literal(v);
}
instruction_ref
parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
op::convolution op;
if(contains(attributes, "padding"))
{
const std::string& pad_mode = attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos)
{
op.padding_mode = op::convolution::same;
}
else if (pad_mode.find("EXPLICIT") != std::string::npos)
{
std::vector<std::size_t> padding(4);
copy(attributes.at("explicit_paddings").list().i(), padding.begin());
if(padding.size() != 4)
{
MIGRAPHX_THROW("padding should have 4 values");
}
if(padding[0] != padding[2] || padding[1] != padding[3])
{
MIGRAPHX_THROW("migraphx does not support asymetric padding");
}
op.padding[0] = padding[0];
op.padding[1] = padding[1];
}
}
if(contains(attributes, "strides"))
{
copy(attributes.at("strides").list().i(), op.stride.begin());
}
if(contains(attributes, "dilations"))
{
copy(attributes.at("dilations").list().i(), op.dilation.begin());
}
auto l0 = args[1];
if(is_nhwc)
l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, l0);
return prog.add_instruction(op, {args[0], l0});
}
// instruction_ref parse_pooling(const std::string& name,
// attribute_map attributes,
// std::vector<instruction_ref> args)
// {
// op::pooling op{starts_with(name, "Max") ? "max" : "average"};
// if(contains(attributes, "pads"))
// {
// std::vector<std::size_t> padding(4);
// copy(attributes["pads"].ints(), padding.begin());
// if(padding.size() != 4)
// {
// MIGRAPHX_THROW("padding should have 4 values");
// }
// if(padding[0] != padding[2] || padding[1] != padding[3])
// {
// MIGRAPHX_THROW("migraphx does not support asymetric padding");
// }
// op.padding[0] = padding[0];
// op.padding[1] = padding[1];
// }
// if(contains(attributes, "strides"))
// {
// copy(attributes["strides"].ints(), op.stride.begin());
// }
// if(contains(attributes, "kernel_shape"))
// {
// copy(attributes["kernel_shape"].ints(), op.lengths.begin());
// }
// if(contains(attributes, "auto_pad"))
// {
// auto s = attributes["auto_pad"].s();
// if(to_upper(s) != "NOTSET")
// {
// MIGRAPHX_THROW("auto_pad is not supported for pooling");
// }
// }
// return prog.add_instruction(op, std::move(args));
// }
void parse_from(std::istream& is)
{
tensorflow::GraphDef graph;
if(graph.ParseFromIstream(&is))
{
this->parse_graph(graph);
}
else
{
throw std::runtime_error("Failed reading");
}
}
void parse_graph(const tensorflow::GraphDef& graph)
{
nodes = get_nodes(graph, input_nodes);
for(auto&& input : input_nodes)
{
const std::string& name = input.name();
attribute_map input_attrs = get_attributes(input);
shape::type_t shape_type = parse_type(input_attrs.at("dtype").type());
std::vector<size_t> dims = parse_dims(input_attrs.at("shape").shape());
shape s = shape{shape_type, dims};
instructions[name] = prog.add_parameter(name, s);
if(is_nhwc)
{
// nhwc to nchw
prog.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, instructions[name]);
}
}
for(auto&& p : nodes)
{
this->parse_node(get_name(p.second));
}
}
void parse_node(const std::string& name)
{
if(instructions.count(name) == 0)
{
auto&& node = nodes.at(name);
std::vector<instruction_ref> args;
std::cout << name << std::endl;
for(auto&& input : node.input())
{
if(nodes.count(input) > 0)
{
auto&& iname = get_name(nodes.at(input));
assert(name != iname);
this->parse_node(iname);
args.push_back(instructions.at(iname));
}
else
{
args.push_back(instructions.at(input));
}
}
if(ops.count(node.op()) == 0)
{
instructions[name] = prog.add_instruction(unknown{node.op()}, args);
}
else
{
instructions[name] = ops[node.op()](get_attributes(node), args);
}
}
}
static attribute_map get_attributes(const tensorflow::NodeDef& node)
{
attribute_map result;
for (auto&& attr : node.attr())
{
result[attr.first] = attr.second;
}
return result;
}
static std::string get_name(const tensorflow::NodeDef& node)
{
return node.name();
}
static node_map get_nodes(const tensorflow::GraphDef& graph, std::vector<tensorflow::NodeDef>& input_nodes)
{
node_map result;
for(auto&& node : graph.node())
{
auto node_name = get_name(node);
// assume each node in graph has an associated name
if(node_name.empty())
MIGRAPHX_THROW("tf node with no name found");
result[node_name] = node;
if(node.op() == "Placeholder")
{
input_nodes.push_back(node);
}
}
return result;
}
static shape::type_t parse_type(const tensorflow::DataType t)
{
shape::type_t shape_type{};
switch(t)
{
case tensorflow::DataType::DT_INVALID:
break; // throw std::runtime_error("Unsupported type UNDEFINED");
case tensorflow::DataType::DT_FLOAT: shape_type = shape::float_type; break;
case tensorflow::DataType::DT_DOUBLE: shape_type = shape::double_type; break;
case tensorflow::DataType::DT_INT32: shape_type = shape::int32_type; break;
case tensorflow::DataType::DT_UINT8:
break; // throw std::runtime_error("Unsupported type UINT8");
case tensorflow::DataType::DT_INT16: shape_type = shape::int16_type; break;
case tensorflow::DataType::DT_INT8: shape_type = shape::int8_type; break;
case tensorflow::DataType::DT_STRING:
break; // throw std::runtime_error("Unsupported type STRING");
case tensorflow::DataType::DT_COMPLEX64:
break; // throw std::runtime_error("Unsupported type COMPLEX64");
case tensorflow::DataType::DT_INT64: shape_type = shape::int64_type; break;
case tensorflow::DataType::DT_BOOL:
break; // throw std::runtime_error("Unsupported type BOOL");
case tensorflow::DataType::DT_QINT8:
break; // throw std::runtime_error("Unsupported type QINT8");
case tensorflow::DataType::DT_QUINT8:
break; // throw std::runtime_error("Unsupported type QUINT8");
case tensorflow::DataType::DT_QINT32:
break; // throw std::runtime_error("Unsupported type QINT32");
case tensorflow::DataType::DT_BFLOAT16:
break; // throw std::runtime_error("Unsupported type BFLOAT16");
case tensorflow::DataType::DT_QINT16:
break; // throw std::runtime_error("Unsupported type QINT16");
case tensorflow::DataType::DT_QUINT16:
break; // throw std::runtime_error("Unsupported type QUINT16");
case tensorflow::DataType::DT_UINT16: shape_type = shape::uint16_type; break;
case tensorflow::DataType::DT_COMPLEX128:
break; // throw std::runtime_error("Unsupported type COMPLEX128");
case tensorflow::DataType::DT_HALF: shape_type = shape::half_type; break;
case tensorflow::DataType::DT_RESOURCE:
break; // throw std::runtime_error("Unsupported type RESOURCE");
case tensorflow::DataType::DT_VARIANT:
break; // throw std::runtime_error("Unsupported type VARIANT");
case tensorflow::DataType::DT_UINT32: shape_type = shape::uint32_type; break;
case tensorflow::DataType::DT_UINT64: shape_type = shape::uint64_type; break;
default:
break;
}
return shape_type;
}
static literal parse_tensor(const tensorflow::TensorProto t)
{
std::vector<size_t> dims = parse_dims(t.tensor_shape());
if(dims.empty())
{
dims = {1};
}
if(!t.tensor_content().empty()) // has raw data
{
const std::string&s = t.tensor_content();
switch(t.dtype())
{
case tensorflow::DataType::DT_INVALID: throw std::runtime_error("");
case tensorflow::DataType::DT_FLOAT: return literal{{shape::float_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_INT8: return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT16: return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_INT16: return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_INT32: return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_INT64: return literal{{shape::int64_type, dims}, s.data()};
case tensorflow::DataType::DT_STRING: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL: return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_HALF: return literal{{shape::half_type, dims}, s.data()};
case tensorflow::DataType::DT_DOUBLE: return literal{{shape::double_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT32: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT64: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX64: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX128: throw std::runtime_error("");
default:
break;
}
MIGRAPHX_THROW("Invalid tensor type");
}
switch(t.dtype())
{
case tensorflow::DataType::DT_INVALID: throw std::runtime_error("");
case tensorflow::DataType::DT_FLOAT:
return literal{{shape::float_type, dims}, t.float_val().begin(), t.float_val().end()};
case tensorflow::DataType::DT_UINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_INT8:
return literal{{shape::int32_type, dims}, t.int_val().begin(), t.int_val().end()};
case tensorflow::DataType::DT_UINT16:
return literal{{shape::int32_type, dims}, t.int_val().begin(), t.int_val().end()};
case tensorflow::DataType::DT_INT16:
return literal{{shape::int32_type, dims}, t.int_val().begin(), t.int_val().end()};
case tensorflow::DataType::DT_INT32:
return literal{{shape::int32_type, dims}, t.int_val().begin(), t.int_val().end()};
case tensorflow::DataType::DT_INT64:
return literal{{shape::int64_type, dims}, t.int64_val().begin(), t.int64_val().end()};
case tensorflow::DataType::DT_STRING: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL:
return literal{{shape::int32_type, dims}, t.bool_val().begin(), t.bool_val().end()};
case tensorflow::DataType::DT_HALF:
return literal{{shape::half_type, dims}, t.half_val().begin(), t.half_val().end()};
case tensorflow::DataType::DT_DOUBLE:
return literal{
{shape::double_type, dims}, t.double_val().begin(), t.double_val().end()};
case tensorflow::DataType::DT_UINT32: throw std::runtime_error("");
case tensorflow::DataType::DT_UINT64: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX64: throw std::runtime_error("");
case tensorflow::DataType::DT_COMPLEX128: throw std::runtime_error("");
default:
break;
}
MIGRAPHX_THROW("Invalid tensor type");
}
static std::vector<size_t> parse_dims(const tensorflow::TensorShapeProto& s)
{
std::vector<size_t> dims;
auto input_dims = s.dim();
for(auto dim : input_dims)
{
dims.push_back(dim.size());
}
return dims;
}
};
program parse_tf(const std::string& name, bool is_nhwc)
{
std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
tf_parser parser;
parser.is_nhwc = is_nhwc;
#ifndef NDEBUG
// Log the program when it can't be parsed
try
{
parser.parse_from(input);
}
catch(...)
{
std::cerr << parser.prog << std::endl;
throw;
}
#else
parser.parse_from(input);
#endif
return std::move(parser.prog);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
syntax = "proto3";
package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "TypesProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework";
// LINT.IfChange
enum DataType {
// Not a legal value for DataType. Used to indicate a DataType field
// has not been set.
DT_INVALID = 0;
// Data types that all computation devices are expected to be
// capable to support.
DT_FLOAT = 1;
DT_DOUBLE = 2;
DT_INT32 = 3;
DT_UINT8 = 4;
DT_INT16 = 5;
DT_INT8 = 6;
DT_STRING = 7;
DT_COMPLEX64 = 8; // Single-precision complex
DT_INT64 = 9;
DT_BOOL = 10;
DT_QINT8 = 11; // Quantized int8
DT_QUINT8 = 12; // Quantized uint8
DT_QINT32 = 13; // Quantized int32
DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops.
DT_QINT16 = 15; // Quantized int16
DT_QUINT16 = 16; // Quantized uint16
DT_UINT16 = 17;
DT_COMPLEX128 = 18; // Double-precision complex
DT_HALF = 19;
DT_RESOURCE = 20;
DT_VARIANT = 21; // Arbitrary C++ data types
DT_UINT32 = 22;
DT_UINT64 = 23;
// Do not use! These are only for parameters. Every enum above
// should have a corresponding value below (verified by types_test).
DT_FLOAT_REF = 101;
DT_DOUBLE_REF = 102;
DT_INT32_REF = 103;
DT_UINT8_REF = 104;
DT_INT16_REF = 105;
DT_INT8_REF = 106;
DT_STRING_REF = 107;
DT_COMPLEX64_REF = 108;
DT_INT64_REF = 109;
DT_BOOL_REF = 110;
DT_QINT8_REF = 111;
DT_QUINT8_REF = 112;
DT_QINT32_REF = 113;
DT_BFLOAT16_REF = 114;
DT_QINT16_REF = 115;
DT_QUINT16_REF = 116;
DT_UINT16_REF = 117;
DT_COMPLEX128_REF = 118;
DT_HALF_REF = 119;
DT_RESOURCE_REF = 120;
DT_VARIANT_REF = 121;
DT_UINT32_REF = 122;
DT_UINT64_REF = 123;
}
// LINT.ThenChange(
// https://www.tensorflow.org/code/tensorflow/c/c_api.h,
// https://www.tensorflow.org/code/tensorflow/go/tensor.go,
// https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc,
// https://www.tensorflow.org/code/tensorflow/core/framework/types.h,
// https://www.tensorflow.org/code/tensorflow/core/framework/types.cc,
// https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py,
// https://www.tensorflow.org/code/tensorflow/python/framework/function.py)
syntax = "proto3";
package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "VersionsProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework";
// Version information for a piece of serialized data
//
// There are different types of versions for each type of data
// (GraphDef, etc.), but they all have the same common shape
// described here.
//
// Each consumer has "consumer" and "min_producer" versions (specified
// elsewhere). A consumer is allowed to consume this data if
//
// producer >= min_producer
// consumer >= min_consumer
// consumer not in bad_consumers
//
message VersionDef {
// The version of the code that produced this data.
int32 producer = 1;
// Any consumer below this version is not allowed to consume this data.
int32 min_consumer = 2;
// Specific consumer versions which are disallowed (e.g. due to bugs).
repeated int32 bad_consumers = 3;
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment