Unverified Commit a29ea36d authored by junming huang's avatar junming huang Committed by GitHub
Browse files

Update train_unconditional.py (#3899)

increase the time of timeout when using big dataset or high resolution
parent af48bf20
......@@ -4,6 +4,7 @@ import logging
import math
import os
import shutil
from datetime import timedelta
from pathlib import Path
from typing import Optional
......@@ -11,7 +12,7 @@ import accelerate
import datasets
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate import Accelerator, InitProcessGroupKwargs
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration
from datasets import load_dataset
......@@ -286,11 +287,13 @@ def main(args):
logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=7200))#a big number for high resolution or big dataset
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.logger,
project_config=accelerator_project_config,
kwargs_handlers=[kwargs],
)
if args.logger == "tensorboard":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment