# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

GLOBAL_BATCH_SIZE = 8
MICRO_BATCH_SIZE = 1


def get_train_ds_config(offload, stage=2, precision="fp16"):

    device = "cpu" if offload else "none"
    zero_opt_dict = {
        "stage": stage,
        "offload_param": {
            "device": device
        },
        "offload_optimizer": {
            "device": device
        },
        "stage3_param_persistence_threshold": 1e4,
        "stage3_max_live_parameters": 3e7,
        "stage3_prefetch_bucket_size": 3e7,
    }
    ds_config = {
        "train_batch_size": GLOBAL_BATCH_SIZE,
        "train_micro_batch_size_per_gpu": MICRO_BATCH_SIZE,
        "steps_per_print": 1,
        "zero_optimization": zero_opt_dict,
        "gradient_clipping": 1.0,
        "prescale_gradients": False,
        "wall_clock_breakdown": False,
        "checkpoint": {
            "use_node_local_storage": True
        }

    }
    if precision == "fp16":
        ds_config["fp16"] = {
            "enabled": True,
            "loss_scale": 0,
            "loss_scale_window": 500,
            "hysteresis": 2,
            "min_loss_scale": 1,
            "initial_scale_power":12
        }
    elif precision == "bf16":
        ds_config["bf16"] = {"enabled": True}
    else:
        raise ValueError("Mixed Precision type must be selected")
    return ds_config
