myelination.cc 5.81 KB
Newer Older
Terry Koo's avatar
Terry Koo committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================

#include "dragnn/runtime/myelin/myelination.h"

#include <algorithm>
#include <memory>
#include <utility>
#include <vector>

#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/myelin/myelin_cell_converter.h"
#include "dragnn/runtime/myelin/myelin_spec_utils.h"
#include "dragnn/runtime/trained_model.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"

namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {

// Updates the Component subclass in the |component_spec| to a Myelin-based
// version.  On error, returns non-OK and modifies nothing.
tensorflow::Status MyelinateComponentSubclass(ComponentSpec *component_spec) {
  const string subclass = GetNormalizedComponentBuilderName(*component_spec);
  if (subclass != "DynamicComponent") {
    return tensorflow::errors::Unimplemented(
        "No Myelin-based version of Component subclass '", subclass, "'");
  }

  // By convention, the Myelin-based version of "FooComponent" should be named
  // "MyelinFooComponent".
  component_spec->mutable_component_builder()->set_registered_name(
      tensorflow::strings::StrCat("Myelin", subclass));
  return tensorflow::Status::OK();
}

// Appends the list of component specs in the |master_spec| whose names match
// |component_names| to |matching_components|.  On error, returns non-OK.
tensorflow::Status GetMatchingComponentSpecs(
    const std::set<string> &component_names, MasterSpec *master_spec,
    std::vector<ComponentSpec *> *matching_components) {
  // Index the components in the |master_spec| by name.
  std::map<string, ComponentSpec *> components;
  for (ComponentSpec &component_spec : *master_spec->mutable_component()) {
    if (!components.emplace(component_spec.name(), &component_spec).second) {
      return tensorflow::errors::InvalidArgument("Duplicate component name: ",
                                                 component_spec.name());
    }
  }

  // Append the components named in the |component_names|.
  for (const string &component_name : component_names) {
    if (components.find(component_name) == components.end()) {
        return tensorflow::errors::InvalidArgument("Unknown component name: ",
                                                   component_name);
    }
    matching_components->push_back(components[component_name]);
  }

  return tensorflow::Status::OK();
}

}  // namespace

tensorflow::Status MyelinateCells(const string &saved_model_dir,
                                  const string &master_spec_path,
                                  const std::set<string> &component_names,
                                  const string &output_dir) {
  MasterSpec master_spec;
  TF_RETURN_IF_ERROR(tensorflow::ReadTextProto(tensorflow::Env::Default(),
                                               master_spec_path, &master_spec));

  std::vector<ComponentSpec *> components;
  TF_RETURN_IF_ERROR(
      GetMatchingComponentSpecs(component_names, &master_spec, &components));

  // Returns the path to the output Flow file for the |component_spec|.
  const auto get_flow_path = [&](const ComponentSpec &component_spec) {
    return tensorflow::io::JoinPath(
        output_dir,
        tensorflow::strings::StrCat(component_spec.name(), ".flow"));
  };

  // Modify the MasterSpec first, to catch issues before loading the trained
  // model, which is slow.
  for (ComponentSpec *component_spec : components) {
    // Add a resource for the Flow file to each component.  The file will be
    // created in a second pass, after loading the trained model.
    TF_RETURN_IF_ERROR(
        AddMyelinFlowResource(get_flow_path(*component_spec), component_spec));

    // Replace the Component subclass with a Myelin-based version.
    TF_RETURN_IF_ERROR(MyelinateComponentSubclass(component_spec));

    // Set embedding_dim=-1 for all channels.
    for (auto &fixed_channel : *component_spec->mutable_fixed_feature()) {
      fixed_channel.set_embedding_dim(-1);
    }
    for (auto &linked_channel : *component_spec->mutable_linked_feature()) {
      linked_channel.set_embedding_dim(-1);
    }
  }

  // Write the updated MasterSpec.
  TF_RETURN_IF_ERROR(
      tensorflow::Env::Default()->RecursivelyCreateDir(output_dir));
  TF_RETURN_IF_ERROR(tensorflow::WriteTextProto(
      tensorflow::Env::Default(),
      tensorflow::io::JoinPath(output_dir, "master-spec"), master_spec));

  // Convert each component into a Flow and write it.
  TrainedModel trained_model;
  TF_RETURN_IF_ERROR(trained_model.Reset(saved_model_dir));
  for (const ComponentSpec *component_spec : components) {
    string flow_data;
    TF_RETURN_IF_ERROR(MyelinCellConverter::Convert(component_spec->name(),
                                                    trained_model, &flow_data));

    TF_RETURN_IF_ERROR(tensorflow::WriteStringToFile(
        tensorflow::Env::Default(), get_flow_path(*component_spec), flow_data));
  }

  return tensorflow::Status::OK();
}

}  // namespace runtime
}  // namespace dragnn
}  // namespace syntaxnet