Unverified Commit 706ec24d authored by Aarni Koskela's avatar Aarni Koskela Committed by GitHub
Browse files

Ruff fixes (#984)



* Adjust Ruff configuration

* do not autofix always
* be less strict around tests and benchmarks
* adjust ignores for now

* Ruff: autofix I and F401

* Apply ruff autofixes

* Fix RUF013 complaint

* Fix mutable default in replace_linear

* Don't use bare except

* Wrap bitsandbytes.__main__ entrypoint in function; fix "sensible" typo

* Fix ruff B008 (function call in arguments)

* Add ruff noqas as suitable

* Fix RUF005 (splat instead of concatenating)

* Fix B018 (useless expression)

* Add pre-commit configuration + GitHub Actions lint workflow

* Fix unused `e` in bitsandbytes/__main__.py

* fix merge conflict resolution error

* run pre-commit hook

---------
Co-authored-by: default avatarTitus <9048635+Titus-von-Koeller@users.noreply.github.com>
parent a8c9dfa6
from .modules import LinearFP8Mixed, LinearFP8Global from .modules import LinearFP8Global, LinearFP8Mixed
from typing import Optional, TypeVar, Union, overload from typing import TypeVar
import torch import torch
import torch.nn.functional as F from torch import nn
from torch import Tensor, device, dtype, nn
import bitsandbytes as bnb import bitsandbytes as bnb
from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import OutlierTracer, find_outlier_dims
T = TypeVar("T", bound="torch.nn.Module") T = TypeVar("T", bound="torch.nn.Module")
......
import math import math
import torch import torch
import time
from bitsandbytes.triton.triton_utils import is_triton_available from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available(): if not is_triton_available():
...@@ -9,7 +10,6 @@ else: ...@@ -9,7 +10,6 @@ else:
import triton import triton
import triton.language as tl import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# rowwise quantize # rowwise quantize
......
import torch import torch
from bitsandbytes.triton.triton_utils import is_triton_available from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available(): if not is_triton_available():
...@@ -57,7 +58,8 @@ else: ...@@ -57,7 +58,8 @@ else:
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
] + get_configs_io_bound(), *get_configs_io_bound(),
],
key=['M', 'N', 'K'], key=['M', 'N', 'K'],
prune_configs_by={ prune_configs_by={
'early_config_prune': early_config_prune, 'early_config_prune': early_config_prune,
......
...@@ -57,7 +57,8 @@ else: ...@@ -57,7 +57,8 @@ else:
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
] + get_configs_io_bound(), *get_configs_io_bound(),
],
key=['M', 'N', 'K'], key=['M', 'N', 'K'],
prune_configs_by={ prune_configs_by={
'early_config_prune': early_config_prune, 'early_config_prune': early_config_prune,
......
import math import math
import torch import torch
import time
from bitsandbytes.triton.triton_utils import is_triton_available from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available(): if not is_triton_available():
...@@ -9,7 +10,6 @@ else: ...@@ -9,7 +10,6 @@ else:
import triton import triton
import triton.language as tl import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# This kernel does fused columnwise quantization and transpose. # This kernel does fused columnwise quantization and transpose.
......
import math
import torch import torch
import time
from bitsandbytes.triton.triton_utils import is_triton_available from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available(): if not is_triton_available():
...@@ -10,7 +10,6 @@ else: ...@@ -10,7 +10,6 @@ else:
import triton import triton
import triton.language as tl import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# global quantize # global quantize
@triton.autotune( @triton.autotune(
......
import math import math
import torch import torch
import time
from bitsandbytes.triton.triton_utils import is_triton_available from bitsandbytes.triton.triton_utils import is_triton_available
...@@ -10,7 +10,6 @@ else: ...@@ -10,7 +10,6 @@ else:
import triton import triton
import triton.language as tl import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# rowwise quantize # rowwise quantize
......
import importlib import importlib
def is_triton_available(): def is_triton_available():
return importlib.util.find_spec("triton") is not None return importlib.util.find_spec("triton") is not None
import json import json
import shlex import shlex
import subprocess import subprocess
import torch
from typing import Tuple from typing import Tuple
import torch
def outlier_hook(module, input): def outlier_hook(module, input):
assert isinstance(module, torch.nn.Linear) assert isinstance(module, torch.nn.Linear)
tracer = OutlierTracer.get_instance() tracer = OutlierTracer.get_instance()
...@@ -37,7 +39,7 @@ def outlier_hook(module, input): ...@@ -37,7 +39,7 @@ def outlier_hook(module, input):
hook.remove() hook.remove()
class OutlierTracer(object): class OutlierTracer:
_instance = None _instance = None
def __init__(self): def __init__(self):
...@@ -122,7 +124,13 @@ def execute_and_return(command_string: str) -> Tuple[str, str]: ...@@ -122,7 +124,13 @@ def execute_and_return(command_string: str) -> Tuple[str, str]:
def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_weights=False, post_processing_function=None): def replace_linear(
model,
linear_replacement,
skip_modules=("lm_head",),
copy_weights=False,
post_processing_function=None,
):
""" """
Replace linear modules with a new Linear module. Replace linear modules with a new Linear module.
Parameters: Parameters:
......
import bitsandbytes as bnb
import torch import torch
import bitsandbytes as bnb
p = torch.nn.Parameter(torch.rand(10,10).cuda()) p = torch.nn.Parameter(torch.rand(10,10).cuda())
a = torch.rand(10,10).cuda() a = torch.rand(10,10).cuda()
......
import os import os
import sys
import subprocess import subprocess
import sys
from urllib.request import urlretrieve from urllib.request import urlretrieve
cuda_versions = { cuda_versions = {
......
...@@ -11,9 +11,7 @@ src = [ ...@@ -11,9 +11,7 @@ src = [
"tests", "tests",
"benchmarking" "benchmarking"
] ]
fix = true
select = [ select = [
"A", # prevent using keywords that clobber python builtins
"B", # bugbear: security warnings "B", # bugbear: security warnings
"E", # pycodestyle "E", # pycodestyle
"F", # pyflakes "F", # pyflakes
...@@ -24,12 +22,29 @@ select = [ ...@@ -24,12 +22,29 @@ select = [
] ]
target-version = "py38" target-version = "py38"
ignore = [ ignore = [
"E712", # Allow using if x == False, as it's not always equivalent to if x. "B007", # Loop control variable not used within the loop body (TODO: enable)
"B028", # Warning without stacklevel (TODO: enable)
"E501", # Supress line-too-long warnings: trust yapf's judgement on this one. "E501", # Supress line-too-long warnings: trust yapf's judgement on this one.
"F401", "E701", # Multiple statements on one line (TODO: enable)
"E712", # Allow using if x == False, as it's not always equivalent to if x.
"E731", # Do not use lambda
"F841", # Local assigned but not used (TODO: enable, these are likely bugs)
"RUF012", # Mutable class attribute annotations
] ]
ignore-init-module-imports = true # allow to expose in __init__.py via imports ignore-init-module-imports = true # allow to expose in __init__.py via imports
[tool.ruff.extend-per-file-ignores]
"**/__init__.py" = ["F401"] # allow unused imports in __init__.py
"{benchmarking,tests}/**/*.py" = [
"B007",
"B011",
"B023",
"E701",
"E731",
"F841",
"UP030",
]
[tool.ruff.isort] [tool.ruff.isort]
combine-as-imports = true combine-as-imports = true
detect-same-package = true detect-same-package = true
......
...@@ -15,13 +15,11 @@ ...@@ -15,13 +15,11 @@
Script to close stale issue. Taken in part from the AllenNLP repository. Script to close stale issue. Taken in part from the AllenNLP repository.
https://github.com/allenai/allennlp. https://github.com/allenai/allennlp.
""" """
from datetime import datetime as dt, timezone
import os import os
from datetime import datetime as dt
from datetime import timezone
from github import Github from github import Github
# All labels that we don't want to touch # All labels that we don't want to touch
LABELS_TO_EXEMPT = [ LABELS_TO_EXEMPT = [
"feature-request", "feature-request",
......
...@@ -7,7 +7,6 @@ import os ...@@ -7,7 +7,6 @@ import os
from setuptools import find_packages, setup from setuptools import find_packages, setup
libs = list(glob.glob("./bitsandbytes/libbitsandbytes*.so")) libs = list(glob.glob("./bitsandbytes/libbitsandbytes*.so"))
libs += list(glob.glob("./bitsandbytes/libbitsandbytes*.dll")) libs += list(glob.glob("./bitsandbytes/libbitsandbytes*.dll"))
libs = [os.path.basename(p) for p in libs] libs = [os.path.basename(p) for p in libs]
...@@ -19,7 +18,7 @@ def read(fname): ...@@ -19,7 +18,7 @@ def read(fname):
setup( setup(
name=f"bitsandbytes", name="bitsandbytes",
version="0.42.0", version="0.42.0",
author="Tim Dettmers", author="Tim Dettmers",
author_email="dettmers@cs.washington.edu", author_email="dettmers@cs.washington.edu",
......
from itertools import permutations, product from itertools import product
import pytest import pytest
import torch import torch
......
import os import os
import pytest
import torch
from pathlib import Path from pathlib import Path
import torch
# hardcoded test. Not good, but a sanity check for now # hardcoded test. Not good, but a sanity check for now
# TODO: improve this # TODO: improve this
def test_manual_override(requires_cuda): def test_manual_override(requires_cuda):
......
from itertools import product
import math import math
import random import random
import time import time
from itertools import product
import einops import einops
import numpy as np
import pytest import pytest
from scipy.stats import norm
import torch import torch
import numpy as np
import bitsandbytes as bnb import bitsandbytes as bnb
from bitsandbytes import functional as F from bitsandbytes import functional as F
from scipy.stats import norm
torch.set_printoptions( torch.set_printoptions(
precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000 precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
......
import pytest
import torch
import math
from itertools import product from itertools import product
import math
import pytest
import torch
import transformers import transformers
from transformers import ( from transformers import (
AutoConfig,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig, BitsAndBytesConfig,
GenerationConfig,
set_seed,
) )
def get_4bit_config(): def get_4bit_config():
return BitsAndBytesConfig( return BitsAndBytesConfig(
load_in_4bit=True, load_in_4bit=True,
......
import os
from contextlib import nullcontext
from itertools import product from itertools import product
import os
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
import pytest import pytest
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment