"git@developer.sourcefind.cn:dadigang/Ventoy.git" did not exist on "7715bd705cf56c45a7c1f434235eb7279f4d089b"
Unverified Commit d4e1f97f authored by Lukasz Kaiser's avatar Lukasz Kaiser Committed by GitHub
Browse files

Merge pull request #5352 from clarkkev/master

Added cvt_text model
parents 1f484095 e9b55413
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base classes for task-specific modules."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
class SupervisedModule(object):
__metaclass__ = abc.ABCMeta
def __init__(self):
self.supervised_loss = NotImplemented
self.probs = NotImplemented
self.preds = NotImplemented
@abc.abstractmethod
def update_feed_dict(self, feed, mb):
pass
class SemiSupervisedModule(SupervisedModule):
__metaclass__ = abc.ABCMeta
def __init__(self):
super(SemiSupervisedModule, self).__init__()
self.unsupervised_loss = NotImplemented
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Preprocesses pretrained word embeddings, creates dev sets for tasks without a
provided one, and figures out the set of output classes for each task.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import random
from base import configure
from base import embeddings
from base import utils
from task_specific.word_level import word_level_data
def main(data_dir='./data'):
random.seed(0)
utils.log("BUILDING WORD VOCABULARY/EMBEDDINGS")
for pretrained in ['glove.6B.300d.txt']:
config = configure.Config(data_dir=data_dir,
for_preprocessing=True,
pretrained_embeddings=pretrained,
word_embedding_size=300)
embeddings.PretrainedEmbeddingLoader(config).build()
utils.log("CONSTRUCTING DEV SETS")
for task_name in ["chunk"]:
# chunking does not come with a provided dev split, so create one by
# selecting a random subset of the data
config = configure.Config(data_dir=data_dir,
for_preprocessing=True)
task_data_dir = os.path.join(config.raw_data_topdir, task_name) + '/'
train_sentences = word_level_data.TaggedDataLoader(
config, task_name, False).get_labeled_sentences("train")
random.shuffle(train_sentences)
write_sentences(task_data_dir + 'train_subset.txt', train_sentences[1500:])
write_sentences(task_data_dir + 'dev.txt', train_sentences[:1500])
utils.log("WRITING LABEL MAPPINGS")
for task_name in ["chunk"]:
for i, label_encoding in enumerate(["BIOES"]):
config = configure.Config(data_dir=data_dir,
for_preprocessing=True,
label_encoding=label_encoding)
token_level = task_name in ["ccg", "pos", "depparse"]
loader = word_level_data.TaggedDataLoader(config, task_name, token_level)
if token_level:
if i != 0:
continue
utils.log("WRITING LABEL MAPPING FOR", task_name.upper())
else:
utils.log(" Writing label mapping for", task_name.upper(),
label_encoding)
utils.log(" ", len(loader.label_mapping), "classes")
utils.write_cpickle(loader.label_mapping,
loader.label_mapping_path)
def write_sentences(fname, sentences):
with open(fname, 'w') as f:
for words, tags in sentences:
for word, tag in zip(words, tags):
f.write(word + " " + tag + "\n")
f.write("\n")
if __name__ == '__main__':
main()
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines all the tasks the model can learn."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
from base import embeddings
from task_specific.word_level import depparse_module
from task_specific.word_level import depparse_scorer
from task_specific.word_level import tagging_module
from task_specific.word_level import tagging_scorers
from task_specific.word_level import word_level_data
class Task(object):
__metaclass__ = abc.ABCMeta
def __init__(self, config, name, loader):
self.config = config
self.name = name
self.loader = loader
self.train_set = self.loader.get_dataset("train")
self.val_set = self.loader.get_dataset("dev" if config.dev_set else "test")
@abc.abstractmethod
def get_module(self, inputs, encoder):
pass
@abc.abstractmethod
def get_scorer(self):
pass
class Tagging(Task):
def __init__(self, config, name, is_token_level=True):
super(Tagging, self).__init__(
config, name, word_level_data.TaggedDataLoader(
config, name, is_token_level))
self.n_classes = len(set(self.loader.label_mapping.values()))
self.is_token_level = is_token_level
def get_module(self, inputs, encoder):
return tagging_module.TaggingModule(
self.config, self.name, self.n_classes, inputs, encoder)
def get_scorer(self):
if self.is_token_level:
return tagging_scorers.AccuracyScorer()
else:
return tagging_scorers.EntityLevelF1Scorer(self.loader.label_mapping)
class DependencyParsing(Tagging):
def __init__(self, config, name):
super(DependencyParsing, self).__init__(config, name, True)
def get_module(self, inputs, encoder):
return depparse_module.DepparseModule(
self.config, self.name, self.n_classes, inputs, encoder)
def get_scorer(self):
return depparse_scorer.DepparseScorer(
self.n_classes, (embeddings.get_punctuation_ids(self.config)))
def get_task(config, name):
if name in ["ccg", "pos"]:
return Tagging(config, name, True)
elif name in ["chunk", "ner", "er"]:
return Tagging(config, name, False)
elif name == "depparse":
return DependencyParsing(config, name)
else:
raise ValueError("Unknown task", name)
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dependency parsing module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from corpus_processing import minibatching
from model import model_helpers
from model import task_module
class DepparseModule(task_module.SemiSupervisedModule):
def __init__(self, config, task_name, n_classes, inputs, encoder):
super(DepparseModule, self).__init__()
self.task_name = task_name
self.n_classes = n_classes
self.labels = labels = tf.placeholder(tf.float32, [None, None, None],
name=task_name + '_labels')
class PredictionModule(object):
def __init__(self, name, dep_reprs, head_reprs, roll_direction=0):
self.name = name
with tf.variable_scope(name + '/predictions'):
# apply hidden layers to the input representations
arc_dep_hidden = model_helpers.project(
dep_reprs, config.projection_size, 'arc_dep_hidden')
arc_head_hidden = model_helpers.project(
head_reprs, config.projection_size, 'arc_head_hidden')
arc_dep_hidden = tf.nn.relu(arc_dep_hidden)
arc_head_hidden = tf.nn.relu(arc_head_hidden)
arc_head_hidden = tf.nn.dropout(arc_head_hidden, inputs.keep_prob)
arc_dep_hidden = tf.nn.dropout(arc_dep_hidden, inputs.keep_prob)
# bilinear classifier excluding the final dot product
arc_head = tf.layers.dense(
arc_head_hidden, config.depparse_projection_size, name='arc_head')
W = tf.get_variable('shared_W',
shape=[config.projection_size, n_classes,
config.depparse_projection_size])
Wr = tf.get_variable('relation_specific_W',
shape=[config.projection_size,
config.depparse_projection_size])
Wr_proj = tf.tile(tf.expand_dims(Wr, axis=-2), [1, n_classes, 1])
W += Wr_proj
arc_dep = tf.tensordot(arc_dep_hidden, W, axes=[[-1], [0]])
shape = tf.shape(arc_dep)
arc_dep = tf.reshape(arc_dep,
[shape[0], -1, config.depparse_projection_size])
# apply the transformer scaling trick to prevent dot products from
# getting too large (possibly not necessary)
scale = np.power(
config.depparse_projection_size, 0.25).astype('float32')
scale = tf.get_variable('scale', initializer=scale, dtype=tf.float32)
arc_dep /= scale
arc_head /= scale
# compute the scores for each candidate arc
word_scores = tf.matmul(arc_head, arc_dep, transpose_b=True)
root_scores = tf.layers.dense(arc_head, n_classes, name='root_score')
arc_scores = tf.concat([root_scores, word_scores], axis=-1)
# disallow the model from making impossible predictions
mask = inputs.mask
mask_shape = tf.shape(mask)
mask = tf.tile(tf.expand_dims(mask, -1), [1, 1, n_classes])
mask = tf.reshape(mask, [-1, mask_shape[1] * n_classes])
mask = tf.concat([tf.ones((mask_shape[0], 1)),
tf.zeros((mask_shape[0], n_classes - 1)), mask],
axis=1)
mask = tf.tile(tf.expand_dims(mask, 1), [1, mask_shape[1], 1])
arc_scores += (mask - 1) * 100.0
self.logits = arc_scores
self.loss = model_helpers.masked_ce_loss(
self.logits, labels, inputs.mask,
roll_direction=roll_direction)
primary = PredictionModule(
'primary',
[encoder.uni_reprs, encoder.bi_reprs],
[encoder.uni_reprs, encoder.bi_reprs])
ps = [
PredictionModule(
'full',
[encoder.uni_reprs, encoder.bi_reprs],
[encoder.uni_reprs, encoder.bi_reprs]),
PredictionModule('fw_fw', [encoder.uni_fw], [encoder.uni_fw]),
PredictionModule('fw_bw', [encoder.uni_fw], [encoder.uni_bw]),
PredictionModule('bw_fw', [encoder.uni_bw], [encoder.uni_fw]),
PredictionModule('bw_bw', [encoder.uni_bw], [encoder.uni_bw]),
]
self.unsupervised_loss = sum(p.loss for p in ps)
self.supervised_loss = primary.loss
self.probs = tf.nn.softmax(primary.logits)
self.preds = tf.argmax(primary.logits, axis=-1)
def update_feed_dict(self, feed, mb):
if self.task_name in mb.teacher_predictions:
feed[self.labels] = mb.teacher_predictions[self.task_name]
elif mb.task_name != 'unlabeled':
labels = minibatching.build_array(
[[0] + e.labels + [0] for e in mb.examples])
feed[self.labels] = np.eye(
(1 + mb.words.shape[1]) * self.n_classes)[labels]
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dependency parsing evaluation (computes UAS/LAS)."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from task_specific.word_level import word_level_scorer
class DepparseScorer(word_level_scorer.WordLevelScorer):
def __init__(self, n_relations, punctuation):
super(DepparseScorer, self).__init__()
self._n_relations = n_relations
self._punctuation = punctuation if punctuation else None
def _get_results(self):
correct_unlabeled, correct_labeled, count = 0, 0, 0
for example, preds in zip(self._examples, self._preds):
for w, y_true, y_pred in zip(example.words[1:-1], example.labels, preds):
if w in self._punctuation:
continue
count += 1
correct_labeled += (1 if y_pred == y_true else 0)
correct_unlabeled += (1 if int(y_pred // self._n_relations) ==
int(y_true // self._n_relations) else 0)
return [
("las", 100.0 * correct_labeled / count),
("uas", 100.0 * correct_unlabeled / count),
("loss", self.get_loss()),
]
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Sequence tagging module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from corpus_processing import minibatching
from model import model_helpers
from model import task_module
class TaggingModule(task_module.SemiSupervisedModule):
def __init__(self, config, task_name, n_classes, inputs,
encoder):
super(TaggingModule, self).__init__()
self.task_name = task_name
self.n_classes = n_classes
self.labels = labels = tf.placeholder(tf.float32, [None, None, None],
name=task_name + '_labels')
class PredictionModule(object):
def __init__(self, name, input_reprs, roll_direction=0, activate=True):
self.name = name
with tf.variable_scope(name + '/predictions'):
projected = model_helpers.project(input_reprs, config.projection_size)
if activate:
projected = tf.nn.relu(projected)
self.logits = tf.layers.dense(projected, n_classes, name='predict')
targets = labels
targets *= (1 - inputs.label_smoothing)
targets += inputs.label_smoothing / n_classes
self.loss = model_helpers.masked_ce_loss(
self.logits, targets, inputs.mask, roll_direction=roll_direction)
primary = PredictionModule('primary',
([encoder.uni_reprs, encoder.bi_reprs]))
ps = [
PredictionModule('full', ([encoder.uni_reprs, encoder.bi_reprs]),
activate=False),
PredictionModule('forwards', [encoder.uni_fw]),
PredictionModule('backwards', [encoder.uni_bw]),
PredictionModule('future', [encoder.uni_fw], roll_direction=1),
PredictionModule('past', [encoder.uni_bw], roll_direction=-1),
]
self.unsupervised_loss = sum(p.loss for p in ps)
self.supervised_loss = primary.loss
self.probs = tf.nn.softmax(primary.logits)
self.preds = tf.argmax(primary.logits, axis=-1)
def update_feed_dict(self, feed, mb):
if self.task_name in mb.teacher_predictions:
feed[self.labels] = mb.teacher_predictions[self.task_name]
elif mb.task_name != 'unlabeled':
labels = minibatching.build_array(
[[0] + e.labels + [0] for e in mb.examples])
feed[self.labels] = np.eye(self.n_classes)[labels]
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Sequence tagging evaluation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
from task_specific.word_level import tagging_utils
from task_specific.word_level import word_level_scorer
class AccuracyScorer(word_level_scorer.WordLevelScorer):
def __init__(self, auto_fail_label=None):
super(AccuracyScorer, self).__init__()
self._auto_fail_label = auto_fail_label
def _get_results(self):
correct, count = 0, 0
for example, preds in zip(self._examples, self._preds):
for y_true, y_pred in zip(example.labels, preds):
count += 1
correct += (1 if y_pred == y_true and y_true != self._auto_fail_label
else 0)
return [
("accuracy", 100.0 * correct / count),
("loss", self.get_loss())
]
class F1Scorer(word_level_scorer.WordLevelScorer):
__metaclass__ = abc.ABCMeta
def __init__(self):
super(F1Scorer, self).__init__()
self._n_correct, self._n_predicted, self._n_gold = 0, 0, 0
def _get_results(self):
if self._n_correct == 0:
p, r, f1 = 0, 0, 0
else:
p = 100.0 * self._n_correct / self._n_predicted
r = 100.0 * self._n_correct / self._n_gold
f1 = 2 * p * r / (p + r)
return [
("precision", p),
("recall", r),
("f1", f1),
("loss", self.get_loss()),
]
class EntityLevelF1Scorer(F1Scorer):
def __init__(self, label_mapping):
super(EntityLevelF1Scorer, self).__init__()
self._inv_label_mapping = {v: k for k, v in label_mapping.iteritems()}
def _get_results(self):
self._n_correct, self._n_predicted, self._n_gold = 0, 0, 0
for example, preds in zip(self._examples, self._preds):
sent_spans = set(tagging_utils.get_span_labels(
example.labels, self._inv_label_mapping))
span_preds = set(tagging_utils.get_span_labels(
preds, self._inv_label_mapping))
self._n_correct += len(sent_spans & span_preds)
self._n_gold += len(sent_spans)
self._n_predicted += len(span_preds)
return super(EntityLevelF1Scorer, self)._get_results()
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for sequence tagging tasks for entity-level tasks (e.g., NER)."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def get_span_labels(sentence_tags, inv_label_mapping=None):
"""Go from token-level labels to list of entities (start, end, class)."""
if inv_label_mapping:
sentence_tags = [inv_label_mapping[i] for i in sentence_tags]
span_labels = []
last = 'O'
start = -1
for i, tag in enumerate(sentence_tags):
pos, _ = (None, 'O') if tag == 'O' else tag.split('-')
if (pos == 'S' or pos == 'B' or tag == 'O') and last != 'O':
span_labels.append((start, i - 1, last.split('-')[-1]))
if pos == 'B' or pos == 'S' or last == 'O':
start = i
last = tag
if sentence_tags[-1] != 'O':
span_labels.append((start, len(sentence_tags) - 1,
sentence_tags[-1].split('-')[-1]))
return span_labels
def get_tags(span_labels, length, encoding):
"""Converts a list of entities to token-label labels based on the provided
encoding (e.g., BIOES).
"""
tags = ['O' for _ in range(length)]
for s, e, t in span_labels:
for i in range(s, e + 1):
tags[i] = 'I-' + t
if 'E' in encoding:
tags[e] = 'E-' + t
if 'B' in encoding:
tags[s] = 'B-' + t
if 'S' in encoding and s == e:
tags[s] = 'S-' + t
return tags
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for processing word-level datasets."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
import random
import tensorflow as tf
from base import embeddings
from base import utils
from corpus_processing import example
from corpus_processing import minibatching
from task_specific.word_level import tagging_utils
class TaggedDataLoader(object):
def __init__(self, config, name, is_token_level):
self._config = config
self._task_name = name
self._raw_data_path = os.path.join(config.raw_data_topdir, name)
self._is_token_level = is_token_level
self.label_mapping_path = os.path.join(
config.preprocessed_data_topdir,
(name if is_token_level else
name + '_' + config.label_encoding) + '_label_mapping.pkl')
if self.label_mapping:
self._n_classes = len(set(self.label_mapping.values()))
else:
self._n_classes = None
def get_dataset(self, split):
if (split == 'train' and not self._config.for_preprocessing and
tf.gfile.Exists(os.path.join(self._raw_data_path, 'train_subset.txt'))):
split = 'train_subset'
return minibatching.Dataset(
self._config, self._get_examples(split), self._task_name)
def get_labeled_sentences(self, split):
sentences = []
path = os.path.join(self._raw_data_path, split + '.txt')
if not tf.gfile.Exists(path):
if self._config.for_preprocessing:
return []
else:
raise ValueError('Unable to load data from', path)
with tf.gfile.GFile(path, 'r') as f:
sentence = []
for line in f:
line = line.strip().split()
if not line:
if sentence:
words, tags = zip(*sentence)
sentences.append((words, tags))
sentence = []
continue
if line[0] == '-DOCSTART-':
continue
word, tag = line[0], line[-1]
sentence.append((word, tag))
return sentences
@property
def label_mapping(self):
if not self._config.for_preprocessing:
return utils.load_cpickle(self.label_mapping_path)
tag_counts = collections.Counter()
train_tags = set()
for split in ['train', 'dev', 'test']:
for words, tags in self.get_labeled_sentences(split):
if not self._is_token_level:
span_labels = tagging_utils.get_span_labels(tags)
tags = tagging_utils.get_tags(
span_labels, len(words), self._config.label_encoding)
for tag in tags:
if self._task_name == 'depparse':
tag = tag.split('-')[1]
tag_counts[tag] += 1
if split == 'train':
train_tags.add(tag)
if self._task_name == 'ccg':
# for CCG, there are tags in the test sets that aren't in the train set
# all tags not in the train set get mapped to a special label
# the model will never predict this label because it never sees it in the
# training set
not_in_train_tags = []
for tag, count in tag_counts.items():
if tag not in train_tags:
not_in_train_tags.append(tag)
label_mapping = {
label: i for i, label in enumerate(sorted(filter(
lambda t: t not in not_in_train_tags, tag_counts.keys())))
}
n = len(label_mapping)
for tag in not_in_train_tags:
label_mapping[tag] = n
else:
labels = sorted(tag_counts.keys())
if self._task_name == 'depparse':
labels.remove('root')
labels.insert(0, 'root')
label_mapping = {label: i for i, label in enumerate(labels)}
return label_mapping
def _get_examples(self, split):
word_vocab = embeddings.get_word_vocab(self._config)
char_vocab = embeddings.get_char_vocab()
examples = [
TaggingExample(
self._config, self._is_token_level, words, tags,
word_vocab, char_vocab, self.label_mapping, self._task_name)
for words, tags in self.get_labeled_sentences(split)]
if self._config.train_set_percent < 100:
utils.log('using reduced train set ({:}%)'.format(
self._config.train_set_percent))
random.shuffle(examples)
examples = examples[:int(len(examples) *
self._config.train_set_percent / 100.0)]
return examples
class TaggingExample(example.Example):
def __init__(self, config, is_token_level, words, original_tags,
word_vocab, char_vocab, label_mapping, task_name):
super(TaggingExample, self).__init__(words, word_vocab, char_vocab)
if is_token_level:
labels = original_tags
else:
span_labels = tagging_utils.get_span_labels(original_tags)
labels = tagging_utils.get_tags(
span_labels, len(words), config.label_encoding)
if task_name == 'depparse':
self.labels = []
for l in labels:
split = l.split('-')
self.labels.append(
len(label_mapping) * (0 if split[0] == '0' else 1 + int(split[0]))
+ label_mapping[split[1]])
else:
self.labels = [label_mapping[l] for l in labels]
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base class for word-level scorers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
from corpus_processing import scorer
class WordLevelScorer(scorer.Scorer):
__metaclass__ = abc.ABCMeta
def __init__(self):
super(WordLevelScorer, self).__init__()
self._total_loss = 0
self._total_words = 0
self._examples = []
self._preds = []
def update(self, examples, predictions, loss):
super(WordLevelScorer, self).update(examples, predictions, loss)
n_words = 0
for example, preds in zip(examples, predictions):
self._examples.append(example)
self._preds.append(list(preds)[1:len(example.words) - 1])
n_words += len(example.words) - 2
self._total_loss += loss * n_words
self._total_words += n_words
def get_loss(self):
return self._total_loss / max(1, self._total_words)
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs training for CVT text models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import bisect
import time
import numpy as np
import tensorflow as tf
from base import utils
from model import multitask_model
from task_specific import task_definitions
class Trainer(object):
def __init__(self, config):
self._config = config
self.tasks = [task_definitions.get_task(self._config, task_name)
for task_name in self._config.task_names]
utils.log('Loading Pretrained Embeddings')
pretrained_embeddings = utils.load_cpickle(self._config.word_embeddings)
utils.log('Building Model')
self._model = multitask_model.Model(
self._config, pretrained_embeddings, self.tasks)
utils.log()
def train(self, sess, progress, summary_writer):
heading = lambda s: utils.heading(s, '(' + self._config.model_name + ')')
trained_on_sentences = 0
start_time = time.time()
unsupervised_loss_total, unsupervised_loss_count = 0, 0
supervised_loss_total, supervised_loss_count = 0, 0
for mb in self._get_training_mbs(progress.unlabeled_data_reader):
if mb.task_name != 'unlabeled':
loss = self._model.train_labeled(sess, mb)
supervised_loss_total += loss
supervised_loss_count += 1
if mb.task_name == 'unlabeled':
self._model.run_teacher(sess, mb)
loss = self._model.train_unlabeled(sess, mb)
unsupervised_loss_total += loss
unsupervised_loss_count += 1
mb.teacher_predictions.clear()
trained_on_sentences += mb.size
global_step = self._model.get_global_step(sess)
if global_step % self._config.print_every == 0:
utils.log('step {:} - '
'supervised loss: {:.2f} - '
'unsupervised loss: {:.2f} - '
'{:.1f} sentences per second'.format(
global_step,
supervised_loss_total / max(1, supervised_loss_count),
unsupervised_loss_total / max(1, unsupervised_loss_count),
trained_on_sentences / (time.time() - start_time)))
unsupervised_loss_total, unsupervised_loss_count = 0, 0
supervised_loss_total, supervised_loss_count = 0, 0
if global_step % self._config.eval_dev_every == 0:
heading('EVAL ON DEV')
self.evaluate_all_tasks(sess, summary_writer, progress.history)
progress.save_if_best_dev_model(sess, global_step)
utils.log()
if global_step % self._config.eval_train_every == 0:
heading('EVAL ON TRAIN')
self.evaluate_all_tasks(sess, summary_writer, progress.history, True)
utils.log()
if global_step % self._config.save_model_every == 0:
heading('CHECKPOINTING MODEL')
progress.write(sess, global_step)
utils.log()
def evaluate_all_tasks(self, sess, summary_writer, history, train_set=False):
for task in self.tasks:
results = self._evaluate_task(sess, task, summary_writer, train_set)
if history is not None:
results.append(('step', self._model.get_global_step(sess)))
history.append(results)
if history is not None:
utils.write_cpickle(history, self._config.history_file)
def _evaluate_task(self, sess, task, summary_writer, train_set):
scorer = task.get_scorer()
data = task.train_set if train_set else task.val_set
for i, mb in enumerate(data.get_minibatches(self._config.test_batch_size)):
loss, batch_preds = self._model.test(sess, mb)
scorer.update(mb.examples, batch_preds, loss)
results = scorer.get_results(task.name +
('_train_' if train_set else '_dev_'))
utils.log(task.name.upper() + ': ' + scorer.results_str())
write_summary(summary_writer, results,
global_step=self._model.get_global_step(sess))
return results
def _get_training_mbs(self, unlabeled_data_reader):
datasets = [task.train_set for task in self.tasks]
weights = [np.sqrt(dataset.size) for dataset in datasets]
thresholds = np.cumsum([w / np.sum(weights) for w in weights])
labeled_mbs = [dataset.endless_minibatches(self._config.train_batch_size)
for dataset in datasets]
unlabeled_mbs = unlabeled_data_reader.endless_minibatches()
while True:
dataset_ind = bisect.bisect(thresholds, np.random.random())
yield next(labeled_mbs[dataset_ind])
if self._config.is_semisup:
yield next(unlabeled_mbs)
def write_summary(writer, results, global_step):
for k, v in results:
if 'f1' in k or 'acc' in k or 'loss' in k:
writer.add_summary(tf.Summary(
value=[tf.Summary.Value(tag=k, simple_value=v)]), global_step)
writer.flush()
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Tracks and saves training progress (models and other data such as the current
location in the lm1b corpus) for later reloading.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from base import utils
from corpus_processing import unlabeled_data
class TrainingProgress(object):
def __init__(self, config, sess, checkpoint_saver, best_model_saver,
restore_if_possible=True):
self.config = config
self.checkpoint_saver = checkpoint_saver
self.best_model_saver = best_model_saver
tf.gfile.MakeDirs(config.checkpoints_dir)
if restore_if_possible and tf.gfile.Exists(config.progress):
history, current_file, current_line = utils.load_cpickle(
config.progress, memoized=False)
self.history = history
self.unlabeled_data_reader = unlabeled_data.UnlabeledDataReader(
config, current_file, current_line)
utils.log("Continuing from global step", dict(self.history[-1])["step"],
"(lm1b file {:}, line {:})".format(current_file, current_line))
self.checkpoint_saver.restore(sess, tf.train.latest_checkpoint(
self.config.checkpoints_dir))
else:
utils.log("No previous checkpoint found - starting from scratch")
self.history = []
self.unlabeled_data_reader = (
unlabeled_data.UnlabeledDataReader(config))
def write(self, sess, global_step):
self.checkpoint_saver.save(sess, self.config.checkpoint,
global_step=global_step)
utils.write_cpickle(
(self.history, self.unlabeled_data_reader.current_file,
self.unlabeled_data_reader.current_line),
self.config.progress)
def save_if_best_dev_model(self, sess, global_step):
best_avg_score = 0
for i, results in enumerate(self.history):
if any("train" in metric for metric, value in results):
continue
total, count = 0, 0
for metric, value in results:
if "f1" in metric or "las" in metric or "accuracy" in metric:
total += value
count += 1
avg_score = total / count
if avg_score >= best_avg_score:
best_avg_score = avg_score
if i == len(self.history) - 1:
utils.log("New best model! Saving...")
self.best_model_saver.save(sess, self.config.best_model_checkpoint,
global_step=global_step)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment