Unverified Commit a8c9dfa6 authored by Aarni Koskela's avatar Aarni Koskela Committed by GitHub
Browse files

Fix some issues found by Mypy (#995)

* Fix erroneous type aliasing

* Fix `Optional` typings (see PEP 484)

* Add Mypy ignores

* Fix Mypy complaints for method tables

* Fix type for get_ptr

* Fix various Mypy errors

* Fix missed call to is_triton_available
parent 32be2897
......@@ -2,7 +2,7 @@ import operator
import warnings
from dataclasses import dataclass
from functools import reduce # Required in Python 3
from typing import Tuple, Optional, List
from typing import Tuple, Optional, Callable
from warnings import warn
import torch
......@@ -14,9 +14,6 @@ import bitsandbytes.functional as F
def prod(iterable):
return reduce(operator.mul, iterable, 1)
tensor = torch.Tensor
# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
......@@ -56,7 +53,10 @@ class GlobalOutlierPooler:
return torch.Tensor(list(self.outliers)).to(torch.int64)
def get_inverse_transform_indices(transform_tile: callable, tile_size: Tuple[int, int]):
def get_inverse_transform_indices(
transform_tile: Callable[[torch.Tensor], torch.Tensor],
tile_size: Tuple[int, int],
):
"""
Compute a permutation of indices that invert the specified (tiled) matrix transformation
......@@ -496,7 +496,7 @@ class MatMul4Bit(torch.autograd.Function):
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@staticmethod
def forward(ctx, A, B, out=None, bias=None, quant_state: F.QuantState = None):
def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] = None):
# default of pytorch behavior if inputs are empty
ctx.is_empty = False
if prod(A.shape) == 0:
......@@ -549,10 +549,10 @@ class MatMul4Bit(torch.autograd.Function):
def matmul(
A: tensor,
B: tensor,
out: tensor = None,
state: MatmulLtState = None,
A: torch.Tensor,
B: torch.Tensor,
out: Optional[torch.Tensor] = None,
state: Optional[MatmulLtState] = None,
threshold=0.0,
bias=None
):
......@@ -562,7 +562,7 @@ def matmul(
return MatMul8bitLt.apply(A, B, out, bias, state)
def matmul_4bit(A: tensor, B: tensor, quant_state: F.QuantState, out: tensor = None, bias=None):
def matmul_4bit(A: torch.Tensor, B: torch.Tensor, quant_state: F.QuantState, out: Optional[torch.Tensor] = None, bias=None):
assert quant_state is not None
if A.numel() == A.shape[-1] and A.requires_grad == False:
if A.shape[-1] % quant_state.blocksize != 0:
......
......@@ -34,9 +34,9 @@ from .env_vars import get_potentially_lib_path_containing_env_vars
# not sure if libcudart.so.12.0 exists in pytorch installs, but it does not hurt
system = platform.system()
if system == 'Windows':
CUDA_RUNTIME_LIBS: list = ["nvcuda.dll"]
CUDA_RUNTIME_LIBS = ["nvcuda.dll"]
else: # Linux or other
CUDA_RUNTIME_LIBS: list = ["libcudart.so", 'libcudart.so.11.0', 'libcudart.so.12.0', 'libcudart.so.12.1', 'libcudart.so.12.2']
CUDA_RUNTIME_LIBS = ["libcudart.so", 'libcudart.so.11.0', 'libcudart.so.12.0', 'libcudart.so.12.1', 'libcudart.so.12.2']
# this is a order list of backup paths to search CUDA in, if it cannot be found in the main environmental paths
backup_paths = []
......
......@@ -12,7 +12,7 @@ import math
import numpy as np
from functools import reduce # Required in Python 3
from typing import Tuple, Any, Dict
from typing import Tuple, Any, Dict, Optional
from torch import Tensor
from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict
......@@ -27,71 +27,83 @@ name2qmap = {}
if COMPILED_WITH_CUDA:
"""C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer32bit = {}
str2optimizer32bit["adam"] = (lib.cadam32bit_grad_fp32, lib.cadam32bit_grad_fp16, lib.cadam32bit_grad_bf16)
str2optimizer32bit["momentum"] = (
str2optimizer32bit = {
"adam": (
lib.cadam32bit_grad_fp32,
lib.cadam32bit_grad_fp16,
lib.cadam32bit_grad_bf16,
),
"momentum": (
lib.cmomentum32bit_grad_32,
lib.cmomentum32bit_grad_16,
)
str2optimizer32bit["rmsprop"] = (
),
"rmsprop": (
lib.crmsprop32bit_grad_32,
lib.crmsprop32bit_grad_16,
)
str2optimizer32bit["lion"] = (lib.clion32bit_grad_fp32, lib.clion32bit_grad_fp16, lib.clion32bit_grad_bf16)
str2optimizer32bit["adagrad"] = (
),
"lion": (
lib.clion32bit_grad_fp32,
lib.clion32bit_grad_fp16,
lib.clion32bit_grad_bf16,
),
"adagrad": (
lib.cadagrad32bit_grad_32,
lib.cadagrad32bit_grad_16,
)
),
}
str2optimizer8bit = {}
str2optimizer8bit["adam"] = (
str2optimizer8bit = {
"adam": (
lib.cadam_static_8bit_grad_32,
lib.cadam_static_8bit_grad_16,
)
str2optimizer8bit["momentum"] = (
),
"momentum": (
lib.cmomentum_static_8bit_grad_32,
lib.cmomentum_static_8bit_grad_16,
)
str2optimizer8bit["rmsprop"] = (
),
"rmsprop": (
lib.crmsprop_static_8bit_grad_32,
lib.crmsprop_static_8bit_grad_16,
)
str2optimizer8bit["lion"] = (
),
"lion": (
lib.clion_static_8bit_grad_32,
lib.clion_static_8bit_grad_16,
)
str2optimizer8bit["lamb"] = (
),
"lamb": (
lib.cadam_static_8bit_grad_32,
lib.cadam_static_8bit_grad_16,
)
str2optimizer8bit["lars"] = (
),
"lars": (
lib.cmomentum_static_8bit_grad_32,
lib.cmomentum_static_8bit_grad_16,
)
),
}
str2optimizer8bit_blockwise = {}
str2optimizer8bit_blockwise["adam"] = (
str2optimizer8bit_blockwise = {
"adam": (
lib.cadam_8bit_blockwise_grad_fp32,
lib.cadam_8bit_blockwise_grad_fp16,
lib.cadam_8bit_blockwise_grad_bf16,
)
str2optimizer8bit_blockwise["momentum"] = (
),
"momentum": (
lib.cmomentum_8bit_blockwise_grad_fp32,
lib.cmomentum_8bit_blockwise_grad_fp16,
)
str2optimizer8bit_blockwise["rmsprop"] = (
),
"rmsprop": (
lib.crmsprop_8bit_blockwise_grad_fp32,
lib.crmsprop_8bit_blockwise_grad_fp16,
)
str2optimizer8bit_blockwise["lion"] = (
),
"lion": (
lib.clion_8bit_blockwise_grad_fp32,
lib.clion_8bit_blockwise_grad_fp16,
lib.clion_8bit_blockwise_grad_bf16,
)
str2optimizer8bit_blockwise["adagrad"] = (
),
"adagrad": (
lib.cadagrad_8bit_blockwise_grad_fp32,
lib.cadagrad_8bit_blockwise_grad_fp16,
)
),
}
class GlobalPageManager:
_instance = None
......@@ -400,7 +412,8 @@ def is_on_gpu(tensors):
raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}')
return on_gpu
def get_ptr(A: Tensor) -> ct.c_void_p:
def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]:
"""
Get the ctypes pointer from a PyTorch Tensor.
......@@ -521,7 +534,7 @@ def nvidia_transform(
return out, new_state
def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor:
def estimate_quantiles(A: Tensor, out: Optional[torch.Tensor] = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor:
'''
Estimates 256 equidistant quantiles on the input tensor eCDF.
......@@ -626,8 +639,8 @@ class QuantState:
# unpacking minor and non-tensor quant state items if necessary
if len(qs_key) == 1:
qs_key = qs_key[0]
qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(qs_key)))
first_qs_key = qs_key[0]
qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(first_qs_key)))
qs_dict = {k.split('.')[-1]: v for k, v in qs_dict.items()} # strip prefixes
assert set(qs_dict.keys()).issubset(cls.valid_qs_keys)
......@@ -694,7 +707,14 @@ class QuantState:
self.state2.code = self.state2.code.to(device)
def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor:
def quantize_blockwise(
A: Tensor,
code: Optional[torch.Tensor] = None,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize=4096,
nested=False,
) -> Tuple[Tensor, QuantState]:
"""
Quantize tensor A in blocks of size 4096 values.
......@@ -769,10 +789,10 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou
def dequantize_blockwise(
A: Tensor,
quant_state: QuantState = None,
absmax: Tensor = None,
code: Tensor = None,
out: Tensor = None,
quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None,
code: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize: int = 4096,
nested=False
) -> Tensor:
......@@ -891,17 +911,17 @@ def get_4bit_type(typename, device=None, blocksize=64):
return data.to(device)
def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8):
def quantize_fp4(A: Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8):
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4', quant_storage)
def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8):
def quantize_nf4(A: Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8):
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4', quant_storage)
def quantize_4bit(
A: Tensor,
absmax: Tensor = None,
out: Tensor = None,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize=64,
compress_statistics=False,
quant_type='fp4',
......@@ -987,13 +1007,13 @@ def quantize_4bit(
return out, state
def dequantize_fp4(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor:
def dequantize_fp4(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64) -> Tensor:
return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'fp4')
def dequantize_nf4(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor:
def dequantize_nf4(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64) -> Tensor:
return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'nf4')
def dequantize_4bit(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4') -> Tensor:
def dequantize_4bit(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64, quant_type='fp4') -> Tensor:
"""
Dequantizes FP4 blockwise quantized values.
......@@ -1070,7 +1090,11 @@ def dequantize_4bit(A: Tensor, quant_state: QuantState = None, absmax: Tensor =
else: return out
def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor:
def quantize(
A: Tensor,
code: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
if code is None:
if "dynamic" not in name2qmap:
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
......@@ -1086,10 +1110,10 @@ def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor:
def dequantize(
A: Tensor,
state: Tuple[Tensor, Tensor] = None,
absmax: Tensor = None,
code: Tensor = None,
out: Tensor = None,
state: Optional[Tuple[Tensor, Tensor]] = None,
absmax: Optional[torch.Tensor] = None,
code: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
) -> Tensor:
assert state is not None or absmax is not None
if code is None and state is None:
......@@ -1104,7 +1128,7 @@ def dequantize(
return out * state[0]
def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor:
'''
Quantizes input tensor to 8-bit.
......@@ -1133,7 +1157,7 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
return out
def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor:
'''
Dequantizes the 8-bit tensor to 32-bit.
......@@ -1171,11 +1195,11 @@ def optimizer_update_32bit(
eps: float,
step: int,
lr: float,
state2: Tensor = None,
state2: Optional[torch.Tensor] = None,
beta2: float = 0.0,
weight_decay: float = 0.0,
gnorm_scale: float = 1.0,
unorm_vec: Tensor = None,
unorm_vec: Optional[torch.Tensor] = None,
max_unorm: float = 0.0,
skip_zeros=False,
) -> None:
......@@ -1274,7 +1298,7 @@ def optimizer_update_8bit(
new_max2: Tensor,
weight_decay: float = 0.0,
gnorm_scale: float = 1.0,
unorm_vec: Tensor = None,
unorm_vec: Optional[torch.Tensor] = None,
max_unorm: float = 0.0,
) -> None:
"""
......@@ -1603,7 +1627,7 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
def gemv_4bit(
A: Tensor,
B: Tensor,
out: Tensor = None,
out: Optional[torch.Tensor] = None,
transposed_A=False,
transposed_B=False,
state=None
......@@ -1663,7 +1687,7 @@ def gemv_4bit(
def igemm(
A: Tensor,
B: Tensor,
out: Tensor = None,
out: Optional[torch.Tensor] = None,
transposed_A=False,
transposed_B=False,
):
......@@ -1752,7 +1776,7 @@ def igemm(
def batched_igemm(
A: Tensor,
B: Tensor,
out: Tensor = None,
out: Optional[torch.Tensor] = None,
transposed_A=False,
transposed_B=False,
):
......
......@@ -145,7 +145,7 @@ class Params4bit(torch.nn.Parameter):
cls,
data: Optional[torch.Tensor] = None,
requires_grad=True,
quant_state: QuantState = None,
quant_state: Optional[QuantState] = None,
blocksize: int = 64,
compress_statistics: bool = True,
quant_type: str = 'fp4',
......
......@@ -162,7 +162,7 @@ class SwitchBackLinear(nn.Linear):
):
super().__init__(in_features, out_features, bias, device, dtype)
if not is_triton_available:
if not is_triton_available():
raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear.
Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''')
......
......@@ -2,6 +2,7 @@ import operator
import warnings
from dataclasses import dataclass
from functools import reduce # Required in Python 3
from typing import Optional
import torch
......@@ -14,7 +15,6 @@ from bitsandbytes.autograd._functions import MatmulLtState, GlobalOutlierPooler
def prod(iterable):
return reduce(operator.mul, iterable, 1)
tensor = torch.Tensor
class MatMulFP8Mixed(torch.autograd.Function):
# forward is the same, but we added the fallback for pre-turing GPUs
......@@ -389,19 +389,38 @@ def get_block_sizes(input_matrix, weight_matrix):
return bsz, bsz2
def matmul_fp8_global(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1):
def matmul_fp8_global(
A: torch.Tensor,
B: torch.Tensor,
fw_code: torch.Tensor,
bw_code: torch.Tensor,
out: Optional[torch.Tensor] = None,
bsz: int = -1,
bsz2: int = -1,
):
if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B)
return MatMulFP8Global.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
def matmul_fp8_mixed(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1):
def matmul_fp8_mixed(
A: torch.Tensor,
B: torch.Tensor,
fw_code: torch.Tensor,
bw_code: torch.Tensor,
out: Optional[torch.Tensor] = None,
bsz: int = -1,
bsz2: int = -1,
):
if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B)
return MatMulFP8Mixed.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
def switchback_bnb(
A: tensor,
B: tensor,
out: tensor = None,
state: MatmulLtState = None,
A: torch.Tensor,
B: torch.Tensor,
out: Optional[torch.Tensor] = None,
state: Optional[MatmulLtState] = None,
threshold=0.0,
bias=None
):
......
......@@ -35,3 +35,11 @@ combine-as-imports = true
detect-same-package = true
force-sort-within-sections = true
known-first-party = ["bitsandbytes"]
[[tool.mypy.overrides]]
module = "triton.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "scipy.stats"
ignore_missing_imports = true
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment