test_fp8_parameter.py 705 Bytes
Newer Older
chenzk's avatar
v1.0.3  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
from nanotron.fp8 import DTypes, FP8Parameter, FP8Tensor
from nanotron.fp8.meta import FP8Meta


def test_create_fp8_parameter():
    # TODO(xrsrke): test FP8E5M2 format
    # TODO(xrsrke): test take a cpu tensor
    tensor = torch.randn(16, 16, device="cuda", dtype=torch.float32)

    fp8_parameter = FP8Parameter(tensor, DTypes.FP8E4M3)

    assert isinstance(fp8_parameter.data, FP8Tensor)
    assert fp8_parameter.requires_grad is True
    assert fp8_parameter.grad is None
    assert isinstance(fp8_parameter.fp8_meta, FP8Meta)
    assert isinstance(fp8_parameter.data.fp8_meta, FP8Meta)


# TODO(xrsrke): add test for preventing torch autograd do the backward pass
# on a FP8Parameter