Unverified Commit e82f72b3 authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

PyTorch Custom Operator Integration (#1544)



* Sketch out first custom op registration

* Add note

* Initial int8 op registration

* Cleanup some deprecated functions.

* Int8 ops updates; tests

* Implement 4bit quant/dequant ops

* Fix nested quant

* cleanup

* Test improvements

* Clean up and improve tests

* Add higher level custom op for int8 matmul + dequant + bias

* Add gemv 4bit custom op

* Cleanup

* Implement out kwarg overloads for custom ops

* Update PyTorch minimum to 2.1

* Deprecation updates

* Deprecation updates

* Cleanup; rename int8_linear_dequant -> int8_scaled_mm

* Bump min pytorch to 2.2

* cleanup

* Test reorganization

* Remove deprecated supports_igemmlt

* More cleanup

* Cleanup obsolete C++/CUDA code

* Cleanup

* Create 'default' backend for fallback op implementations; initial CPU nf4 work

* Stub out for multi-platform

* Fix serialization tests for torch>=2.6.0

* Add example for torch.compile e2e inference

* Test update

---------
Co-authored-by: default avatarTitus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com>
parent f0735f95
import torch
import torch._dynamo
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
# torch._dynamo.config.suppress_errors = True
torch.set_float32_matmul_precision("high")
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
# torch._dynamo.config.capture_dynamic_output_shape_ops = True
model_id = "google/gemma-2-2b-it"
# model_id = "Qwen/Qwen2.5-7B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.bfloat16,
)
input_text = "Write me a poem about Machine Learning."
input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)
# model.forward = torch.compile(model.forward, fullgraph=True)
model = torch.compile(model)
outputs = model.generate(**input_ids, max_new_tokens=32)
print(tokenizer.decode(outputs[0]))
...@@ -42,7 +42,7 @@ classifiers = [ ...@@ -42,7 +42,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Artificial Intelligence" "Topic :: Scientific/Engineering :: Artificial Intelligence"
] ]
dependencies = [ dependencies = [
"torch>=2.0,<3", "torch>=2.2,<3",
"numpy>=1.17" "numpy>=1.17"
] ]
......
...@@ -22,7 +22,7 @@ def torch_save_to_buffer(obj): ...@@ -22,7 +22,7 @@ def torch_save_to_buffer(obj):
def torch_load_from_buffer(buffer): def torch_load_from_buffer(buffer):
buffer.seek(0) buffer.seek(0)
obj = torch.load(buffer) obj = torch.load(buffer, weights_only=False)
buffer.seek(0) buffer.seek(0)
return obj return obj
...@@ -36,6 +36,8 @@ def format_with_label(label: str, value: Any) -> str: ...@@ -36,6 +36,8 @@ def format_with_label(label: str, value: Any) -> str:
formatted = "T" if value else "F" formatted = "T" if value else "F"
elif isinstance(value, (list, tuple)) and all(isinstance(v, bool) for v in value): elif isinstance(value, (list, tuple)) and all(isinstance(v, bool) for v in value):
formatted = "".join("T" if b else "F" for b in value) formatted = "".join("T" if b else "F" for b in value)
elif isinstance(value, torch.dtype):
formatted = describe_dtype(value)
else: else:
formatted = str(value) formatted = str(value)
return f"{label}={formatted}" return f"{label}={formatted}"
......
from typing import Tuple
import pytest import pytest
import torch import torch
import bitsandbytes as bnb import bitsandbytes as bnb
from tests.helpers import ( from tests.helpers import (
BOOLEAN_TRIPLES, BOOLEAN_TRIPLES,
BOOLEAN_TUPLES,
TRUE_FALSE, TRUE_FALSE,
describe_dtype, describe_dtype,
get_test_dims, get_test_dims,
...@@ -16,189 +13,6 @@ from tests.helpers import ( ...@@ -16,189 +13,6 @@ from tests.helpers import (
TRANSPOSE_VALS = [(False, True), (False, False)] TRANSPOSE_VALS = [(False, True), (False, False)]
@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(32, 96, n=1), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
@pytest.mark.parametrize(
"funcs",
[(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)],
ids=["func=bmm", "func=matmul"],
)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("req_grad", BOOLEAN_TUPLES, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
@pytest.mark.deprecated
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool], transpose: Tuple[bool, bool]):
if dim2 > 0:
dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16)
dim4 = dim4 - (dim4 % 16)
for i in range(25):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0])
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1])
torch.nn.init.xavier_uniform_(B)
if not transpose[0] and not transpose[1]:
out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B)
elif not transpose[0] and transpose[1]:
out_torch = funcs[0](A, B.t())
out_bnb = funcs[1](A, B.t())
elif transpose[0] and not transpose[1]:
out_torch = funcs[0](A.t(), B)
out_bnb = funcs[1](A.t(), B)
elif transpose[0] and transpose[1]:
out_torch = funcs[0](A.t(), B.t())
out_bnb = funcs[1](A.t(), B.t())
n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx == 0).sum().item() < n * 0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
assert (idx == 0).sum().item() < n * 0.001
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None
if req_grad[0]:
torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02
torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3)
# batched matrix multiply
if funcs[0] in [torch.bmm, torch.matmul]:
A = torch.randn(
size=(dim1, dim2, dim3),
device="cuda",
requires_grad=req_grad[0],
)
B = torch.randn(
size=(dim1, dim3, dim4),
device="cuda",
requires_grad=req_grad[1],
)
target = torch.randn(
size=(dim1, dim2, dim4),
device="cuda",
requires_grad=req_grad[1],
)
torch.nn.init.xavier_uniform_(B)
out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B)
n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx == 0).sum().item() < n * 0.01
torch.testing.assert_close(out_bnb, out_torch, atol=0.027, rtol=0.2)
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None
if req_grad[0]:
torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02
if funcs[0] in [torch.matmul]:
dim1 = dim1 - (dim1 % 16)
A = torch.randn(
size=(dim1, dim2, dim3),
device="cuda",
requires_grad=req_grad[0],
)
dimB = (dim4, dim3) if transpose[1] else (dim3, dim4)
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
target = torch.randn(
size=(dim1, dim2, dim4),
device="cuda",
requires_grad=req_grad[1],
)
torch.nn.init.xavier_uniform_(B)
if transpose[1]:
out_torch = funcs[0](A, B.t())
out_bnb = funcs[1](A, B.t())
else:
out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B)
n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx == 0).sum().item() < n * 0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
assert (idx == 0).sum().item() < n * 0.001
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None
if req_grad[0]:
torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02
@pytest.mark.parametrize("dim1", [40], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [40], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3")) @pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3"))
......
import numpy as np
import pytest
from scipy.stats import norm
import torch
from bitsandbytes import functional as F
@pytest.mark.deprecated
def test_kbit_quantile_estimation():
for i in range(100):
data = torch.randn(1024, 1024, device="cuda")
for bits in range(2, 9):
p = np.linspace(1.3e-4, 1 - 1.3e-4, 2**bits)
val1 = torch.Tensor(norm.ppf(p)).cuda()
val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
err = torch.abs(val1 - val2).mean()
assert err < 0.038
for i in range(100):
data = torch.randn(1024, 1024, device="cuda")
for bits in range(2, 4):
total_values = 2**bits - 1
p = np.linspace(0, 1, 2 * total_values + 1)
idx = np.arange(1, 2 * total_values + 1, 2)
p = p[idx]
offset = 1 / (2 * total_values)
p = np.linspace(offset, 1 - offset, total_values)
val1 = torch.Tensor(norm.ppf(p)).cuda()
val2 = F.estimate_quantiles(data, num_quantiles=2**bits - 1)
err = torch.abs(val1 - val2).mean()
assert err < 0.035
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["float", "half"])
@pytest.mark.deprecated
def test_estimate_quantiles(dtype):
A = torch.rand(1024, 1024, device="cuda")
A = A.to(dtype)
code = F.estimate_quantiles(A)
percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device)
torch.testing.assert_close(percs, code, atol=1e-3, rtol=1e-2)
A = torch.randn(1024, 1024, device="cuda")
A = A.to(dtype)
code = F.estimate_quantiles(A)
quantiles = torch.quantile(A.float(), percs)
diff = torch.abs(code - quantiles)
assert (diff > 5e-02).sum().item() == 0
@pytest.mark.deprecated
def test_quantile_quantization():
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
code = F.estimate_quantiles(A1)
C = F.quantize_no_absmax(A1, code)
A2 = F.dequantize_no_absmax(C, code)
diff = torch.abs(A1 - A2).mean().item()
assert diff < 0.0075
A1 = torch.rand(1024, 1024, device="cuda")
code = F.estimate_quantiles(A1)
C = F.quantize_no_absmax(A1, code)
A2 = F.dequantize_no_absmax(C, code)
diff = torch.abs(A1 - A2).mean().item()
torch.testing.assert_close(A1, A2, atol=5e-3, rtol=0)
assert diff < 0.001
@pytest.mark.deprecated
def test_dynamic_quantization():
diffs = []
reldiffs = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, S = F.quantize(A1)
A2 = F.dequantize(C, S)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
assert diff.mean().item() < 0.0135
print(sum(diffs) / len(diffs))
print(sum(reldiffs) / len(reldiffs))
for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda")
C, S = F.quantize(A1)
A2 = F.dequantize(C, S)
diff = torch.abs(A1 - A2).mean().item()
torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
assert diff < 0.004
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=["float", "half"])
@pytest.mark.deprecated
def test_percentile_clipping(gtype):
gnorm_vec1 = torch.zeros(100, device="cuda")
gnorm_vec2 = torch.zeros(100, device="cuda")
n = 4
step = 0
percentile = 5
for i in range(20):
step += 1
g = torch.randn(n, n, dtype=gtype, device="cuda")
gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile)
assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1
gnorm2 = torch.norm(g.float())
if step == 1:
gnorm_vec1[:] = gnorm2
else:
gnorm_vec1[step % 100] = gnorm2
vals, idx = torch.sort(gnorm_vec1)
clip1 = vals[percentile]
torch.testing.assert_close(gnorm_vec1, torch.sqrt(gnorm_vec2))
torch.testing.assert_close(clip1, clip2)
torch.testing.assert_close(gnorm1, gnorm2)
This diff is collapsed.
...@@ -8,8 +8,6 @@ import pytest ...@@ -8,8 +8,6 @@ import pytest
import torch import torch
import bitsandbytes as bnb import bitsandbytes as bnb
from bitsandbytes import functional as F
from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout
from bitsandbytes.nn.modules import Linear8bitLt from bitsandbytes.nn.modules import Linear8bitLt
from tests.helpers import ( from tests.helpers import (
TRUE_FALSE, TRUE_FALSE,
...@@ -18,28 +16,9 @@ from tests.helpers import ( ...@@ -18,28 +16,9 @@ from tests.helpers import (
torch_save_to_buffer, torch_save_to_buffer,
) )
# contributed by Alex Borzunov, see: # contributed by Alex Borzunov, see:
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py # https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
@pytest.mark.skipif(
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5),
reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs",
)
def test_layout_exact_match():
x = (torch.randn(14336 * 3, 14336) * 10).to(torch.int8).cuda()
for tile_size, order in ((8, 32), "col_turing"), ((32, 32), "col_ampere"):
transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device)
tile_indices = get_inverse_transform_indices(transform, tile_size)
cxb = transform(x)
torch.cuda.synchronize()
restored_x = undo_layout(cxb, tile_indices)
torch.cuda.synchronize()
assert restored_x.is_contiguous()
assert torch.all(torch.eq(restored_x, x))
def test_linear_no_igemmlt(): def test_linear_no_igemmlt():
linear = torch.nn.Linear(1024, 3072) linear = torch.nn.Linear(1024, 3072)
x = torch.randn(3, 1024, dtype=torch.half) x = torch.randn(3, 1024, dtype=torch.half)
...@@ -139,7 +118,7 @@ def test_linear_serialization( ...@@ -139,7 +118,7 @@ def test_linear_serialization(
if not has_fp16_weights: if not has_fp16_weights:
assert os.path.getsize(state_path_8bit) < 0.5 * os.path.getsize(state_path) assert os.path.getsize(state_path_8bit) < 0.5 * os.path.getsize(state_path)
new_state_dict = torch.load(state_path_8bit) new_state_dict = torch.load(state_path_8bit, weights_only=False)
new_linear_custom = Linear8bitLt( new_linear_custom = Linear8bitLt(
linear.in_features, linear.in_features,
......
from math import prod
import pytest
import torch
import bitsandbytes
from tests.helpers import TRUE_FALSE, id_formatter
class TestLLMInt8Ops:
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_int8_linear_matmul(self, device):
A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device)
B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device)
out = torch.ops.bitsandbytes.int8_linear_matmul.default(A, B)
assert out.shape == (10, 30)
assert out.dtype == torch.int32
assert out.device == A.device
torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.default, (A, B))
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_int8_linear_matmul_out(self, device):
A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device)
B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device)
out = torch.empty((10, 30), dtype=torch.int32, device=device)
torch.ops.bitsandbytes.int8_linear_matmul.out(A, B, out)
assert out.shape == (10, 30)
assert out.dtype == torch.int32
assert out.device == A.device
torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.out, (A, B, out))
@pytest.mark.parametrize("threshold", [0.0, 6.0])
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_int8_vectorwise_quant(self, threshold, device):
if device == "cpu":
pytest.skip("CPU implementation is not available")
A = torch.randn(10, 20, dtype=torch.float16, device=device)
A[1][0] = 1000.0
out_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant(A, threshold=threshold)
assert out_row.shape == (10, 20)
assert out_row.dtype == torch.int8
assert out_row.device == A.device
assert row_stats.shape == (10,)
assert row_stats.dtype == torch.float32
assert row_stats.device == A.device
if threshold > 0.0:
assert outlier_cols is not None
assert outlier_cols.dim() == 1
assert outlier_cols.shape[0] <= A.shape[1]
assert outlier_cols.device == A.device
else:
assert outlier_cols is None
torch.library.opcheck(torch.ops.bitsandbytes.int8_vectorwise_quant, (A,))
torch.library.opcheck(torch.ops.bitsandbytes.int8_vectorwise_quant, (A, threshold))
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_int8_mm_dequant(self, device):
A = torch.randint(-128, 127, (256, 256), dtype=torch.int32, device=device)
row_stats = torch.randn(256, dtype=torch.float32, device=device)
col_stats = torch.randn(256, dtype=torch.float32, device=device)
out = torch.ops.bitsandbytes.int8_mm_dequant(A, row_stats, col_stats)
assert out.shape == A.shape
assert out.dtype == torch.float16
assert out.device == A.device
torch.library.opcheck(torch.ops.bitsandbytes.int8_mm_dequant, (A, row_stats, col_stats))
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE)
def test_int8_scaled_mm(self, device, dtype, has_bias):
A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device)
B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device)
row_stats = torch.randn(10, dtype=torch.float32, device=device)
col_stats = torch.randn(30, dtype=torch.float32, device=device)
bias = torch.randn(30, dtype=dtype, device=device) if has_bias else None
out = torch.ops.bitsandbytes.int8_scaled_mm(A, B, row_stats, col_stats, bias=bias, dtype=dtype)
assert out.shape == (10, 30)
assert out.dtype == dtype
assert out.device == A.device
torch.library.opcheck(torch.ops.bitsandbytes.int8_scaled_mm, (A, B, row_stats, col_stats, bias, dtype))
class TestInt8BlockwiseQuantOps:
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
def test_quantize_blockwise(self, device, dtype, blocksize):
if device == "cpu" and dtype != torch.float32:
pytest.skip("CPU implementation is only available for float32")
code = bitsandbytes.functional.create_dynamic_map().to(device)
A = torch.randn(1024, 1024, dtype=dtype, device=device)
out, absmax = torch.ops.bitsandbytes.quantize_blockwise(A, code, blocksize)
assert out.shape == A.shape
assert out.dtype == torch.uint8
assert out.device == A.device
assert absmax.device == A.device
assert absmax.dtype == torch.float32
torch.library.opcheck(torch.ops.bitsandbytes.quantize_blockwise, (A, code, blocksize))
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
def test_dequantize_blockwise(self, device, dtype, blocksize):
if device == "cpu" and dtype != torch.float32:
pytest.skip("CPU implementation is only available for float32")
A = torch.randint(0, 255, (1024, 1024), dtype=torch.uint8, device=device)
code = bitsandbytes.functional.create_dynamic_map().to(device, dtype=torch.float32)
n = A.numel()
blocks = -(n // -blocksize)
absmax = torch.randn((blocks,), device=device, dtype=torch.float32)
out = torch.ops.bitsandbytes.dequantize_blockwise.default(A, absmax, code, blocksize, dtype)
assert out.shape == A.shape
assert out.dtype == dtype
assert out.device == A.device
torch.library.opcheck(torch.ops.bitsandbytes.dequantize_blockwise.default, (A, absmax, code, blocksize, dtype))
class Test4bitBlockwiseQuantOps:
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
if device == "cpu" and quant_type != "nf4":
pytest.skip("CPU implementation is only available for nf4")
A = torch.randn(1024, 1024, dtype=dtype, device=device)
out, absmax = torch.ops.bitsandbytes.quantize_4bit(A, blocksize, quant_type, storage_dtype)
assert out.device == A.device
assert out.dtype == storage_dtype
assert absmax.device == A.device
assert absmax.dtype == torch.float32
torch.library.opcheck(torch.ops.bitsandbytes.quantize_4bit, (A, blocksize, quant_type, storage_dtype))
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
if device == "cpu":
pytest.skip("CPU implementation is not available")
shape = (128, 128)
n = prod(shape)
blocks = -(n // -blocksize)
quantized_shape = ((n + 1) // (storage_dtype.itemsize * 2), 1)
A = (
torch.randint(0, 255, ((n + 1) // 2,), dtype=torch.uint8, device=device)
.view(storage_dtype)
.reshape(quantized_shape)
.contiguous()
)
absmax = torch.randn((blocks,), dtype=torch.float32, device=device)
out = torch.ops.bitsandbytes.dequantize_4bit.default(A, absmax, blocksize, quant_type, shape, dtype)
assert out.device == A.device
assert out.shape == shape
torch.library.opcheck(
torch.ops.bitsandbytes.dequantize_4bit.default, (A, absmax, blocksize, quant_type, shape, dtype)
)
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
if device == "cpu":
pytest.skip("CPU implementation is not available")
out_features = 1024
in_features = 256
A = torch.randn((1, 1, in_features), dtype=dtype, device=device)
B = torch.randn((out_features, in_features), dtype=dtype, device=A.device)
B_q, absmax = torch.ops.bitsandbytes.quantize_4bit(B, blocksize, quant_type, storage_dtype)
code = bitsandbytes.functional.get_4bit_type(quant_type, device=A.device, blocksize=blocksize)
out = torch.ops.bitsandbytes.gemv_4bit.default(A, B_q, B.shape, absmax, code, blocksize)
assert out.device == A.device
assert out.dtype == dtype
assert out.shape == (1, 1, out_features)
assert out.isreal().all()
torch.library.opcheck(torch.ops.bitsandbytes.gemv_4bit.default, (A, B_q, B.shape, absmax, code, blocksize))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment