test_shape_prop.py 2.05 KB
Newer Older
1
2
3
import pytest
import torch
import torchvision.models as tm
4
5
6
7
from packaging import version

from colossalai.testing.utils import parameterize
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
8
9
10
11
12
13

try:
    from colossalai._analyzer._subclasses import MetaTensorMode
    from colossalai._analyzer.fx import symbolic_trace
    from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
    from colossalai._analyzer.fx.symbolic_profile import register_shape_impl
14

15
16
17
18
19
20
21
22
23
24
25
26
    @register_shape_impl(torch.nn.functional.linear)
    def linear_impl(*args, **kwargs):
        assert True
        return torch.nn.functional.linear(*args, **kwargs)
except:
    pass


def _check_gm_validity(gm: torch.fx.GraphModule):
    for node in gm.graph.nodes:
        assert node.meta['info'].outputs, f'In {gm.__class__.__name__}, {node} has no output shape.'
        if node.op in [
27
28
29
                'call_module',    # can apply to params
                'call_function',    # can apply to params
                'call_method',    # can apply to params
30
        ]:
31
            assert hasattr(node.meta['info'], 'inputs'), f'In {gm.__class__.__name__}, {node} has no input shape.'
32
33


34
35
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@parameterize('m', tm_models)
36
37
38
39
40
41
42
43
44
45
46
47
def test_torchvision_shape_prop(m):
    with MetaTensorMode():
        model = m()
        data = torch.rand(100, 3, 224, 224)
    meta_args = {
        "x": data,
    }
    gm = symbolic_trace(model, meta_args=meta_args)
    shape_prop_pass(gm, data)
    _check_gm_validity(gm)


48
49
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@parameterize('m', tmm_models)
50
51
52
53
54
55
56
def test_timm_shape_prop(m):
    with MetaTensorMode():
        model = m()
        data = torch.rand(100, 3, 224, 224)
    meta_args = {
        "x": data,
    }
57

58
59
60
61
62
63
    gm = symbolic_trace(model, meta_args=meta_args)
    shape_prop_pass(gm, data)
    _check_gm_validity(gm)


if __name__ == "__main__":
64
65
    test_torchvision_shape_prop()
    test_timm_shape_prop()