Unverified Commit 10b9d4cd authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

Add simple op implementations for CPU (#1602)

* Additional 4bit CPU ops

* Additional 4bit CPU ops

* Implement additional device-agnostic ops and test updates

* More test fixes

* int8 tests passing

* Fix feature flag for multi_backend
parent b7e60cac
......@@ -21,7 +21,7 @@ from .optim import adam
# This is a signal for integrations with transformers/diffusers.
# Eventually we may remove this but it is currently required for compatibility.
features = {"multi-backend"}
features = {"multi_backend"}
supported_torch_devices = {
"cpu",
"cuda", # NVIDIA/AMD GPU
......
from collections.abc import Sequence
import ctypes as ct
from typing import Optional
......@@ -119,6 +120,10 @@ def _(
) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
torch._check(
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
)
n = A.numel()
......@@ -140,3 +145,73 @@ def _(
packed = packed.squeeze().view(quant_storage).unsqueeze(1)
return packed, absmax.float()
@register_kernel("bitsandbytes::dequantize_4bit", "cpu")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
torch._check(
dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
)
torch._check(
A.dtype == torch.uint8,
lambda: f"Blockwise 4bit dequantization on CPU only supports uint8 storage, got {A.dtype}",
)
A = A.view(-1, 1)
# Grab upper and lower nibbles. Using int64 for indexing in the LUT.
upper = (A >> 4).to(torch.int64)
lower = (A & 0x0F).to(torch.int64)
# Expand to blocks
blocks = torch.cat((upper, lower), dim=1).reshape(-1, blocksize)
# Dequantize
blocks = _NF4_QUANT_TABLE[blocks] * absmax[:, None]
# Reshape to original shape
blocks = blocks.reshape(-1, *shape[1:])
return blocks.to(dtype)
@register_kernel("bitsandbytes::gemv_4bit", "cpu")
def _(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
) -> torch.Tensor:
# TODO: We need to determine whether `code` is NF4, FP4, or other.
# Right now we assume NF4, as this is the only one supported on CPU.
B_dq = torch.ops.bitsandbytes.dequantize_4bit.default(
B,
absmax,
blocksize,
"nf4",
shape=shapeB,
dtype=A.dtype,
)
# User called gemv with B.t(), so we need to transpose it back.
# if B.shape[0] == 1:
# B_dq = B_dq.t()
return torch.nn.functional.linear(
A,
B_dq,
bias=None,
)
......@@ -22,45 +22,6 @@ def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
_int8_linear_matmul_impl(A, B, out)
@register_kernel("bitsandbytes::int8_mixed_scaled_mm", "cuda")
def _(
A: torch.Tensor,
CA: torch.Tensor,
CB: torch.Tensor,
SCA: torch.Tensor,
SCB: torch.Tensor,
outlier_cols: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
subB = None
if outlier_cols is not None and outlier_cols.numel():
# Extract the inputs with outliers in original precision
subA = A[:, outlier_cols].contiguous()
# Dequantize the corresponding weight columns
subB = (
torch.ops.bitsandbytes.int8_vectorwise_dequant.default(CB[:, outlier_cols].contiguous(), SCB)
.to(A.dtype)
.t()
)
# TODO: if state.has_fp16_weights: subB = B[:, outlier_cols].t()
else:
# Needed for torch.compile when there are no outliers.
subA = torch.empty(0, device=A.device, dtype=A.dtype)
# Int8 Matmul + Dequant + Bias
output = torch.ops.bitsandbytes.int8_scaled_mm.default(CA, CB, SCA, SCB, bias=bias, dtype=A.dtype)
if subB is not None:
# Add the outlier columns back to the output
output = output.addmm(subA, subB)
return output, subA
def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
A, B = B, A
......
from math import prod
from typing import Optional
import torch
......@@ -5,6 +6,45 @@ import torch
from ..._ops import register_kernel
@register_kernel("bitsandbytes::int8_mixed_scaled_mm", "default")
def _(
A: torch.Tensor,
CA: torch.Tensor,
CB: torch.Tensor,
SCA: torch.Tensor,
SCB: torch.Tensor,
outlier_cols: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
subB = None
if outlier_cols is not None and outlier_cols.numel():
# Extract the inputs with outliers in original precision
subA = A[:, outlier_cols].contiguous()
# Dequantize the corresponding weight columns
subB = (
torch.ops.bitsandbytes.int8_vectorwise_dequant.default(CB[:, outlier_cols].contiguous(), SCB)
.to(A.dtype)
.t()
)
# TODO: if state.has_fp16_weights: subB = B[:, outlier_cols].t()
else:
# Needed for torch.compile when there are no outliers.
subA = torch.empty(0, device=A.device, dtype=A.dtype)
# Int8 Matmul + Dequant + Bias
output = torch.ops.bitsandbytes.int8_scaled_mm.default(CA, CB, SCA, SCB, bias=bias, dtype=A.dtype)
if subB is not None:
# Add the outlier columns back to the output
output = output.addmm(subA, subB)
return output, subA
@register_kernel("bitsandbytes::int8_scaled_mm", "default")
def _(
A: torch.Tensor,
......@@ -41,3 +81,41 @@ def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: Optional[tor
if out is not None:
result = out.copy_(result)
return result
@register_kernel("bitsandbytes::int8_vectorwise_quant", "default")
def _(A: torch.Tensor, threshold=0.0):
rows = prod(A.shape[:-1])
outlier_cols = None
outlier_restore = None
if threshold > 0.0:
outliers = A.abs() >= threshold
if outliers.any():
# Determine which columns contain outliers, and zero out the
# outliers ahead of quantization. We need to keep a backup of these
# outliers to restore them after quantization.
outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1)
outlier_restore = A[outliers].clone()
A[outliers] = 0
else:
# Needed for torch.compile support.
outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64)
# Get absmax for each row.
row_stats = torch.max(A.abs(), dim=1).values.float()
# Quantize row-wise to int8.
out_row = torch.round(A * (127.0 / row_stats.unsqueeze(-1))).to(torch.int8)
# Zero out values from outlier columns across all rows.
if rows > 1 and outlier_cols is not None:
out_row[:, outlier_cols] = 0
# Restore outliers.
if outlier_restore is not None:
A[outliers] = outlier_restore
return out_row, row_stats, outlier_cols
......@@ -779,7 +779,7 @@ def quantize_blockwise(
state2=state2,
)
else:
quant_state = QuantState(absmax=_absmax, code=code, blocksize=blocksize, dtype=A.dtype)
quant_state = QuantState(absmax=_absmax, code=code.to(A.device), blocksize=blocksize, dtype=A.dtype)
# TODO(matthewdouglas): Deprecate out kwarg
out = out.copy_(_out) if out is not None else _out
......
......@@ -592,19 +592,28 @@ class Int8Params(torch.nn.Parameter):
obj.has_fp16_weights = has_fp16_weights
return obj
def cuda(self, device):
def _quantize(self, device):
if self.has_fp16_weights:
return super().cuda(device)
else:
# We quantize the weight and store in 8bit row-major
B = self.data.contiguous().half().cuda(device)
CB, SCB, _ = bnb.functional.int8_vectorwise_quant(B)
self.data = CB
self.CB = CB
self.SCB = SCB
return super().to(device)
# We quantize the weight and store in 8bit row-major
B = self.data.contiguous().to(device=device, dtype=torch.float16)
CB, SCB, _ = bnb.functional.int8_vectorwise_quant(B)
self.data = CB
self.CB = CB
self.SCB = SCB
return self
def cpu(self):
return self.to(device="cpu")
def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
return self.to(device="cuda" if device is None else device, non_blocking=non_blocking)
def xpu(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
return self.to(device="xpu" if device is None else device, non_blocking=non_blocking)
def __deepcopy__(self, memo):
# adjust this if new arguments are added to the constructor
new_instance = type(self).__new__(
......@@ -634,8 +643,8 @@ class Int8Params(torch.nn.Parameter):
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None and device.type == "cuda" and self.data.device.type == "cpu":
return self.cuda(device)
if device is not None and device.type != "meta" and self.data.device.type == "cpu":
return self._quantize(device)
else:
new_param = Int8Params(
super().to(device=device, dtype=dtype, non_blocking=non_blocking),
......
......@@ -32,9 +32,15 @@ TRANSPOSE_VALS = [(False, True), (False, False)]
def test_matmullt(
device, dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias
):
if device != "cuda" and funcs[1] == bnb.research.switchback_bnb:
# TODO: Deprecate/remove?
pytest.skip("switchback_bnb only works on CUDA.")
if device != "cuda":
if funcs[1] == bnb.research.switchback_bnb:
# TODO: Deprecate/remove?
pytest.skip("switchback_bnb only works on CUDA.")
if req_grad[1]:
# This will be deprecated for CUDA in the future. We don't expect
# this to work on any other device.
pytest.skip("Deprecated feature with CUDA support only.")
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
......@@ -171,7 +177,7 @@ def test_matmul_4bit(
quant_type,
):
if device == "cpu" and quant_type == "fp4":
pytest.skip("Only nf4 is supported on CPU")
pytest.xfail("Only nf4 is supported on CPU")
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
......
......@@ -186,7 +186,7 @@ class Test8BitBlockwiseQuantizeFunctional:
code = F.create_dynamic_map(True, bits - 0, bits).to(device)
elif method == "quantile":
if device != "cuda":
pytest.xfail("Quantile map only works on CUDA")
pytest.skip("Quantile map only works on CUDA")
values = torch.randn(2048, 2048, device="cuda")
code = F.create_quantile_map(values, bits).cuda()
# for some data types we have no zero
......@@ -593,7 +593,7 @@ class TestLLMInt8Functional:
A = A.view(-1, A.shape[-1])
CA, _, statsA, _, _ = F.int8_double_quant(A)
CA, statsA, _ = F.int8_vectorwise_quant(A)
CB, statsB, _ = F.int8_vectorwise_quant(B)
output = F.int8_mm_dequant(F.int8_linear_matmul(CA, CB), statsA, statsB)
......@@ -1102,6 +1102,9 @@ class TestQuantize4BitFunctional:
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096])
def test_4bit_quant(self, device, dtype, quant_type, blocksize):
if device == "cpu" and quant_type != "nf4":
pytest.xfail("fp4 quantization is not supported on CPU")
A1 = torch.randn(1024, 1024, device=device, dtype=dtype)
qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
......@@ -1134,6 +1137,9 @@ class TestQuantize4BitFunctional:
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128], ids=id_formatter("blocksize"))
def test_4bit_compressed_stats(self, device, quant_type, blocksize):
if device == "cpu" and quant_type != "nf4":
pytest.xfail("fp4 quantization is not supported on CPU")
errs1 = []
errs2 = []
for i in range(10):
......@@ -1206,6 +1212,12 @@ class TestQuantize4BitFunctional:
)
@pytest.mark.parametrize("dim", [128, 256, 512, 1024], ids=id_formatter("dim"))
def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double_quant, kind):
if device == "cpu":
if storage_type != "nf4":
pytest.xfail("fp4 quantization is not supported on CPU")
if quant_storage != torch.uint8:
pytest.xfail("Only uint8 storage is supported on CPU")
errs1 = []
errs2 = []
errs3 = []
......@@ -1216,7 +1228,11 @@ class TestQuantize4BitFunctional:
max_errs2 = []
max_errs3 = []
for i in range(100):
# Large number of iterations is excessive and slow on CPU.
# Keep for CUDA for now.
iters = 100 if device == "cuda" else 10
for i in range(iters):
if kind == "fc1":
A = torch.randn(1, dim, dtype=dtype, device=device)
B = torch.randn(dim * 4, dim, dtype=dtype, device=device) / math.sqrt(dim)
......@@ -1337,6 +1353,9 @@ class TestQuantize4BitFunctional:
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"])
def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant):
if device == "cpu" and storage_type != "nf4":
pytest.xfail("fp4 quantization is not supported on CPU")
dims = 10
torch.random.manual_seed(np.random.randint(0, 412424242))
dims = get_test_dims(0, 8192, n=dims)
......
......@@ -25,7 +25,10 @@ storage = {
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward"))
def test_linear_serialization(device, quant_type, compress_statistics, bias, quant_storage, save_before_forward):
if device == "cpu":
pytest.xfail("Dequantization is not yet implemented for CPU")
if quant_type == "fp4":
pytest.xfail("FP4 is not supported for CPU")
if quant_storage != "uint8":
pytest.xfail("Only uint8 storage is supported for CPU")
original_dtype = torch.float16
compute_dtype = None
......@@ -144,8 +147,9 @@ def test_linear_serialization(device, quant_type, compress_statistics, bias, qua
linear_q3 = torch_load_from_buffer(bytes_4bit)
# Test moving to CPU and back to GPU
linear_q2.to("cpu")
linear_q2.to(device)
if device != "cpu":
linear_q2.to("cpu")
linear_q2.to(device)
d = linear_qs(x)
assert c.dtype == d.dtype
assert c.device == d.device
......
......@@ -22,9 +22,6 @@ from tests.helpers import (
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
@pytest.mark.parametrize("device", get_available_devices())
def test_linear_no_igemmlt(device):
if device == "cpu":
pytest.xfail("Not yet implemented on CPU")
linear = torch.nn.Linear(1024, 3072)
x = torch.randn(3, 1024, dtype=torch.half)
linear_custom = Linear8bitLt(
......@@ -81,8 +78,8 @@ def test_linear_serialization(
save_before_forward,
load_before_cuda,
):
if device == "cpu":
pytest.xfail("Not yet implemented on CPU")
if device != "cuda" and has_fp16_weights:
pytest.skip("has_fp16_weights is only supported on CUDA and is deprecated")
linear = torch.nn.Linear(32, 96)
# TODO: Fallback for bad shapes
......@@ -111,7 +108,7 @@ def test_linear_serialization(
if save_before_forward:
bytes_8bit = torch_save_to_buffer(linear_custom)
x_first = x.clone().cuda().requires_grad_(True)
x_first = x.clone().to(device).requires_grad_(True)
fx_first = linear_custom(x_first).float()
grad_proj = torch.randn_like(fx_first)
(fx_first * grad_proj).mean().backward()
......@@ -157,11 +154,11 @@ def test_linear_serialization(
if not load_before_cuda:
new_linear_custom2 = torch_load_from_buffer(bytes_8bit)
x_second = x.clone().cuda().requires_grad_(True)
x_second = x.clone().to(device).requires_grad_(True)
fx_second = new_linear_custom(x_second).float()
(fx_second * grad_proj).mean().backward()
x_third = x.clone().cuda().requires_grad_(True)
x_third = x.clone().to(device).requires_grad_(True)
fx_third = new_linear_custom2(x_third).float()
(fx_third * grad_proj).mean().backward()
......
......@@ -55,9 +55,6 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("threshold"))
def test_linear8bitlt_inference(device, threshold):
if device == "cpu":
pytest.xfail("Not yet implemented on CPU")
l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold, has_fp16_weights=False).to(device).half()
assert l1.weight.device.type == device
assert l1.weight.dtype == torch.int8
......@@ -120,9 +117,6 @@ def test_linear8bitlt_accumulated_gradient(device):
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("threshold", [0.0, 2.0])
def test_linear8bitlt_no_fp16_weights(device, threshold):
if device == "cpu":
pytest.xfail("Not yet supported on CPU")
l1 = (
bnb.nn.Linear8bitLt(
32,
......@@ -211,7 +205,7 @@ def test_linear8bitlt_no_fp16_weights(device, threshold):
has_fp16_weights=False,
)
w1, w2 = mlp.fc1.weight.clone().to(device), mlp.fc2.weight.clone().to(device) # grab weights before quantization,
mlp = mlp.cuda().half() # and this line triggers quantization
mlp = mlp.to(device).half() # and this line triggers quantization
for i in range(100):
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
......@@ -253,9 +247,6 @@ def test_linear8bitlt_no_fp16_weights(device, threshold):
ids=["Int8Lt", "NF4"],
)
def test_linear_kbit_fp32_bias(device, module):
if device == "cpu":
pytest.xfail("Not yet implemented on CPU")
# casts model to fp16 -> int8 automatically
l1 = module(32, 64).to(device)
assert l1.weight.dtype in [torch.int8, torch.uint8]
......@@ -295,7 +286,7 @@ module_dict = {
@pytest.mark.parametrize("module", module_dict.values(), ids=module_dict.keys())
def test_kbit_backprop(device, module):
if device == "cpu":
pytest.xfail("Not yet implemented on CPU")
pytest.xfail("Test is not yet supported on CPU")
b = 16
dim1 = 36
......@@ -401,7 +392,10 @@ def test_fp8linear():
)
def test_embedding_lossless(device, embedding_class, input_shape, embedding_dim, quant_storage):
if device == "cpu":
pytest.xfail("Not yet supported on CPU")
if embedding_class is bnb.nn.EmbeddingFP4:
pytest.xfail("FP4 is not supported for CPU")
if quant_storage is not None and quant_storage != torch.uint8:
pytest.xfail("CPU only supports uint8 storage for 4bit")
num_embeddings = 128
......@@ -449,7 +443,10 @@ def test_embedding_lossless(device, embedding_class, input_shape, embedding_dim,
)
def test_embedding_error(device, embedding_class, input_shape, embedding_dim, quant_storage):
if device == "cpu":
pytest.xfail("Not yet supported on CPU")
if embedding_class is bnb.nn.EmbeddingFP4:
pytest.xfail("FP4 is not supported for CPU")
if quant_storage is not None and quant_storage != torch.uint8:
pytest.xfail("CPU only supports uint8 storage for 4bit")
is_8bit = embedding_class is bnb.nn.Embedding8bit
......@@ -486,7 +483,7 @@ def test_embedding_error(device, embedding_class, input_shape, embedding_dim, qu
@pytest.mark.parametrize("device", get_available_devices())
def test_4bit_linear_warnings(device):
if device == "cpu":
pytest.xfail("Not yet implemented on CPU")
pytest.xfail("gemv_4bit op is not yet implemented on CPU")
dim1 = 64
......@@ -525,9 +522,6 @@ def test_4bit_linear_warnings(device):
@pytest.mark.parametrize("device", get_available_devices())
def test_4bit_embedding_warnings(device):
if device == "cpu":
pytest.xfail("Not yet implemented on CPU")
num_embeddings = 128
default_block_size = 64
......
......@@ -37,9 +37,6 @@ class TestLLMInt8Ops:
@pytest.mark.parametrize("threshold", [0.0, 6.0])
@pytest.mark.parametrize("device", get_available_devices())
def test_int8_vectorwise_quant(self, threshold, device):
if device == "cpu":
pytest.skip("CPU implementation is not available")
A = torch.randn(10, 20, dtype=torch.float16, device=device)
A[1][0] = 1000.0
......@@ -147,7 +144,7 @@ class Test4bitBlockwiseQuantOps:
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
if device == "cpu" and quant_type != "nf4":
pytest.skip("CPU implementation is only available for nf4")
pytest.xfail("CPU implementation is only available for nf4")
if storage_dtype != torch.uint8:
pytest.xfail("Known issue with storage_dtype != uint8")
......@@ -171,7 +168,11 @@ class Test4bitBlockwiseQuantOps:
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
if device == "cpu":
pytest.skip("CPU implementation is not available")
if quant_type != "nf4":
pytest.xfail("CPU implementation is only available for nf4")
if storage_dtype != torch.uint8:
pytest.xfail("CPU implementation only supports uint8 storage")
shape = (128, 128)
......@@ -204,7 +205,7 @@ class Test4bitBlockwiseQuantOps:
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
if device == "cpu":
pytest.skip("CPU implementation is not available")
pytest.xfail("CPU implementation is not available")
out_features = 1024
in_features = 256
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment