Unverified Commit 1088ec52 authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

Updates for device agnosticism (#1601)

* Include device support tags for transformers multi-backend compatability; add xpu() and cpu() to Params4bit

* Make test suite more device-agnostic

* Additional device agnostic tests

* Additional device agnosticism for tests

* Add BNB_TEST_DEVICE env var to manually select device for unit tests

* Include device support tags for transformers multi-backend compatability; add xpu() and cpu() to Params4bit

* Make test suite more device-agnostic

* Additional device agnostic tests

* Additional device agnosticism for tests

* Add BNB_TEST_DEVICE env var to manually select device for unit tests

* Small bugfix for int8 test

* Exclude backward() from code coverage reports

* Params4bit: don't try to quantize when moving to meta device
parent 97073cdb
...@@ -20,15 +20,15 @@ from .nn import modules ...@@ -20,15 +20,15 @@ from .nn import modules
from .optim import adam from .optim import adam
# This is a signal for integrations with transformers/diffusers. # This is a signal for integrations with transformers/diffusers.
# Eventually, we will remove this and check based on release version. # Eventually we may remove this but it is currently required for compatibility.
features = {"multi-backend"} features = {"multi-backend"}
supported_torch_devices = { supported_torch_devices = {
"cuda",
"cpu", "cpu",
# "mps", "cuda", # NVIDIA/AMD GPU
# "xpu", "xpu", # Intel GPU
# "hpu", "hpu", # Gaudi
# "npu", "npu", # Ascend NPU
"mps", # Apple Silicon
} }
if torch.cuda.is_available(): if torch.cuda.is_available():
......
...@@ -284,7 +284,7 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -284,7 +284,7 @@ class MatMul8bitLt(torch.autograd.Function):
dtype=torch.float16, dtype=torch.float16,
) )
if state.threshold > 0.0 and subA is not None: if state.threshold > 0.0 and subA is not None and subA.numel() > 0:
grad_B[:, idx] += torch.matmul(grad_output.t(), subA) grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
if req_gradA: if req_gradA:
......
...@@ -341,7 +341,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) ...@@ -341,7 +341,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
for i in range(gap): for i in range(gap):
values.append(0) values.append(0)
values.sort() values.sort()
code = torch.Tensor(values) code = torch.tensor(values)
code /= code.max() code /= code.max()
return code return code
......
...@@ -306,9 +306,15 @@ class Params4bit(torch.nn.Parameter): ...@@ -306,9 +306,15 @@ class Params4bit(torch.nn.Parameter):
self.bnb_quantized = True self.bnb_quantized = True
return self return self
def cpu(self):
return self.to(device="cpu")
def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False): def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
return self.to(device="cuda" if device is None else device, non_blocking=non_blocking) return self.to(device="cuda" if device is None else device, non_blocking=non_blocking)
def xpu(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
return self.to(device="xpu" if device is None else device, non_blocking=non_blocking)
@overload @overload
def to( def to(
self: T, self: T,
...@@ -326,7 +332,7 @@ class Params4bit(torch.nn.Parameter): ...@@ -326,7 +332,7 @@ class Params4bit(torch.nn.Parameter):
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None and device.type == "cuda" and not self.bnb_quantized: if device is not None and device.type != "meta" and not self.bnb_quantized:
return self._quantize(device) return self._quantize(device)
else: else:
if self.quant_state is not None: if self.quant_state is not None:
......
...@@ -79,6 +79,12 @@ include = ["bitsandbytes*"] ...@@ -79,6 +79,12 @@ include = ["bitsandbytes*"]
[tool.setuptools.dynamic] [tool.setuptools.dynamic]
version = {attr = "bitsandbytes.__version__"} version = {attr = "bitsandbytes.__version__"}
[tool.coverage.report]
exclude_also = [
# exclude backward() functions from coverage, as they are invoked from C++
'def backward\(ctx'
]
[tool.pytest.ini_options] [tool.pytest.ini_options]
addopts = "-rP -m 'not slow and not benchmark and not deprecated'" addopts = "-rP -m 'not slow and not benchmark and not deprecated'"
# ; --cov=bitsandbytes # ; --cov=bitsandbytes
......
import functools
from io import BytesIO from io import BytesIO
from itertools import product from itertools import product
import os
import random import random
from typing import Any from typing import Any
...@@ -13,6 +15,38 @@ BOOLEAN_TRIPLES = list(product(TRUE_FALSE, repeat=3)) # all combinations of (bo ...@@ -13,6 +15,38 @@ BOOLEAN_TRIPLES = list(product(TRUE_FALSE, repeat=3)) # all combinations of (bo
BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (bool, bool) BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (bool, bool)
@functools.cache
def get_available_devices():
if "BNB_TEST_DEVICE" in os.environ:
# If the environment variable is set, use it directly.
return [os.environ["BNB_TEST_DEVICE"]]
devices = ["cpu"]
if hasattr(torch, "accelerator"):
# PyTorch 2.6+ - determine accelerator using agnostic API.
if torch.accelerator.is_available():
devices += [str(torch.accelerator.current_accelerator())]
else:
if torch.cuda.is_available():
devices += ["cuda"]
if torch.backends.mps.is_available():
devices += ["mps"]
if hasattr(torch, "xpu") and torch.xpu.is_available():
devices += ["xpu"]
custom_backend_name = torch._C._get_privateuse1_backend_name()
custom_backend_module = getattr(torch, custom_backend_name, None)
custom_backend_is_available_fn = getattr(custom_backend_module, "is_available", None)
if custom_backend_is_available_fn and custom_backend_module.is_available():
devices += [custom_backend_name]
return devices
def torch_save_to_buffer(obj): def torch_save_to_buffer(obj):
buffer = BytesIO() buffer = BytesIO()
torch.save(obj, buffer) torch.save(obj, buffer)
......
...@@ -6,12 +6,14 @@ from tests.helpers import ( ...@@ -6,12 +6,14 @@ from tests.helpers import (
BOOLEAN_TRIPLES, BOOLEAN_TRIPLES,
TRUE_FALSE, TRUE_FALSE,
describe_dtype, describe_dtype,
get_available_devices,
id_formatter, id_formatter,
) )
TRANSPOSE_VALS = [(False, True), (False, False)] TRANSPOSE_VALS = [(False, True), (False, False)]
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dim1", [40], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [40], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3")) @pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3"))
...@@ -27,10 +29,16 @@ TRANSPOSE_VALS = [(False, True), (False, False)] ...@@ -27,10 +29,16 @@ TRANSPOSE_VALS = [(False, True), (False, False)]
@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose")) @pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights")) @pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) @pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias): def test_matmullt(
device, dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias
):
if device != "cuda" and funcs[1] == bnb.research.switchback_bnb:
# TODO: Deprecate/remove?
pytest.skip("switchback_bnb only works on CUDA.")
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda") outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device=device)
if has_bias == False: if has_bias == False:
req_grad = list(req_grad) req_grad = list(req_grad)
req_grad[2] = False req_grad[2] = False
...@@ -38,21 +46,21 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec ...@@ -38,21 +46,21 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
for i in range(3): for i in range(3):
# normal multiply # normal multiply
if funcs[0] in [torch.mm, torch.matmul]: if funcs[0] in [torch.mm, torch.matmul]:
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype) A = torch.randn(size=dimA, device=device, requires_grad=req_grad[0], dtype=dtype)
if decomp == 6.0: if decomp == 6.0:
with torch.no_grad(): with torch.no_grad():
A[:, outlier_dim] = 6.0 A[:, outlier_dim] = 6.0
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype) B = torch.randn(size=dimB, device=device, requires_grad=req_grad[1], dtype=dtype)
target = torch.randn( target = torch.randn(
size=(dim2, dim4), size=(dim2, dim4),
device="cuda", device=device,
requires_grad=req_grad[1], requires_grad=req_grad[1],
dtype=dtype, dtype=dtype,
) )
bias = None bias = None
bias2 = None bias2 = None
if has_bias: if has_bias:
bias = torch.randn(dim4, device="cuda", dtype=dtype, requires_grad=req_grad[2]) bias = torch.randn(dim4, device=device, dtype=dtype, requires_grad=req_grad[2])
bias2 = bias.clone() bias2 = bias.clone()
torch.nn.init.xavier_uniform_(B) torch.nn.init.xavier_uniform_(B)
B2 = B.clone() B2 = B.clone()
...@@ -91,6 +99,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec ...@@ -91,6 +99,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
if has_fp16_weights: if has_fp16_weights:
if any(req_grad): if any(req_grad):
out_bnb.data.copy_(out_torch) out_bnb.data.copy_(out_torch)
if device == "cuda":
torch.cuda.synchronize() torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward() loss_bnb.backward()
...@@ -135,6 +144,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec ...@@ -135,6 +144,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
torch.testing.assert_close(gradBias1, gradBias2) torch.testing.assert_close(gradBias1, gradBias2)
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dim1", [48], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [48], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [64], ids=id_formatter("dim3")) @pytest.mark.parametrize("dim3", [64], ids=id_formatter("dim3"))
...@@ -147,6 +157,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec ...@@ -147,6 +157,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"], ids=id_formatter("quant_type")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"], ids=id_formatter("quant_type"))
def test_matmul_4bit( def test_matmul_4bit(
device,
dim1, dim1,
dim2, dim2,
dim3, dim3,
...@@ -159,6 +170,9 @@ def test_matmul_4bit( ...@@ -159,6 +170,9 @@ def test_matmul_4bit(
compress_statistics, compress_statistics,
quant_type, quant_type,
): ):
if device == "cpu" and quant_type == "fp4":
pytest.skip("Only nf4 is supported on CPU")
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
if has_bias == False: if has_bias == False:
...@@ -168,13 +182,13 @@ def test_matmul_4bit( ...@@ -168,13 +182,13 @@ def test_matmul_4bit(
for i in range(3): for i in range(3):
# normal multiply # normal multiply
if funcs[0] in [torch.mm, torch.matmul]: if funcs[0] in [torch.mm, torch.matmul]:
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype) A = torch.randn(size=dimA, device=device, requires_grad=req_grad[0], dtype=dtype)
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype) B = torch.randn(size=dimB, device=device, requires_grad=req_grad[1], dtype=dtype)
target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype) target = torch.randn(size=(dim2, dim4), device=device, requires_grad=req_grad[1], dtype=dtype)
bias = None bias = None
bias2 = None bias2 = None
if has_bias: if has_bias:
bias = torch.randn(dim4, device="cuda", dtype=dtype, requires_grad=req_grad[2]) bias = torch.randn(dim4, device=device, dtype=dtype, requires_grad=req_grad[2])
bias2 = bias.clone() bias2 = bias.clone()
torch.nn.init.xavier_uniform_(B) torch.nn.init.xavier_uniform_(B)
...@@ -204,6 +218,7 @@ def test_matmul_4bit( ...@@ -204,6 +218,7 @@ def test_matmul_4bit(
# assert err < 0.20 # assert err < 0.20
if any(req_grad): if any(req_grad):
out_bnb.data.copy_(out_torch) out_bnb.data.copy_(out_torch)
if device == "cuda":
torch.cuda.synchronize() torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward() loss_bnb.backward()
......
This diff is collapsed.
...@@ -7,7 +7,7 @@ import pytest ...@@ -7,7 +7,7 @@ import pytest
import torch import torch
import bitsandbytes as bnb import bitsandbytes as bnb
from tests.helpers import TRUE_FALSE, torch_load_from_buffer, torch_save_to_buffer from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, torch_load_from_buffer, torch_save_to_buffer
storage = { storage = {
"uint8": torch.uint8, "uint8": torch.uint8,
...@@ -17,15 +17,18 @@ storage = { ...@@ -17,15 +17,18 @@ storage = {
} }
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"]) @pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"])
@pytest.mark.parametrize("bias", TRUE_FALSE) @pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias"))
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE) @pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward"))
def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage, save_before_forward): def test_linear_serialization(device, quant_type, compress_statistics, bias, quant_storage, save_before_forward):
if device == "cpu":
pytest.xfail("Dequantization is not yet implemented for CPU")
original_dtype = torch.float16 original_dtype = torch.float16
compute_dtype = None compute_dtype = None
device = "cuda"
layer_shape = (300, 400) layer_shape = (300, 400)
linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu") # original layer linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu") # original layer
...@@ -52,7 +55,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora ...@@ -52,7 +55,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
# restoring from state_dict: # restoring from state_dict:
bias_data2 = sd.pop("bias", None) bias_data2 = sd.pop("bias", None)
weight_data2 = sd.pop("weight") weight_data2 = sd.pop("weight")
weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2) weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2, device=device)
# creating new layer with same params: # creating new layer with same params:
linear_q2 = bnb.nn.Linear4bit( linear_q2 = bnb.nn.Linear4bit(
...@@ -174,18 +177,50 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora ...@@ -174,18 +177,50 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
assert size_ratio < target_compression, ratio_error_msg assert size_ratio < target_compression, ratio_error_msg
def test_copy_param(): @pytest.mark.parametrize("device", get_available_devices())
tensor = torch.tensor([1.0, 2.0, 3.0, 4.0]) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0) @pytest.mark.parametrize("blocksize", [64, 128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_copy_param(device, quant_type, blocksize, compress_statistics):
if device == "cpu":
if compress_statistics:
pytest.skip("Currently segfaults on CPU")
if quant_type == "fp4":
pytest.xfail("FP4 not supported on CPU")
tensor = torch.linspace(1, blocksize, blocksize)
param = bnb.nn.Params4bit(
data=tensor,
quant_type=quant_type,
blocksize=blocksize,
compress_statistics=compress_statistics,
requires_grad=False,
).to(device)
shallow_copy_param = copy.copy(param) shallow_copy_param = copy.copy(param)
assert param.quant_state is shallow_copy_param.quant_state assert param.quant_state is shallow_copy_param.quant_state
assert param.data.data_ptr() == shallow_copy_param.data.data_ptr() assert param.data.data_ptr() == shallow_copy_param.data.data_ptr()
def test_deepcopy_param(): @pytest.mark.parametrize("device", get_available_devices())
tensor = torch.tensor([1.0, 2.0, 3.0, 4.0]) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0) @pytest.mark.parametrize("blocksize", [64, 128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):
if device == "cpu":
if compress_statistics:
pytest.skip("Currently segfaults on CPU")
if quant_type == "fp4":
pytest.xfail("FP4 not supported on CPU")
tensor = torch.linspace(1, blocksize, blocksize)
param = bnb.nn.Params4bit(
data=tensor,
quant_type=quant_type,
blocksize=blocksize,
compress_statistics=compress_statistics,
requires_grad=False,
).to(device)
dict_keys_before = set(param.__dict__.keys()) dict_keys_before = set(param.__dict__.keys())
copy_param = copy.deepcopy(param) copy_param = copy.deepcopy(param)
dict_keys_after = set(param.__dict__.keys()) dict_keys_after = set(param.__dict__.keys())
...@@ -199,12 +234,27 @@ def test_deepcopy_param(): ...@@ -199,12 +234,27 @@ def test_deepcopy_param():
assert dict_keys_before == dict_keys_copy assert dict_keys_before == dict_keys_copy
def test_params4bit_real_serialization(): @pytest.mark.parametrize("device", get_available_devices())
original_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
original_param = bnb.nn.Params4bit(data=original_tensor, quant_type="fp4") @pytest.mark.parametrize("blocksize", [64, 128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics):
if device == "cpu":
if compress_statistics:
pytest.skip("Currently segfaults on CPU")
if quant_type == "fp4":
pytest.xfail("FP4 not supported on CPU")
original_tensor = torch.linspace(1, blocksize, blocksize, dtype=torch.float32)
original_param = bnb.nn.Params4bit(
data=original_tensor,
quant_type=quant_type,
blocksize=blocksize,
compress_statistics=compress_statistics,
)
dict_keys_before = set(original_param.__dict__.keys()) dict_keys_before = set(original_param.__dict__.keys())
original_param.cuda(0) # move to CUDA to trigger quantization original_param.to(device) # change device to trigger quantization
serialized_param = pickle.dumps(original_param) serialized_param = pickle.dumps(original_param)
deserialized_param = pickle.loads(serialized_param) deserialized_param = pickle.loads(serialized_param)
......
...@@ -11,6 +11,7 @@ import bitsandbytes as bnb ...@@ -11,6 +11,7 @@ import bitsandbytes as bnb
from bitsandbytes.nn.modules import Linear8bitLt from bitsandbytes.nn.modules import Linear8bitLt
from tests.helpers import ( from tests.helpers import (
TRUE_FALSE, TRUE_FALSE,
get_available_devices,
id_formatter, id_formatter,
torch_load_from_buffer, torch_load_from_buffer,
torch_save_to_buffer, torch_save_to_buffer,
...@@ -19,7 +20,11 @@ from tests.helpers import ( ...@@ -19,7 +20,11 @@ from tests.helpers import (
# contributed by Alex Borzunov, see: # contributed by Alex Borzunov, see:
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py # https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
def test_linear_no_igemmlt(): @pytest.mark.parametrize("device", get_available_devices())
def test_linear_no_igemmlt(device):
if device == "cpu":
pytest.xfail("Not yet implemented on CPU")
linear = torch.nn.Linear(1024, 3072) linear = torch.nn.Linear(1024, 3072)
x = torch.randn(3, 1024, dtype=torch.half) x = torch.randn(3, 1024, dtype=torch.half)
linear_custom = Linear8bitLt( linear_custom = Linear8bitLt(
...@@ -29,6 +34,8 @@ def test_linear_no_igemmlt(): ...@@ -29,6 +34,8 @@ def test_linear_no_igemmlt():
has_fp16_weights=False, has_fp16_weights=False,
threshold=6.0, threshold=6.0,
) )
# TODO: Remove, this is no longer implemented
linear_custom.state.force_no_igemmlt = True linear_custom.state.force_no_igemmlt = True
linear_custom.weight = bnb.nn.Int8Params( linear_custom.weight = bnb.nn.Int8Params(
...@@ -37,11 +44,11 @@ def test_linear_no_igemmlt(): ...@@ -37,11 +44,11 @@ def test_linear_no_igemmlt():
has_fp16_weights=False, has_fp16_weights=False,
).to(linear.weight.dtype) ).to(linear.weight.dtype)
linear_custom.bias = linear.bias linear_custom.bias = linear.bias
linear_custom = linear_custom.cuda() linear_custom = linear_custom.to(device)
linear = linear.half().cuda() linear = linear.half().to(device)
x_ref = x.clone().cuda().requires_grad_(True) x_ref = x.clone().to(device).requires_grad_(True)
x_ours = x.clone().cuda().requires_grad_(True) x_ours = x.clone().to(device).requires_grad_(True)
fx_ref = linear(x_ref).float() fx_ref = linear(x_ref).float()
grad_proj = torch.randn_like(fx_ref) grad_proj = torch.randn_like(fx_ref)
(fx_ref * grad_proj).mean().backward() (fx_ref * grad_proj).mean().backward()
...@@ -58,18 +65,25 @@ def test_linear_no_igemmlt(): ...@@ -58,18 +65,25 @@ def test_linear_no_igemmlt():
torch.testing.assert_close(x_ref.grad, x_ours.grad, atol=0.01, rtol=1e-5) torch.testing.assert_close(x_ref.grad, x_ours.grad, atol=0.01, rtol=1e-5)
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights")) @pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
@pytest.mark.parametrize("threshold", [0.0, 6.0], ids=id_formatter("threshold"))
@pytest.mark.parametrize("serialize_before_forward", TRUE_FALSE, ids=id_formatter("serialize_before_forward")) @pytest.mark.parametrize("serialize_before_forward", TRUE_FALSE, ids=id_formatter("serialize_before_forward"))
@pytest.mark.parametrize("deserialize_before_cuda", TRUE_FALSE, ids=id_formatter("deserialize_before_cuda")) @pytest.mark.parametrize("deserialize_before_cuda", TRUE_FALSE, ids=id_formatter("deserialize_before_cuda"))
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward")) @pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward"))
@pytest.mark.parametrize("load_before_cuda", TRUE_FALSE, ids=id_formatter("load_before_cuda")) @pytest.mark.parametrize("load_before_cuda", TRUE_FALSE, ids=id_formatter("load_before_cuda"))
def test_linear_serialization( def test_linear_serialization(
device,
has_fp16_weights, has_fp16_weights,
threshold,
serialize_before_forward, serialize_before_forward,
deserialize_before_cuda, deserialize_before_cuda,
save_before_forward, save_before_forward,
load_before_cuda, load_before_cuda,
): ):
if device == "cpu":
pytest.xfail("Not yet implemented on CPU")
linear = torch.nn.Linear(32, 96) linear = torch.nn.Linear(32, 96)
# TODO: Fallback for bad shapes # TODO: Fallback for bad shapes
x = torch.randn(4, 32, dtype=torch.half) x = torch.randn(4, 32, dtype=torch.half)
...@@ -80,7 +94,7 @@ def test_linear_serialization( ...@@ -80,7 +94,7 @@ def test_linear_serialization(
linear.out_features, linear.out_features,
linear.bias is not None, linear.bias is not None,
has_fp16_weights=has_fp16_weights, has_fp16_weights=has_fp16_weights,
threshold=6.0, threshold=threshold,
) )
linear_custom.weight = bnb.nn.Int8Params( linear_custom.weight = bnb.nn.Int8Params(
...@@ -89,7 +103,7 @@ def test_linear_serialization( ...@@ -89,7 +103,7 @@ def test_linear_serialization(
has_fp16_weights=has_fp16_weights, has_fp16_weights=has_fp16_weights,
) )
linear_custom.bias = linear.bias linear_custom.bias = linear.bias
linear_custom = linear_custom.cuda() linear_custom = linear_custom.to(device)
if serialize_before_forward: if serialize_before_forward:
state_dict_8bit = linear_custom.state_dict() state_dict_8bit = linear_custom.state_dict()
...@@ -125,7 +139,7 @@ def test_linear_serialization( ...@@ -125,7 +139,7 @@ def test_linear_serialization(
linear.out_features, linear.out_features,
linear.bias is not None, linear.bias is not None,
has_fp16_weights=has_fp16_weights, has_fp16_weights=has_fp16_weights,
threshold=6.0, threshold=threshold,
) )
if deserialize_before_cuda: if deserialize_before_cuda:
...@@ -135,7 +149,7 @@ def test_linear_serialization( ...@@ -135,7 +149,7 @@ def test_linear_serialization(
if load_before_cuda: if load_before_cuda:
new_linear_custom2 = torch_load_from_buffer(bytes_8bit) new_linear_custom2 = torch_load_from_buffer(bytes_8bit)
new_linear_custom = new_linear_custom.cuda() new_linear_custom = new_linear_custom.to(device)
if not deserialize_before_cuda: if not deserialize_before_cuda:
new_linear_custom.load_state_dict(new_state_dict, strict=True) new_linear_custom.load_state_dict(new_state_dict, strict=True)
......
This diff is collapsed.
...@@ -4,11 +4,11 @@ import pytest ...@@ -4,11 +4,11 @@ import pytest
import torch import torch
import bitsandbytes import bitsandbytes
from tests.helpers import TRUE_FALSE, id_formatter from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter
class TestLLMInt8Ops: class TestLLMInt8Ops:
@pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("device", get_available_devices())
def test_int8_linear_matmul(self, device): def test_int8_linear_matmul(self, device):
A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device) A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device)
B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device) B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device)
...@@ -20,7 +20,7 @@ class TestLLMInt8Ops: ...@@ -20,7 +20,7 @@ class TestLLMInt8Ops:
torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.default, (A, B)) torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.default, (A, B))
@pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("device", get_available_devices())
def test_int8_linear_matmul_out(self, device): def test_int8_linear_matmul_out(self, device):
A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device) A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device)
B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device) B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device)
...@@ -35,7 +35,7 @@ class TestLLMInt8Ops: ...@@ -35,7 +35,7 @@ class TestLLMInt8Ops:
torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.out, (A, B, out)) torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.out, (A, B, out))
@pytest.mark.parametrize("threshold", [0.0, 6.0]) @pytest.mark.parametrize("threshold", [0.0, 6.0])
@pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("device", get_available_devices())
def test_int8_vectorwise_quant(self, threshold, device): def test_int8_vectorwise_quant(self, threshold, device):
if device == "cpu": if device == "cpu":
pytest.skip("CPU implementation is not available") pytest.skip("CPU implementation is not available")
...@@ -64,7 +64,7 @@ class TestLLMInt8Ops: ...@@ -64,7 +64,7 @@ class TestLLMInt8Ops:
torch.library.opcheck(torch.ops.bitsandbytes.int8_vectorwise_quant, (A, threshold)) torch.library.opcheck(torch.ops.bitsandbytes.int8_vectorwise_quant, (A, threshold))
@pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("device", get_available_devices())
def test_int8_mm_dequant(self, device): def test_int8_mm_dequant(self, device):
A = torch.randint(-128, 127, (256, 256), dtype=torch.int32, device=device) A = torch.randint(-128, 127, (256, 256), dtype=torch.int32, device=device)
row_stats = torch.randn(256, dtype=torch.float32, device=device) row_stats = torch.randn(256, dtype=torch.float32, device=device)
...@@ -77,7 +77,7 @@ class TestLLMInt8Ops: ...@@ -77,7 +77,7 @@ class TestLLMInt8Ops:
torch.library.opcheck(torch.ops.bitsandbytes.int8_mm_dequant, (A, row_stats, col_stats)) torch.library.opcheck(torch.ops.bitsandbytes.int8_mm_dequant, (A, row_stats, col_stats))
@pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE) @pytest.mark.parametrize("has_bias", TRUE_FALSE)
def test_int8_scaled_mm(self, device, dtype, has_bias): def test_int8_scaled_mm(self, device, dtype, has_bias):
...@@ -96,7 +96,7 @@ class TestLLMInt8Ops: ...@@ -96,7 +96,7 @@ class TestLLMInt8Ops:
class TestInt8BlockwiseQuantOps: class TestInt8BlockwiseQuantOps:
@pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
def test_quantize_blockwise(self, device, dtype, blocksize): def test_quantize_blockwise(self, device, dtype, blocksize):
...@@ -116,7 +116,7 @@ class TestInt8BlockwiseQuantOps: ...@@ -116,7 +116,7 @@ class TestInt8BlockwiseQuantOps:
torch.library.opcheck(torch.ops.bitsandbytes.quantize_blockwise, (A, code, blocksize)) torch.library.opcheck(torch.ops.bitsandbytes.quantize_blockwise, (A, code, blocksize))
@pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
def test_dequantize_blockwise(self, device, dtype, blocksize): def test_dequantize_blockwise(self, device, dtype, blocksize):
...@@ -140,7 +140,7 @@ class TestInt8BlockwiseQuantOps: ...@@ -140,7 +140,7 @@ class TestInt8BlockwiseQuantOps:
class Test4bitBlockwiseQuantOps: class Test4bitBlockwiseQuantOps:
@pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
...@@ -164,7 +164,7 @@ class Test4bitBlockwiseQuantOps: ...@@ -164,7 +164,7 @@ class Test4bitBlockwiseQuantOps:
torch.library.opcheck(torch.ops.bitsandbytes.quantize_4bit, (A, blocksize, quant_type, storage_dtype)) torch.library.opcheck(torch.ops.bitsandbytes.quantize_4bit, (A, blocksize, quant_type, storage_dtype))
@pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
...@@ -197,7 +197,7 @@ class Test4bitBlockwiseQuantOps: ...@@ -197,7 +197,7 @@ class Test4bitBlockwiseQuantOps:
torch.ops.bitsandbytes.dequantize_4bit.default, (A, absmax, blocksize, quant_type, shape, dtype) torch.ops.bitsandbytes.dequantize_4bit.default, (A, absmax, blocksize, quant_type, shape, dtype)
) )
@pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
......
...@@ -47,7 +47,6 @@ str2optimizers["momentum_pytorch"] = ( ...@@ -47,7 +47,6 @@ str2optimizers["momentum_pytorch"] = (
) )
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam) str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
str2optimizers["adam8bit"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False))
str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True)) str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam) str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam)
str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW) str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW)
...@@ -88,19 +87,14 @@ str2optimizers["paged_ademamix8bit_blockwise_scheduled"] = ( ...@@ -88,19 +87,14 @@ str2optimizers["paged_ademamix8bit_blockwise_scheduled"] = (
) )
str2optimizers["lion"] = (Lion, bnb.optim.Lion) str2optimizers["lion"] = (Lion, bnb.optim.Lion)
str2optimizers["lion8bit"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False))
str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True))
str2optimizers["paged_lion"] = (Lion, bnb.optim.PagedLion) str2optimizers["paged_lion"] = (Lion, bnb.optim.PagedLion)
str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True))
str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True)) str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True))
str2optimizers["momentum"] = ( str2optimizers["momentum"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
) )
str2optimizers["momentum8bit"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["momentum8bit_blockwise"] = ( str2optimizers["momentum8bit_blockwise"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
...@@ -110,10 +104,6 @@ str2optimizers["rmsprop"] = ( ...@@ -110,10 +104,6 @@ str2optimizers["rmsprop"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
) )
str2optimizers["rmsprop8bit"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["rmsprop8bit_blockwise"] = ( str2optimizers["rmsprop8bit_blockwise"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True),
...@@ -128,8 +118,7 @@ str2statenames["paged_lion"] = [("exp_avg", "state1")] ...@@ -128,8 +118,7 @@ str2statenames["paged_lion"] = [("exp_avg", "state1")]
str2statenames["momentum"] = [("momentum_buffer", "state1")] str2statenames["momentum"] = [("momentum_buffer", "state1")]
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["rmsprop"] = [("square_avg", "state1")] str2statenames["rmsprop"] = [("square_avg", "state1")]
str2statenames["adam8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
str2statenames["lamb8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
str2statenames["adam8bit_blockwise"] = [ str2statenames["adam8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg", "state1", "qmap1", "absmax1"),
("exp_avg_sq", "state2", "qmap2", "absmax2"), ("exp_avg_sq", "state2", "qmap2", "absmax2"),
...@@ -142,10 +131,8 @@ str2statenames["paged_adamw8bit_blockwise"] = [ ...@@ -142,10 +131,8 @@ str2statenames["paged_adamw8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg", "state1", "qmap1", "absmax1"),
("exp_avg_sq", "state2", "qmap2", "absmax2"), ("exp_avg_sq", "state2", "qmap2", "absmax2"),
] ]
str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
str2statenames["lion8bit"] = [("exp_avg", "state1", "qmap1", "max1")]
str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")] str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")]
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")] str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")]
str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")] str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")] str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
...@@ -180,7 +167,7 @@ optimizer_names_32bit = [ ...@@ -180,7 +167,7 @@ optimizer_names_32bit = [
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2"))
def test_optimizer32bit(dim1, dim2, gtype, optim_name): def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name):
if gtype == torch.bfloat16 and optim_name in ["momentum", "rmsprop"]: if gtype == torch.bfloat16 and optim_name in ["momentum", "rmsprop"]:
pytest.skip() pytest.skip()
if dim1 == 1 and dim2 == 1: if dim1 == 1 and dim2 == 1:
...@@ -256,7 +243,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): ...@@ -256,7 +243,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype) @pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype)
def test_global_config(dim1, dim2, gtype): def test_global_config(requires_cuda, dim1, dim2, gtype):
if dim1 == 1 and dim2 == 1: if dim1 == 1 and dim2 == 1:
return return
p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1 p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
...@@ -298,10 +285,11 @@ def test_global_config(dim1, dim2, gtype): ...@@ -298,10 +285,11 @@ def test_global_config(dim1, dim2, gtype):
optimizer_names_8bit = [ optimizer_names_8bit = [
"adam8bit", # Non-blockwise optimizers are deprecated.
"lion8bit", # "adam8bit",
"momentum8bit", # "lion8bit",
"rmsprop8bit", # "momentum8bit",
# "rmsprop8bit",
"adam8bit_blockwise", "adam8bit_blockwise",
"lion8bit_blockwise", "lion8bit_blockwise",
"momentum8bit_blockwise", "momentum8bit_blockwise",
...@@ -315,7 +303,7 @@ optimizer_names_8bit = [ ...@@ -315,7 +303,7 @@ optimizer_names_8bit = [
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
def test_optimizer8bit(dim1, dim2, gtype, optim_name): def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name):
torch.set_printoptions(precision=6) torch.set_printoptions(precision=6)
if gtype == torch.bfloat16 and "blockwise" not in optim_name: if gtype == torch.bfloat16 and "blockwise" not in optim_name:
...@@ -479,7 +467,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ...@@ -479,7 +467,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
@pytest.mark.parametrize("gtype", [torch.float32], ids=describe_dtype) @pytest.mark.parametrize("gtype", [torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): @pytest.mark.deprecated
def test_adam_percentile_clipping(requires_cuda, dim1, dim2, gtype, optim_bits):
if dim1 == 1 and dim2 == 1: if dim1 == 1 and dim2 == 1:
return return
p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1 p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment