test_vit.py 1.83 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import re

import torch
import pytest

from timm.models.vision_transformer import vit_base_patch16_224

from flash_attn.models.vit import vit_base_patch16_224 as flash_vit_base_patch16_224


@pytest.mark.parametrize('fused_dense_gelu_dense', [False, True])
# @pytest.mark.parametrize('fused_dense_gelu_dense', [False])
@pytest.mark.parametrize('optimized', [False, True])
# @pytest.mark.parametrize('optimized', [True])
def test_vit(optimized, fused_dense_gelu_dense):
    """Check that our implementation of ViT matches the timm's implementation:
    the output of our forward pass in fp16 should be around the same as
    timm' forward pass in fp16, when compared to timm's forward pass in fp32.
    """
    dtype = torch.float16
    device = 'cuda'

    kwargs = {}
    if optimized:
        kwargs = dict(use_flash_attn=True, fused_bias_fc=True, fused_dropout_add_ln=True)
    kwargs['fused_dense_gelu_dense'] = fused_dense_gelu_dense
    model = flash_vit_base_patch16_224(**kwargs).to(device=device, dtype=dtype)

    model_ref = vit_base_patch16_224(pretrained=True).to(device=device)
    model_timm = vit_base_patch16_224(pretrained=True).to(device=device, dtype=dtype)

    model.load_state_dict(model_ref.state_dict())

    model.eval()
    model_ref.eval()
    model_timm.eval()

    torch.manual_seed(0)
    batch_size = 2
    x = torch.randn(batch_size, 3, 224, 224, device=device, dtype=dtype)
    out = model(x)
    out_timm = model_timm(x)
    out_ref = model_ref(x.float())

    print(f'Output max diff: {(out - out_ref).abs().max().item()}')
    print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
    print(f'timm fp16 max diff: {(out_timm - out_ref).abs().max().item()}')
    print(f'timm fp16 mean diff: {(out_timm - out_ref).abs().mean().item()}')
    assert (out - out_ref).abs().max().item() < 3 * (out_timm - out_ref).abs().max().item()