test_flops_profiler.py 3.73 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
aiss's avatar
aiss committed
5
6
7
8
9
10
11

import torch
import pytest
import deepspeed
from deepspeed.profiling.flops_profiler import get_model_profile
from unit.simple_model import SimpleModel, random_dataloader
from unit.common import DistributedTest
aiss's avatar
aiss committed
12
from unit.util import required_minimum_torch_version
aiss's avatar
aiss committed
13

aiss's avatar
aiss committed
14
pytestmark = pytest.mark.skipif(not required_minimum_torch_version(major_version=1, minor_version=3),
aiss's avatar
aiss committed
15
16
17
18
19
20
21
22
23
24
25
                                reason='requires Pytorch version 1.3 or above')


def within_range(val, target, tolerance):
    return abs(val - target) / target < tolerance


TOLERANCE = 0.05


class LeNet5(torch.nn.Module):
aiss's avatar
aiss committed
26

aiss's avatar
aiss committed
27
28
29
30
    def __init__(self, n_classes):
        super(LeNet5, self).__init__()

        self.feature_extractor = torch.nn.Sequential(
aiss's avatar
aiss committed
31
            torch.nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1),
aiss's avatar
aiss committed
32
33
            torch.nn.Tanh(),
            torch.nn.AvgPool2d(kernel_size=2),
aiss's avatar
aiss committed
34
            torch.nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),
aiss's avatar
aiss committed
35
36
            torch.nn.Tanh(),
            torch.nn.AvgPool2d(kernel_size=2),
aiss's avatar
aiss committed
37
            torch.nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5, stride=1),
aiss's avatar
aiss committed
38
39
40
41
            torch.nn.Tanh(),
        )

        self.classifier = torch.nn.Sequential(
aiss's avatar
aiss committed
42
            torch.nn.Linear(in_features=120, out_features=84),
aiss's avatar
aiss committed
43
            torch.nn.Tanh(),
aiss's avatar
aiss committed
44
            torch.nn.Linear(in_features=84, out_features=n_classes),
aiss's avatar
aiss committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
        )

    def forward(self, x):
        x = self.feature_extractor(x)
        x = torch.flatten(x, 1)
        logits = self.classifier(x)
        probs = torch.nn.functional.softmax(logits, dim=1)
        return logits, probs


class TestFlopsProfiler(DistributedTest):
    world_size = 1

    def test(self):
        config_dict = {
            "train_batch_size": 1,
            "steps_per_print": 1,
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 0.001,
                }
            },
            "zero_optimization": {
                "stage": 0
            },
            "fp16": {
                "enabled": True,
            },
            "flops_profiler": {
                "enabled": True,
                "step": 1,
                "module_depth": -1,
                "top_modules": 3,
            },
        }
        hidden_dim = 10
        model = SimpleModel(hidden_dim, empty_grad=False)

aiss's avatar
aiss committed
84
        model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters())
aiss's avatar
aiss committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

        data_loader = random_dataloader(model=model,
                                        total_samples=50,
                                        hidden_dim=hidden_dim,
                                        device=model.device,
                                        dtype=torch.half)
        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()
            if n == 3: break
        assert within_range(model.flops_profiler.flops, 200, tolerance=TOLERANCE)
        assert model.flops_profiler.params == 110

    def test_flops_profiler_in_inference(self):
        mod = LeNet5(10)
        batch_size = 1024
        input = torch.randn(batch_size, 1, 32, 32)
        flops, macs, params = get_model_profile(
            mod,
            tuple(input.shape),
            print_profile=True,
            detailed=True,
            module_depth=-1,
            top_modules=3,
            warm_up=1,
            as_string=False,
            ignore_modules=None,
        )
        print(flops, macs, params)
        assert within_range(flops, 866076672, TOLERANCE)
        assert within_range(macs, 426516480, TOLERANCE)
        assert params == 61706