test_linear.py 3.01 KB
Newer Older
chenzk's avatar
v1.0.8  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import pytest
import torch
from nanotron.fp8 import DTypes, FP8Linear, FP8Parameter, FP8Tensor
from torch import nn
from torch.optim import Adam


@pytest.mark.parametrize("is_bias", [True, False])
def test_fp8_linear_forward_pass(is_bias):
    input = torch.randn(16, 16, device="cuda", dtype=torch.float32)
    ref_input = input.detach().clone()
    ref_linear = nn.Linear(16, 16, bias=is_bias, device="cuda", dtype=torch.float32)

    fp8_linear = FP8Linear(16, 16, bias=is_bias, device="cuda:0")
    fp8_linear.weight = FP8Parameter(ref_linear.weight.detach().clone(), DTypes.FP8E4M3)

    if is_bias:
        fp8_linear.bias.data = ref_linear.bias.detach().clone()

    ref_output = ref_linear(ref_input)
    output = fp8_linear(input)

    assert isinstance(output, torch.Tensor)
    assert output.dtype == torch.float32
    assert torch.allclose(output, ref_output, rtol=0, atol=0.1)


# TODO(xrsrke): add cases where the input requires and don't require grad
@pytest.mark.parametrize("input_requires_grad", [True, False])
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
def test_fp8_linear_backward_pass(input_requires_grad, device):
    input = torch.randn(16, 16, device=device, dtype=torch.float32, requires_grad=input_requires_grad)
    ref_input = input.detach().clone().requires_grad_(True)
    ref_linear = nn.Linear(16, 16, device=device, dtype=torch.float32)
    fp8_linear = FP8Linear(16, 16, device=device)

    if device == "cpu":
        fp8_linear.weight.data = ref_linear.weight.detach().clone()
    else:
        fp8_linear.weight.data = FP8Tensor(ref_linear.weight.detach().clone(), dtype=DTypes.FP8E4M3)
    fp8_linear.bias.data = ref_linear.bias.detach().clone()

    ref_linear(ref_input).sum().backward()
    fp8_linear(input).sum().backward()

    # TODO(xrsrke): investigate why input.grad is so high tolerance
    # assert torch.allclose(input.grad, ref_input.grad, 0.2, 0.2) if input_requires_grad else True
    assert torch.allclose(fp8_linear.weight.grad, ref_linear.weight.grad, 0.1, 0.1)
    assert torch.allclose(fp8_linear.bias.grad, ref_linear.bias.grad, 0, 0.1)


# TODO(xrsrke): test if FP8Linear has all the methods of a torch.nn.Linear


def test_fp8_linear_attrs():
    fp8_linear = FP8Linear(16, 16, device="cuda:0")

    assert next(fp8_linear.parameters()) is not None
    assert all(p.requires_grad for p in fp8_linear.parameters()) is True


# TODO(xrsrke): test only calculating the gradients of the weight, bias, or input based
# on the requires_grad of the input, weight, or bias


def test_fp8_model_bwd():
    HIDEEN_SIZE = 128
    N_LAYERS = 5
    N_EPOCHS = 3

    input = torch.randn(HIDEEN_SIZE, HIDEEN_SIZE, device="cuda", requires_grad=True)

    model = nn.Sequential(
        *[nn.Sequential(FP8Linear(HIDEEN_SIZE, HIDEEN_SIZE, device="cuda"), nn.ReLU()) for _ in range(N_LAYERS)]
    )
    optim = Adam(model.parameters(), lr=1e-3)

    for _ in range(N_EPOCHS):
        optim.zero_grad()
        model(input).sum().backward()
        optim.step()

    assert all(p.grad is not None for p in model.parameters())