import os
import re
import subprocess
from pathlib import Path

import torch

ROOT_DIR = Path(__file__).parent.resolve()

def _run_cmd(cmd, shell=False):
    try:
        return subprocess.check_output(cmd, cwd=ROOT_DIR, stderr=subprocess.DEVNULL, shell=shell).decode("ascii").strip()
    except Exception:
        return None

def _get_version():
    version = '2.5.1'
    if os.path.exists(ROOT_DIR / "version.txt"):
        with open(ROOT_DIR / "version.txt", "r") as f:
            content = f.read().strip()
        version_match = re.search("__version__\s*=\s*['\"]([^'\"]+)['\"]", content)
        if version_match:
            print(f"version_match.group(1) = {version_match.group(1)}")
            print(f"version_match.group(0) = {version_match.group(0)}")
            version = version_match.group(1)
    else:
        version = '2.5.1'
    if os.getenv("BUILD_VERSION"):
        version = os.getenv("BUILD_VERSION")
    return version

def _update_hcu_version(version, das_version, sha, abi, dtk, torch_version, branch):
    """
    修改 __hcu_version__ 的值，不改变其他内容。

    :param file_path: 要修改的 .py 文件路径
    :param new_value: 新的版本号字符串，比如 '2.5.1+das.opt1.dtk25042'
    """
    sha = "Unknown" if sha is None else sha
    hcu_version = f"{das_version}"
    file_path = ROOT_DIR / "version.txt"

    # 读取整个文件内容
    with open(file_path, "r", encoding="utf-8") as f:
        content = f.read()

    # 正则匹配 __hcu_version__ 行，并替换引号内的值
    pattern = r"(__hcu_version__\s*=\s*)['\"]([^'\"]*)['\"]"
    replacement = rf"\g<1>'{hcu_version}'"

    # 执行替换
    updated_content = re.sub(pattern, replacement, content)

    # 写回文件
    with open(file_path, "w", encoding="utf-8") as f:
        f.write(updated_content)

    return hcu_version

def get_origin_version():
    version_file = ROOT_DIR / 'version.txt'
    with open(version_file, encoding='utf-8') as f:
        exec(compile(f.read(), version_file, 'exec'))
    return locals()['__version__']

def _make_version_file(version, das_version, sha, abi, dtk, torch_version, branch):
    sha = "Unknown" if sha is None else sha
    # torch_version = '.'.join(torch_version.split('.')[:2])
    # hcu_version = f"{das_version}.git{sha}.abi{abi}.dtk{dtk}.torch{torch_version}"
    hcu_version = f"{das_version}"
    version_path = ROOT_DIR / "version.txt"
    with open(version_path, "a", encoding="utf-8") as f:
        #f.write(f"version = '{version}'\n")
        #f.write(f"git_hash = '{sha}'\n")
        #f.write(f"git_branch = '{branch}'\n")
        #f.write(f"abi = 'abi{abi}'\n")
        #f.write(f"dtk = '{dtk}'\n")
        #f.write(f"torch_version = '{torch_version}'\n")
        f.write(f"__hcu_version__ = '{hcu_version}'\n")
    return hcu_version

def _get_pytorch_version():
    if "PYTORCH_VERSION" in os.environ:
        return f"{os.environ['PYTORCH_VERSION']}"
    return torch.__version__

def get_version():
    ROCM_HOME = os.getenv("ROCM_PATH")
    print("ROCM_HOME = {ROCM_HOME}")
    if not ROCM_HOME:
        return get_origin_version()
    sha = _run_cmd(["git", "rev-parse", "HEAD"])
    sha = sha[:7]
    branch = _run_cmd(["git", "rev-parse", "--abbrev-ref", "HEAD"])
    tag = _run_cmd(["git", "describe", "--tags", "--exact-match", "@"])
    das_tag = _run_cmd(["git", "describe", "--abbrev=0"])
    print("-- Git branch:", branch)
    print("-- Git SHA:", sha)
    print("-- Git tag:", tag)
    torch_version = _get_pytorch_version()
    print("-- PyTorch:", torch_version)
    version = _get_version()
    print("-- Building version", version)
    # das_version = tag
    das_version = version+"+das.opt1"
    print("-- Building das_version", das_version)
    abi = _run_cmd(["echo '#include <string>' | gcc -x c++ -E -dM - | fgrep _GLIBCXX_USE_CXX11_ABI | awk '{print $3}'"], shell=True)
    print("-- _GLIBCXX_USE_CXX11_ABI:", abi)
    dtk = _run_cmd(["cat", os.path.join(ROCM_HOME, '.info/rocm_version')])
    dtk = ''.join(dtk.split('.'))
    print("-- DTK:", dtk)
    das_version += ".dtk" +dtk

    #return _make_version_file(version, das_version, sha, abi, dtk, torch_version, branch)
    return _update_hcu_version(version, das_version, sha, abi, dtk, torch_version, branch)
