Commit 2d7a0d6a authored by Jiang Yu's avatar Jiang Yu Committed by Taylor Robie
Browse files

fix batch_size in transformer_main.py (#4897)

* fix batch_size in transformer_main.py

fix batch_size in transformer_main.py which causes ResourceExhaustedError: OOM during training Transformer models using models/official/transformer

* small format change

change format from one line to multiple ones in order to pass lint tests

* remove trailing space and add comment
parent c1588f00
...@@ -555,9 +555,12 @@ def run_transformer(flags_obj): ...@@ -555,9 +555,12 @@ def run_transformer(flags_obj):
params["use_synthetic_data"] = flags_obj.use_synthetic_data params["use_synthetic_data"] = flags_obj.use_synthetic_data
# Set batch size parameter, which depends on TPU and distribution settings. # Set batch size parameter, which depends on the availability of
params["batch_size"] = ( # TPU and GPU, and distribution settings.
flags_obj.batch_size or params["default_batch_size_tpu"]) params["batch_size"] = (flags_obj.batch_size or (
params["default_batch_size_tpu"] if params["use_tpu"]
else params["default_batch_size"]))
if not params["use_tpu"]: if not params["use_tpu"]:
params["batch_size"] = distribution_utils.per_device_batch_size( params["batch_size"] = distribution_utils.per_device_batch_size(
params["batch_size"], num_gpus) params["batch_size"], num_gpus)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment