"tests/vscode:/vscode.git/clone" did not exist on "b63c956860373ef169cfb24c2088a9b173a72bfd"
Commit 0f5bdd0e authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 285285388
parent 143d09bc
...@@ -365,7 +365,12 @@ class TransformerTask(object): ...@@ -365,7 +365,12 @@ class TransformerTask(object):
def eval(self): def eval(self):
"""Evaluates the model.""" """Evaluates the model."""
with distribution_utils.get_strategy_scope(self.distribution_strategy): distribution_strategy = self.distribution_strategy if self.use_tpu else None
# We only want to create the model under DS scope for TPU case.
# When 'distribution_strategy' is None, a no-op DummyContextManager will
# be used.
with distribution_utils.get_strategy_scope(distribution_strategy):
if not self.predict_model: if not self.predict_model:
self.predict_model = transformer.create_model(self.params, False) self.predict_model = transformer.create_model(self.params, False)
self._load_weights_if_possible( self._load_weights_if_possible(
...@@ -375,7 +380,7 @@ class TransformerTask(object): ...@@ -375,7 +380,7 @@ class TransformerTask(object):
return evaluate_and_log_bleu( return evaluate_and_log_bleu(
self.predict_model, self.params, self.flags_obj.bleu_source, self.predict_model, self.params, self.flags_obj.bleu_source,
self.flags_obj.bleu_ref, self.flags_obj.vocab_file, self.flags_obj.bleu_ref, self.flags_obj.vocab_file,
self.distribution_strategy if self.use_tpu else None) distribution_strategy)
def predict(self): def predict(self):
"""Predicts result from the model.""" """Predicts result from the model."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment