from argparse import ArgumentParser def get_args(): parser = ArgumentParser() # 模型 parser.add_argument("--model_root", type=str, help="inpainting模型路径") parser.add_argument("--vae_subfolder", type=str, default="vae") # 数据 & 加载设置 parser.add_argument("--train_data_record_path", type=str, help="数据集信息路径") parser.add_argument("--eval_data_record_path", type=str, help="数据集信息路径") parser.add_argument("--height", type=int, default=512) parser.add_argument("--width", type=int, default=384) parser.add_argument("--max_grad_norm", default=1.0) # 训练相关 parser.add_argument("--batch_size", type=int, default=8) parser.add_argument("--num_workers", type=int, default=8) parser.add_argument("--gradient_accumulation_steps", type=int, default=4) parser.add_argument("--weight_dtype", type=str, default="bf16") parser.add_argument("--max_steps", type=int, default=60000) parser.add_argument("--noise_offset", type=float, default=None) parser.add_argument("--use_ema", action="store_true") parser.add_argument("--ema_decay", type=float, default=0.999) parser.add_argument("--extra_condition_key", type=str, default="empty") ## 优化器参数 parser.add_argument("--lr", type=float, default=1e-5) parser.add_argument("--beta1", type=float, default=0.9) parser.add_argument("--beta2", type=float, default=0.999) parser.add_argument("--weight_decay", type=float, default=0.01) parser.add_argument("--eps", type=float, default=1e-08) ## 保存设置 parser.add_argument("--logging_steps", type=int, default=5) parser.add_argument("--output_dir", type=str, default="../checkpoints") parser.add_argument("--checkpoint_dir", type=str) parser.add_argument("--eval_output_dir", type=str) parser.add_argument("--global_steps", type=int, default=0) args = parser.parse_args() return args