#!/usr/bin/env python3 # -*- coding: utf-8 -*- import re import sys from collections import defaultdict def parse_qz_file(qz_file_path): """解析 qz.txt 文件,提取性能数据和 test_np 值,支持多个 test_np""" data_by_test_np = {} # {test_np: [raw_data]} current_test_np = None with open(qz_file_path, 'r', encoding='utf-8') as f: for line in f: line = line.strip() # 检查是否是 test_np 行 if line.startswith('test_np='): current_test_np = int(line.split('=')[1]) if current_test_np not in data_by_test_np: data_by_test_np[current_test_np] = [] continue # 解析性能数据行 match = re.match( r'Operation: (\w+), Size=(\d+) bytes, baseline=([\d.]+), min_other=([\d.]+) \((RING|TREE),(LL|SIMPLE)\), diff=([\d.]+)%', line ) if match and current_test_np is not None: operation = match.group(1) size = int(match.group(2)) baseline = float(match.group(3)) min_other = float(match.group(4)) algorithm = match.group(5) protocol = match.group(6) diff = float(match.group(7)) algorithm_map = { 'RING': 'NCCL_ALGO_RING', 'TREE': 'NCCL_ALGO_TREE' } protocol_map = { 'LL': 'NCCL_PROTO_LL', 'SIMPLE': 'NCCL_PROTO_SIMPLE' } operation_map = { 'AllReduce': 'ncclFuncAllReduce', 'AllGather': 'ncclFuncAllGather', 'Broadcast': 'ncclFuncBroadcast', 'ReduceScatter': 'ncclFuncReduceScatter', 'Reduce': 'ncclFuncReduce' } data_by_test_np[current_test_np].append({ 'operation': operation_map[operation], 'size': size, 'baseline': baseline, 'min_other': min_other, 'algorithm': algorithm_map[algorithm], 'protocol': protocol_map[protocol], 'diff': diff }) return data_by_test_np def find_mergeable_sequences(raw_data): """找出可以合并的连续成倍序列""" grouped = defaultdict(list) for d in raw_data: key = (d['operation'], d['algorithm'], d['protocol']) grouped[key].append(d) merge_sequences = [] for key, records in grouped.items(): records.sort(key=lambda x: x['size']) current_sequence = [records[0]] for i in range(1, len(records)): if records[i]['size'] == current_sequence[-1]['size'] * 2: current_sequence.append(records[i]) else: if len(current_sequence) >= 2: merge_sequences.append(current_sequence) current_sequence = [records[i]] if len(current_sequence) >= 2: merge_sequences.append(current_sequence) return merge_sequences def generate_all_adjustments(data_by_test_np): """为所有 test_np 生成调整代码""" all_code_lines = [ " // 根据 qz.txt 性能数据自动调整执行时间", " // 支持多个 nRanks 配置的优化", " // 调整策略:", " // - diff≥8%: size_range = [original×0.5, original×2.0]", " // - diff<8%: size_range = [original×0.75, original×1.5]", "" ] for test_np in sorted(data_by_test_np.keys()): raw_data = data_by_test_np[test_np] if not raw_data: continue merge_sequences = find_mergeable_sequences(raw_data) all_code_lines.append(f" // 优化配置: nRanks == {test_np}") all_code_lines.append(f" if (info->comm->nRanks == {test_np}) {{") # 处理合并的序列 for seq in merge_sequences: first = seq[0] last = seq[-1] min_multiplier = 0.5 if first['diff'] >= 8.0 else 0.75 max_multiplier = 2.0 if last['diff'] >= 8.0 else 1.5 adjusted_min = int(first['size'] * min_multiplier) adjusted_max = int(last['size'] * max_multiplier) all_code_lines.append( f" // {first['operation']} {first['algorithm']} {first['protocol']}: " f"{adjusted_min}~{adjusted_max} bytes (原始 {first['size']}~{last['size']})" ) all_code_lines.append( f" if (info->coll == {first['operation']} && algorithm == {first['algorithm']} && " f"protocol == {first['protocol']} && info->nBytes > {adjusted_min} && " f"info->nBytes <= {adjusted_max}) {{" ) all_code_lines.append(" *time = 0;") all_code_lines.append(" return ncclSuccess;") all_code_lines.append(" }") # 处理独立条目 merged_indices = set() for seq in merge_sequences: for d in seq: merged_indices.add(raw_data.index(d)) for i, d in enumerate(raw_data): if i not in merged_indices: multiplier = 0.5 if d['diff'] >= 8.0 else 0.75 adjusted_min = int(d['size'] * multiplier) adjusted_max = int(d['size'] * (2.0 if d['diff'] >= 8.0 else 1.5)) all_code_lines.append( f" if (info->coll == {d['operation']} && algorithm == {d['algorithm']} && " f"protocol == {d['protocol']} && info->nBytes > {adjusted_min} && " f"info->nBytes <= {adjusted_max}) {{" ) all_code_lines.append(" *time = 0;") all_code_lines.append(" return ncclSuccess;") all_code_lines.append(" }") all_code_lines.append(" }") # 结束当前 nRanks 条件 all_code_lines.append("") return "\n".join(all_code_lines) def modify_tuning_cc(tuning_cc_path, adjustment_code): """修改 tuning.cc 文件""" with open(tuning_cc_path, 'r', encoding='utf-8') as f: content = f.read() pattern = re.compile( r'(float lat = info->comm->latencies\[info->coll\]\[algorithm\]\[protocol\];)', re.DOTALL ) def replace_func(match): return match.group(1) + "\n" + adjustment_code content = re.sub(pattern, replace_func, content) with open(tuning_cc_path, 'w', encoding='utf-8') as f: f.write(content) def main(): if len(sys.argv) != 3: print("Usage: python parse_qz_and_modify_tuning.py ") sys.exit(1) qz_file_path = sys.argv[1] tuning_cc_path = sys.argv[2] data_by_test_np = parse_qz_file(qz_file_path) if not data_by_test_np: print("错误: qz.txt 中未找到有效的 test_np 和性能数据") sys.exit(1) print(f"找到 {len(data_by_test_np)} 个 test_np 配置:") for test_np, data in sorted(data_by_test_np.items()): print(f" test_np={test_np}: {len(data)} 条优化记录") adjustment_code = generate_all_adjustments(data_by_test_np) modify_tuning_cc(tuning_cc_path, adjustment_code) print(f"\ntuning.cc 文件已成功修改,包含 {len(data_by_test_np)} 个 nRanks 配置的优化") if __name__ == "__main__": main()