#!/usr/bin/env python3

import argparse
import os
import subprocess
import datetime
import sys

def check_required():
    return "--version" not in sys.argv and "-V" not in sys.argv 

# 初始化参数解析器
parser = argparse.ArgumentParser(description='nvprune prunes host object files and libraries to only contain device code for the specified targets.')

# 添加命令行参数
parser.add_argument('--version', '-V', action='store_true', help='Display version information')
parser.add_argument('--arch', '-arch', metavar='<architecture>',required= check_required(), help='Set the architecture')
parser.add_argument('-o', '--output', metavar='<outfile>', required= check_required(), help='Specify the output file')
parser.add_argument('infile', metavar='<infile>', type=str, nargs= 1 if check_required() else "?", help='Specify the input file')


run_on_shell=True
DEBUG = False

def output_cmd(cmd):
    if DEBUG:
        print("cmd: {}".format(cmd))

def check_ROCM_PATH():
    if 'ROCM_PATH' in os.environ:
        return os.environ['ROCM_PATH']
    else:
        exit('Error: ROCM_PATH is not set')

def exec_cmd(cmd):
    output_cmd(cmd)
    status = subprocess.run(cmd, shell=run_on_shell, check=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding='utf-8')
    if status.returncode != 0:
        print("nvprune error: {}\n{}".format(cmd, status.stderr))
        exit(1)
    return status.stdout

def check_run_and_tripe(input):
    cmd = "./{}".format(input)
    output = exec_cmd(cmd)
    print("run: {}".format(output))
    cmd = "llvm-amdgpu-objdump --show-target-triple --inputs={}".format(input)
    output = exec_cmd(cmd)
    print("triple: {}".format(output))

def prune_arch(input, arch, output):
    # check file exist
    if not os.path.exists(input):
        print("Error: Not exist Input file {}".format(input))
        exit(1)

    # get full path of filename
    input_full = os.path.abspath(input)
    input_name = os.path.basename(input_full)
    
    if DEBUG:
        print("input_full: {}".format(input_full))
        print("input_name: {}".format(input_name))

    # create dir
    original_dir = os.getcwd()

    ct = datetime.datetime.now()
    ts = ct.timestamp()
    os.mkdir("{}".format(ts))
    os.chdir("{}".format(ts))
    work_dir = os.getcwd()
    if DEBUG:
        print("work_dir:{}".format(work_dir))

    # copy input to work dir
    cmd = "cp {} {}".format(input_full, input_name)
    exec_cmd(cmd)

    # get triple
    cmd = "llvm-amdgpu-objdump --show-target-triple --inputs={}".format(input_name)
    status = exec_cmd(cmd)
    if status.find("amdgcn-amd-amdhsa") == -1:
        print("{} is not a amdgpu object file".format(input))
        exit(0)
    if DEBUG:
        print("triple:\n{}".format(status))
    triple_list = [triple for triple in status.split("\n") if triple]
    if arch not in triple_list:
        print("There isn't triple: {} in {}. \nCurrent file has trpile:\n{}".format(arch, input, status))
        exit(0)

    # dump hip_fatbin
    cmd = "objcopy {0} --dump-section=.hip_fatbin={0}.hipfb.ori".format(input_name)
    exec_cmd(cmd)
    
    # unbundle hip_fatbin
    cmd = "clang-offload-bundler --unbundle --input={}.hipfb.ori --type=o --targets=".format(input_name)
    cmd += ",".join(triple_list)
    for triple in triple_list:
        cmd += " --output={}".format(triple)
    exec_cmd(cmd)

    # bundle hip_fatbin
    cmd = "clang-offload-bundler -type=o -bundle-align=4096 -targets=host-x86_64-unknown-linux,{0} -input=/dev/null -input={0} -output={1}.hipfb".format(arch, input_name)
    exec_cmd(cmd)

    # update hip_fatbin section
    cmd = "objcopy {0} --update-section .hip_fatbin={0}.hipfb".format(input_name)
    exec_cmd(cmd)
    
    # cp hip_fatbin to output
    os.chdir(original_dir)

    cmd = "cp {} {}".format(os.path.join(work_dir, input_name), output)
    exec_cmd(cmd)

    # remove work dir
    cmd = "rm -rf {}".format(work_dir)
    exec_cmd(cmd)

    if DEBUG:
        check_run_and_tripe(output)


def main():
    # 解析命令行参数
    args = parser.parse_args()

    if args.version:
        print('nvprune: version 1.0')
        return
    
    if not args.infile or len(args.infile) != 1:
        exit("Erro: please set one input file!")
    infile = args.infile[0]

    if args.arch:
        print(f'Setting architecture to: {args.arch}')

    if DEBUG:
        print("infile: {}".format(infile))
        print("arch: {}".format(args.arch))
        print("output: {}".format(args.output))

    check_ROCM_PATH()

    prune_arch(infile, args.arch, args.output)


# ./nvprune --version
# ./nvprune MatrixTranspose -arch hipv4-amdgcn-amd-amdhsa--gfx906 -o gfx906.out

main()

