Commit b0693846 authored by Ken Franko's avatar Ken Franko Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 321057047
parent be8c2556
...@@ -262,7 +262,7 @@ def predict(task: TaggingTask, params: cfg.DataConfig, ...@@ -262,7 +262,7 @@ def predict(task: TaggingTask, params: cfg.DataConfig,
label_mask=label_mask, label_mask=label_mask,
sentence_ids=sentence_ids) sentence_ids=sentence_ids)
outputs = tf.distribute.get_strategy().experimental_run_v2( outputs = tf.distribute.get_strategy().run(
_replicated_step, args=(next(iterator),)) _replicated_step, args=(next(iterator),))
return tf.nest.map_structure( return tf.nest.map_structure(
tf.distribute.get_strategy().experimental_local_results, outputs) tf.distribute.get_strategy().experimental_local_results, outputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment