metric_sorter.py 3.43 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Metric sort helpers for analyzer outputs.

This module keeps benchmark-specific metric ordering isolated from the generic
summary generation flow. Benchmarks without a registered sorter fall back to
plain string ordering.
"""

import re

_RCCL_PATTERN = re.compile(r'^(?P<bench>rccl-bw(?::[^/]+)?)/(?P<op>[^_]+)_(?P<size>\d+)_(?P<suffix>.+?)(?::\d+)?$')
_HPCG_PATTERN = re.compile(r'^(?P<bench>gpu-hpcg(?::[^/]+)?)/(?P<metric>.+?)(?::\d+)?$')
15
16
17
18
_HPCG_WORKLOAD_PATTERN = re.compile(
    r'^(?P<subject>final|ddot|waxpby|spmv|mg|total)_'
    r'p(?P<npx>\d+)x(?P<npy>\d+)x(?P<npz>\d+)_'
    r'n(?P<nx>\d+)x(?P<ny>\d+)x(?P<nz>\d+)_'
19
    r'(?P<type>flops|bandwidth|flops_per_process|bandwidth_per_process)$'
20
21
22
23
24
25
)
_HPCG_TIME_PATTERN = re.compile(
    r'^(?P<subject>setup_time|optimization_time|total_time)_'
    r'p(?P<npx>\d+)x(?P<npy>\d+)x(?P<npz>\d+)_'
    r'n(?P<nx>\d+)x(?P<ny>\d+)x(?P<nz>\d+)$'
)
26

27
28
29
30
31
32
33
34
35
36
37
38
39
_HPCG_SUBJECT_ORDER = {
    'setup_time': 0,
    'optimization_time': 1,
    'total_time': 2,
    'ddot': 3,
    'waxpby': 4,
    'spmv': 5,
    'mg': 6,
    'total': 7,
    'final': 8,
}

_HPCG_PERF_TYPE_ORDER = {
40
    'flops': 0,
41
    'bandwidth': 1,
42
    'flops_per_process': 2,
43
    'bandwidth_per_process': 3,
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
}


def _rccl_sort_key(metric_name):
    """Sort RCCL metrics by benchmark, operation, then numeric message size."""
    match = _RCCL_PATTERN.match(metric_name)
    if not match:
        return None

    return (
        0,
        match.group('bench'),
        match.group('op'),
        int(match.group('size')),
        match.group('suffix'),
        metric_name,
    )


63
64
65
66
67
68
69
70
71
72
73
74
def _hpcg_workload_key(match):
    """Return a numeric sort key for the HPCG process domain and local problem size."""
    return (
        int(match.group('npx')),
        int(match.group('npy')),
        int(match.group('npz')),
        int(match.group('nx')),
        int(match.group('ny')),
        int(match.group('nz')),
    )


75
76
77
78
79
80
81
def _hpcg_sort_key(metric_name):
    """Sort HPCG metrics roughly in the order they appear in rocHPCG logs."""
    match = _HPCG_PATTERN.match(metric_name)
    if not match:
        return None

    metric = match.group('metric')
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    time_match = _HPCG_TIME_PATTERN.match(metric)
    if time_match:
        return (
            1,
            match.group('bench'),
            _HPCG_SUBJECT_ORDER.get(time_match.group('subject'), 999),
            0,
            *_hpcg_workload_key(time_match),
            metric_name,
        )

    workload_match = _HPCG_WORKLOAD_PATTERN.match(metric)
    if workload_match:
        subject = workload_match.group('subject')
        metric_type = workload_match.group('type')
        return (
            1,
            match.group('bench'),
            _HPCG_SUBJECT_ORDER.get(subject, 999),
            _HPCG_PERF_TYPE_ORDER.get(metric_type, 999),
            *_hpcg_workload_key(workload_match),
            metric_name,
        )

106
107
108
    return (
        1,
        match.group('bench'),
109
        _HPCG_SUBJECT_ORDER.get(metric, 999),
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
        metric,
        metric_name,
    )


_SORTERS = (
    _rccl_sort_key,
    _hpcg_sort_key,
)


def sort_metrics(metrics):
    """Sort metrics with benchmark-specific sorters and a stable default fallback."""
    def sort_key(metric_name):
        for sorter in _SORTERS:
            key = sorter(metric_name)
            if key is not None:
                return key
        return (999, metric_name)

    return sorted(metrics, key=sort_key)