utils.py 3.42 KB
Newer Older
yuguo960516's avatar
gpt2  
yuguo960516 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from collections.abc import Mapping

import oneflow as flow

from libai.utils import distributed as dist


def pad_batch(x_dict, batch_size, last_batch_lack, is_last_batch):
    x = list(x_dict.values())[0]
    tensor_batch = x.shape[0]
    assert tensor_batch <= batch_size

    if tensor_batch == batch_size and not is_last_batch:
        return x_dict, batch_size

    valid_sample = tensor_batch - last_batch_lack
    data_parallel_size = dist.get_data_parallel_size()
    assert tensor_batch % data_parallel_size == 0
    tensor_micro_batch_size = tensor_batch // data_parallel_size
    padded_dict = {}
    for key, xi in x_dict.items():
        pad_shape = (batch_size, *xi.shape[1:])
        local_xi = xi.to_global(
            sbp=flow.sbp.broadcast, placement=flow.env.all_device_placement("cuda")
        ).to_local()
        padded_xi = flow.zeros(pad_shape, dtype=xi.dtype, device="cuda")
        padded_xi[:tensor_batch, ...] = padded_xi[:tensor_batch, ...] + local_xi
        for i in range(last_batch_lack - 1):
            start_idx = tensor_micro_batch_size * (data_parallel_size - i - 1) - 1
            padded_xi[start_idx:-1] = padded_xi[start_idx + 1 :]
        padded_xi = padded_xi.to_global(
            sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), placement=xi.placement
        ).to_global(sbp=xi.sbp)
        padded_dict[key] = padded_xi
    return padded_dict, valid_sample


def print_csv_format(results):
    """
    Print main metrics in a particular format
    so that they are easy to copypaste into a spreadsheet.

    Args:
        results (OrderedDict[dict]): task_name -> {metric -> score}
            unordered dict can also be printed, but in arbitrary order
    """
    assert isinstance(results, Mapping) or not len(results), results
    logger = logging.getLogger(__name__)
    for task, res in results.items():
        if isinstance(res, Mapping):
            # Don't print "AP-category" metrics since they are usually not tracked.
            important_res = [(k, v) for k, v in res.items() if "-" not in k]
            logger.info("copypaste: Task: {}".format(task))
            logger.info("copypaste: " + ",".join([k[0] for k in important_res]))
            logger.info("copypaste: " + ",".join(["{0:.4f}".format(k[1]) for k in important_res]))
        else:
            logger.info(f"copypaste: {task}={res}")


def flatten_results_dict(results):
    """
    Expand a hierarchical dict of scalars into a flat dict of scalars.
    If results[k1][k2][k3] = v, the returned dict will have the entry
    {"k1/k2/k3": v}.

    Args:
        results (dict):
    """
    r = {}
    for k, v in results.items():
        if isinstance(v, Mapping):
            v = flatten_results_dict(v)
            for kk, vv in v.items():
                r[k + "/" + kk] = vv
        else:
            r[k] = v
    return r