from contextlib import nullcontext import copy import os import pickle import platform from tempfile import TemporaryDirectory import pytest import torch import bitsandbytes as bnb from bitsandbytes.cextension import HIP_ENVIRONMENT from bitsandbytes.nn.modules import Linear8bitLt from tests.helpers import ( TRUE_FALSE, get_available_devices, id_formatter, torch_load_from_buffer, torch_save_to_buffer, ) # contributed by Alex Borzunov, see: # https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py @pytest.mark.parametrize("device", get_available_devices()) def test_linear_no_igemmlt(device): linear = torch.nn.Linear(1024, 3072) x = torch.randn(3, 1024, dtype=torch.half) linear_custom = Linear8bitLt( linear.in_features, linear.out_features, linear.bias is not None, has_fp16_weights=False, threshold=6.0, ) # TODO: Remove, this is no longer implemented linear_custom.state.force_no_igemmlt = True linear_custom.weight = bnb.nn.Int8Params( linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False, ).to(linear.weight.dtype) linear_custom.bias = linear.bias linear_custom = linear_custom.to(device) linear = linear.half().to(device) x_ref = x.clone().to(device).requires_grad_(True) x_ours = x.clone().to(device).requires_grad_(True) fx_ref = linear(x_ref).float() grad_proj = torch.randn_like(fx_ref) (fx_ref * grad_proj).mean().backward() fx_ours = linear_custom(x_ours).float() (fx_ours * grad_proj).mean().backward() assert linear_custom.state.CB is not None assert not linear_custom.state.has_fp16_weights idx = torch.isclose(fx_ref, fx_ours, atol=0.02, rtol=1e-5) assert (idx == 0).sum().item() < fx_ref.numel() * 2.5e-4 torch.testing.assert_close(fx_ref, fx_ours, atol=0.03, rtol=1e-5) torch.testing.assert_close(x_ref.grad, x_ours.grad, atol=0.01, rtol=1e-5) @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights")) @pytest.mark.parametrize("threshold", [0.0, 6.0], ids=id_formatter("threshold")) @pytest.mark.parametrize("serialize_before_forward", TRUE_FALSE, ids=id_formatter("serialize_before_forward")) @pytest.mark.parametrize("deserialize_before_cuda", TRUE_FALSE, ids=id_formatter("deserialize_before_cuda")) @pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward")) @pytest.mark.parametrize("load_before_cuda", TRUE_FALSE, ids=id_formatter("load_before_cuda")) def test_linear_serialization( device, has_fp16_weights, threshold, serialize_before_forward, deserialize_before_cuda, save_before_forward, load_before_cuda, ): if device != "cuda" and has_fp16_weights: pytest.skip("has_fp16_weights is only supported on CUDA and is deprecated") linear = torch.nn.Linear(32, 96) # TODO: Fallback for bad shapes x = torch.randn(4, 32, dtype=torch.half) # x = torch.randn(3, 32, dtype=torch.half) linear_custom = Linear8bitLt( linear.in_features, linear.out_features, linear.bias is not None, has_fp16_weights=has_fp16_weights, threshold=threshold, ) linear_custom.weight = bnb.nn.Int8Params( linear.weight.data.clone(), requires_grad=has_fp16_weights, has_fp16_weights=has_fp16_weights, ) linear_custom.bias = linear.bias linear_custom = linear_custom.to(device) if serialize_before_forward: state_dict_8bit = linear_custom.state_dict() if save_before_forward: bytes_8bit = torch_save_to_buffer(linear_custom) x_first = x.clone().to(device).requires_grad_(True) fx_first = linear_custom(x_first).float() grad_proj = torch.randn_like(fx_first) (fx_first * grad_proj).mean().backward() if not serialize_before_forward: state_dict_8bit = linear_custom.state_dict() if not save_before_forward: bytes_8bit = torch_save_to_buffer(linear_custom) with TemporaryDirectory() as tmpdir: state_path_8bit = os.path.join(tmpdir, "state_8bit.pth") state_path = os.path.join(tmpdir, "state.pth") torch.save(linear.state_dict(), state_path) torch.save(state_dict_8bit, state_path_8bit) if not has_fp16_weights: assert os.path.getsize(state_path_8bit) < 0.5 * os.path.getsize(state_path) new_state_dict = torch.load(state_path_8bit, weights_only=False) new_linear_custom = Linear8bitLt( linear.in_features, linear.out_features, linear.bias is not None, has_fp16_weights=has_fp16_weights, threshold=threshold, ) if deserialize_before_cuda: with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError): new_linear_custom.load_state_dict(new_state_dict, strict=True) if load_before_cuda: new_linear_custom2 = torch_load_from_buffer(bytes_8bit) new_linear_custom = new_linear_custom.to(device) if not deserialize_before_cuda: new_linear_custom.load_state_dict(new_state_dict, strict=True) if not load_before_cuda: new_linear_custom2 = torch_load_from_buffer(bytes_8bit) x_second = x.clone().to(device).requires_grad_(True) fx_second = new_linear_custom(x_second).float() (fx_second * grad_proj).mean().backward() x_third = x.clone().to(device).requires_grad_(True) fx_third = new_linear_custom2(x_third).float() (fx_third * grad_proj).mean().backward() # if 8-bit weights were loaded before .cuda, state is incorrect anyway and RuntimeError was raised if has_fp16_weights or not deserialize_before_cuda: assert torch.allclose(fx_first, fx_second, atol=1e-5) assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5) assert torch.allclose(fx_first, fx_third, atol=1e-5) assert torch.allclose(x_first.grad, x_third.grad, atol=1e-5) @pytest.fixture def linear8bit(requires_cuda): linear = torch.nn.Linear(32, 96) linear_custom = Linear8bitLt( linear.in_features, linear.out_features, linear.bias is not None, has_fp16_weights=False, threshold=6.0, ) linear_custom.weight = bnb.nn.Int8Params( linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False, ) linear_custom.bias = linear.bias linear_custom = linear_custom.cuda() return linear_custom def test_linear8bit_copy_param(linear8bit): shallow_copy = copy.copy(linear8bit) assert linear8bit.weight is shallow_copy.weight assert linear8bit.bias is shallow_copy.bias assert linear8bit.weight.data.data_ptr() == shallow_copy.weight.data.data_ptr() def test_linear8bit_deepcopy_param(linear8bit): deep_copy = copy.deepcopy(linear8bit) assert linear8bit.weight is not deep_copy.weight assert linear8bit.bias is not deep_copy.bias assert linear8bit.weight.data.data_ptr() != deep_copy.weight.data.data_ptr() assert torch.allclose(linear8bit.weight.data, deep_copy.weight.data) assert linear8bit.state == deep_copy.state # check for a bug where SCB and CB were not copied assert deep_copy.weight.SCB is not None assert (linear8bit.weight.SCB == deep_copy.weight.SCB).all() assert deep_copy.weight.CB is not None assert (linear8bit.weight.CB == deep_copy.weight.CB).all() def test_linear8bit_serialization(linear8bit): serialized = pickle.dumps(linear8bit) deserialized = pickle.loads(serialized) assert linear8bit.weight.data.data_ptr() != deserialized.weight.data.data_ptr() assert torch.allclose(linear8bit.weight.data, deserialized.weight.data) assert linear8bit.bias.data.data_ptr() != deserialized.bias.data.data_ptr() assert torch.allclose(linear8bit.bias.data, deserialized.bias.data) assert linear8bit.state == deserialized.state # check for a bug where SCB and CB were not copied assert (linear8bit.weight.SCB == deserialized.weight.SCB).all() assert (linear8bit.weight.CB == deserialized.weight.CB).all() @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("threshold", [0.0, 6.0], ids=id_formatter("threshold")) @pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias")) @pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph")) @pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode")) @pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4") @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode): if device == "cuda" and platform.system() == "Windows": pytest.skip("Triton is not officially supported on Windows") dim = 256 batch_size = 16 torch.compiler.reset() # Create a small network with Linear8bitLt layers net = torch.nn.Sequential( *[bnb.nn.Linear8bitLt(dim, dim, bias=bias, has_fp16_weights=False, threshold=threshold) for _ in range(4)] ).to(device) dynamic_output_shapes = fullgraph and threshold > 0 with torch._dynamo.config.patch("capture_dynamic_output_shape_ops", dynamic_output_shapes): # Create input tensor x = torch.randn(batch_size, dim, dtype=torch.float16, device=device) # Get reference output before compilation with torch.no_grad(): ref_output = net(x) # Compile the model compile_backend = "hpu_backend" if device == "hpu" else "inductor" compiled_net = torch.compile(net, fullgraph=fullgraph, mode=mode, backend=compile_backend) # Get output from compiled model with torch.no_grad(): compiled_output = compiled_net(x) # Check outputs match assert compiled_output.shape == ref_output.shape assert compiled_output.device == ref_output.device assert compiled_output.dtype == ref_output.dtype torch.testing.assert_close(compiled_output, ref_output) # Test with gradients. Currently only works with threshold=0. # Has a strange regression on Linux aarch64 CPU in torch==2.6.0. is_broken_platform = ( device == "cpu" and platform.system() == "Linux" and platform.machine() == "aarch64" and (2, 6) <= torch.__version__ < (2, 7) ) if threshold == 0 and not is_broken_platform: x.requires_grad_(True) y1 = net(x).sum() y1.backward() grad_ref = x.grad.clone() x.grad = None y2 = compiled_net(x).sum() y2.backward() grad_compiled = x.grad.clone() torch.testing.assert_close(grad_compiled, grad_ref)