Unverified Commit a29ea36d authored by junming huang's avatar junming huang Committed by GitHub
Browse files

Update train_unconditional.py (#3899)

increase the time of timeout when using big dataset or high resolution
parent af48bf20
...@@ -4,6 +4,7 @@ import logging ...@@ -4,6 +4,7 @@ import logging
import math import math
import os import os
import shutil import shutil
from datetime import timedelta
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
...@@ -11,7 +12,7 @@ import accelerate ...@@ -11,7 +12,7 @@ import accelerate
import datasets import datasets
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from accelerate import Accelerator from accelerate import Accelerator, InitProcessGroupKwargs
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration from accelerate.utils import ProjectConfiguration
from datasets import load_dataset from datasets import load_dataset
...@@ -286,11 +287,13 @@ def main(args): ...@@ -286,11 +287,13 @@ def main(args):
logging_dir = os.path.join(args.output_dir, args.logging_dir) logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=7200))#a big number for high resolution or big dataset
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision, mixed_precision=args.mixed_precision,
log_with=args.logger, log_with=args.logger,
project_config=accelerator_project_config, project_config=accelerator_project_config,
kwargs_handlers=[kwargs],
) )
if args.logger == "tensorboard": if args.logger == "tensorboard":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment