from argparse import ArgumentParser


def get_args():
    parser = ArgumentParser()

    # 模型
    parser.add_argument("--model_root", type=str, help="inpainting模型路径")

    parser.add_argument("--vae_subfolder", type=str, default="vae")

    # 数据 & 加载设置
    parser.add_argument("--train_data_record_path", type=str, help="数据集信息路径")

    parser.add_argument("--eval_data_record_path", type=str, help="数据集信息路径")

    parser.add_argument("--height", type=int, default=512)

    parser.add_argument("--width", type=int, default=384)

    parser.add_argument("--max_grad_norm", default=1.0)
    
    # 训练相关
    parser.add_argument("--batch_size", type=int, default=8)

    parser.add_argument("--num_workers", type=int, default=8)

    parser.add_argument("--gradient_accumulation_steps", type=int, default=4)

    parser.add_argument("--weight_dtype", type=str, default="bf16")

    parser.add_argument("--max_steps", type=int, default=60000)

    parser.add_argument("--noise_offset", type=float, default=None)

    parser.add_argument("--use_ema", action="store_true")

    parser.add_argument("--ema_decay", type=float, default=0.999)

    parser.add_argument("--extra_condition_key", type=str, default="empty")
    
    ## 优化器参数
    parser.add_argument("--lr", type=float, default=1e-5)

    parser.add_argument("--beta1", type=float, default=0.9)

    parser.add_argument("--beta2", type=float, default=0.999)

    parser.add_argument("--weight_decay", type=float, default=0.01)

    parser.add_argument("--eps", type=float, default=1e-08)
    
    ## 保存设置
    parser.add_argument("--logging_steps", type=int, default=5)

    parser.add_argument("--output_dir", type=str, default="../checkpoints")

    parser.add_argument("--checkpoint_dir", type=str)

    parser.add_argument("--eval_output_dir", type=str)

    parser.add_argument("--global_steps", type=int, default=0)

    args = parser.parse_args()

    return args

