#!/usr/bin/env python3
"""
universal_ratio_full.py
一键批量对比成对 Excel（含首 token 延迟）
python universal_ratio_full.py folder1 folder2 [folder3 ...]
输出：summary_ratio.xlsx（每张 Sheet 含 5 列百分比）
"""

import argparse
import os
import re
import pandas as pd
from openpyxl import Workbook
from openpyxl.styles import Alignment, Border, Side, Font, PatternFill
from openpyxl.utils import get_column_letter
DEFAULT_FONT = Font(name='Arial', size=10)
# ------------------ 样式 ------------------
thin = Side('thin')
border = Border(top=thin, bottom=thin, left=thin, right=thin)
center = Alignment(horizontal='center', vertical='center', wrap_text=True)

header_font = Font(bold=True, color='FFFFFF')
black_font = Font(name='微软雅黑', size=11, bold=True, color='000000')  # 000000代表黑色
blue_fill   = PatternFill('solid', fgColor='e0ffff')   # 纯蓝
orange_fill = PatternFill('solid', fgColor='faf0e6')   # 橙色
green_fill  = PatternFill('solid', fgColor='4de680')   # 绿色
red_fill    = PatternFill('solid', fgColor='FF7f50')   # 红色

# ================= 工具函数 =================
def build_pairs(folders):
    """按前缀(_tpX) 收集成对文件"""
    files_map = {}
    for folder in folders:
        for f in os.listdir(folder):
            if not f.endswith('.xlsx'):
                continue
            match = re.search(r'(.+_tp\d+)', f)
            if not match:
                continue
            prefix = match.group(1)
            files_map.setdefault(prefix, []).append(os.path.join(folder, f))
    return {k: v for k, v in files_map.items() if len(v) == 2}

def process_sheet(wb, prefix, path_A, path_B):
    df_A = pd.read_excel(path_A, header=None)
    df_B = pd.read_excel(path_B, header=None)

    # 对齐行数
    min_rows = min(len(df_A), len(df_B))
    df_A, df_B = df_A.iloc[:min_rows], df_B.iloc[:min_rows]

    # 列索引（总、生成、首token、单路、不带首）
    pct_cols = [5, 6, 7, 10, 11]
    titles   = [
        '总吞吐量(%)',
        '生成吞吐量(%)',
        '首token延迟(%)',
        '单路生成吞吐(%)',
        '不带首字生成吞吐(%)'
    ]

    pct_rows, pct_df = [], pd.DataFrame()
    for r in range(2, len(df_A)):
        a_vals = [df_A.iloc[r, c] for c in pct_cols]
        b_vals = [df_B.iloc[r, c] for c in pct_cols]
        pct = []
        for i, (a, b) in enumerate(zip(a_vals, b_vals)):
            if i == 2:  # 首 token 延迟：倒序
                pct.append(round(b / a * 100, 2) if a else None)
            else:
                pct.append(round(a / b * 100, 2) if b else None)
        pct_rows.append(pct)

    pct_df = pd.DataFrame(pct_rows, columns=titles)
    avg = pct_df.mean().round(2).tolist()
    max_row_ws = [idx + 3 for idx in pct_df.idxmax()]
    min_row_ws = [idx + 3 for idx in pct_df.idxmin()]

    ws = wb.create_sheet(title=prefix)
    rows_A, rows_B = df_A.values.tolist(), df_B.values.tolist()
    k_cols, l_cols = len(df_A.columns), len(df_B.columns)

    # ---------- 表A ----------
    for r_idx, row in enumerate(rows_A, 1):
        for c_idx, val in enumerate(row, 1):
            ws.cell(row=r_idx, column=c_idx, value=val)
    ws.merge_cells(start_row=1, start_column=1, end_row=1, end_column=k_cols)
    ws.cell(row=1, column=1, value=os.path.basename(path_A)).font = header_font
    ws.cell(row=1, column=1).fill = orange_fill
    ws.cell(row=1, column=1).alignment = center

    # 空白1
    blank1 = k_cols + 1
    for r in range(1, len(rows_A) + 1):
        ws.cell(row=r, column=blank1, value=None)

    # ---------- 表B ----------
    l_start = blank1 + 1
    for r_idx, row in enumerate(rows_B, 1):
        for c_idx, val in enumerate(row, 1):
            ws.cell(row=r_idx, column=l_start + c_idx - 1, value=val)
    ws.merge_cells(start_row=1, start_column=l_start, end_row=1, end_column=l_start + l_cols - 1)
    ws.cell(row=1, column=l_start, value=os.path.basename(path_B)).font = header_font
    ws.cell(row=1, column=l_start).fill = orange_fill
    ws.cell(row=1, column=l_start).alignment = center

    # 空白2
    blank2 = l_start + l_cols
    for r in range(1, len(rows_A) + 1):
        ws.cell(row=r, column=blank2, value=None)


    # 百分比区
    pct_start = blank2 + 1  # 数据从 blank2+2 列开始
    for c_idx, title in enumerate(titles, 0):
        ws.cell(row=2, column=pct_start + c_idx, value=title).font = header_font
        ws.cell(row=2, column=pct_start + c_idx).fill = blue_fill

    # 写入数据
    for r_idx, vals in enumerate(pct_rows, 0):
        for c_idx, val in enumerate(vals, 0):
            ws.cell(row=3 + r_idx, column=pct_start + c_idx, value=val)

    # 平均值行（关键修改：标签列向左移动）
    avg_row = 3 + len(pct_rows) + 1
    ws.cell(row=avg_row, column=pct_start - 1, value='平均值').font = black_font  # 改为-2
    for c_idx, val in enumerate(avg, 0):
        ws.cell(row=avg_row, column=pct_start + c_idx, value=val)

    # 高亮最大最小
    for c_idx, col_name in enumerate(titles, 0):
        max_cell = ws.cell(row=max_row_ws[c_idx], column=pct_start + c_idx)
        min_cell = ws.cell(row=min_row_ws[c_idx], column=pct_start + c_idx)
        max_cell.fill = green_fill
        min_cell.fill = red_fill

    # 通用美化
    for row in ws.iter_rows():
        for cell in row:
            cell.alignment = center
            cell.border = border
            cell.font = DEFAULT_FONT
    for col in ws.columns:
        max_len = max(len(str(cell.value or '')) for cell in col) + 2
        ws.column_dimensions[get_column_letter(col[0].column)].width = max_len

def main():
    parser = argparse.ArgumentParser(description='批量生成比例对比表（含首 token 延迟）')
    parser.add_argument('folders', nargs='+', help='两个或多个文件夹路径')
    args = parser.parse_args()

    pairs = build_pairs(args.folders)
    if not pairs:
        print('未找到可配对文件')
        return

    wb = Workbook()
    wb.remove(wb.active)
    for prefix, paths in pairs.items():
        process_sheet(wb, prefix, *paths)

    output_name = os.getenv('OUTPUT') or 'summary_ratio.xlsx'
    wb.save(output_name)
    print('全部完成✅')

if __name__ == '__main__':
    main()