Unverified Commit 677ff400 authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

Drop Python 3.8 support. (#1574)

* Drop Python 3.8 support.

* Formatting
parent 9b339952
...@@ -111,7 +111,7 @@ jobs: ...@@ -111,7 +111,7 @@ jobs:
matrix: matrix:
os: [ubuntu-latest, macos-latest, windows-latest] os: [ubuntu-latest, macos-latest, windows-latest]
# The specific Python version is irrelevant in this context as we are only packaging non-C extension # The specific Python version is irrelevant in this context as we are only packaging non-C extension
# code. This ensures compatibility across Python versions, including Python 3.8, as compatibility is # code. This ensures compatibility across Python versions, including Python 3.9, as compatibility is
# dictated by the packaged code itself, not the Python version used for packaging. # dictated by the packaged code itself, not the Python version used for packaging.
python-version: ["3.10"] python-version: ["3.10"]
arch: [x86_64, aarch64] arch: [x86_64, aarch64]
......
repos: repos:
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.9 rev: v0.11.2
hooks: hooks:
- id: ruff - id: ruff
args: args:
......
...@@ -65,4 +65,4 @@ for i in range(5): ...@@ -65,4 +65,4 @@ for i in range(5):
print("=" * 40) print("=" * 40)
print(f"Example:\n{tokenizer.decode(generated_ids[0])}") print(f"Example:\n{tokenizer.decode(generated_ids[0])}")
print("=" * 40) print("=" * 40)
print(f"Speed: {num/(time.time() - time_1)}token/s") print(f"Speed: {num / (time.time() - time_1)}token/s")
...@@ -66,7 +66,7 @@ def test_bench_matmul(batch, seq, model, hidden): ...@@ -66,7 +66,7 @@ def test_bench_matmul(batch, seq, model, hidden):
torch.matmul(A, B.t()) torch.matmul(A, B.t())
torch.cuda.synchronize() torch.cuda.synchronize()
print( print(
f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s", f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time() - t0:.4f}s",
) )
# torch.cuda.synchronize() # torch.cuda.synchronize()
...@@ -88,14 +88,16 @@ def test_bench_matmul(batch, seq, model, hidden): ...@@ -88,14 +88,16 @@ def test_bench_matmul(batch, seq, model, hidden):
for i in range(iters): for i in range(iters):
bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4) bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
torch.cuda.synchronize() torch.cuda.synchronize()
print(f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") print(f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time() - t0:.4f}s")
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(iters): for i in range(iters):
bnb.matmul_4bit(A, B_nf4_c.t(), quant_state=state_nf4_c) bnb.matmul_4bit(A, B_nf4_c.t(), quant_state=state_nf4_c)
torch.cuda.synchronize() torch.cuda.synchronize()
print(f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") print(
f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time() - t0:.4f}s"
)
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
...@@ -103,7 +105,7 @@ def test_bench_matmul(batch, seq, model, hidden): ...@@ -103,7 +105,7 @@ def test_bench_matmul(batch, seq, model, hidden):
bnb.matmul(A, B) bnb.matmul(A, B)
torch.cuda.synchronize() torch.cuda.synchronize()
print( print(
f"B -> CB (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" f"B -> CB (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time() - t0:.4f}s"
) )
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -112,7 +114,7 @@ def test_bench_matmul(batch, seq, model, hidden): ...@@ -112,7 +114,7 @@ def test_bench_matmul(batch, seq, model, hidden):
bnb.matmul(A, B, threshold=6.0) bnb.matmul(A, B, threshold=6.0)
torch.cuda.synchronize() torch.cuda.synchronize()
print( print(
f"B -> CB + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" f"B -> CB + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time() - t0:.4f}s"
) )
CA, SCA, _ = F.int8_vectorwise_quant(A, threshold=0.0) CA, SCA, _ = F.int8_vectorwise_quant(A, threshold=0.0)
...@@ -124,7 +126,7 @@ def test_bench_matmul(batch, seq, model, hidden): ...@@ -124,7 +126,7 @@ def test_bench_matmul(batch, seq, model, hidden):
out32 = F.int8_linear_matmul(CA, CB) out32 = F.int8_linear_matmul(CA, CB)
torch.cuda.synchronize() torch.cuda.synchronize()
print( print(
f"no overhead int8 [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" f"no overhead int8 [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time() - t0:.4f}s"
) )
# C32A, SA = F.transform(CA, "col32") # C32A, SA = F.transform(CA, "col32")
...@@ -183,7 +185,7 @@ def test_bench_matmul(batch, seq, model, hidden): ...@@ -183,7 +185,7 @@ def test_bench_matmul(batch, seq, model, hidden):
linear8bit(A) linear8bit(A)
torch.cuda.synchronize() torch.cuda.synchronize()
print( print(
f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time() - t0:.4f}s"
) )
linearMixedBit(A) linearMixedBit(A)
...@@ -193,7 +195,7 @@ def test_bench_matmul(batch, seq, model, hidden): ...@@ -193,7 +195,7 @@ def test_bench_matmul(batch, seq, model, hidden):
linearMixedBit(A) linearMixedBit(A)
torch.cuda.synchronize() torch.cuda.synchronize()
print( print(
f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time() - t0:.4f}s"
) )
# linear8bit_train(A) # linear8bit_train(A)
......
from collections.abc import Sequence
from math import prod from math import prod
from typing import Optional, Sequence, Tuple from typing import Optional
import torch import torch
...@@ -131,7 +132,7 @@ torch.library.define( ...@@ -131,7 +132,7 @@ torch.library.define(
def _( def _(
A: torch.Tensor, A: torch.Tensor,
threshold=0.0, threshold=0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
out_row = torch.empty_like(A, dtype=torch.int8) out_row = torch.empty_like(A, dtype=torch.int8)
out_col = torch.empty_like(A, dtype=torch.int8) out_col = torch.empty_like(A, dtype=torch.int8)
row_stats = torch.empty(prod(A.shape[:-1]), device=A.device, dtype=torch.float32) row_stats = torch.empty(prod(A.shape[:-1]), device=A.device, dtype=torch.float32)
...@@ -191,7 +192,7 @@ torch.library.define( ...@@ -191,7 +192,7 @@ torch.library.define(
@register_fake("bitsandbytes::quantize_4bit") @register_fake("bitsandbytes::quantize_4bit")
def _( def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize) torch._check_is_size(blocksize)
n = A.numel() n = A.numel()
...@@ -235,7 +236,7 @@ torch.library.define("bitsandbytes::quantize_blockwise", "(Tensor A, Tensor code ...@@ -235,7 +236,7 @@ torch.library.define("bitsandbytes::quantize_blockwise", "(Tensor A, Tensor code
@register_fake("bitsandbytes::quantize_blockwise") @register_fake("bitsandbytes::quantize_blockwise")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor, torch.Tensor]: def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize) torch._check_is_size(blocksize)
n = A.numel() n = A.numel()
blocks = -(n // -blocksize) blocks = -(n // -blocksize)
......
from dataclasses import dataclass from dataclasses import dataclass
from math import prod from math import prod
from typing import Callable, Optional, Tuple from typing import Callable, Optional
import warnings import warnings
from warnings import warn from warnings import warn
...@@ -55,7 +55,7 @@ class GlobalOutlierPooler: ...@@ -55,7 +55,7 @@ class GlobalOutlierPooler:
) )
def get_inverse_transform_indices( def get_inverse_transform_indices(
transform_tile: Callable[[torch.Tensor], torch.Tensor], transform_tile: Callable[[torch.Tensor], torch.Tensor],
tile_size: Tuple[int, int], tile_size: tuple[int, int],
): ):
""" """
Compute a permutation of indices that invert the specified (tiled) matrix transformation Compute a permutation of indices that invert the specified (tiled) matrix transformation
......
import ctypes as ct import ctypes as ct
from typing import Optional, Tuple from typing import Optional
import torch import torch
...@@ -47,7 +47,7 @@ def _( ...@@ -47,7 +47,7 @@ def _(
@register_kernel("bitsandbytes::quantize_blockwise", "cpu") @register_kernel("bitsandbytes::quantize_blockwise", "cpu")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor, torch.Tensor]: def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize) torch._check_is_size(blocksize)
torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on cpu, got {A.dtype}") torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on cpu, got {A.dtype}")
...@@ -116,7 +116,7 @@ _NF4_QUANT_TABLE = torch.tensor( ...@@ -116,7 +116,7 @@ _NF4_QUANT_TABLE = torch.tensor(
@register_kernel("bitsandbytes::quantize_4bit", "cpu") @register_kernel("bitsandbytes::quantize_4bit", "cpu")
def _( def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize) torch._check_is_size(blocksize)
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}") torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
......
from collections.abc import Sequence
import ctypes as ct import ctypes as ct
from math import prod from math import prod
from typing import Optional, Sequence, Tuple from typing import Optional
import torch import torch
...@@ -78,10 +79,7 @@ def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor ...@@ -78,10 +79,7 @@ def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor
raise NotImplementedError("int8_linear_matmul not implemented!") raise NotImplementedError("int8_linear_matmul not implemented!")
else: else:
raise RuntimeError( raise RuntimeError(
f"cublasLt ran into an error!\n" f"cublasLt ran into an error!\n\t{shapeA=}, {shapeB=}, {shapeC=}\n\t{(lda, ldb, ldc)=}\n\t{(m, n, k)=}"
f"\t{shapeA=}, {shapeB=}, {shapeC=}\n"
f"\t{(lda, ldb, ldc)=}\n"
f"\t{(m, n, k)=}"
) )
return out return out
...@@ -169,7 +167,7 @@ def _(A: torch.Tensor, threshold=0.0): ...@@ -169,7 +167,7 @@ def _(A: torch.Tensor, threshold=0.0):
def _( def _(
A: torch.Tensor, A: torch.Tensor,
threshold=0.0, threshold=0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
# Use CUDA kernel for rowwise and COO tensor # Use CUDA kernel for rowwise and COO tensor
quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default( quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default(
A, A,
...@@ -188,7 +186,7 @@ def _( ...@@ -188,7 +186,7 @@ def _(
def _get_col_absmax( def _get_col_absmax(
A: torch.Tensor, A: torch.Tensor,
threshold=0.0, threshold=0.0,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
torch._check(A.is_floating_point()) torch._check(A.is_floating_point())
outlier_mask = None outlier_mask = None
...@@ -207,7 +205,7 @@ def _get_col_absmax( ...@@ -207,7 +205,7 @@ def _get_col_absmax(
@register_kernel("bitsandbytes::quantize_blockwise", "cuda") @register_kernel("bitsandbytes::quantize_blockwise", "cuda")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor, torch.Tensor]: def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize) torch._check_is_size(blocksize)
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}")
...@@ -292,7 +290,7 @@ def _dequantize_blockwise_impl( ...@@ -292,7 +290,7 @@ def _dequantize_blockwise_impl(
@register_kernel("bitsandbytes::quantize_4bit", "cuda") @register_kernel("bitsandbytes::quantize_4bit", "cuda")
def _( def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
torch._check(quant_type in ["fp4", "nf4"]) torch._check(quant_type in ["fp4", "nf4"])
torch._check( torch._check(
......
import dataclasses import dataclasses
from functools import lru_cache from functools import lru_cache
from typing import List, Optional, Tuple from typing import Optional
import torch import torch
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class CUDASpecs: class CUDASpecs:
highest_compute_capability: Tuple[int, int] highest_compute_capability: tuple[int, int]
cuda_version_string: str cuda_version_string: str
cuda_version_tuple: Tuple[int, int] cuda_version_tuple: tuple[int, int]
@property @property
def has_imma(self) -> bool: def has_imma(self) -> bool:
return torch.version.hip or self.highest_compute_capability >= (7, 5) return torch.version.hip or self.highest_compute_capability >= (7, 5)
def get_compute_capabilities() -> List[Tuple[int, int]]: def get_compute_capabilities() -> list[tuple[int, int]]:
return sorted(torch.cuda.get_device_capability(torch.cuda.device(i)) for i in range(torch.cuda.device_count())) return sorted(torch.cuda.get_device_capability(torch.cuda.device(i)) for i in range(torch.cuda.device_count()))
@lru_cache(None) @lru_cache(None)
def get_cuda_version_tuple() -> Tuple[int, int]: def get_cuda_version_tuple() -> tuple[int, int]:
if torch.version.cuda: if torch.version.cuda:
return map(int, torch.version.cuda.split(".")[0:2]) return map(int, torch.version.cuda.split(".")[0:2])
elif torch.version.hip: elif torch.version.hip:
......
from collections.abc import Iterable, Iterator
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
from typing import Dict, Iterable, Iterator
import torch import torch
...@@ -76,7 +76,7 @@ def is_relevant_candidate_env_var(env_var: str, value: str) -> bool: ...@@ -76,7 +76,7 @@ def is_relevant_candidate_env_var(env_var: str, value: str) -> bool:
) )
def get_potentially_lib_path_containing_env_vars() -> Dict[str, str]: def get_potentially_lib_path_containing_env_vars() -> dict[str, str]:
return {env_var: value for env_var, value in os.environ.items() if is_relevant_candidate_env_var(env_var, value)} return {env_var: value for env_var, value in os.environ.items() if is_relevant_candidate_env_var(env_var, value)}
......
...@@ -2,10 +2,11 @@ ...@@ -2,10 +2,11 @@
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from collections.abc import Iterable
import ctypes as ct import ctypes as ct
import itertools import itertools
from math import prod from math import prod
from typing import Any, Dict, Iterable, Optional, Tuple, Union from typing import Any, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -619,7 +620,7 @@ class QuantState: ...@@ -619,7 +620,7 @@ class QuantState:
return list_repr[idx] return list_repr[idx]
@classmethod @classmethod
def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> "QuantState": def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> "QuantState":
""" """
unpacks components of state_dict into QuantState unpacks components of state_dict into QuantState
where necessary, convert into strings, torch.dtype, ints, etc. where necessary, convert into strings, torch.dtype, ints, etc.
...@@ -741,7 +742,7 @@ def quantize_blockwise( ...@@ -741,7 +742,7 @@ def quantize_blockwise(
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
blocksize=4096, blocksize=4096,
nested=False, nested=False,
) -> Tuple[torch.Tensor, QuantState]: ) -> tuple[torch.Tensor, QuantState]:
"""Quantize a tensor in blocks of values. """Quantize a tensor in blocks of values.
The input tensor is quantized by dividing it into blocks of `blocksize` values. The input tensor is quantized by dividing it into blocks of `blocksize` values.
...@@ -994,7 +995,7 @@ def quantize_4bit( ...@@ -994,7 +995,7 @@ def quantize_4bit(
compress_statistics=False, compress_statistics=False,
quant_type="fp4", quant_type="fp4",
quant_storage=torch.uint8, quant_storage=torch.uint8,
) -> Tuple[torch.Tensor, QuantState]: ) -> tuple[torch.Tensor, QuantState]:
"""Quantize tensor A in blocks of 4-bit values. """Quantize tensor A in blocks of 4-bit values.
Quantizes tensor A by dividing it into blocks which are independently quantized. Quantizes tensor A by dividing it into blocks which are independently quantized.
...@@ -1161,7 +1162,7 @@ def quantize( ...@@ -1161,7 +1162,7 @@ def quantize(
A: Tensor, A: Tensor,
code: Optional[torch.Tensor] = None, code: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: ) -> tuple[Tensor, tuple[Tensor, Tensor]]:
if code is None: if code is None:
if "dynamic" not in name2qmap: if "dynamic" not in name2qmap:
name2qmap["dynamic"] = create_dynamic_map().to(A.device) name2qmap["dynamic"] = create_dynamic_map().to(A.device)
...@@ -1179,7 +1180,7 @@ def quantize( ...@@ -1179,7 +1180,7 @@ def quantize(
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) @deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def dequantize( def dequantize(
A: Tensor, A: Tensor,
state: Optional[Tuple[Tensor, Tensor]] = None, state: Optional[tuple[Tensor, Tensor]] = None,
absmax: Optional[torch.Tensor] = None, absmax: Optional[torch.Tensor] = None,
code: Optional[torch.Tensor] = None, code: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
...@@ -2006,7 +2007,7 @@ def get_colrow_absmax( ...@@ -2006,7 +2007,7 @@ def get_colrow_absmax(
col_stats: Optional[torch.Tensor] = None, col_stats: Optional[torch.Tensor] = None,
nnz_block_ptr: Optional[torch.Tensor] = None, nnz_block_ptr: Optional[torch.Tensor] = None,
threshold=0.0, threshold=0.0,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
""" "Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm. """ "Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.
The row-wise and column-wise absmax values are determined. The row-wise and column-wise absmax values are determined.
...@@ -2268,9 +2269,9 @@ def spmm_coo( ...@@ -2268,9 +2269,9 @@ def spmm_coo(
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
): ):
if not isinstance(cooA, COOSparseTensor): if not isinstance(cooA, COOSparseTensor):
assert ( assert cooA.is_sparse and cooA.layout == torch.sparse_coo, (
cooA.is_sparse and cooA.layout == torch.sparse_coo "Tensor must be `COOSparseTensor or a PyTorch COO tensor."
), "Tensor must be `COOSparseTensor or a PyTorch COO tensor." )
# Convert to custom COOSparseTensor # Convert to custom COOSparseTensor
cooA = COOSparseTensor( cooA = COOSparseTensor(
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import copy import copy
from typing import Any, Dict, Optional, TypeVar, Union, overload from typing import Any, Optional, TypeVar, Union, overload
import warnings import warnings
import torch import torch
...@@ -268,7 +268,7 @@ class Params4bit(torch.nn.Parameter): ...@@ -268,7 +268,7 @@ class Params4bit(torch.nn.Parameter):
def from_prequantized( def from_prequantized(
cls, cls,
data: torch.Tensor, data: torch.Tensor,
quantized_stats: Dict[str, Any], quantized_stats: dict[str, Any],
requires_grad: bool = False, requires_grad: bool = False,
device="cuda", device="cuda",
module: Optional["Linear4bit"] = None, module: Optional["Linear4bit"] = None,
......
from collections.abc import Iterable
import math import math
from typing import Iterable, Literal, Optional, Tuple from typing import Literal, Optional
import torch import torch
...@@ -16,7 +17,7 @@ class _ReferenceAdEMAMix(torch.optim.Optimizer): ...@@ -16,7 +17,7 @@ class _ReferenceAdEMAMix(torch.optim.Optimizer):
self, self,
params: Iterable[torch.nn.Parameter], params: Iterable[torch.nn.Parameter],
lr: float = 1e-3, lr: float = 1e-3,
betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999), betas: tuple[float, float, float] = (0.9, 0.999, 0.9999),
alpha: float = 5.0, alpha: float = 5.0,
eps: float = 1e-8, eps: float = 1e-8,
weight_decay: float = 1e-2, # default 0.0 or 1e-2? weight_decay: float = 1e-2, # default 0.0 or 1e-2?
...@@ -108,7 +109,7 @@ class AdEMAMix(Optimizer2State): ...@@ -108,7 +109,7 @@ class AdEMAMix(Optimizer2State):
self, self,
params: Iterable[torch.nn.Parameter], params: Iterable[torch.nn.Parameter],
lr: float = 1e-3, lr: float = 1e-3,
betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999), betas: tuple[float, float, float] = (0.9, 0.999, 0.9999),
alpha: float = 5.0, alpha: float = 5.0,
t_alpha: Optional[int] = None, t_alpha: Optional[int] = None,
t_beta3: Optional[int] = None, t_beta3: Optional[int] = None,
...@@ -151,7 +152,7 @@ class AdEMAMix(Optimizer2State): ...@@ -151,7 +152,7 @@ class AdEMAMix(Optimizer2State):
elif config["optim_bits"] == 8: elif config["optim_bits"] == 8:
dtype = torch.uint8 dtype = torch.uint8
else: else:
raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}') raise NotImplementedError(f"Amount of optimizer bits not supported: {config['optim_bits']}")
if p.numel() < config["min_8bit_size"]: if p.numel() < config["min_8bit_size"]:
dtype = torch.float32 dtype = torch.float32
...@@ -274,7 +275,7 @@ class AdEMAMix8bit(AdEMAMix): ...@@ -274,7 +275,7 @@ class AdEMAMix8bit(AdEMAMix):
self, self,
params: Iterable[torch.nn.Parameter], params: Iterable[torch.nn.Parameter],
lr: float = 1e-3, lr: float = 1e-3,
betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999), betas: tuple[float, float, float] = (0.9, 0.999, 0.9999),
alpha: float = 5.0, alpha: float = 5.0,
t_alpha: Optional[int] = None, t_alpha: Optional[int] = None,
t_beta3: Optional[int] = None, t_beta3: Optional[int] = None,
...@@ -303,7 +304,7 @@ class PagedAdEMAMix8bit(AdEMAMix8bit): ...@@ -303,7 +304,7 @@ class PagedAdEMAMix8bit(AdEMAMix8bit):
self, self,
params: Iterable[torch.nn.Parameter], params: Iterable[torch.nn.Parameter],
lr: float = 1e-3, lr: float = 1e-3,
betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999), betas: tuple[float, float, float] = (0.9, 0.999, 0.9999),
alpha: float = 5.0, alpha: float = 5.0,
t_alpha: Optional[int] = None, t_alpha: Optional[int] = None,
t_beta3: Optional[int] = None, t_beta3: Optional[int] = None,
...@@ -330,7 +331,7 @@ class PagedAdEMAMix(AdEMAMix): ...@@ -330,7 +331,7 @@ class PagedAdEMAMix(AdEMAMix):
self, self,
params: Iterable[torch.nn.Parameter], params: Iterable[torch.nn.Parameter],
lr: float = 1e-3, lr: float = 1e-3,
betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999), betas: tuple[float, float, float] = (0.9, 0.999, 0.9999),
alpha: float = 5.0, alpha: float = 5.0,
t_alpha: Optional[int] = None, t_alpha: Optional[int] = None,
t_beta3: Optional[int] = None, t_beta3: Optional[int] = None,
...@@ -359,7 +360,7 @@ class AdEMAMix32bit(Optimizer2State): ...@@ -359,7 +360,7 @@ class AdEMAMix32bit(Optimizer2State):
self, self,
params: Iterable[torch.nn.Parameter], params: Iterable[torch.nn.Parameter],
lr: float = 1e-3, lr: float = 1e-3,
betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999), betas: tuple[float, float, float] = (0.9, 0.999, 0.9999),
alpha: float = 5.0, alpha: float = 5.0,
t_alpha: Optional[int] = None, t_alpha: Optional[int] = None,
t_beta3: Optional[int] = None, t_beta3: Optional[int] = None,
...@@ -392,7 +393,7 @@ class PagedAdEMAMix32bit(AdEMAMix32bit): ...@@ -392,7 +393,7 @@ class PagedAdEMAMix32bit(AdEMAMix32bit):
self, self,
params: Iterable[torch.nn.Parameter], params: Iterable[torch.nn.Parameter],
lr: float = 1e-3, lr: float = 1e-3,
betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999), betas: tuple[float, float, float] = (0.9, 0.999, 0.9999),
alpha: float = 5.0, alpha: float = 5.0,
t_alpha: Optional[int] = None, t_alpha: Optional[int] = None,
t_beta3: Optional[int] = None, t_beta3: Optional[int] = None,
......
...@@ -450,7 +450,7 @@ class Optimizer2State(Optimizer8bit): ...@@ -450,7 +450,7 @@ class Optimizer2State(Optimizer8bit):
elif config["optim_bits"] == 8: elif config["optim_bits"] == 8:
dtype = torch.uint8 dtype = torch.uint8
else: else:
raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}') raise NotImplementedError(f"Amount of optimizer bits not supported: {config['optim_bits']}")
if p.numel() < config["min_8bit_size"]: if p.numel() < config["min_8bit_size"]:
dtype = torch.float32 dtype = torch.float32
...@@ -677,7 +677,7 @@ class Optimizer1State(Optimizer8bit): ...@@ -677,7 +677,7 @@ class Optimizer1State(Optimizer8bit):
elif config["optim_bits"] == 8: elif config["optim_bits"] == 8:
dtype = torch.uint8 dtype = torch.uint8
else: else:
raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}') raise NotImplementedError(f"Amount of optimizer bits not supported: {config['optim_bits']}")
if p.numel() < config["min_8bit_size"]: if p.numel() < config["min_8bit_size"]:
dtype = torch.float32 dtype = torch.float32
......
...@@ -128,7 +128,7 @@ def estimate_matmul_time( ...@@ -128,7 +128,7 @@ def estimate_matmul_time(
print( print(
f"Total time: {total_time_ms}ms, compute time: {compute_ms}ms, " f"Total time: {total_time_ms}ms, compute time: {compute_ms}ms, "
f"loading time: {load_ms}ms, store time: {store_ms}ms, " f"loading time: {load_ms}ms, store time: {store_ms}ms, "
f"Activate CTAs: {active_cta_ratio*100}%" f"Activate CTAs: {active_cta_ratio * 100}%"
) )
return total_time_ms return total_time_ms
......
import json import json
import shlex import shlex
import subprocess import subprocess
from typing import Tuple
import torch import torch
...@@ -104,7 +103,7 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False) ...@@ -104,7 +103,7 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False)
return idx return idx
def execute_and_return(command_string: str) -> Tuple[str, str]: def execute_and_return(command_string: str) -> tuple[str, str]:
def _decode(subprocess_err_out_tuple): def _decode(subprocess_err_out_tuple):
return tuple(to_decode.decode("UTF-8").strip() for to_decode in subprocess_err_out_tuple) return tuple(to_decode.decode("UTF-8").strip() for to_decode in subprocess_err_out_tuple)
......
...@@ -8,7 +8,7 @@ text = "Hamburg is in which country?\n" ...@@ -8,7 +8,7 @@ text = "Hamburg is in which country?\n"
tokenizer = LlamaTokenizer.from_pretrained(model_name) tokenizer = LlamaTokenizer.from_pretrained(model_name)
input_ids = tokenizer(text, return_tensors="pt").input_ids input_ids = tokenizer(text, return_tensors="pt").input_ids
max_memory = f"{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB" max_memory = f"{int(torch.cuda.mem_get_info()[0] / 1024**3) - 2}GB"
n_gpus = torch.cuda.device_count() n_gpus = torch.cuda.device_count()
max_memory = {i: max_memory for i in range(n_gpus)} max_memory = {i: max_memory for i in range(n_gpus)}
......
...@@ -11,7 +11,7 @@ maintainers = [ ...@@ -11,7 +11,7 @@ maintainers = [
{name="Titus von Köller", email="titus@huggingface.co"}, {name="Titus von Köller", email="titus@huggingface.co"},
{name="Matthew Douglas", email="matthew.douglas@huggingface.co"} {name="Matthew Douglas", email="matthew.douglas@huggingface.co"}
] ]
requires-python = ">=3.8" requires-python = ">=3.9"
readme = "README.md" readme = "README.md"
license = {file="LICENSE"} license = {file="LICENSE"}
keywords = [ keywords = [
...@@ -34,11 +34,11 @@ classifiers = [ ...@@ -34,11 +34,11 @@ classifiers = [
"Operating System :: Microsoft :: Windows", "Operating System :: Microsoft :: Windows",
"Programming Language :: C++", "Programming Language :: C++",
"Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Topic :: Scientific/Engineering :: Artificial Intelligence" "Topic :: Scientific/Engineering :: Artificial Intelligence"
] ]
dependencies = [ dependencies = [
...@@ -58,7 +58,7 @@ docs = ["hf-doc-builder==0.5.0"] ...@@ -58,7 +58,7 @@ docs = ["hf-doc-builder==0.5.0"]
dev = [ dev = [
"bitsandbytes[test]", "bitsandbytes[test]",
"build>=1.0.0,<2", "build>=1.0.0,<2",
"ruff==0.9.6", "ruff==0.11.2",
"pre-commit>=3.5.0,<4", "pre-commit>=3.5.0,<4",
"wheel>=0.42,<1" "wheel>=0.42,<1"
] ]
...@@ -66,7 +66,6 @@ test = [ ...@@ -66,7 +66,6 @@ test = [
"einops~=0.8.0", "einops~=0.8.0",
"lion-pytorch==0.2.3", "lion-pytorch==0.2.3",
"pytest~=8.3", "pytest~=8.3",
"scipy>=1.10.1,<2; python_version < '3.9'",
"scipy>=1.11.4,<2; python_version >= '3.9'", "scipy>=1.11.4,<2; python_version >= '3.9'",
"transformers>=4.30.1,<5" "transformers>=4.30.1,<5"
] ]
...@@ -101,7 +100,7 @@ src = [ ...@@ -101,7 +100,7 @@ src = [
"tests", "tests",
"benchmarking" "benchmarking"
] ]
target-version = "py38" target-version = "py39"
line-length = 119 line-length = 119
[tool.ruff.lint] [tool.ruff.lint]
...@@ -124,6 +123,7 @@ ignore = [ ...@@ -124,6 +123,7 @@ ignore = [
"E731", # Do not use lambda "E731", # Do not use lambda
"F841", # Local assigned but not used (TODO: enable, these are likely bugs) "F841", # Local assigned but not used (TODO: enable, these are likely bugs)
"RUF012", # Mutable class attribute annotations "RUF012", # Mutable class attribute annotations
"RUF034", # Useless if-else (TODO: enable)
"ISC001", # single-line-implicit-string-concatenation incompatible with formatter "ISC001", # single-line-implicit-string-concatenation incompatible with formatter
] ]
......
from io import BytesIO from io import BytesIO
from itertools import product from itertools import product
import random import random
from typing import Any, List from typing import Any
import torch import torch
...@@ -27,7 +27,7 @@ def torch_load_from_buffer(buffer): ...@@ -27,7 +27,7 @@ def torch_load_from_buffer(buffer):
return obj return obj
def get_test_dims(min: int, max: int, *, n: int) -> List[int]: def get_test_dims(min: int, max: int, *, n: int) -> list[int]:
return [test_dims_rng.randint(min, max) for _ in range(n)] return [test_dims_rng.randint(min, max) for _ in range(n)]
......
...@@ -674,12 +674,12 @@ class TestLLMInt8Functional: ...@@ -674,12 +674,12 @@ class TestLLMInt8Functional:
min_error = 1 / 500 min_error = 1 / 500
if num_not_close_cols > (min_error * n): if num_not_close_cols > (min_error * n):
print( print(
f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}" f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols / n:.4f}"
) )
assert False assert False
if num_not_close_rows > (min_error * n): if num_not_close_rows > (min_error * n):
print( print(
f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}" f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows / n:.4f}"
) )
assert False assert False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment