Commit 2d7a0d6a authored by Jiang Yu's avatar Jiang Yu Committed by Taylor Robie
Browse files

fix batch_size in transformer_main.py (#4897)

* fix batch_size in transformer_main.py

fix batch_size in transformer_main.py which causes ResourceExhaustedError: OOM during training Transformer models using models/official/transformer

* small format change

change format from one line to multiple ones in order to pass lint tests

* remove trailing space and add comment
parent c1588f00
......@@ -555,9 +555,12 @@ def run_transformer(flags_obj):
params["use_synthetic_data"] = flags_obj.use_synthetic_data
# Set batch size parameter, which depends on TPU and distribution settings.
params["batch_size"] = (
flags_obj.batch_size or params["default_batch_size_tpu"])
# Set batch size parameter, which depends on the availability of
# TPU and GPU, and distribution settings.
params["batch_size"] = (flags_obj.batch_size or (
params["default_batch_size_tpu"] if params["use_tpu"]
else params["default_batch_size"]))
if not params["use_tpu"]:
params["batch_size"] = distribution_utils.per_device_batch_size(
params["batch_size"], num_gpus)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment