parameter.py 1.59 KB
Newer Older
chenzk's avatar
v1.0.8  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
from torch import nn

from nanotron.fp8.constants import FP8_DTYPES
from nanotron.fp8.dtypes import DTypes
from nanotron.fp8.meta import FP8Meta
from nanotron.fp8.tensor import FP8Tensor


class FP8Parameter(nn.Parameter):
    """
    A custom FP8 parameter class that allows gradients
    to flow into FP8 tensors (which are integer tensors).
    """

    def __new__(cls, data: torch.Tensor, dtype: DTypes, requires_grad: bool = True) -> nn.Parameter:
        assert isinstance(data, torch.Tensor), "data must be a tensor"
        assert data.dtype not in FP8_DTYPES, "Currently only support turn a non-fp8 tensor to an fp8 parameter"
        assert data.device != torch.device("cpu"), "FP8Parameter only supports CUDA tensors"
        # TODO(xrsrke): if the tensor is on cpu, then bypass quantization

        with torch.no_grad():
            # TODO(xrsrke): support take an FP8 Tensor as data
            # currently we can't only quantize a tensor to FP8 after the parameter is created
            # because it raise "Only Tensors of floating point and complex dtype can require gradients"
            self = torch.Tensor._make_subclass(cls, data, requires_grad)
            self._data = FP8Tensor(data, dtype=dtype)
        return self

    @property
    def data(self) -> FP8Tensor:
        return self._data

    @data.setter
    def data(self, data: FP8Tensor):
        self._data = data

    @property
    def fp8_meta(self) -> FP8Meta:
        return self.data.fp8_meta

    def __repr__(self) -> str:
        return f"FP8Parameter({self.data}, fp8_meta={self.fp8_meta}, requires_grad={self.requires_grad}"