Commit 54860539 authored by Ruslan Svirschevski's avatar Ruslan Svirschevski
Browse files

type hints in Params4bit constructors

parent 74c00eb1
......@@ -2,7 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, TypeVar, Union, overload
from typing import Any, Dict, Optional, TypeVar, Union, overload
import warnings
import torch
......@@ -142,7 +142,7 @@ class Embedding(torch.nn.Embedding):
class Params4bit(torch.nn.Parameter):
def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, compress_statistics=True, quant_type='fp4'):
def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad=True, quant_state: QuantState = None, blocksize: int = 64, compress_statistics: bool = True, quant_type: str = 'fp4') -> "Params4bit":
if data is None:
data = torch.empty(0)
......@@ -155,7 +155,7 @@ class Params4bit(torch.nn.Parameter):
return self
@classmethod
def from_prequantized(cls, data, quantized_stats, requires_grad=False, device='cuda', **kwargs):
def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], requires_grad: bool = False, device='cuda', **kwargs) -> "Params4bit":
self = torch.Tensor._make_subclass(cls, data.to(device))
self.requires_grad = requires_grad
self.quant_state = QuantState.from_dict(qs_dict=quantized_stats, device=device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment