Unverified Commit 706ec24d authored by Aarni Koskela's avatar Aarni Koskela Committed by GitHub
Browse files

Ruff fixes (#984)



* Adjust Ruff configuration

* do not autofix always
* be less strict around tests and benchmarks
* adjust ignores for now

* Ruff: autofix I and F401

* Apply ruff autofixes

* Fix RUF013 complaint

* Fix mutable default in replace_linear

* Don't use bare except

* Wrap bitsandbytes.__main__ entrypoint in function; fix "sensible" typo

* Fix ruff B008 (function call in arguments)

* Add ruff noqas as suitable

* Fix RUF005 (splat instead of concatenating)

* Fix B018 (useless expression)

* Add pre-commit configuration + GitHub Actions lint workflow

* Fix unused `e` in bitsandbytes/__main__.py

* fix merge conflict resolution error

* run pre-commit hook

---------
Co-authored-by: default avatarTitus <9048635+Titus-von-Koeller@users.noreply.github.com>
parent a8c9dfa6
name: Lint
on:
push:
branches:
- main
pull_request:
jobs:
Lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.12"
- uses: pre-commit/action@v3.0.0
env:
RUFF_OUTPUT_FORMAT: github
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.15
hooks:
- id: ruff
args:
- --fix
# - id: ruff-format # TODO: enable when the time is right
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
import matplotlib.gridspec as gridspec import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import pandas as pd
cmap=plt.get_cmap('cool') cmap=plt.get_cmap('cool')
......
import json import json
import time import time
import torch import torch
import torch.nn as nn
from bitsandbytes.triton.int8_matmul_mixed_dequantize import (
int8_matmul_mixed_dequantize,
)
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import (
int8_matmul_rowwise_dequantize,
)
from bitsandbytes.triton.quantize_columnwise_and_transpose import (
quantize_columnwise_and_transpose,
)
from bitsandbytes.triton.quantize_global import (
quantize_global,
quantize_global_transpose,
)
from bitsandbytes.triton.quantize_rowwise import quantize_rowwise from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose
from bitsandbytes.triton.int8_matmul_mixed_dequantize import int8_matmul_mixed_dequantize
# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large. # KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large.
......
...@@ -3,14 +3,14 @@ ...@@ -3,14 +3,14 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from . import cuda_setup, utils, research from . import cuda_setup, research, utils
from .autograd._functions import ( from .autograd._functions import (
MatmulLtState, MatmulLtState,
bmm_cublas, bmm_cublas,
matmul, matmul,
matmul_4bit,
matmul_cublas, matmul_cublas,
mm_cublas, mm_cublas,
matmul_4bit
) )
from .cextension import COMPILED_WITH_CUDA from .cextension import COMPILED_WITH_CUDA
from .nn import modules from .nn import modules
......
import os import os
from os.path import isdir
import sys import sys
import shlex
import subprocess
from warnings import warn from warnings import warn
from typing import Tuple
from os.path import isdir
import torch import torch
...@@ -20,7 +16,7 @@ def find_file_recursive(folder, filename): ...@@ -20,7 +16,7 @@ def find_file_recursive(folder, filename):
out = glob.glob(os.path.join(folder, "**", filename + ext)) out = glob.glob(os.path.join(folder, "**", filename + ext))
outs.extend(out) outs.extend(out)
except Exception as e: except Exception as e:
raise RuntimeError('Error: Something when wrong when trying to find file. {e}') raise RuntimeError('Error: Something when wrong when trying to find file.') from e
return outs return outs
...@@ -62,14 +58,11 @@ def generate_bug_report_information(): ...@@ -62,14 +58,11 @@ def generate_bug_report_information():
print_header(f"{path} CUDA PATHS") print_header(f"{path} CUDA PATHS")
paths = find_file_recursive(path, '*cuda*') paths = find_file_recursive(path, '*cuda*')
print(paths) print(paths)
except: except Exception as e:
print(f'Could not read LD_LIBRARY_PATH: {path}') print(f'Could not read LD_LIBRARY_PATH: {path} ({e})')
print('') print('')
def print_header( def print_header(
txt: str, width: int = HEADER_WIDTH, filler: str = "+" txt: str, width: int = HEADER_WIDTH, filler: str = "+"
) -> None: ) -> None:
...@@ -78,67 +71,61 @@ def print_header( ...@@ -78,67 +71,61 @@ def print_header(
def print_debug_info() -> None: def print_debug_info() -> None:
from . import PACKAGE_GITHUB_URL
print( print(
"\nAbove we output some debug information. Please provide this info when " "\nAbove we output some debug information. Please provide this info when "
f"creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose ...\n" f"creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose ...\n"
) )
generate_bug_report_information() def main():
generate_bug_report_information()
from . import COMPILED_WITH_CUDA
from .cuda_setup.main import get_compute_capabilities
from . import COMPILED_WITH_CUDA, PACKAGE_GITHUB_URL print_header("OTHER")
from .cuda_setup.env_vars import to_be_ignored print(f"COMPILED_WITH_CUDA = {COMPILED_WITH_CUDA}")
from .cuda_setup.main import get_compute_capabilities print(f"COMPUTE_CAPABILITIES_PER_GPU = {get_compute_capabilities()}")
print_header("")
print_header("DEBUG INFO END")
print_header("")
print("Checking that the library is importable and CUDA is callable...")
print("\nWARNING: Please be sure to sanitize sensitive info from any such env vars!\n")
print_header("OTHER") try:
print(f"COMPILED_WITH_CUDA = {COMPILED_WITH_CUDA}") from bitsandbytes.optim import Adam
print(f"COMPUTE_CAPABILITIES_PER_GPU = {get_compute_capabilities()}")
print_header("")
print_header("DEBUG INFO END")
print_header("")
print(
"""
Running a quick check that:
+ library is importable
+ CUDA function is callable
"""
)
print("\nWARNING: Please be sure to sanitize sensible info from any such env vars!\n")
try: p = torch.nn.Parameter(torch.rand(10, 10).cuda())
from bitsandbytes.optim import Adam a = torch.rand(10, 10).cuda()
p = torch.nn.Parameter(torch.rand(10, 10).cuda()) p1 = p.data.sum().item()
a = torch.rand(10, 10).cuda()
p1 = p.data.sum().item() adam = Adam([p])
adam = Adam([p]) out = a * p
loss = out.sum()
loss.backward()
adam.step()
out = a * p p2 = p.data.sum().item()
loss = out.sum()
loss.backward()
adam.step()
p2 = p.data.sum().item() assert p1 != p2
print("SUCCESS!")
print("Installation was successful!")
except ImportError:
print()
warn(
f"WARNING: {__package__} is currently running as CPU-only!\n"
"Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n"
f"If you think that this is so erroneously,\nplease report an issue!"
)
print_debug_info()
except Exception as e:
print(e)
print_debug_info()
sys.exit(1)
assert p1 != p2
print("SUCCESS!")
print("Installation was successful!")
sys.exit(0)
except ImportError: if __name__ == "__main__":
print() main()
warn(
f"WARNING: {__package__} is currently running as CPU-only!\n"
"Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n"
f"If you think that this is so erroneously,\nplease report an issue!"
)
print_debug_info()
sys.exit(0)
except Exception as e:
print(e)
print_debug_info()
sys.exit(1)
from ._functions import undo_layout, get_inverse_transform_indices from ._functions import get_inverse_transform_indices, undo_layout
import operator
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from functools import reduce # Required in Python 3 from functools import reduce # Required in Python 3
from typing import Tuple, Optional, Callable import operator
from typing import Callable, Optional, Tuple
import warnings
from warnings import warn from warnings import warn
import torch import torch
......
import ctypes as ct import ctypes as ct
import os
import torch
from pathlib import Path
from warnings import warn from warnings import warn
from bitsandbytes.cuda_setup.main import CUDASetup import torch
from bitsandbytes.cuda_setup.main import CUDASetup
setup = CUDASetup.get_instance() setup = CUDASetup.get_instance()
if setup.initialized != True: if setup.initialized != True:
...@@ -25,7 +22,7 @@ try: ...@@ -25,7 +22,7 @@ try:
Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them
to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes
and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues''') and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues''')
lib.cadam32bit_grad_fp32 # runs on an error if the library could not be found -> COMPILED_WITH_CUDA=False _ = lib.cadam32bit_grad_fp32 # runs on an error if the library could not be found -> COMPILED_WITH_CUDA=False
lib.get_context.restype = ct.c_void_p lib.get_context.restype = ct.c_void_p
lib.get_cusparse.restype = ct.c_void_p lib.get_cusparse.restype = ct.c_void_p
lib.cget_managed_ptr.restype = ct.c_void_p lib.cget_managed_ptr.restype = ct.c_void_p
......
...@@ -17,15 +17,15 @@ evaluation: ...@@ -17,15 +17,15 @@ evaluation:
""" """
import ctypes as ct import ctypes as ct
import os
import errno import errno
import os
from pathlib import Path
import platform import platform
import torch from typing import Set, Union
from warnings import warn from warnings import warn
from itertools import product
from pathlib import Path import torch
from typing import Set, Union
from .env_vars import get_potentially_lib_path_containing_env_vars from .env_vars import get_potentially_lib_path_containing_env_vars
# these are the most common libs names # these are the most common libs names
...@@ -111,14 +111,16 @@ class CUDASetup: ...@@ -111,14 +111,16 @@ class CUDASetup:
if torch.cuda.is_available(): if torch.cuda.is_available():
if 'BNB_CUDA_VERSION' in os.environ: if 'BNB_CUDA_VERSION' in os.environ:
if len(os.environ['BNB_CUDA_VERSION']) > 0: if len(os.environ['BNB_CUDA_VERSION']) > 0:
warn((f'\n\n{"="*80}\n' warn(
'WARNING: Manual override via BNB_CUDA_VERSION env variable detected!\n' f'\n\n{"=" * 80}\n'
'BNB_CUDA_VERSION=XXX can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n' 'WARNING: Manual override via BNB_CUDA_VERSION env variable detected!\n'
'If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n' 'BNB_CUDA_VERSION=XXX can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n'
'If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n' 'If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n'
'For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<path_to_cuda_dir/lib64\n' 'If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n'
f'Loading CUDA version: BNB_CUDA_VERSION={os.environ["BNB_CUDA_VERSION"]}' 'For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<path_to_cuda_dir/lib64\n'
f'\n{"="*80}\n\n')) f'Loading CUDA version: BNB_CUDA_VERSION={os.environ["BNB_CUDA_VERSION"]}'
f'\n{"=" * 80}\n\n'
)
binary_name = self.binary_name.rsplit(".", 1)[0] binary_name = self.binary_name.rsplit(".", 1)[0]
suffix = ".so" if os.name != "nt" else ".dll" suffix = ".so" if os.name != "nt" else ".dll"
self.binary_name = binary_name[:-3] + f'{os.environ["BNB_CUDA_VERSION"]}.{suffix}' self.binary_name = binary_name[:-3] + f'{os.environ["BNB_CUDA_VERSION"]}.{suffix}'
...@@ -207,7 +209,7 @@ def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]: ...@@ -207,7 +209,7 @@ def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]:
try: try:
if path.exists(): if path.exists():
existent_directories.add(path) existent_directories.add(path)
except PermissionError as pex: except PermissionError:
# Handle the PermissionError first as it is a subtype of OSError # Handle the PermissionError first as it is a subtype of OSError
# https://docs.python.org/3/library/exceptions.html#exception-hierarchy # https://docs.python.org/3/library/exceptions.html#exception-hierarchy
pass pass
...@@ -217,8 +219,10 @@ def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]: ...@@ -217,8 +219,10 @@ def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]:
non_existent_directories: Set[Path] = candidate_paths - existent_directories non_existent_directories: Set[Path] = candidate_paths - existent_directories
if non_existent_directories: if non_existent_directories:
CUDASetup.get_instance().add_log_entry("The following directories listed in your path were found to " CUDASetup.get_instance().add_log_entry(
f"be non-existent: {non_existent_directories}", is_warning=False) f"The following directories listed in your path were found to be non-existent: {non_existent_directories}",
is_warning=False,
)
return existent_directories return existent_directories
...@@ -360,8 +364,10 @@ def evaluate_cuda_setup(): ...@@ -360,8 +364,10 @@ def evaluate_cuda_setup():
cuda_version_string = get_cuda_version() cuda_version_string = get_cuda_version()
cuda_setup.add_log_entry(f"CUDA SETUP: PyTorch settings found: CUDA_VERSION={cuda_version_string}, Highest Compute Capability: {cc}.") cuda_setup.add_log_entry(f"CUDA SETUP: PyTorch settings found: CUDA_VERSION={cuda_version_string}, Highest Compute Capability: {cc}.")
cuda_setup.add_log_entry(f"CUDA SETUP: To manually override the PyTorch CUDA version please see:" cuda_setup.add_log_entry(
"https://github.com/TimDettmers/bitsandbytes/blob/main/how_to_use_nonpytorch_cuda.md") "CUDA SETUP: To manually override the PyTorch CUDA version please see:"
"https://github.com/TimDettmers/bitsandbytes/blob/main/how_to_use_nonpytorch_cuda.md"
)
# 7.5 is the minimum CC vor cublaslt # 7.5 is the minimum CC vor cublaslt
......
...@@ -3,17 +3,15 @@ ...@@ -3,17 +3,15 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import ctypes as ct import ctypes as ct
from functools import reduce # Required in Python 3
import itertools import itertools
import operator import operator
import random from typing import Any, Dict, Optional, Tuple
import torch
import itertools
import math
import numpy as np
from functools import reduce # Required in Python 3 import numpy as np
from typing import Tuple, Any, Dict, Optional import torch
from torch import Tensor from torch import Tensor
from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict
from .cextension import COMPILED_WITH_CUDA, lib from .cextension import COMPILED_WITH_CUDA, lib
...@@ -178,7 +176,9 @@ dtype2bytes[torch.bfloat16] = 2 ...@@ -178,7 +176,9 @@ dtype2bytes[torch.bfloat16] = 2
dtype2bytes[torch.uint8] = 1 dtype2bytes[torch.uint8] = 1
dtype2bytes[torch.int8] = 1 dtype2bytes[torch.int8] = 1
def get_paged(*shape, dtype=torch.float32, device=torch.device('cuda', index=0)): FIRST_CUDA_DEVICE = torch.device('cuda', index=0)
def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE):
num_bytes = dtype2bytes[dtype]*prod(shape) num_bytes = dtype2bytes[dtype]*prod(shape)
cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes)) cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes))
c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int)) c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int))
...@@ -242,7 +242,7 @@ def create_linear_map(signed=True, total_bits=8, add_zero=True): ...@@ -242,7 +242,7 @@ def create_linear_map(signed=True, total_bits=8, add_zero=True):
if gap == 0: if gap == 0:
return values return values
else: else:
l = values.numel()//2 l = values.numel()//2 # noqa: E741
return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist()) return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist())
...@@ -283,7 +283,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) ...@@ -283,7 +283,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
# the exponent is biased to 2^(e-1) -1 == 0 # the exponent is biased to 2^(e-1) -1 == 0
evalues = [] evalues = []
pvalues = [] pvalues = []
for i, val in enumerate(range(-((2**(exponent_bits-has_sign))), 2**(exponent_bits-has_sign), 1)): for i, val in enumerate(range(-(2**(exponent_bits-has_sign)), 2**(exponent_bits-has_sign), 1)):
evalues.append(2**val) evalues.append(2**val)
...@@ -345,7 +345,7 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): ...@@ -345,7 +345,7 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
non_sign_bits = total_bits - (1 if signed else 1) non_sign_bits = total_bits - (1 if signed else 1)
additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1 additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1
for i in range(max_exponent_bits): for i in range(max_exponent_bits):
fraction_items = int((2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1)) fraction_items = int(2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1)
boundaries = torch.linspace(0.1, 1, fraction_items) boundaries = torch.linspace(0.1, 1, fraction_items)
means = (boundaries[:-1] + boundaries[1:]) / 2.0 means = (boundaries[:-1] + boundaries[1:]) / 2.0
data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
...@@ -899,7 +899,7 @@ def get_4bit_type(typename, device=None, blocksize=64): ...@@ -899,7 +899,7 @@ def get_4bit_type(typename, device=None, blocksize=64):
-0.04934812, 0., 0.04273164, 0.12934483, 0.21961274, 0.31675666, -0.04934812, 0., 0.04273164, 0.12934483, 0.21961274, 0.31675666,
0.42563882, 0.55496234, 0.72424863, 1.][::-1] 0.42563882, 0.55496234, 0.72424863, 1.][::-1]
else: else:
raise NotImplementedError(f'4-bit AbnormalFloats currently only support blocksize 64.') raise NotImplementedError('4-bit AbnormalFloats currently only support blocksize 64.')
if data is None: if data is None:
raise NotImplementedError(f'Typename {typename} not supported') raise NotImplementedError(f'Typename {typename} not supported')
...@@ -1635,10 +1635,10 @@ def gemv_4bit( ...@@ -1635,10 +1635,10 @@ def gemv_4bit(
prev_device = pre_call(A.device) prev_device = pre_call(A.device)
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) #sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
if state is None: if state is None:
raise ValueError(f'state cannot None. gem_4bit( ) requires the state from quantize_4bit( )') raise ValueError('state cannot None. gem_4bit( ) requires the state from quantize_4bit( )')
if A.numel() != A.shape[-1]: if A.numel() != A.shape[-1]:
raise ValueError(f'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]') raise ValueError('Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]')
Bshape = state.shape Bshape = state.shape
bout = Bshape[0] bout = Bshape[0]
......
...@@ -2,5 +2,21 @@ ...@@ -2,5 +2,21 @@
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from .modules import Int8Params, Linear8bitLt, StableEmbedding, Linear4bit, LinearNF4, LinearFP4, Params4bit, OutlierAwareLinear, SwitchBackLinearBnb, Embedding from .modules import (
from .triton_based_modules import SwitchBackLinear, SwitchBackLinearGlobal, SwitchBackLinearVectorwise, StandardLinear Embedding,
Int8Params,
Linear4bit,
Linear8bitLt,
LinearFP4,
LinearNF4,
OutlierAwareLinear,
Params4bit,
StableEmbedding,
SwitchBackLinearBnb,
)
from .triton_based_modules import (
StandardLinear,
SwitchBackLinear,
SwitchBackLinearGlobal,
SwitchBackLinearVectorwise,
)
...@@ -3,17 +3,17 @@ ...@@ -3,17 +3,17 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import Any, Dict, Optional, TypeVar, Union, overload from typing import Any, Dict, Optional, TypeVar, Union, overload
import warnings import warnings
import torch import torch
import torch.nn.functional as F
from torch import Tensor, device, dtype, nn from torch import Tensor, device, dtype, nn
import torch.nn.functional as F
import bitsandbytes as bnb import bitsandbytes as bnb
from bitsandbytes.autograd._functions import get_tile_inds, undo_layout
from bitsandbytes.functional import QuantState from bitsandbytes.functional import QuantState
from bitsandbytes.autograd._functions import undo_layout, get_tile_inds
from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import OutlierTracer, find_outlier_dims from bitsandbytes.utils import OutlierTracer
T = TypeVar("T", bound="torch.nn.Module") T = TypeVar("T", bound="torch.nn.Module")
...@@ -242,10 +242,10 @@ class Linear4bit(nn.Linear): ...@@ -242,10 +242,10 @@ class Linear4bit(nn.Linear):
if self.compute_dtype == torch.float32 and (x.numel() == x.shape[-1]): if self.compute_dtype == torch.float32 and (x.numel() == x.shape[-1]):
# single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast # single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast
# warn the user about this # warn the user about this
warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.') warnings.warn('Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.')
warnings.filterwarnings('ignore', message='.*inference.') warnings.filterwarnings('ignore', message='.*inference.')
if self.compute_dtype == torch.float32 and (x.numel() != x.shape[-1]): if self.compute_dtype == torch.float32 and (x.numel() != x.shape[-1]):
warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.') warnings.warn('Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.')
warnings.filterwarnings('ignore', message='.*inference or training') warnings.filterwarnings('ignore', message='.*inference or training')
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_state_dict(self, destination, prefix, keep_vars):
...@@ -337,8 +337,8 @@ class Int8Params(torch.nn.Parameter): ...@@ -337,8 +337,8 @@ class Int8Params(torch.nn.Parameter):
del CBt del CBt
del SCBt del SCBt
self.data = CB self.data = CB
setattr(self, "CB", CB) self.CB = CB
setattr(self, "SCB", SCB) self.SCB = SCB
return self return self
......
import torch
import torch.nn as nn
import time
from functools import partial from functools import partial
from bitsandbytes.triton.triton_utils import is_triton_available import torch
import torch.nn as nn
from bitsandbytes.triton.dequantize_rowwise import dequantize_rowwise from bitsandbytes.triton.dequantize_rowwise import dequantize_rowwise
from bitsandbytes.triton.int8_matmul_mixed_dequantize import (
int8_matmul_mixed_dequantize,
)
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import (
int8_matmul_rowwise_dequantize,
)
from bitsandbytes.triton.quantize_columnwise_and_transpose import (
quantize_columnwise_and_transpose,
)
from bitsandbytes.triton.quantize_global import (
quantize_global,
quantize_global_transpose,
)
from bitsandbytes.triton.quantize_rowwise import quantize_rowwise from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose from bitsandbytes.triton.triton_utils import is_triton_available
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose
from bitsandbytes.triton.int8_matmul_mixed_dequantize import int8_matmul_mixed_dequantize
class _switchback_global(torch.autograd.Function): class _switchback_global(torch.autograd.Function):
......
...@@ -7,10 +7,17 @@ from bitsandbytes.cextension import COMPILED_WITH_CUDA ...@@ -7,10 +7,17 @@ from bitsandbytes.cextension import COMPILED_WITH_CUDA
from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
from .adam import Adam, Adam8bit, Adam32bit, PagedAdam, PagedAdam8bit, PagedAdam32bit from .adam import Adam, Adam8bit, Adam32bit, PagedAdam, PagedAdam8bit, PagedAdam32bit
from .adamw import AdamW, AdamW8bit, AdamW32bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit from .adamw import (
AdamW,
AdamW8bit,
AdamW32bit,
PagedAdamW,
PagedAdamW8bit,
PagedAdamW32bit,
)
from .lamb import LAMB, LAMB8bit, LAMB32bit from .lamb import LAMB, LAMB8bit, LAMB32bit
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
from .lion import Lion, Lion8bit, Lion32bit, PagedLion, PagedLion8bit, PagedLion32bit
from .optimizer import GlobalOptimManager from .optimizer import GlobalOptimManager
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
from .lion import Lion, Lion8bit, Lion32bit, PagedLion, PagedLion8bit, PagedLion32bit
from .sgd import SGD, SGD8bit, SGD32bit from .sgd import SGD, SGD8bit, SGD32bit
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
from bitsandbytes.optim.optimizer import Optimizer2State from bitsandbytes.optim.optimizer import Optimizer2State
class AdamW(Optimizer2State): class AdamW(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer1State from bitsandbytes.optim.optimizer import Optimizer1State
class Lion(Optimizer1State): class Lion(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
......
...@@ -2,8 +2,7 @@ ...@@ -2,8 +2,7 @@
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from collections import abc as container_abcs from collections import abc as container_abcs, defaultdict
from collections import defaultdict
from copy import deepcopy from copy import deepcopy
from itertools import chain from itertools import chain
......
from . import nn from . import nn
from .autograd._functions import ( from .autograd._functions import (
switchback_bnb,
matmul_fp8_global, matmul_fp8_global,
matmul_fp8_mixed, matmul_fp8_mixed,
switchback_bnb,
) )
import operator
import warnings
from dataclasses import dataclass
from functools import reduce # Required in Python 3 from functools import reduce # Required in Python 3
import operator
from typing import Optional from typing import Optional
import warnings
import torch import torch
from bitsandbytes.autograd._functions import GlobalOutlierPooler, MatmulLtState
import bitsandbytes.functional as F import bitsandbytes.functional as F
from bitsandbytes.autograd._functions import MatmulLtState, GlobalOutlierPooler
# math.prod not compatible with python < 3.8 # math.prod not compatible with python < 3.8
def prod(iterable): def prod(iterable):
...@@ -186,7 +184,9 @@ class MatMulFP8Global(torch.autograd.Function): ...@@ -186,7 +184,9 @@ class MatMulFP8Global(torch.autograd.Function):
class SwitchBackBnb(torch.autograd.Function): class SwitchBackBnb(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # TODO: the B008 on the line below is a likely bug; the current implementation will
# have each SwitchBackBnb instance share a single MatmulLtState instance!!!
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B008
# default to pytorch behavior if inputs are empty # default to pytorch behavior if inputs are empty
ctx.is_empty = False ctx.is_empty = False
if prod(A.shape) == 0: if prod(A.shape) == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment