Commit a1af5247 authored by VictorSanh's avatar VictorSanh
Browse files

Add seed in initialization

parent 4faeb38b
...@@ -427,7 +427,10 @@ def main(): ...@@ -427,7 +427,10 @@ def main():
type=int, type=int,
default=-1, default=-1,
help="local_rank for distributed training on gpus") help="local_rank for distributed training on gpus")
parser.add_argument('--seed',
type=int,
default=42,
help="random seed for initialization")
args = parser.parse_args() args = parser.parse_args()
processors = { processors = {
...@@ -444,7 +447,12 @@ def main(): ...@@ -444,7 +447,12 @@ def main():
n_gpu = 1 n_gpu = 1
# print("Initializing the distributed backend: NCCL") # print("Initializing the distributed backend: NCCL")
print("device", device, "n_gpu", n_gpu) print("device", device, "n_gpu", n_gpu)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if n_gpu>0: torch.cuda.manual_seed_all(args.seed)
if not args.do_train and not args.do_eval: if not args.do_train and not args.do_eval:
raise ValueError("At least one of `do_train` or `do_eval` must be True.") raise ValueError("At least one of `do_train` or `do_eval` must be True.")
......
...@@ -745,6 +745,10 @@ def main(): ...@@ -745,6 +745,10 @@ def main():
type=int, type=int,
default=-1, default=-1,
help="local_rank for distributed training on gpus") help="local_rank for distributed training on gpus")
parser.add_argument('--seed',
type=int,
default=42,
help="random seed for initialization")
args = parser.parse_args() args = parser.parse_args()
...@@ -757,6 +761,11 @@ def main(): ...@@ -757,6 +761,11 @@ def main():
# print("Initializing the distributed backend: NCCL") # print("Initializing the distributed backend: NCCL")
print("device", device, "n_gpu", n_gpu) print("device", device, "n_gpu", n_gpu)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if n_gpu>0: torch.cuda.manual_seed_all(args.seed)
if not args.do_train and not args.do_predict: if not args.do_train and not args.do_predict:
raise ValueError("At least one of `do_train` or `do_predict` must be True.") raise ValueError("At least one of `do_train` or `do_predict` must be True.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment