"vscode:/vscode.git/clone" did not exist on "66d883f2118cbaf925fcbfd130cbdc5d2387073d"
Commit 02242bc8 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 318387106
parent 7ebcbe20
...@@ -15,7 +15,12 @@ ...@@ -15,7 +15,12 @@
# ============================================================================== # ==============================================================================
"""Tagging (e.g., NER/POS) task.""" """Tagging (e.g., NER/POS) task."""
import logging import logging
from typing import List, Optional
import dataclasses import dataclasses
from seqeval import metrics as seqeval_metrics
import tensorflow as tf import tensorflow as tf
import tensorflow_hub as hub import tensorflow_hub as hub
...@@ -36,12 +41,12 @@ class TaggingConfig(cfg.TaskConfig): ...@@ -36,12 +41,12 @@ class TaggingConfig(cfg.TaskConfig):
model: encoders.TransformerEncoderConfig = ( model: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig()) encoders.TransformerEncoderConfig())
# The number of real labels. Note that a word may be tokenized into # The real class names, the order of which should match real label id.
# multiple word_pieces tokens, and we asssume the real label id (non-negative) # Note that a word may be tokenized into multiple word_pieces tokens, and
# is assigned to the first token of the word, and a negative label id is # we asssume the real label id (non-negative) is assigned to the first token
# assigned to the remaining tokens. The negative label id will not contribute # of the word, and a negative label id is assigned to the remaining tokens.
# to loss and metrics. # The negative label id will not contribute to loss and metrics.
num_classes: int = 0 class_names: Optional[List[str]] = None
train_data: cfg.DataConfig = cfg.DataConfig() train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig() validation_data: cfg.DataConfig = cfg.DataConfig()
...@@ -75,8 +80,8 @@ class TaggingTask(base_task.Task): ...@@ -75,8 +80,8 @@ class TaggingTask(base_task.Task):
if params.hub_module_url and params.init_checkpoint: if params.hub_module_url and params.init_checkpoint:
raise ValueError('At most one of `hub_module_url` and ' raise ValueError('At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.') '`init_checkpoint` can be specified.')
if params.num_classes == 0: if not params.class_names:
raise ValueError('TaggingConfig.num_classes cannot be 0.') raise ValueError('TaggingConfig.class_names cannot be empty.')
if params.hub_module_url: if params.hub_module_url:
self._hub_module = hub.load(params.hub_module_url) self._hub_module = hub.load(params.hub_module_url)
...@@ -92,7 +97,7 @@ class TaggingTask(base_task.Task): ...@@ -92,7 +97,7 @@ class TaggingTask(base_task.Task):
return models.BertTokenClassifier( return models.BertTokenClassifier(
network=encoder_network, network=encoder_network,
num_classes=self.task_config.num_classes, num_classes=len(self.task_config.class_names),
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=self.task_config.model.initializer_range), stddev=self.task_config.model.initializer_range),
dropout_rate=self.task_config.model.dropout_rate, dropout_rate=self.task_config.model.dropout_rate,
...@@ -123,7 +128,7 @@ class TaggingTask(base_task.Task): ...@@ -123,7 +128,7 @@ class TaggingTask(base_task.Task):
y = tf.random.uniform( y = tf.random.uniform(
shape=(1, params.seq_length), shape=(1, params.seq_length),
minval=-1, minval=-1,
maxval=self.task_config.num_classes, maxval=len(self.task_config.class_names),
dtype=tf.dtypes.int32) dtype=tf.dtypes.int32)
return (x, y) return (x, y)
...@@ -136,19 +141,66 @@ class TaggingTask(base_task.Task): ...@@ -136,19 +141,66 @@ class TaggingTask(base_task.Task):
dataset = tagging_data_loader.TaggingDataLoader(params).load(input_context) dataset = tagging_data_loader.TaggingDataLoader(params).load(input_context)
return dataset return dataset
def build_metrics(self, training=None): def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
del training """Validatation step.
# TODO(chendouble): evaluate using seqeval's f1/precision/recall.
return [tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')] Args:
inputs: a dictionary of input tensors.
def process_metrics(self, metrics, labels, model_outputs): model: the keras.Model.
masked_labels, masked_weights = _masked_labels_and_weights(labels) metrics: a nested structure of metrics objects.
for metric in metrics:
metric.update_state(masked_labels, model_outputs, masked_weights) Returns:
A dictionary of logs.
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs): """
masked_labels, masked_weights = _masked_labels_and_weights(labels) features, labels = inputs
compiled_metrics.update_state(masked_labels, model_outputs, masked_weights) outputs = self.inference_step(features, model)
loss = self.build_losses(labels=labels, model_outputs=outputs)
# Negative label ids are padding labels which should be ignored.
real_label_index = tf.where(tf.greater_equal(labels, 0))
predict_ids = tf.math.argmax(outputs, axis=-1)
predict_ids = tf.gather_nd(predict_ids, real_label_index)
label_ids = tf.gather_nd(labels, real_label_index)
return {
self.loss: loss,
'predict_ids': predict_ids,
'label_ids': label_ids,
}
def aggregate_logs(self, state=None, step_outputs=None):
"""Aggregates over logs returned from a validation step."""
if state is None:
state = {'predict_class': [], 'label_class': []}
def id_to_class_name(batched_ids):
class_names = []
for per_example_ids in batched_ids:
class_names.append([])
for per_token_id in per_example_ids.numpy().tolist():
class_names[-1].append(self.task_config.class_names[per_token_id])
return class_names
# Convert id to class names, because `seqeval_metrics` relies on the class
# name to decide IOB tags.
state['predict_class'].extend(id_to_class_name(step_outputs['predict_ids']))
state['label_class'].extend(id_to_class_name(step_outputs['label_ids']))
return state
def reduce_aggregated_logs(self, aggregated_logs):
"""Reduces aggregated logs over validation steps."""
label_class = aggregated_logs['label_class']
predict_class = aggregated_logs['predict_class']
return {
'f1':
seqeval_metrics.f1_score(label_class, predict_class),
'precision':
seqeval_metrics.precision_score(label_class, predict_class),
'recall':
seqeval_metrics.recall_score(label_class, predict_class),
'accuracy':
seqeval_metrics.accuracy_score(label_class, predict_class),
}
def initialize(self, model): def initialize(self, model):
"""Load a pretrained checkpoint (if exists) and then train from iter 0.""" """Load a pretrained checkpoint (if exists) and then train from iter 0."""
......
...@@ -58,7 +58,7 @@ class TaggingTest(tf.test.TestCase): ...@@ -58,7 +58,7 @@ class TaggingTest(tf.test.TestCase):
init_checkpoint=saved_path, init_checkpoint=saved_path,
model=self._encoder_config, model=self._encoder_config,
train_data=self._train_data_config, train_data=self._train_data_config,
num_classes=3) class_names=["O", "B-PER", "I-PER"])
task = tagging.TaggingTask(config) task = tagging.TaggingTask(config)
model = task.build_model() model = task.build_model()
metrics = task.build_metrics() metrics = task.build_metrics()
...@@ -74,7 +74,7 @@ class TaggingTest(tf.test.TestCase): ...@@ -74,7 +74,7 @@ class TaggingTest(tf.test.TestCase):
config = tagging.TaggingConfig( config = tagging.TaggingConfig(
model=self._encoder_config, model=self._encoder_config,
train_data=self._train_data_config, train_data=self._train_data_config,
num_classes=3) class_names=["O", "B-PER", "I-PER"])
task = tagging.TaggingTask(config) task = tagging.TaggingTask(config)
model = task.build_model() model = task.build_model()
...@@ -116,10 +116,31 @@ class TaggingTest(tf.test.TestCase): ...@@ -116,10 +116,31 @@ class TaggingTest(tf.test.TestCase):
config = tagging.TaggingConfig( config = tagging.TaggingConfig(
hub_module_url=hub_module_url, hub_module_url=hub_module_url,
model=self._encoder_config, model=self._encoder_config,
num_classes=4, class_names=["O", "B-PER", "I-PER"],
train_data=self._train_data_config) train_data=self._train_data_config)
self._run_task(config) self._run_task(config)
def test_seqeval_metrics(self):
config = tagging.TaggingConfig(
model=self._encoder_config,
train_data=self._train_data_config,
class_names=["O", "B-PER", "I-PER"])
task = tagging.TaggingTask(config)
model = task.build_model()
dataset = task.build_inputs(config.train_data)
iterator = iter(dataset)
strategy = tf.distribute.get_strategy()
distributed_outputs = strategy.run(
functools.partial(task.validation_step, model=model),
args=(next(iterator),))
outputs = tf.nest.map_structure(strategy.experimental_local_results,
distributed_outputs)
aggregated = task.aggregate_logs(step_outputs=outputs)
aggregated = task.aggregate_logs(state=aggregated, step_outputs=outputs)
self.assertCountEqual({"f1", "precision", "recall", "accuracy"},
task.reduce_aggregated_logs(aggregated).keys())
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
...@@ -45,6 +45,9 @@ def _get_requirements(): ...@@ -45,6 +45,9 @@ def _get_requirements():
os.path.join(os.path.dirname(__file__), '../requirements.txt'), 'r') as f: os.path.join(os.path.dirname(__file__), '../requirements.txt'), 'r') as f:
for line in f: for line in f:
package_name = line.strip() package_name = line.strip()
# Skip empty line or comments starting with "#".
if not package_name or package_name[0] == '#':
continue
if package_name.startswith('-e '): if package_name.startswith('-e '):
dependency_links_tmp.append(package_name[3:].strip()) dependency_links_tmp.append(package_name[3:].strip())
else: else:
......
...@@ -16,10 +16,13 @@ dataclasses ...@@ -16,10 +16,13 @@ dataclasses
gin-config gin-config
tf_slim>=1.1.0 tf_slim>=1.1.0
typing typing
sentencepiece
Cython Cython
matplotlib matplotlib
opencv-python-headless
pyyaml pyyaml
# CV related dependencies
opencv-python-headless
Pillow Pillow
-e git+https://github.com/cocodataset/cocoapi#egg=pycocotools&subdirectory=PythonAPI -e git+https://github.com/cocodataset/cocoapi#egg=pycocotools&subdirectory=PythonAPI
# NLP related dependencies
seqeval
sentencepiece
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment