Unverified Commit 048a2d40 authored by Aarni Koskela's avatar Aarni Koskela Committed by GitHub
Browse files

Deduplicate helpers & fix lint issues from #1099 (#1107)

parent a1c0844b
from io import BytesIO
from itertools import product from itertools import product
import random import random
from typing import Any, List from typing import Any, List
...@@ -7,6 +8,25 @@ import torch ...@@ -7,6 +8,25 @@ import torch
test_dims_rng = random.Random(42) test_dims_rng = random.Random(42)
TRUE_FALSE = (True, False)
BOOLEAN_TRIPLES = list(product(TRUE_FALSE, repeat=3)) # all combinations of (bool, bool, bool)
BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (bool, bool)
def torch_save_to_buffer(obj):
buffer = BytesIO()
torch.save(obj, buffer)
buffer.seek(0)
return buffer
def torch_load_from_buffer(buffer):
buffer.seek(0)
obj = torch.load(buffer)
buffer.seek(0)
return obj
def get_test_dims(min: int, max: int, *, n: int) -> List[int]: def get_test_dims(min: int, max: int, *, n: int) -> List[int]:
return [test_dims_rng.randint(min, max) for _ in range(n)] return [test_dims_rng.randint(min, max) for _ in range(n)]
...@@ -42,10 +62,3 @@ DTYPE_NAMES = { ...@@ -42,10 +62,3 @@ DTYPE_NAMES = {
def describe_dtype(dtype: torch.dtype) -> str: def describe_dtype(dtype: torch.dtype) -> str:
return DTYPE_NAMES.get(dtype) or str(dtype).rpartition(".")[2] return DTYPE_NAMES.get(dtype) or str(dtype).rpartition(".")[2]
TRUE_FALSE = (True, False)
BOOLEAN_TRIPLES = list(
product(TRUE_FALSE, repeat=3)
) # all combinations of (bool, bool, bool)
BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (bool, bool)
import copy import copy
from io import BytesIO
import os import os
import pickle import pickle
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
...@@ -8,7 +7,7 @@ import pytest ...@@ -8,7 +7,7 @@ import pytest
import torch import torch
import bitsandbytes as bnb import bitsandbytes as bnb
from tests.helpers import TRUE_FALSE from tests.helpers import TRUE_FALSE, torch_load_from_buffer, torch_save_to_buffer
storage = { storage = {
"uint8": torch.uint8, "uint8": torch.uint8,
...@@ -17,17 +16,6 @@ storage = { ...@@ -17,17 +16,6 @@ storage = {
"float32": torch.float32, "float32": torch.float32,
} }
def torch_save_to_buffer(obj):
buffer = BytesIO()
torch.save(obj, buffer)
buffer.seek(0)
return buffer
def torch_load_from_buffer(buffer):
buffer.seek(0)
obj = torch.load(buffer)
buffer.seek(0)
return obj
@pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"]) @pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"])
@pytest.mark.parametrize("bias", TRUE_FALSE) @pytest.mark.parametrize("bias", TRUE_FALSE)
......
from contextlib import nullcontext from contextlib import nullcontext
from io import BytesIO
import os import os
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
...@@ -10,7 +9,12 @@ import bitsandbytes as bnb ...@@ -10,7 +9,12 @@ import bitsandbytes as bnb
from bitsandbytes import functional as F from bitsandbytes import functional as F
from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout
from bitsandbytes.nn.modules import Linear8bitLt from bitsandbytes.nn.modules import Linear8bitLt
from tests.helpers import TRUE_FALSE, id_formatter from tests.helpers import (
TRUE_FALSE,
id_formatter,
torch_load_from_buffer,
torch_save_to_buffer,
)
# contributed by Alex Borzunov, see: # contributed by Alex Borzunov, see:
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py # https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
...@@ -66,17 +70,6 @@ def test_linear_no_igemmlt(): ...@@ -66,17 +70,6 @@ def test_linear_no_igemmlt():
assert linear_custom.state.CB is not None assert linear_custom.state.CB is not None
assert linear_custom.state.CxB is None assert linear_custom.state.CxB is None
def torch_save_to_buffer(obj):
buffer = BytesIO()
torch.save(obj, buffer)
buffer.seek(0)
return buffer
def torch_load_from_buffer(buffer):
buffer.seek(0)
obj = torch.load(buffer)
buffer.seek(0)
return obj
@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights")) @pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
@pytest.mark.parametrize("serialize_before_forward", TRUE_FALSE, ids=id_formatter("serialize_before_forward")) @pytest.mark.parametrize("serialize_before_forward", TRUE_FALSE, ids=id_formatter("serialize_before_forward"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment