Commit b1188d03 authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Fix transformer_main.py.

An earlier change made it so running this file would cause an error. The unit tests still passed, as the unit tests do not directly call the main() function

PiperOrigin-RevId: 264268400
parent 67420e17
...@@ -403,8 +403,7 @@ def main(_): ...@@ -403,8 +403,7 @@ def main(_):
if not flags_obj.distribution_strategy != "tpu": if not flags_obj.distribution_strategy != "tpu":
_run_task(task) _run_task(task)
else: else:
primary_cpu_task = ("/job:worker" primary_cpu_task = "/job:worker" if flags_obj.use_tpu_2vm_config else ""
if flags_obj.use_tpu_2vm_config is not None else "")
with tf.device(primary_cpu_task): with tf.device(primary_cpu_task):
_run_task(task) _run_task(task)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment