performance_evaluator.py 3.44 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
from time import time
from typing import Optional

import torch
import torch.distributed as dist
from torch import Tensor

from colossalai.cluster import DistCoordinator


def divide(x: float, y: float) -> float:
    if y == 0:
13
14
15
        return float("inf")
    elif y == float("inf"):
        return float("nan")
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
    return x / y


@torch.no_grad()
def all_reduce_mean(x: float, world_size: int) -> float:
    if world_size == 1:
        return x
    tensor = torch.tensor([x], device=torch.cuda.current_device())
    dist.all_reduce(tensor)
    tensor = tensor / world_size
    return tensor.item()


class Timer:
    def __init__(self) -> None:
        self.start_time: Optional[float] = None
32
        self.duration: float = 0.0
33
34
35
36
37
38
39
40
41
42

    def start(self) -> None:
        self.start_time = time()

    def end(self) -> None:
        assert self.start_time is not None
        self.duration += time() - self.start_time
        self.start_time = None

    def reset(self) -> None:
43
        self.duration = 0.0
44
45
46
47
48
49
50
51
52
53
54
55
56
57


class PerformanceEvaluator:
    """
        Callback for valuate the performance of the model.
    Args:
        actor_num_params: The number of parameters of the actor model.
        critic_num_params: The number of parameters of the critic model.
        initial_model_num_params: The number of parameters of the initial model.
        reward_model_num_params: The number of parameters of the reward model.
        enable_grad_checkpoint: Whether to enable gradient checkpointing.
        ignore_episodes: The number of episodes to ignore when calculating the performance.
    """

58
59
60
61
62
63
64
    def __init__(
        self,
        model_numel: int,
        enable_grad_checkpoint: bool = False,
        ignore_steps: int = 0,
        dp_world_size: Optional[int] = None,
    ) -> None:
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        self.model_numel = model_numel
        self.enable_grad_checkpoint = enable_grad_checkpoint
        self.ignore_steps = ignore_steps

        self.coordinator = DistCoordinator()
        self.dp_world_size = dp_world_size or self.coordinator.world_size
        self.disable: bool = False
        self.timer = Timer()
        self.num_samples: int = 0
        self.flop: int = 0

    def on_step_start(self, step: int) -> None:
        self.disable = self.ignore_steps > 0 and step < self.ignore_steps
        if self.disable:
            return
        torch.cuda.synchronize()
        self.timer.start()

    def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
        if self.disable:
            return
        torch.cuda.synchronize()
        self.timer.end()

        batch_size, seq_len = input_ids.shape

        self.num_samples += batch_size
        self.flop += batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint))

    def on_fit_end(self) -> None:
        avg_duration = all_reduce_mean(self.timer.duration, self.coordinator.world_size)
        avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12)
        mp_world_size = self.coordinator.world_size // self.dp_world_size
        avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size
        self.coordinator.print_on_master(
100
101
102
            f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, "
            f"avg_throughput: {avg_throughput}"
        )
103
        self.coordinator.print_on_master(
104
105
            f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}"
        )