Commit d66941ac authored by Ivan Bogatyy's avatar Ivan Bogatyy Committed by calberti
Browse files

Sync w TF r0.12 & Bazel 0.4.3, internal updates (#953)

parent efa4a6cf
......@@ -7,7 +7,7 @@ RUN mkdir -p $SYNTAXNETDIR \
&& apt-get update \
&& apt-get install git zlib1g-dev file swig python2.7 python-dev python-pip python-mock -y \
&& pip install --upgrade pip \
&& pip install -U protobuf==3.0.0 \
&& pip install -U protobuf==3.0.0b2 \
&& pip install asciitree \
&& pip install numpy \
&& wget https://github.com/bazelbuild/bazel/releases/download/0.4.3/bazel-0.4.3-installer-linux-x86_64.sh \
......@@ -15,7 +15,7 @@ RUN mkdir -p $SYNTAXNETDIR \
&& ./bazel-0.4.3-installer-linux-x86_64.sh --user \
&& git clone --recursive https://github.com/tensorflow/models.git \
&& cd $SYNTAXNETDIR/models/syntaxnet/tensorflow \
&& echo "\n\n\n\n" | ./configure \
&& echo -e "\n\n\n\n\n\n" | ./configure \
&& apt-get autoremove -y \
&& apt-get clean
......
......@@ -78,15 +78,27 @@ source. You'll need to install:
* python 2.7:
* python 3 support is not available yet
* pip (python package manager)
* `apt-get install python-pip` on Ubuntu
* `brew` installs pip along with python on OSX
* bazel:
* **version 0.4.3**
* follow the instructions [here](http://bazel.build/docs/install.html)
* Alternately, Download bazel (0.4.3) <.deb> from
[https://github.com/bazelbuild/bazel/releases]
(https://github.com/bazelbuild/bazel/releases) for your system
configuration.
* Install it using the command: sudo dpkg -i <.deb file>
* Check for the bazel version by typing: bazel version
* swig:
* `apt-get install swig` on Ubuntu
* `brew install swig` on OSX
* protocol buffers, with a version supported by TensorFlow:
* check your protobuf version with `pip freeze | grep protobuf`
* upgrade to a supported version with `pip install -U protobuf==3.0.0b2`
* mock, the testing package:
* `pip install mock`
* asciitree, to draw parse trees on the console for the demo:
* `pip install asciitree`
* numpy, package for scientific computing:
* `pip install numpy`
* mock, package for unit testing:
* `pip install mock`
Once you completed the above steps, you can build and test SyntaxNet with the
following commands:
......
......@@ -3,9 +3,9 @@ local_repository(
path = "tensorflow",
)
load('@org_tensorflow//tensorflow:workspace.bzl', 'tf_workspace')
load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace")
tf_workspace()
# Specify the minimum required Bazel version.
load("@org_tensorflow//tensorflow:tensorflow.bzl", "check_version")
check_version("0.3.0")
check_version("0.4.3")
......@@ -733,5 +733,5 @@ sh_test(
":parser_trainer",
":testdata",
],
tags = ["notsan"],
tags = ["slow"],
)
......@@ -146,10 +146,10 @@ class AffixTable {
int max_length_;
// Index from affix ids to affix items.
vector<Affix *> affixes_;
std::vector<Affix *> affixes_;
// Buckets for word-to-affix hash map.
vector<Affix *> buckets_;
std::vector<Affix *> buckets_;
TF_DISALLOW_COPY_AND_ASSIGN(AffixTable);
};
......
......@@ -298,6 +298,40 @@ class ArcStandardTransitionSystem : public ParserTransitionSystem {
ParserTransitionState *NewTransitionState(bool training_mode) const override {
return new ArcStandardTransitionState();
}
// Meta information API. Returns token indices to link parser actions back
// to positions in the input sentence.
bool SupportsActionMetaData() const override { return true; }
// Returns the child of a new arc for reduce actions.
int ChildIndex(const ParserState &state,
const ParserAction &action) const override {
switch (ActionType(action)) {
case SHIFT:
return -1;
case LEFT_ARC: // left arc pops stack(1)
return state.Stack(1);
case RIGHT_ARC:
return state.Stack(0);
default:
LOG(FATAL) << "Invalid parser action: " << action;
}
}
// Returns the parent of a new arc for reduce actions.
int ParentIndex(const ParserState &state,
const ParserAction &action) const override {
switch (ActionType(action)) {
case SHIFT:
return -1;
case LEFT_ARC: // left arc pops stack(1)
return state.Stack(0);
case RIGHT_ARC:
return state.Stack(1);
default:
LOG(FATAL) << "Invalid parser action: " << action;
}
}
};
REGISTER_TRANSITION_SYSTEM("arc-standard", ArcStandardTransitionSystem);
......
......@@ -841,7 +841,7 @@ class BeamEvalOutput : public OpKernel {
BatchState *batch_state =
reinterpret_cast<BatchState *>(context->input(0).scalar<int64>()());
const int batch_size = batch_state->BatchSize();
vector<Sentence> documents;
std::vector<Sentence> documents;
for (int beam_id = 0; beam_id < batch_size; ++beam_id) {
if (batch_state->Beam(beam_id).gold_ != nullptr &&
batch_state->Beam(beam_id).AllFinal()) {
......
......@@ -54,7 +54,7 @@ void BinarySegmentState::AddParseToDocument(const ParserState &state,
bool rewrite_root_labels,
Sentence *sentence) const {
if (sentence->token_size() == 0) return;
vector<bool> is_starts(sentence->token_size(), false);
std::vector<bool> is_starts(sentence->token_size(), false);
for (int i = 0; i < NumStarts(state); ++i) {
is_starts[LastStart(i, state)] = true;
}
......
......@@ -95,9 +95,9 @@ class SegmentationTransitionTest : public ::testing::Test {
return result.size() > 0 ? result.value(0) : -1;
}
void CheckStarts(const ParserState &state, const vector<int> &target) {
void CheckStarts(const ParserState &state, const std::vector<int> &target) {
ASSERT_EQ(state.StackSize(), target.size());
vector<int> starts;
std::vector<int> starts;
for (int i = 0; i < state.StackSize(); ++i) {
EXPECT_EQ(state.Stack(i), target[i]);
}
......
......@@ -88,7 +88,7 @@ namespace syntaxnet {
struct CharPropertyImplementation {
unordered_set<char32> chars;
vector<vector<int> > rows;
std::vector<std::vector<int> > rows;
CharPropertyImplementation() {
rows.reserve(10);
rows.resize(1);
......@@ -261,7 +261,7 @@ int CharProperty::NextElementAfter(int c) const {
return *it;
}
REGISTER_CLASS_REGISTRY("char property wrapper", CharPropertyWrapper);
REGISTER_SYNTAXNET_CLASS_REGISTRY("char property wrapper", CharPropertyWrapper);
const CharProperty *CharProperty::Lookup(const char *subclass) {
// Create a CharPropertyWrapper object and delete it. We only care about
......
......@@ -92,7 +92,7 @@ struct CharPropertyWrapper : RegisterableClass<CharPropertyWrapper> {
};
#define REGISTER_CHAR_PROPERTY_WRAPPER(type, component) \
REGISTER_CLASS_COMPONENT(CharPropertyWrapper, type, component)
REGISTER_SYNTAXNET_CLASS_COMPONENT(CharPropertyWrapper, type, component)
#define REGISTER_CHAR_PROPERTY(lsp, name) \
struct name##CharPropertyWrapper : public CharPropertyWrapper { \
......
......@@ -55,7 +55,7 @@ void GetTaskContext(OpKernelConstruction *context, TaskContext *task_context) {
// Outputs the given batch of sentences as a tensor and deletes them.
void OutputDocuments(OpKernelContext *context,
vector<Sentence *> *document_batch) {
std::vector<Sentence *> *document_batch) {
const int64 size = document_batch->size();
Tensor *output;
OP_REQUIRES_OK(context,
......@@ -84,7 +84,7 @@ class DocumentSource : public OpKernel {
void Compute(OpKernelContext *context) override {
mutex_lock lock(mu_);
Sentence *document;
vector<Sentence *> document_batch;
std::vector<Sentence *> document_batch;
while ((document = corpus_->Read()) != nullptr) {
document_batch.push_back(document);
if (static_cast<int>(document_batch.size()) == batch_size_) {
......@@ -166,7 +166,7 @@ class WellFormedFilter : public OpKernel {
void Compute(OpKernelContext *context) override {
auto documents = context->input(0).vec<string>();
vector<Sentence *> output_documents;
std::vector<Sentence *> output_documents;
for (int i = 0; i < documents.size(); ++i) {
Sentence *document = new Sentence;
OP_REQUIRES(context, document->ParseFromString(documents(i)),
......@@ -182,7 +182,7 @@ class WellFormedFilter : public OpKernel {
private:
bool ShouldKeep(const Sentence &doc) {
vector<int> visited(doc.token_size(), -1);
std::vector<int> visited(doc.token_size(), -1);
for (int i = 0; i < doc.token_size(); ++i) {
// Already visited node.
if (visited[i] != -1) continue;
......@@ -235,7 +235,7 @@ class ProjectivizeFilter : public OpKernel {
void Compute(OpKernelContext *context) override {
auto documents = context->input(0).vec<string>();
vector<Sentence *> output_documents;
std::vector<Sentence *> output_documents;
for (int i = 0; i < documents.size(); ++i) {
Sentence *document = new Sentence;
OP_REQUIRES(context, document->ParseFromString(documents(i)),
......@@ -255,8 +255,8 @@ class ProjectivizeFilter : public OpKernel {
// Left and right boundaries for arcs. The left and right ends of an arc are
// bounded by the arcs that pass over it. If an arc exceeds these bounds it
// will cross an arc passing over it, making it a non-projective arc.
vector<int> left(num_tokens);
vector<int> right(num_tokens);
std::vector<int> left(num_tokens);
std::vector<int> right(num_tokens);
// Lift the shortest non-projective arc until the document is projective.
while (true) {
......
......@@ -18,6 +18,6 @@ limitations under the License.
namespace syntaxnet {
// Component registry for document formatters.
REGISTER_CLASS_REGISTRY("document format", DocumentFormat);
REGISTER_SYNTAXNET_CLASS_REGISTRY("document format", DocumentFormat);
} // namespace syntaxnet
......@@ -32,7 +32,7 @@ namespace syntaxnet {
// A document format component converts a key/value pair from a record to one or
// more documents. The record format is used for selecting the document format
// component. A document format component can be registered with the
// REGISTER_DOCUMENT_FORMAT macro.
// REGISTER_SYNTAXNET_DOCUMENT_FORMAT macro.
class DocumentFormat : public RegisterableClass<DocumentFormat> {
public:
DocumentFormat() {}
......@@ -47,7 +47,7 @@ class DocumentFormat : public RegisterableClass<DocumentFormat> {
// Converts a key/value pair to one or more documents.
virtual void ConvertFromString(const string &key, const string &value,
vector<Sentence *> *documents) = 0;
std::vector<Sentence *> *documents) = 0;
// Converts a document to a key/value pair.
virtual void ConvertToString(const Sentence &document,
......@@ -57,8 +57,8 @@ class DocumentFormat : public RegisterableClass<DocumentFormat> {
TF_DISALLOW_COPY_AND_ASSIGN(DocumentFormat);
};
#define REGISTER_DOCUMENT_FORMAT(type, component) \
REGISTER_CLASS_COMPONENT(DocumentFormat, type, component)
#define REGISTER_SYNTAXNET_DOCUMENT_FORMAT(type, component) \
REGISTER_SYNTAXNET_CLASS_COMPONENT(DocumentFormat, type, component)
} // namespace syntaxnet
......
......@@ -46,14 +46,16 @@ void GenericEmbeddingFeatureExtractor::Setup(TaskContext *context) {
void GenericEmbeddingFeatureExtractor::Init(TaskContext *context) {
}
vector<vector<SparseFeatures>> GenericEmbeddingFeatureExtractor::ConvertExample(
const vector<FeatureVector> &feature_vectors) const {
std::vector<std::vector<SparseFeatures>>
GenericEmbeddingFeatureExtractor::ConvertExample(
const std::vector<FeatureVector> &feature_vectors) const {
// Extract the features.
vector<vector<SparseFeatures>> sparse_features(feature_vectors.size());
std::vector<std::vector<SparseFeatures>> sparse_features(
feature_vectors.size());
for (size_t i = 0; i < feature_vectors.size(); ++i) {
// Convert the nlp_parser::FeatureVector to dist belief format.
sparse_features[i] =
vector<SparseFeatures>(generic_feature_extractor(i).feature_types());
sparse_features[i] = std::vector<SparseFeatures>(
generic_feature_extractor(i).feature_types());
for (int j = 0; j < feature_vectors[i].size(); ++j) {
const FeatureType &feature_type = *feature_vectors[i].type(j);
......
......@@ -78,15 +78,20 @@ class GenericEmbeddingFeatureExtractor {
int EmbeddingDims(int index) const { return embedding_dims_[index]; }
// Accessor for embedding dims (dimensions of the embedding spaces).
const vector<int> &embedding_dims() const { return embedding_dims_; }
const std::vector<int> &embedding_dims() const { return embedding_dims_; }
const vector<string> &embedding_fml() const { return embedding_fml_; }
const std::vector<string> &embedding_fml() const { return embedding_fml_; }
// Get parameter name by concatenating the prefix and the original name.
string GetParamName(const string &param_name) const {
return tensorflow::strings::StrCat(ArgPrefix(), "_", param_name);
}
// Returns the name of the embedding space.
const string &embedding_name(int index) const {
return embedding_names_[index];
}
protected:
// Provides the generic class with access to the templated extractors. This is
// used to get the type information out of the feature extractor without
......@@ -99,21 +104,21 @@ class GenericEmbeddingFeatureExtractor {
// single SparseFeatures. The predicates are mapped through map_fn which
// should point to either mutable_map_fn or const_map_fn depending on whether
// or not the predicate maps should be updated.
vector<vector<SparseFeatures>> ConvertExample(
const vector<FeatureVector> &feature_vectors) const;
std::vector<std::vector<SparseFeatures>> ConvertExample(
const std::vector<FeatureVector> &feature_vectors) const;
private:
// Embedding space names for parameter sharing.
vector<string> embedding_names_;
std::vector<string> embedding_names_;
// FML strings for each feature extractor.
vector<string> embedding_fml_;
std::vector<string> embedding_fml_;
// Size of each of the embedding spaces (maximum predicate id).
vector<int> embedding_sizes_;
std::vector<int> embedding_sizes_;
// Embedding dimensions of the embedding spaces (i.e. 32, 64 etc.)
vector<int> embedding_dims_;
std::vector<int> embedding_dims_;
// Whether or not to add string descriptions to converted examples.
bool add_strings_;
......@@ -168,9 +173,9 @@ class EmbeddingFeatureExtractor : public GenericEmbeddingFeatureExtractor {
// will not be updated and so unrecognized predicates may occur. In such a
// case the SparseFeatures object associated with a given extractor class and
// feature will be empty.
vector<vector<SparseFeatures>> ExtractSparseFeatures(
std::vector<std::vector<SparseFeatures>> ExtractSparseFeatures(
const WorkspaceSet &workspaces, const OBJ &obj, ARGS... args) const {
vector<FeatureVector> features(feature_extractors_.size());
std::vector<FeatureVector> features(feature_extractors_.size());
ExtractFeatures(workspaces, obj, args..., &features);
return ConvertExample(features);
}
......@@ -180,7 +185,7 @@ class EmbeddingFeatureExtractor : public GenericEmbeddingFeatureExtractor {
// mapping is applied.
void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &obj,
ARGS... args,
vector<FeatureVector> *features) const {
std::vector<FeatureVector> *features) const {
DCHECK(features != nullptr);
DCHECK_EQ(features->size(), feature_extractors_.size());
for (int i = 0; i < feature_extractors_.size(); ++i) {
......@@ -201,7 +206,7 @@ class EmbeddingFeatureExtractor : public GenericEmbeddingFeatureExtractor {
private:
// Templated feature extractor class.
vector<EXTRACTOR> feature_extractors_;
std::vector<EXTRACTOR> feature_extractors_;
};
class ParserEmbeddingFeatureExtractor
......
......@@ -50,13 +50,13 @@ void GenericFeatureExtractor::InitializeFeatureTypes() {
}
}
vector<string> types_names;
std::vector<string> types_names;
GetFeatureTypeNames(&types_names);
CHECK_EQ(feature_types_.size(), types_names.size());
}
void GenericFeatureExtractor::GetFeatureTypeNames(
vector<string> *type_names) const {
std::vector<string> *type_names) const {
for (size_t i = 0; i < feature_types_.size(); ++i) {
FeatureType *ft = feature_types_[i];
type_names->push_back(ft->name());
......@@ -102,7 +102,7 @@ int GenericFeatureFunction::GetIntParameter(const string &name,
}
void GenericFeatureFunction::GetFeatureTypes(
vector<FeatureType *> *types) const {
std::vector<FeatureType *> *types) const {
if (feature_type_ != nullptr) types->push_back(feature_type_);
}
......@@ -111,7 +111,7 @@ FeatureType *GenericFeatureFunction::GetFeatureType() const {
if (feature_type_ != nullptr) return feature_type_;
// Get feature types for function.
vector<FeatureType *> types;
std::vector<FeatureType *> types;
GetFeatureTypes(&types);
// If there is exactly one feature type return this, else return null.
......
......@@ -101,7 +101,7 @@ class FeatureVector {
};
// Array for storing feature vector elements.
vector<Element> features_;
std::vector<Element> features_;
TF_DISALLOW_COPY_AND_ASSIGN(FeatureVector);
};
......@@ -133,7 +133,7 @@ class GenericFeatureExtractor {
// Returns all feature types names used by the extractor. The names are
// added to the types_names array. Invalid before Init() has been called.
void GetFeatureTypeNames(vector<string> *type_names) const;
void GetFeatureTypeNames(std::vector<string> *type_names) const;
// Returns a feature type used in the extractor. Invalid before Init() has
// been called.
......@@ -157,7 +157,7 @@ class GenericFeatureExtractor {
// Returns all feature types used by the extractor. The feature types are
// added to the result array.
virtual void GetFeatureTypes(vector<FeatureType *> *types) const = 0;
virtual void GetFeatureTypes(std::vector<FeatureType *> *types) const = 0;
// Descriptor for the feature extractor. This is a protocol buffer that
// contains all the information about the feature extractor. The feature
......@@ -167,7 +167,7 @@ class GenericFeatureExtractor {
// All feature types used by the feature extractor. The collection of all the
// feature types describes the feature space of the feature set produced by
// the feature extractor. Not owned.
vector<FeatureType *> feature_types_;
std::vector<FeatureType *> feature_types_;
};
// The generic feature function is the type-independent part of a feature
......@@ -198,7 +198,7 @@ class GenericFeatureFunction {
// Appends the feature types produced by the feature function to types. The
// default implementation appends feature_type(), if non-null. Invalid
// before Init() has been called.
virtual void GetFeatureTypes(vector<FeatureType *> *types) const;
virtual void GetFeatureTypes(std::vector<FeatureType *> *types) const;
// Returns the feature type for feature produced by this feature function. If
// the feature function produces features of different types this returns
......@@ -383,7 +383,7 @@ class NestedFeatureFunction : public FeatureFunction<OBJ, ARGS...> {
~NestedFeatureFunction() override { utils::STLDeleteElements(&nested_); }
// By default, just appends the nested feature types.
void GetFeatureTypes(vector<FeatureType *> *types) const override {
void GetFeatureTypes(std::vector<FeatureType *> *types) const override {
CHECK(!this->nested().empty())
<< "Nested features require nested features to be defined.";
for (auto *function : nested_) function->GetFeatureTypes(types);
......@@ -415,14 +415,14 @@ class NestedFeatureFunction : public FeatureFunction<OBJ, ARGS...> {
}
// Returns the list of nested feature functions.
const vector<NES *> &nested() const { return nested_; }
const std::vector<NES *> &nested() const { return nested_; }
// Instantiates nested feature functions for a feature function. Creates and
// initializes one feature function for each sub-descriptor in the feature
// descriptor.
static void CreateNested(GenericFeatureExtractor *extractor,
FeatureFunctionDescriptor *fd,
vector<NES *> *functions,
std::vector<NES *> *functions,
const string &prefix) {
for (int i = 0; i < fd->feature_size(); ++i) {
FeatureFunctionDescriptor *sub = fd->mutable_feature(i);
......@@ -434,7 +434,7 @@ class NestedFeatureFunction : public FeatureFunction<OBJ, ARGS...> {
protected:
// The nested feature functions, if any, in order of declaration in the
// feature descriptor. Owned.
vector<NES *> nested_;
std::vector<NES *> nested_;
};
// Base class for a nested feature function that takes nested features with the
......@@ -506,7 +506,7 @@ template<class DER, class OBJ, class ...ARGS>
class FeatureLocator : public MetaFeatureFunction<OBJ, ARGS...> {
public:
// Feature locators have an additional check that there is no intrinsic type.
void GetFeatureTypes(vector<FeatureType *> *types) const override {
void GetFeatureTypes(std::vector<FeatureType *> *types) const override {
CHECK(this->feature_type() == nullptr)
<< "FeatureLocators should not have an intrinsic type.";
MetaFeatureFunction<OBJ, ARGS...>::GetFeatureTypes(types);
......@@ -604,7 +604,7 @@ class FeatureExtractor : public GenericFeatureExtractor {
}
// Collect all feature types used in the feature extractor.
void GetFeatureTypes(vector<FeatureType *> *types) const override {
void GetFeatureTypes(std::vector<FeatureType *> *types) const override {
for (int i = 0; i < functions_.size(); ++i) {
functions_[i]->GetFeatureTypes(types);
}
......@@ -612,11 +612,11 @@ class FeatureExtractor : public GenericFeatureExtractor {
// Top-level feature functions (and variables) in the feature extractor.
// Owned.
vector<Function *> functions_;
std::vector<Function *> functions_;
};
#define REGISTER_FEATURE_FUNCTION(base, name, component) \
REGISTER_CLASS_COMPONENT(base, name, component)
#define REGISTER_SYNTAXNET_FEATURE_FUNCTION(base, name, component) \
REGISTER_SYNTAXNET_CLASS_COMPONENT(base, name, component)
} // namespace syntaxnet
......
......@@ -81,7 +81,7 @@ class ResourceBasedFeatureType : public FeatureType {
// resource->NumValues() so as to avoid collisions; this is verified with
// CHECK at creation.
ResourceBasedFeatureType(const string &name, const Resource *resource,
const map<FeatureValue, string> &values)
const std::map<FeatureValue, string> &values)
: FeatureType(name), resource_(resource), values_(values) {
max_value_ = resource->NumValues() - 1;
for (const auto &pair : values) {
......@@ -121,7 +121,7 @@ class ResourceBasedFeatureType : public FeatureType {
FeatureValue max_value_;
// Mapping for extra feature values not in the resource.
map<FeatureValue, string> values_;
std::map<FeatureValue, string> values_;
};
// Feature type that is defined using an explicit map from FeatureValue to
......@@ -139,7 +139,7 @@ class ResourceBasedFeatureType : public FeatureType {
class EnumFeatureType : public FeatureType {
public:
EnumFeatureType(const string &name,
const map<FeatureValue, string> &value_names)
const std::map<FeatureValue, string> &value_names)
: FeatureType(name), value_names_(value_names) {
for (const auto &pair : value_names) {
CHECK_GE(pair.first, 0)
......@@ -168,7 +168,26 @@ class EnumFeatureType : public FeatureType {
FeatureValue domain_size_ = 0;
// Names of feature values.
map<FeatureValue, string> value_names_;
std::map<FeatureValue, string> value_names_;
};
// Feature type for numeric features.
class NumericFeatureType : public FeatureType {
public:
// Initializes numeric feature.
NumericFeatureType(const string &name, FeatureValue size)
: FeatureType(name), size_(size) {}
// Returns numeric feature value.
string GetFeatureValueName(FeatureValue value) const override {
return value < 0 ? "" : tensorflow::strings::Printf("%lld", value);
}
// Returns the number of feature values.
FeatureValue GetDomainSize() const override { return size_; }
private:
FeatureValue size_;
};
} // namespace syntaxnet
......
......@@ -69,7 +69,7 @@ def EmbeddingLookupFeatures(params, sparse_features, allow_weights):
if allow_weights:
# Multiply by weights, reshaping to allow broadcast.
broadcast_weights_shape = tf.concat(0, [tf.shape(weights), [1]])
broadcast_weights_shape = tf.concat_v2([tf.shape(weights), [1]], 0)
embeddings *= tf.reshape(weights, broadcast_weights_shape)
# Sum embeddings by index.
......@@ -330,7 +330,7 @@ class GreedyParser(object):
i,
return_average=return_average))
last_layer = tf.concat(1, embeddings)
last_layer = tf.concat_v2(embeddings, 1)
last_layer_size = self.embedding_size
# Create ReLU layers.
......@@ -404,8 +404,9 @@ class GreedyParser(object):
"""Cross entropy plus L2 loss on weights and biases of the hidden layers."""
dense_golden = BatchedSparseToDense(gold_actions, self._num_actions)
cross_entropy = tf.div(
tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(
logits, dense_golden)), batch_size)
tf.reduce_sum(
tf.nn.softmax_cross_entropy_with_logits(
labels=dense_golden, logits=logits)), batch_size)
regularized_params = [tf.nn.l2_loss(p)
for k, p in self.params.items()
if k.startswith('weights') or k.startswith('bias')]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment