Commit 649e9774 authored by VictorSanh's avatar VictorSanh
Browse files

Fix bug train_batch_size not an int.

Division makes args.train_batch_size becoming a float.
cc @thomwolf
parent d55c3ae8
...@@ -426,7 +426,7 @@ def main(): ...@@ -426,7 +426,7 @@ def main():
raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format( raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format(
args.accumulate_gradients)) args.accumulate_gradients))
args.train_batch_size = args.train_batch_size / args.accumulate_gradients args.train_batch_size = int(args.train_batch_size / args.accumulate_gradients)
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
......
...@@ -756,7 +756,7 @@ def main(): ...@@ -756,7 +756,7 @@ def main():
raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format( raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format(
args.accumulate_gradients)) args.accumulate_gradients))
args.train_batch_size = args.train_batch_size / args.accumulate_gradients args.train_batch_size = int(args.train_batch_size / args.accumulate_gradients)
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment