#!/usr/bin/env python3

import argparse
import os
import subprocess
import datetime

'''
./cuobjdump MatrixTranspose -lelf
./cuobjdump MatrixTranspose -lptx
./cuobjdump MatrixTranspose -all

./cuobjdump MatrixTranspose -elf
./cuobjdump MatrixTranspose -ptx
./cuobjdump MatrixTranspose -sass
./cuobjdump MatrixTranspose -symbols

./cuobjdump MatrixTranspose -xelf hipv4-amdgcn-amd-amdhsa--gfx926
./cuobjdump MatrixTranspose -xptx hipv4-amdgcn-amd-amdhsa--gfx906
'''


parser = argparse.ArgumentParser(description='cuobjdump extracts information from CUDA binary files (both standalone and those embedded in host binaries) and presents them in human readable format.')

group = parser.add_mutually_exclusive_group()
group.add_argument('--version', '-V', action='store_true', help='Print version information on this tool.')
group.add_argument('--list-elf', '-lelf', action='store_true', help='List all the ELF files available in the fatbin.')
group.add_argument('--list-ptx', '-lptx', action='store_true', help='List all the PTX files available in the fatbin.')
# group.add_argument('--list-text', '-ltext', metavar='<partial_file_name>', help='List all the text binary function names available in the fatbin.')
group.add_argument('--all-fatbin', '-all', action='store_true', help='Dump all fatbin sections.')
group.add_argument('--dump-elf', '-elf', action='store_true', help='Dump ELF Object sections.')
group.add_argument('--dump-elf-symbols', '-symbols', action='store_true', help='Dump ELF symbol names.')
group.add_argument('--dump-ptx', '-ptx', action='store_true', help='Dump ptx for all listed device functions.')
# group.add_argument('--dump-resource-usage', '-res-usage', action='store_true', help='Dump resource usage for each ELF.')
group.add_argument('--dump-sass', '-sass', action='store_true', help='Dump CUDA assembly for a single cubin file or all cubin files embedded in the binary.')
group.add_argument('--extract-elf', '-xelf', metavar='extract-elf', type=str, help='Extract ELF file(s) containing <partial_file_name>.')
group.add_argument('--extract-ptx', '-xptx', metavar='extract-ptx', type=str, help='Extract PTX file(s) containing <partial_file_name>.')
# parser.add_argument('--extract-text', '-xtext', metavar='<partial_file_name>', nargs='*', help='Extract text binary encoding file(s) containing <partial_file_name>.')
# parser.add_argument('--function', '-fun', metavar='<function_name>', nargs='*', help='Specify names of device functions whose fat binary structures must be dumped.')
# parser.add_argument('--function-index', '-findex', metavar='<function_index>', nargs='*', help='Specify symbol table index of the function whose fat binary structures must be dumped.')
parser.add_argument('--gpu-architecture', '-arch', metavar='<gpu_architecture>', choices=['gfx906', 'gfx926', 'sm_35', 'sm_37', 'sm_50', 'sm_52', 'sm_53', 'sm_60', 'sm_61', 'sm_62', 'sm_70', 'sm_72', 'sm_75', 'sm_80', 'sm_86', 'sm_87', 'sm_89', 'sm_90'], help='Specify GPU Architecture.')
parser.add_argument('infile', metavar='<infile>', type=str, help='Specify the input file')
# parser.add_argument('--help', '-h', action='store_true', help='Print help information on this tool.')
# group.add_argument('--options-file', '-optf', metavar='<file>', nargs='*', help='Include command line options from specified file.')
# parser.add_argument('--sort-functions', '-sort', action='store_true', help='Sort functions when dumping sass.')

run_on_shell=True
DEBUG = False

def print_title(input):
    print("\nFatbin elf code:", )
    print("=" * 30)
    print("arch = {}".format(input))
    print("code version = [5,2]")
    print("host = linux")
    print("compile_size = 64bit")

def output_cmd(cmd):
    if DEBUG:
        print("cmd: {}".format(cmd))

def check_ROCM_PATH():
    if 'ROCM_PATH' in os.environ:
        return os.environ['ROCM_PATH']
    else:
        exit('Error: ROCM_PATH is not set')

def exec_cmd(cmd):
    output_cmd(cmd)
    status = subprocess.run(cmd, shell=run_on_shell, check=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding='utf-8')
    if status.returncode != 0:
        print("cuobjdump error: {}\n{}".format(cmd, status.stderr))
        exit(1)
    return status.stdout

import re
def parse_arch(arch):
    if arch.startswith('sm_'):
        arch = "gfx906"
    if arch.find("gfx") == -1:
        print("Error: arch {} is not supported".format(arch))
        exit(1)
    m = re.match(r'gfx(\d+)', arch)
    if m:
        return "hipv4-amdgcn-amd-amdhsa--gfx" + m.group(1)
    

def dump_elf(input):
    cmd = "readelf -a --wide {}\n".format(input)
    output = exec_cmd(cmd)
    print_title(input)
    print(output)

def dump_ptx(input):
    dump_sass(input)

def dump_symbols(input):
    cmd = "readelf -s --wide {}".format(input)
    output = exec_cmd(cmd)
    print_title(input)
    print(output)

def dump_sass(input):
    cmd = "llvm-objdump -d {}".format(input)
    output = exec_cmd(cmd)
    print_title(input)
    print("\ncode for {}".format(input))
    num = 0
    lines = output.split("\n")
    for line in lines:
        if not line.strip() or line.startswith(input) or line.startswith("Disassembly of section"):
            continue
        print(line)
        num += 1
        if num > 25:
            break
    if len(lines) > 25:
        print(" "* 10, "." * 10)

def extract_elf_internal(para, original_dir):
    cmd = "cp {} {}".format(para, original_dir)
    exec_cmd(cmd)

def extract_elf(para, triple_list, original_dir, bool_elf=True):
    if para == "all":
        for triple in triple_list:
            extract_elf_internal(triple, original_dir)
    else:
        if para not in triple_list:
            print("There isn't {}: {}".format("elf" if bool_elf else "ptx",para))
            exit(0)
        extract_elf_internal(para, original_dir)
    
def extract_ptx(para, triple_list, original_dir):
    extract_elf(para, triple_list, original_dir)


def main():
    # 解析命令行参数
    args = parser.parse_args()

    if args.version:
        print('cuobjdump: version 1.0')
        return
    
    check_ROCM_PATH()

    input = args.infile
    if not os.path.exists(input):
        print("Error: Not exist Input file {}".format(input))
        exit(1)

    # get full path of filename
    input_full = os.path.abspath(input)
    input_name = os.path.basename(input_full)
    
    if DEBUG:
        print("input_full:{}".format(input_full))
        print("input_name:{}".format(input_name))

    # create dir
    original_dir = os.getcwd()

    ct = datetime.datetime.now()
    ts = ct.timestamp()
    os.mkdir("{}".format(ts))
    os.chdir("{}".format(ts))
    work_dir = os.getcwd()
    if DEBUG:
        print("work_dir:{}".format(work_dir))

    # copy input to work dir
    cmd = "cp {} {}".format(input_full, input_name)
    output_cmd(cmd)
    status = subprocess.run(cmd, shell=run_on_shell, check=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding='utf-8')
    if status.returncode != 0:
        print("Error: cp {}".format(status.stderr))
        exit(1)

    # get triple
    cmd = "llvm-amdgpu-objdump --show-target-triple --inputs={}".format(input_name)
    output_cmd(cmd)
    status = subprocess.run(cmd, shell=run_on_shell, check=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding='utf-8')
    if status.returncode != 0:
        print("Error: llvm-amdgpu-objdump {}".format(status.stderr))
        exit(1)
    if status.stdout.find("amdgcn-amd-amdhsa") == -1:
        print("{} is not a amdgpu object file".format(input))
        exit(0)
    triple_list = [triple for triple in status.stdout.split("\n") if triple]
    # if arch not in triple_list:
    #     print("There isn't triple: {} in {}. \nCurrent file is:{}".format(arch, input, status.stdout))
    #     exit(0)

    # dump hip_fatbin
    cmd = "objcopy {0} --dump-section=.hip_fatbin={0}.hipfb.ori".format(input_name)
    output_cmd(cmd)
    exec_cmd(cmd)
    
    # unbundle hip_fatbin
    cmd = "clang-offload-bundler --unbundle --input={}.hipfb.ori --type=o --targets=".format(input_name)
    cmd += ",".join(triple_list)
    for triple in triple_list:
        cmd += " --output={}".format(triple)
    exec_cmd(cmd)

    arch = ""
    if args.gpu_architecture:
        arch = parse_arch(args.gpu_architecture)
        print("gpu_architecture: {}".format(arch))


    if arch and arch not in triple_list:
        print("There isn't arch: {}".format(args.gpu_architecture))
        exit(0)

    if args.list_elf or args.list_ptx:
        for triple in triple_list:
            print("ELF file", " "* 4, triple)

    if args.all_fatbin:
        for target in triple_list:
            if target:
                print_title(target)
        

    if args.dump_elf_symbols:
        if arch:
            dump_symbols(arch)
        else:
            for triple in triple_list:
                dump_symbols(triple)

    
    if args.dump_elf:
        if arch:
            dump_elf(arch)
        else:
            for triple in triple_list:
                dump_elf(triple)

    if args.dump_ptx:
        for target in triple_list:
            dump_ptx(target)

    if args.dump_sass:
        for target in triple_list:
            dump_sass(target)


    if args.extract_elf:
        extract_elf(args.extract_elf, triple_list, original_dir)

    
    if args.extract_ptx:
        extract_ptx(args.extract_ptx, triple_list, original_dir)



    # # cp hip_fatbin to output
    os.chdir(original_dir)

    cmd = "rm -rf {}".format(work_dir)
    exec_cmd(cmd)


main()

