import pandas as pd
import argparse
from conf import config

# nvidia/cuda:11.7.1-cudnn8-runtime-ubuntu20.04
# pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime
BASE_NVIDIA_IMAGE_TAG = "nvidia/cuda:{cuda_version}-cudnn{cudnn_version}-{tag}-{op_system}"
BASE_TORCH_IMAGE_TAG = "pytorch/pytorch:{torch_version}-cuda{cuda_version}-cudnn{cudnn_version}-{tag}"


def generate():
    data = pd.read_csv(args.input_csv)

    for index, row in data.iterrows():
        op_system = row["操作系统"]
        cuda_version = row["Runtime版本"].replace("cuda", "")
        cudnn_version = config.CUDNN_CONFIG[cuda_version.split(".")[0]]
        torch_version = row["框架版本"]
        python_version = row["Python版本"]
        if args.devel_image:
            tag = "devel"
        else:
            tag = "runtime"
        if args.base_image_from == "nvidia":
            base_image_tag = BASE_NVIDIA_IMAGE_TAG.format(cuda_version=cuda_version,
                                                          cudnn_version=cudnn_version,
                                                          tag=tag,
                                                          op_system=op_system)
        else:
            base_image_tag = BASE_TORCH_IMAGE_TAG.format(cuda_version=cuda_version,
                                                         cudnn_version=cudnn_version,
                                                         tag=tag,
                                                         torch_version=torch_version)
        print(base_image_tag)
        break


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Generate docker build args.')


    parser.add_argument('--input-csv', type=str, default="AI内容协作表_GPU基础镜像(聂释隆).csv",
                        help='input csv file path')
    parser.add_argument('--base-image-from', type=str, default="nvidia", choices=["nvidia", "torch"],
                        help='choice base image from nvidia or torch')
    parser.add_argument('--devel-image', action='store_true', default=False,
                        help='build devel image')

    args = parser.parse_args()
    generate()
