You need to sign in or sign up before continuing.
Unverified Commit a8c9dfa6 authored by Aarni Koskela's avatar Aarni Koskela Committed by GitHub
Browse files

Fix some issues found by Mypy (#995)

* Fix erroneous type aliasing

* Fix `Optional` typings (see PEP 484)

* Add Mypy ignores

* Fix Mypy complaints for method tables

* Fix type for get_ptr

* Fix various Mypy errors

* Fix missed call to is_triton_available
parent 32be2897
...@@ -2,7 +2,7 @@ import operator ...@@ -2,7 +2,7 @@ import operator
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from functools import reduce # Required in Python 3 from functools import reduce # Required in Python 3
from typing import Tuple, Optional, List from typing import Tuple, Optional, Callable
from warnings import warn from warnings import warn
import torch import torch
...@@ -14,9 +14,6 @@ import bitsandbytes.functional as F ...@@ -14,9 +14,6 @@ import bitsandbytes.functional as F
def prod(iterable): def prod(iterable):
return reduce(operator.mul, iterable, 1) return reduce(operator.mul, iterable, 1)
tensor = torch.Tensor
# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov: # The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py # https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
...@@ -56,7 +53,10 @@ class GlobalOutlierPooler: ...@@ -56,7 +53,10 @@ class GlobalOutlierPooler:
return torch.Tensor(list(self.outliers)).to(torch.int64) return torch.Tensor(list(self.outliers)).to(torch.int64)
def get_inverse_transform_indices(transform_tile: callable, tile_size: Tuple[int, int]): def get_inverse_transform_indices(
transform_tile: Callable[[torch.Tensor], torch.Tensor],
tile_size: Tuple[int, int],
):
""" """
Compute a permutation of indices that invert the specified (tiled) matrix transformation Compute a permutation of indices that invert the specified (tiled) matrix transformation
...@@ -496,7 +496,7 @@ class MatMul4Bit(torch.autograd.Function): ...@@ -496,7 +496,7 @@ class MatMul4Bit(torch.autograd.Function):
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@staticmethod @staticmethod
def forward(ctx, A, B, out=None, bias=None, quant_state: F.QuantState = None): def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] = None):
# default of pytorch behavior if inputs are empty # default of pytorch behavior if inputs are empty
ctx.is_empty = False ctx.is_empty = False
if prod(A.shape) == 0: if prod(A.shape) == 0:
...@@ -549,10 +549,10 @@ class MatMul4Bit(torch.autograd.Function): ...@@ -549,10 +549,10 @@ class MatMul4Bit(torch.autograd.Function):
def matmul( def matmul(
A: tensor, A: torch.Tensor,
B: tensor, B: torch.Tensor,
out: tensor = None, out: Optional[torch.Tensor] = None,
state: MatmulLtState = None, state: Optional[MatmulLtState] = None,
threshold=0.0, threshold=0.0,
bias=None bias=None
): ):
...@@ -562,7 +562,7 @@ def matmul( ...@@ -562,7 +562,7 @@ def matmul(
return MatMul8bitLt.apply(A, B, out, bias, state) return MatMul8bitLt.apply(A, B, out, bias, state)
def matmul_4bit(A: tensor, B: tensor, quant_state: F.QuantState, out: tensor = None, bias=None): def matmul_4bit(A: torch.Tensor, B: torch.Tensor, quant_state: F.QuantState, out: Optional[torch.Tensor] = None, bias=None):
assert quant_state is not None assert quant_state is not None
if A.numel() == A.shape[-1] and A.requires_grad == False: if A.numel() == A.shape[-1] and A.requires_grad == False:
if A.shape[-1] % quant_state.blocksize != 0: if A.shape[-1] % quant_state.blocksize != 0:
......
...@@ -34,9 +34,9 @@ from .env_vars import get_potentially_lib_path_containing_env_vars ...@@ -34,9 +34,9 @@ from .env_vars import get_potentially_lib_path_containing_env_vars
# not sure if libcudart.so.12.0 exists in pytorch installs, but it does not hurt # not sure if libcudart.so.12.0 exists in pytorch installs, but it does not hurt
system = platform.system() system = platform.system()
if system == 'Windows': if system == 'Windows':
CUDA_RUNTIME_LIBS: list = ["nvcuda.dll"] CUDA_RUNTIME_LIBS = ["nvcuda.dll"]
else: # Linux or other else: # Linux or other
CUDA_RUNTIME_LIBS: list = ["libcudart.so", 'libcudart.so.11.0', 'libcudart.so.12.0', 'libcudart.so.12.1', 'libcudart.so.12.2'] CUDA_RUNTIME_LIBS = ["libcudart.so", 'libcudart.so.11.0', 'libcudart.so.12.0', 'libcudart.so.12.1', 'libcudart.so.12.2']
# this is a order list of backup paths to search CUDA in, if it cannot be found in the main environmental paths # this is a order list of backup paths to search CUDA in, if it cannot be found in the main environmental paths
backup_paths = [] backup_paths = []
......
...@@ -12,7 +12,7 @@ import math ...@@ -12,7 +12,7 @@ import math
import numpy as np import numpy as np
from functools import reduce # Required in Python 3 from functools import reduce # Required in Python 3
from typing import Tuple, Any, Dict from typing import Tuple, Any, Dict, Optional
from torch import Tensor from torch import Tensor
from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict
...@@ -27,71 +27,83 @@ name2qmap = {} ...@@ -27,71 +27,83 @@ name2qmap = {}
if COMPILED_WITH_CUDA: if COMPILED_WITH_CUDA:
"""C FUNCTIONS FOR OPTIMIZERS""" """C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer32bit = {} str2optimizer32bit = {
str2optimizer32bit["adam"] = (lib.cadam32bit_grad_fp32, lib.cadam32bit_grad_fp16, lib.cadam32bit_grad_bf16) "adam": (
str2optimizer32bit["momentum"] = ( lib.cadam32bit_grad_fp32,
lib.cmomentum32bit_grad_32, lib.cadam32bit_grad_fp16,
lib.cmomentum32bit_grad_16, lib.cadam32bit_grad_bf16,
) ),
str2optimizer32bit["rmsprop"] = ( "momentum": (
lib.crmsprop32bit_grad_32, lib.cmomentum32bit_grad_32,
lib.crmsprop32bit_grad_16, lib.cmomentum32bit_grad_16,
) ),
str2optimizer32bit["lion"] = (lib.clion32bit_grad_fp32, lib.clion32bit_grad_fp16, lib.clion32bit_grad_bf16) "rmsprop": (
str2optimizer32bit["adagrad"] = ( lib.crmsprop32bit_grad_32,
lib.cadagrad32bit_grad_32, lib.crmsprop32bit_grad_16,
lib.cadagrad32bit_grad_16, ),
) "lion": (
lib.clion32bit_grad_fp32,
lib.clion32bit_grad_fp16,
lib.clion32bit_grad_bf16,
),
"adagrad": (
lib.cadagrad32bit_grad_32,
lib.cadagrad32bit_grad_16,
),
}
str2optimizer8bit = {
"adam": (
lib.cadam_static_8bit_grad_32,
lib.cadam_static_8bit_grad_16,
),
"momentum": (
lib.cmomentum_static_8bit_grad_32,
lib.cmomentum_static_8bit_grad_16,
),
"rmsprop": (
lib.crmsprop_static_8bit_grad_32,
lib.crmsprop_static_8bit_grad_16,
),
"lion": (
lib.clion_static_8bit_grad_32,
lib.clion_static_8bit_grad_16,
),
"lamb": (
lib.cadam_static_8bit_grad_32,
lib.cadam_static_8bit_grad_16,
),
"lars": (
lib.cmomentum_static_8bit_grad_32,
lib.cmomentum_static_8bit_grad_16,
),
}
str2optimizer8bit_blockwise = {
"adam": (
lib.cadam_8bit_blockwise_grad_fp32,
lib.cadam_8bit_blockwise_grad_fp16,
lib.cadam_8bit_blockwise_grad_bf16,
),
"momentum": (
lib.cmomentum_8bit_blockwise_grad_fp32,
lib.cmomentum_8bit_blockwise_grad_fp16,
),
"rmsprop": (
lib.crmsprop_8bit_blockwise_grad_fp32,
lib.crmsprop_8bit_blockwise_grad_fp16,
),
"lion": (
lib.clion_8bit_blockwise_grad_fp32,
lib.clion_8bit_blockwise_grad_fp16,
lib.clion_8bit_blockwise_grad_bf16,
),
"adagrad": (
lib.cadagrad_8bit_blockwise_grad_fp32,
lib.cadagrad_8bit_blockwise_grad_fp16,
),
}
str2optimizer8bit = {}
str2optimizer8bit["adam"] = (
lib.cadam_static_8bit_grad_32,
lib.cadam_static_8bit_grad_16,
)
str2optimizer8bit["momentum"] = (
lib.cmomentum_static_8bit_grad_32,
lib.cmomentum_static_8bit_grad_16,
)
str2optimizer8bit["rmsprop"] = (
lib.crmsprop_static_8bit_grad_32,
lib.crmsprop_static_8bit_grad_16,
)
str2optimizer8bit["lion"] = (
lib.clion_static_8bit_grad_32,
lib.clion_static_8bit_grad_16,
)
str2optimizer8bit["lamb"] = (
lib.cadam_static_8bit_grad_32,
lib.cadam_static_8bit_grad_16,
)
str2optimizer8bit["lars"] = (
lib.cmomentum_static_8bit_grad_32,
lib.cmomentum_static_8bit_grad_16,
)
str2optimizer8bit_blockwise = {}
str2optimizer8bit_blockwise["adam"] = (
lib.cadam_8bit_blockwise_grad_fp32,
lib.cadam_8bit_blockwise_grad_fp16,
lib.cadam_8bit_blockwise_grad_bf16,
)
str2optimizer8bit_blockwise["momentum"] = (
lib.cmomentum_8bit_blockwise_grad_fp32,
lib.cmomentum_8bit_blockwise_grad_fp16,
)
str2optimizer8bit_blockwise["rmsprop"] = (
lib.crmsprop_8bit_blockwise_grad_fp32,
lib.crmsprop_8bit_blockwise_grad_fp16,
)
str2optimizer8bit_blockwise["lion"] = (
lib.clion_8bit_blockwise_grad_fp32,
lib.clion_8bit_blockwise_grad_fp16,
lib.clion_8bit_blockwise_grad_bf16,
)
str2optimizer8bit_blockwise["adagrad"] = (
lib.cadagrad_8bit_blockwise_grad_fp32,
lib.cadagrad_8bit_blockwise_grad_fp16,
)
class GlobalPageManager: class GlobalPageManager:
_instance = None _instance = None
...@@ -400,7 +412,8 @@ def is_on_gpu(tensors): ...@@ -400,7 +412,8 @@ def is_on_gpu(tensors):
raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}') raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}')
return on_gpu return on_gpu
def get_ptr(A: Tensor) -> ct.c_void_p:
def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]:
""" """
Get the ctypes pointer from a PyTorch Tensor. Get the ctypes pointer from a PyTorch Tensor.
...@@ -521,7 +534,7 @@ def nvidia_transform( ...@@ -521,7 +534,7 @@ def nvidia_transform(
return out, new_state return out, new_state
def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor: def estimate_quantiles(A: Tensor, out: Optional[torch.Tensor] = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor:
''' '''
Estimates 256 equidistant quantiles on the input tensor eCDF. Estimates 256 equidistant quantiles on the input tensor eCDF.
...@@ -626,8 +639,8 @@ class QuantState: ...@@ -626,8 +639,8 @@ class QuantState:
# unpacking minor and non-tensor quant state items if necessary # unpacking minor and non-tensor quant state items if necessary
if len(qs_key) == 1: if len(qs_key) == 1:
qs_key = qs_key[0] first_qs_key = qs_key[0]
qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(qs_key))) qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(first_qs_key)))
qs_dict = {k.split('.')[-1]: v for k, v in qs_dict.items()} # strip prefixes qs_dict = {k.split('.')[-1]: v for k, v in qs_dict.items()} # strip prefixes
assert set(qs_dict.keys()).issubset(cls.valid_qs_keys) assert set(qs_dict.keys()).issubset(cls.valid_qs_keys)
...@@ -694,7 +707,14 @@ class QuantState: ...@@ -694,7 +707,14 @@ class QuantState:
self.state2.code = self.state2.code.to(device) self.state2.code = self.state2.code.to(device)
def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor: def quantize_blockwise(
A: Tensor,
code: Optional[torch.Tensor] = None,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize=4096,
nested=False,
) -> Tuple[Tensor, QuantState]:
""" """
Quantize tensor A in blocks of size 4096 values. Quantize tensor A in blocks of size 4096 values.
...@@ -769,10 +789,10 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou ...@@ -769,10 +789,10 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou
def dequantize_blockwise( def dequantize_blockwise(
A: Tensor, A: Tensor,
quant_state: QuantState = None, quant_state: Optional[QuantState] = None,
absmax: Tensor = None, absmax: Optional[torch.Tensor] = None,
code: Tensor = None, code: Optional[torch.Tensor] = None,
out: Tensor = None, out: Optional[torch.Tensor] = None,
blocksize: int = 4096, blocksize: int = 4096,
nested=False nested=False
) -> Tensor: ) -> Tensor:
...@@ -891,17 +911,17 @@ def get_4bit_type(typename, device=None, blocksize=64): ...@@ -891,17 +911,17 @@ def get_4bit_type(typename, device=None, blocksize=64):
return data.to(device) return data.to(device)
def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8): def quantize_fp4(A: Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8):
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4', quant_storage) return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4', quant_storage)
def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8): def quantize_nf4(A: Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8):
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4', quant_storage) return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4', quant_storage)
def quantize_4bit( def quantize_4bit(
A: Tensor, A: Tensor,
absmax: Tensor = None, absmax: Optional[torch.Tensor] = None,
out: Tensor = None, out: Optional[torch.Tensor] = None,
blocksize=64, blocksize=64,
compress_statistics=False, compress_statistics=False,
quant_type='fp4', quant_type='fp4',
...@@ -987,13 +1007,13 @@ def quantize_4bit( ...@@ -987,13 +1007,13 @@ def quantize_4bit(
return out, state return out, state
def dequantize_fp4(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor: def dequantize_fp4(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64) -> Tensor:
return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'fp4') return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'fp4')
def dequantize_nf4(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor: def dequantize_nf4(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64) -> Tensor:
return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'nf4') return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'nf4')
def dequantize_4bit(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4') -> Tensor: def dequantize_4bit(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64, quant_type='fp4') -> Tensor:
""" """
Dequantizes FP4 blockwise quantized values. Dequantizes FP4 blockwise quantized values.
...@@ -1070,7 +1090,11 @@ def dequantize_4bit(A: Tensor, quant_state: QuantState = None, absmax: Tensor = ...@@ -1070,7 +1090,11 @@ def dequantize_4bit(A: Tensor, quant_state: QuantState = None, absmax: Tensor =
else: return out else: return out
def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor: def quantize(
A: Tensor,
code: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
if code is None: if code is None:
if "dynamic" not in name2qmap: if "dynamic" not in name2qmap:
name2qmap["dynamic"] = create_dynamic_map().to(A.device) name2qmap["dynamic"] = create_dynamic_map().to(A.device)
...@@ -1086,10 +1110,10 @@ def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor: ...@@ -1086,10 +1110,10 @@ def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor:
def dequantize( def dequantize(
A: Tensor, A: Tensor,
state: Tuple[Tensor, Tensor] = None, state: Optional[Tuple[Tensor, Tensor]] = None,
absmax: Tensor = None, absmax: Optional[torch.Tensor] = None,
code: Tensor = None, code: Optional[torch.Tensor] = None,
out: Tensor = None, out: Optional[torch.Tensor] = None,
) -> Tensor: ) -> Tensor:
assert state is not None or absmax is not None assert state is not None or absmax is not None
if code is None and state is None: if code is None and state is None:
...@@ -1104,7 +1128,7 @@ def dequantize( ...@@ -1104,7 +1128,7 @@ def dequantize(
return out * state[0] return out * state[0]
def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor:
''' '''
Quantizes input tensor to 8-bit. Quantizes input tensor to 8-bit.
...@@ -1133,7 +1157,7 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: ...@@ -1133,7 +1157,7 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
return out return out
def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor:
''' '''
Dequantizes the 8-bit tensor to 32-bit. Dequantizes the 8-bit tensor to 32-bit.
...@@ -1171,11 +1195,11 @@ def optimizer_update_32bit( ...@@ -1171,11 +1195,11 @@ def optimizer_update_32bit(
eps: float, eps: float,
step: int, step: int,
lr: float, lr: float,
state2: Tensor = None, state2: Optional[torch.Tensor] = None,
beta2: float = 0.0, beta2: float = 0.0,
weight_decay: float = 0.0, weight_decay: float = 0.0,
gnorm_scale: float = 1.0, gnorm_scale: float = 1.0,
unorm_vec: Tensor = None, unorm_vec: Optional[torch.Tensor] = None,
max_unorm: float = 0.0, max_unorm: float = 0.0,
skip_zeros=False, skip_zeros=False,
) -> None: ) -> None:
...@@ -1274,7 +1298,7 @@ def optimizer_update_8bit( ...@@ -1274,7 +1298,7 @@ def optimizer_update_8bit(
new_max2: Tensor, new_max2: Tensor,
weight_decay: float = 0.0, weight_decay: float = 0.0,
gnorm_scale: float = 1.0, gnorm_scale: float = 1.0,
unorm_vec: Tensor = None, unorm_vec: Optional[torch.Tensor] = None,
max_unorm: float = 0.0, max_unorm: float = 0.0,
) -> None: ) -> None:
""" """
...@@ -1603,7 +1627,7 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8 ...@@ -1603,7 +1627,7 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
def gemv_4bit( def gemv_4bit(
A: Tensor, A: Tensor,
B: Tensor, B: Tensor,
out: Tensor = None, out: Optional[torch.Tensor] = None,
transposed_A=False, transposed_A=False,
transposed_B=False, transposed_B=False,
state=None state=None
...@@ -1663,7 +1687,7 @@ def gemv_4bit( ...@@ -1663,7 +1687,7 @@ def gemv_4bit(
def igemm( def igemm(
A: Tensor, A: Tensor,
B: Tensor, B: Tensor,
out: Tensor = None, out: Optional[torch.Tensor] = None,
transposed_A=False, transposed_A=False,
transposed_B=False, transposed_B=False,
): ):
...@@ -1752,7 +1776,7 @@ def igemm( ...@@ -1752,7 +1776,7 @@ def igemm(
def batched_igemm( def batched_igemm(
A: Tensor, A: Tensor,
B: Tensor, B: Tensor,
out: Tensor = None, out: Optional[torch.Tensor] = None,
transposed_A=False, transposed_A=False,
transposed_B=False, transposed_B=False,
): ):
......
...@@ -145,7 +145,7 @@ class Params4bit(torch.nn.Parameter): ...@@ -145,7 +145,7 @@ class Params4bit(torch.nn.Parameter):
cls, cls,
data: Optional[torch.Tensor] = None, data: Optional[torch.Tensor] = None,
requires_grad=True, requires_grad=True,
quant_state: QuantState = None, quant_state: Optional[QuantState] = None,
blocksize: int = 64, blocksize: int = 64,
compress_statistics: bool = True, compress_statistics: bool = True,
quant_type: str = 'fp4', quant_type: str = 'fp4',
......
...@@ -162,7 +162,7 @@ class SwitchBackLinear(nn.Linear): ...@@ -162,7 +162,7 @@ class SwitchBackLinear(nn.Linear):
): ):
super().__init__(in_features, out_features, bias, device, dtype) super().__init__(in_features, out_features, bias, device, dtype)
if not is_triton_available: if not is_triton_available():
raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear. raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear.
Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''') Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''')
......
...@@ -2,6 +2,7 @@ import operator ...@@ -2,6 +2,7 @@ import operator
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from functools import reduce # Required in Python 3 from functools import reduce # Required in Python 3
from typing import Optional
import torch import torch
...@@ -14,7 +15,6 @@ from bitsandbytes.autograd._functions import MatmulLtState, GlobalOutlierPooler ...@@ -14,7 +15,6 @@ from bitsandbytes.autograd._functions import MatmulLtState, GlobalOutlierPooler
def prod(iterable): def prod(iterable):
return reduce(operator.mul, iterable, 1) return reduce(operator.mul, iterable, 1)
tensor = torch.Tensor
class MatMulFP8Mixed(torch.autograd.Function): class MatMulFP8Mixed(torch.autograd.Function):
# forward is the same, but we added the fallback for pre-turing GPUs # forward is the same, but we added the fallback for pre-turing GPUs
...@@ -389,19 +389,38 @@ def get_block_sizes(input_matrix, weight_matrix): ...@@ -389,19 +389,38 @@ def get_block_sizes(input_matrix, weight_matrix):
return bsz, bsz2 return bsz, bsz2
def matmul_fp8_global(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1):
def matmul_fp8_global(
A: torch.Tensor,
B: torch.Tensor,
fw_code: torch.Tensor,
bw_code: torch.Tensor,
out: Optional[torch.Tensor] = None,
bsz: int = -1,
bsz2: int = -1,
):
if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B) if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B)
return MatMulFP8Global.apply(A, B, out, fw_code, bw_code, bsz, bsz2) return MatMulFP8Global.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
def matmul_fp8_mixed(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1):
def matmul_fp8_mixed(
A: torch.Tensor,
B: torch.Tensor,
fw_code: torch.Tensor,
bw_code: torch.Tensor,
out: Optional[torch.Tensor] = None,
bsz: int = -1,
bsz2: int = -1,
):
if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B) if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B)
return MatMulFP8Mixed.apply(A, B, out, fw_code, bw_code, bsz, bsz2) return MatMulFP8Mixed.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
def switchback_bnb( def switchback_bnb(
A: tensor, A: torch.Tensor,
B: tensor, B: torch.Tensor,
out: tensor = None, out: Optional[torch.Tensor] = None,
state: MatmulLtState = None, state: Optional[MatmulLtState] = None,
threshold=0.0, threshold=0.0,
bias=None bias=None
): ):
......
...@@ -34,4 +34,12 @@ ignore-init-module-imports = true # allow to expose in __init__.py via imports ...@@ -34,4 +34,12 @@ ignore-init-module-imports = true # allow to expose in __init__.py via imports
combine-as-imports = true combine-as-imports = true
detect-same-package = true detect-same-package = true
force-sort-within-sections = true force-sort-within-sections = true
known-first-party = ["bitsandbytes"] known-first-party = ["bitsandbytes"]
\ No newline at end of file
[[tool.mypy.overrides]]
module = "triton.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "scipy.stats"
ignore_missing_imports = true
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment