test_fused_dense.py 7.46 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import math

import torch
import torch.nn.functional as F
import pytest

from einops import rearrange

from flash_attn.ops.fused_dense import FusedDenseTD, FusedDenseGeluDenseTD
from flash_attn.ops.fused_dense import FusedDenseResidual, FusedDenseResGeluDense


@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
@pytest.mark.parametrize('out_features', [1024, 4096])
@pytest.mark.parametrize('in_features', [1024, 4096])
def test_fused_linear_bias(in_features, out_features, dtype):
    device = 'cuda'
    rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
    # set seed
    torch.random.manual_seed(0)
    batch_size = 8
    seqlen = 512
    x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True)
    x = x_pt.detach().clone().requires_grad_()
    model_pt = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
    model = FusedDenseTD(in_features, out_features, device=device, dtype=dtype)
    with torch.no_grad():
        model.weight.copy_(model_pt.weight)
        model.bias.copy_(model_pt.bias)
    out_pt = model_pt(x_pt)
    out = model(x)
    # with torch.no_grad():
    #     out_fl = F.linear(x_pt.float(), model.weight.float(), model.bias.float()).half()
    assert torch.allclose(out, out_pt, rtol=rtol, atol=atol)

    # If we don't divide by batch_size, the gradient gets a bit too large.
    g = torch.randn_like(out) / 32
    out_pt.backward(g)
    out.backward(g)
    assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
    # The error for d_weight and d_bias is quite a bit higher
    assert torch.allclose(model.weight.grad, model_pt.weight.grad, rtol=rtol, atol=atol * 10)
    assert torch.allclose(model.bias.grad, model_pt.bias.grad, rtol=rtol, atol=atol * 5)


@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
@pytest.mark.parametrize('out_features,in_features', [(1024, 1024), (4096, 4096)])
def test_fused_linear_bias_residual(in_features, out_features, dtype):
    device = 'cuda'
    rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
    # set seed
    torch.random.manual_seed(0)
    batch_size = 8
    seqlen = 512
    x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True)
    x = x_pt.detach().clone().requires_grad_()
    model_pt = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
    model = FusedDenseResidual(in_features, out_features, device=device, dtype=dtype)
    with torch.no_grad():
        model.weight.copy_(model_pt.weight)
        model.bias.copy_(model_pt.bias)
    out_pt = model_pt(x_pt) + F.gelu(x_pt)  # Just add some random function of the residual x_pt
    out, x_copy = model(x)
    out = out + F.gelu(x_copy)
    assert torch.allclose(out, out_pt, rtol=rtol, atol=atol * 2)

    # If we don't divide by batch_size, the gradient gets a bit too large.
    g = torch.randn_like(out) / 32
    out_pt.backward(g)
    out.backward(g)
    assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
    # The error for d_weight and d_bias is quite a bit higher
    assert torch.allclose(model.weight.grad, model_pt.weight.grad, rtol=rtol, atol=atol * 10)
    assert torch.allclose(model.bias.grad, model_pt.bias.grad, rtol=rtol, atol=atol * 5)


@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
@pytest.mark.parametrize('heuristic', [1, -1])
@pytest.mark.parametrize('checkpoint_lvl', [0, 1, 2])
@pytest.mark.parametrize('out_features', [1024, 4096])
@pytest.mark.parametrize('in_features', [1024, 4096])
def test_fused_dense_gelu_dense(in_features, out_features, checkpoint_lvl, heuristic, dtype):
    device = 'cuda'
    rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
    # set seed
    torch.random.manual_seed(0)
    batch_size = 8
    seqlen = 512
    x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True)
    x = x_pt.detach().clone().requires_grad_()
    model_pt_fc1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
    model_pt_fc2 = torch.nn.Linear(out_features, in_features, device=device, dtype=dtype)
    model = FusedDenseGeluDenseTD(in_features, out_features, in_features,
                                  checkpoint_lvl=checkpoint_lvl, heuristic=heuristic,
                                  device=device, dtype=dtype)
    with torch.no_grad():
        model.fc1.weight.copy_(model_pt_fc1.weight)
        model.fc1.bias.copy_(model_pt_fc1.bias)
        model.fc2.weight.copy_(model_pt_fc2.weight)
        model.fc2.bias.copy_(model_pt_fc2.bias)
    out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate='tanh'))
    out = model(x)
    assert torch.allclose(out, out_pt, rtol=rtol, atol=atol)

    # If we don't divide by batch_size, the gradient gets a bit too large.
    g = torch.randn_like(out) / 32
    out_pt.backward(g)
    out.backward(g)
    assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
    # The error for d_weight and d_bias is quite a bit higher
    assert torch.allclose(model.fc1.weight.grad, model_pt_fc1.weight.grad, rtol=rtol, atol=atol * 10)
    assert torch.allclose(model.fc1.bias.grad, model_pt_fc1.bias.grad, rtol=rtol, atol=atol * 5)
    assert torch.allclose(model.fc2.weight.grad, model_pt_fc2.weight.grad, rtol=rtol, atol=atol * 10)
    assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5)


@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
@pytest.mark.parametrize('checkpoint_lvl', [0, 1, 2])
@pytest.mark.parametrize('out_features', [1024, 4096])
@pytest.mark.parametrize('in_features', [1024, 4096])
def test_fused_dense_residual_gelu_dense(in_features, out_features, checkpoint_lvl, dtype):
    device = 'cuda'
    rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
    # set seed
    torch.random.manual_seed(0)
    batch_size = 8
    seqlen = 512
    x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True)
    x = x_pt.detach().clone().requires_grad_()
    model_pt_fc1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
    model_pt_fc2 = torch.nn.Linear(out_features, in_features, device=device, dtype=dtype)
    model = FusedDenseResGeluDense(in_features, out_features, in_features,
                                   checkpoint_lvl=checkpoint_lvl,
                                   device=device, dtype=dtype)
    with torch.no_grad():
        model.fc1.weight.copy_(model_pt_fc1.weight)
        model.fc1.bias.copy_(model_pt_fc1.bias)
        model.fc2.weight.copy_(model_pt_fc2.weight)
        model.fc2.bias.copy_(model_pt_fc2.bias)
    out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate='tanh')) + F.gelu(x_pt)
    out, x_copy = model(x)
    out = out + F.gelu(x_copy)
    assert torch.allclose(out, out_pt, rtol=rtol, atol=atol * 2)

    # If we don't divide by batch_size, the gradient gets a bit too large.
    g = torch.randn_like(out) / 32
    out_pt.backward(g)
    out.backward(g)
    assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
    # The error for d_weight and d_bias is quite a bit higher
    assert torch.allclose(model.fc1.weight.grad, model_pt_fc1.weight.grad, rtol=rtol, atol=atol * 10)
    assert torch.allclose(model.fc1.bias.grad, model_pt_fc1.bias.grad, rtol=rtol, atol=atol * 5)
    assert torch.allclose(model.fc2.weight.grad, model_pt_fc2.weight.grad, rtol=rtol, atol=atol * 10)
    assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5)