Unverified Commit e82f72b3 authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

PyTorch Custom Operator Integration (#1544)



* Sketch out first custom op registration

* Add note

* Initial int8 op registration

* Cleanup some deprecated functions.

* Int8 ops updates; tests

* Implement 4bit quant/dequant ops

* Fix nested quant

* cleanup

* Test improvements

* Clean up and improve tests

* Add higher level custom op for int8 matmul + dequant + bias

* Add gemv 4bit custom op

* Cleanup

* Implement out kwarg overloads for custom ops

* Update PyTorch minimum to 2.1

* Deprecation updates

* Deprecation updates

* Cleanup; rename int8_linear_dequant -> int8_scaled_mm

* Bump min pytorch to 2.2

* cleanup

* Test reorganization

* Remove deprecated supports_igemmlt

* More cleanup

* Cleanup obsolete C++/CUDA code

* Cleanup

* Create 'default' backend for fallback op implementations; initial CPU nf4 work

* Stub out for multi-platform

* Fix serialization tests for torch>=2.6.0

* Add example for torch.compile e2e inference

* Test update

---------
Co-authored-by: default avatarTitus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com>
parent f0735f95
import torch
import torch._dynamo
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
# torch._dynamo.config.suppress_errors = True
torch.set_float32_matmul_precision("high")
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
# torch._dynamo.config.capture_dynamic_output_shape_ops = True
model_id = "google/gemma-2-2b-it"
# model_id = "Qwen/Qwen2.5-7B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.bfloat16,
)
input_text = "Write me a poem about Machine Learning."
input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)
# model.forward = torch.compile(model.forward, fullgraph=True)
model = torch.compile(model)
outputs = model.generate(**input_ids, max_new_tokens=32)
print(tokenizer.decode(outputs[0]))
...@@ -42,7 +42,7 @@ classifiers = [ ...@@ -42,7 +42,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Artificial Intelligence" "Topic :: Scientific/Engineering :: Artificial Intelligence"
] ]
dependencies = [ dependencies = [
"torch>=2.0,<3", "torch>=2.2,<3",
"numpy>=1.17" "numpy>=1.17"
] ]
......
...@@ -22,7 +22,7 @@ def torch_save_to_buffer(obj): ...@@ -22,7 +22,7 @@ def torch_save_to_buffer(obj):
def torch_load_from_buffer(buffer): def torch_load_from_buffer(buffer):
buffer.seek(0) buffer.seek(0)
obj = torch.load(buffer) obj = torch.load(buffer, weights_only=False)
buffer.seek(0) buffer.seek(0)
return obj return obj
...@@ -36,6 +36,8 @@ def format_with_label(label: str, value: Any) -> str: ...@@ -36,6 +36,8 @@ def format_with_label(label: str, value: Any) -> str:
formatted = "T" if value else "F" formatted = "T" if value else "F"
elif isinstance(value, (list, tuple)) and all(isinstance(v, bool) for v in value): elif isinstance(value, (list, tuple)) and all(isinstance(v, bool) for v in value):
formatted = "".join("T" if b else "F" for b in value) formatted = "".join("T" if b else "F" for b in value)
elif isinstance(value, torch.dtype):
formatted = describe_dtype(value)
else: else:
formatted = str(value) formatted = str(value)
return f"{label}={formatted}" return f"{label}={formatted}"
......
from typing import Tuple
import pytest import pytest
import torch import torch
import bitsandbytes as bnb import bitsandbytes as bnb
from tests.helpers import ( from tests.helpers import (
BOOLEAN_TRIPLES, BOOLEAN_TRIPLES,
BOOLEAN_TUPLES,
TRUE_FALSE, TRUE_FALSE,
describe_dtype, describe_dtype,
get_test_dims, get_test_dims,
...@@ -16,189 +13,6 @@ from tests.helpers import ( ...@@ -16,189 +13,6 @@ from tests.helpers import (
TRANSPOSE_VALS = [(False, True), (False, False)] TRANSPOSE_VALS = [(False, True), (False, False)]
@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(32, 96, n=1), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
@pytest.mark.parametrize(
"funcs",
[(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)],
ids=["func=bmm", "func=matmul"],
)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("req_grad", BOOLEAN_TUPLES, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
@pytest.mark.deprecated
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool], transpose: Tuple[bool, bool]):
if dim2 > 0:
dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16)
dim4 = dim4 - (dim4 % 16)
for i in range(25):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0])
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1])
torch.nn.init.xavier_uniform_(B)
if not transpose[0] and not transpose[1]:
out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B)
elif not transpose[0] and transpose[1]:
out_torch = funcs[0](A, B.t())
out_bnb = funcs[1](A, B.t())
elif transpose[0] and not transpose[1]:
out_torch = funcs[0](A.t(), B)
out_bnb = funcs[1](A.t(), B)
elif transpose[0] and transpose[1]:
out_torch = funcs[0](A.t(), B.t())
out_bnb = funcs[1](A.t(), B.t())
n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx == 0).sum().item() < n * 0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
assert (idx == 0).sum().item() < n * 0.001
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None
if req_grad[0]:
torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02
torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3)
# batched matrix multiply
if funcs[0] in [torch.bmm, torch.matmul]:
A = torch.randn(
size=(dim1, dim2, dim3),
device="cuda",
requires_grad=req_grad[0],
)
B = torch.randn(
size=(dim1, dim3, dim4),
device="cuda",
requires_grad=req_grad[1],
)
target = torch.randn(
size=(dim1, dim2, dim4),
device="cuda",
requires_grad=req_grad[1],
)
torch.nn.init.xavier_uniform_(B)
out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B)
n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx == 0).sum().item() < n * 0.01
torch.testing.assert_close(out_bnb, out_torch, atol=0.027, rtol=0.2)
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None
if req_grad[0]:
torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02
if funcs[0] in [torch.matmul]:
dim1 = dim1 - (dim1 % 16)
A = torch.randn(
size=(dim1, dim2, dim3),
device="cuda",
requires_grad=req_grad[0],
)
dimB = (dim4, dim3) if transpose[1] else (dim3, dim4)
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
target = torch.randn(
size=(dim1, dim2, dim4),
device="cuda",
requires_grad=req_grad[1],
)
torch.nn.init.xavier_uniform_(B)
if transpose[1]:
out_torch = funcs[0](A, B.t())
out_bnb = funcs[1](A, B.t())
else:
out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B)
n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx == 0).sum().item() < n * 0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
assert (idx == 0).sum().item() < n * 0.001
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None
if req_grad[0]:
torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02
@pytest.mark.parametrize("dim1", [40], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [40], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3")) @pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3"))
......
import numpy as np
import pytest
from scipy.stats import norm
import torch
from bitsandbytes import functional as F
@pytest.mark.deprecated
def test_kbit_quantile_estimation():
for i in range(100):
data = torch.randn(1024, 1024, device="cuda")
for bits in range(2, 9):
p = np.linspace(1.3e-4, 1 - 1.3e-4, 2**bits)
val1 = torch.Tensor(norm.ppf(p)).cuda()
val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
err = torch.abs(val1 - val2).mean()
assert err < 0.038
for i in range(100):
data = torch.randn(1024, 1024, device="cuda")
for bits in range(2, 4):
total_values = 2**bits - 1
p = np.linspace(0, 1, 2 * total_values + 1)
idx = np.arange(1, 2 * total_values + 1, 2)
p = p[idx]
offset = 1 / (2 * total_values)
p = np.linspace(offset, 1 - offset, total_values)
val1 = torch.Tensor(norm.ppf(p)).cuda()
val2 = F.estimate_quantiles(data, num_quantiles=2**bits - 1)
err = torch.abs(val1 - val2).mean()
assert err < 0.035
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["float", "half"])
@pytest.mark.deprecated
def test_estimate_quantiles(dtype):
A = torch.rand(1024, 1024, device="cuda")
A = A.to(dtype)
code = F.estimate_quantiles(A)
percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device)
torch.testing.assert_close(percs, code, atol=1e-3, rtol=1e-2)
A = torch.randn(1024, 1024, device="cuda")
A = A.to(dtype)
code = F.estimate_quantiles(A)
quantiles = torch.quantile(A.float(), percs)
diff = torch.abs(code - quantiles)
assert (diff > 5e-02).sum().item() == 0
@pytest.mark.deprecated
def test_quantile_quantization():
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
code = F.estimate_quantiles(A1)
C = F.quantize_no_absmax(A1, code)
A2 = F.dequantize_no_absmax(C, code)
diff = torch.abs(A1 - A2).mean().item()
assert diff < 0.0075
A1 = torch.rand(1024, 1024, device="cuda")
code = F.estimate_quantiles(A1)
C = F.quantize_no_absmax(A1, code)
A2 = F.dequantize_no_absmax(C, code)
diff = torch.abs(A1 - A2).mean().item()
torch.testing.assert_close(A1, A2, atol=5e-3, rtol=0)
assert diff < 0.001
@pytest.mark.deprecated
def test_dynamic_quantization():
diffs = []
reldiffs = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, S = F.quantize(A1)
A2 = F.dequantize(C, S)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
assert diff.mean().item() < 0.0135
print(sum(diffs) / len(diffs))
print(sum(reldiffs) / len(reldiffs))
for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda")
C, S = F.quantize(A1)
A2 = F.dequantize(C, S)
diff = torch.abs(A1 - A2).mean().item()
torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
assert diff < 0.004
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=["float", "half"])
@pytest.mark.deprecated
def test_percentile_clipping(gtype):
gnorm_vec1 = torch.zeros(100, device="cuda")
gnorm_vec2 = torch.zeros(100, device="cuda")
n = 4
step = 0
percentile = 5
for i in range(20):
step += 1
g = torch.randn(n, n, dtype=gtype, device="cuda")
gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile)
assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1
gnorm2 = torch.norm(g.float())
if step == 1:
gnorm_vec1[:] = gnorm2
else:
gnorm_vec1[step % 100] = gnorm2
vals, idx = torch.sort(gnorm_vec1)
clip1 = vals[percentile]
torch.testing.assert_close(gnorm_vec1, torch.sqrt(gnorm_vec2))
torch.testing.assert_close(clip1, clip2)
torch.testing.assert_close(gnorm1, gnorm2)
from itertools import product
import math import math
import random import random
import time import time
...@@ -6,7 +5,6 @@ import time ...@@ -6,7 +5,6 @@ import time
import einops import einops
import numpy as np import numpy as np
import pytest import pytest
from scipy.stats import norm
import torch import torch
import bitsandbytes as bnb import bitsandbytes as bnb
...@@ -88,77 +86,194 @@ class Timer: ...@@ -88,77 +86,194 @@ class Timer:
print("Resetting benchmark data") print("Resetting benchmark data")
def setup(): class Test8BitBlockwiseQuantizeFunctional:
pass @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested"))
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
def teardown(): @pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed"))
pass def test_dynamic_blockwise_quantization(self, dtype, nested, blocksize, signed):
diffs = []
reldiffs = []
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["float", "half"]) for i in range(100):
def test_estimate_quantiles(dtype): A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
A = torch.rand(1024, 1024, device="cuda") C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
A = A.to(dtype) A2 = F.dequantize_blockwise(C, S)
code = F.estimate_quantiles(A) diff = torch.abs(A1 - A2).float()
reldiff = diff / torch.abs(A1.float() + 1e-8)
percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device) diffs.append(diff.mean().item())
torch.testing.assert_close(percs, code, atol=1e-3, rtol=1e-2) reldiffs.append(reldiff.mean().item())
abserr = sum(diffs) / len(diffs)
A = torch.randn(1024, 1024, device="cuda") relerr = sum(reldiffs) / len(reldiffs)
A = A.to(dtype) # print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs))
code = F.estimate_quantiles(A) # print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs))
assert abserr < 0.011
quantiles = torch.quantile(A.float(), percs) assert relerr < 0.018
diff = torch.abs(code - quantiles) assert A2.dtype == dtype
assert (diff > 5e-02).sum().item() == 0
diffs = []
code = F.create_dynamic_map(signed=signed)
for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda", dtype=dtype)
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested, code=code)
A2 = F.dequantize_blockwise(C, S)
diff = torch.abs(A1 - A2).float()
reldiff = diff / torch.abs(A1.float() + 1e-8)
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
# torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
abserr = sum(diffs) / len(diffs)
relerr = sum(reldiffs) / len(reldiffs)
if signed:
assert abserr < 0.0035
assert relerr < 0.015
else:
assert abserr < 0.00175
assert relerr < 0.012
assert A2.dtype == dtype
# print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
# print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
def test_blockwise_cpu_large(self):
diffs = []
reldiffs = []
batch = 128
seq = 128
for hidden in [128]: # , 14336]:
for blocksize in [4096, 16384]:
for i in range(2):
A1 = torch.randn(batch, seq, hidden, device="cpu")
t0 = time.time()
C, S = F.quantize_blockwise(A1, blocksize=blocksize)
A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
print(time.time() - t0)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
assert diffs[-1] < 0.011
# print(sum(diffs)/len(diffs))
# print(sum(reldiffs)/len(reldiffs))
@pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits"))
@pytest.mark.parametrize("method", ["linear", "fp8", "dynamic", "quantile"])
def test_few_bit_quant(self, bits, method):
abserrs = []
relerrs = []
code = None
if method == "linear":
code = F.create_linear_map(True, total_bits=bits).cuda()
elif method == "fp8":
ebits = math.ceil(bits / 2)
pbits = bits - ebits - 1
code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
elif method == "dynamic":
code = F.create_dynamic_map(True, bits - 0, bits).cuda()
elif method == "quantile":
values = torch.randn(2048, 2048, device="cuda")
code = F.create_quantile_map(values, bits).cuda()
# for some data types we have no zero
# for some data types we have one zero
# for some data types we have two zeros
assert torch.unique(code).numel() in [2**bits, 2**bits - 1], f"bits: {bits}, method: {method}"
# print(method, (code==0).sum())
assert code.numel() == 256
for i in range(10):
values = torch.randn(1, 32, device="cuda")
values /= values.abs().max()
# values[values.abs() < 1e-6] += 1e-5
q1 = []
v1 = []
for v in values[0]:
idx = torch.abs(v - code).argmin()
q1.append(idx.item())
v1.append(code[idx].item())
q1 = torch.Tensor(q1).cuda()
v1 = torch.Tensor(v1).cuda()
q2, S2 = F.quantize_blockwise(values, code=code)
v2 = F.dequantize_blockwise(q2, S2)
idx = torch.isclose(q1.int(), q2.int())
err2 = torch.abs(v2 - values)
abserrs.append(err2.mean().item())
relerrs.append((err2 / (1e-10 + values).abs()).mean().item())
if idx.sum():
# some weird cases
err1 = torch.abs(v1 - values).mean()
# assert err2.mean() <= err1
else:
torch.testing.assert_close(q1, q2)
def test_fp8_quant(self):
for e_bits in range(1, 7):
p_bits = 7 - e_bits
code = F.create_fp8_map(True, e_bits, p_bits).cuda()
abserr = []
relerr = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
# assert diff < 0.0075
# print(sum(abserr)/len(abserr))
# print(sum(relerr)/len(relerr))
abserr = []
relerr = []
for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
# assert diff < 0.0075
# print(sum(abserr)/len(abserr))
# print(sum(relerr)/len(relerr))
abserr = []
relerr = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
# assert diff < 0.0075
# print(3, sum(abserr)/len(abserr))
# print(3, sum(relerr)/len(relerr))
@pytest.mark.benchmark
def test_bench_dequantization(self):
a = torch.rand(1024, 1024, device="cuda").half()
code = F.create_fp8_map(True, 3, 0, 4).cuda()
qa, SA = F.quantize_blockwise(a, code=code)
print(qa.max())
max_theoretical_mu = 1024 * 1024 * 2 / 1024**3 / 672 * 1000 * 1000
# print(max_theoretical_mu)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
qa, SA = F.quantize_blockwise(a)
torch.cuda.synchronize()
# print((time.time()-t0)/1e6)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) def test_stable_embedding():
@pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested")) layer = bnb.nn.StableEmbedding(1024, 1024)
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64]) layer.reset_parameters()
@pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed"))
def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
diffs = []
reldiffs = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
A2 = F.dequantize_blockwise(C, S)
diff = torch.abs(A1 - A2).float()
reldiff = diff / torch.abs(A1.float() + 1e-8)
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
abserr = sum(diffs) / len(diffs)
relerr = sum(reldiffs) / len(reldiffs)
# print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs))
# print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs))
assert abserr < 0.011
assert relerr < 0.018
assert A2.dtype == dtype
diffs = []
code = F.create_dynamic_map(signed=signed)
for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda", dtype=dtype)
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested, code=code)
A2 = F.dequantize_blockwise(C, S)
diff = torch.abs(A1 - A2).float()
reldiff = diff / torch.abs(A1.float() + 1e-8)
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
# torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
abserr = sum(diffs) / len(diffs)
relerr = sum(reldiffs) / len(reldiffs)
if signed:
assert abserr < 0.0035
assert relerr < 0.015
else:
assert abserr < 0.00175
assert relerr < 0.012
assert A2.dtype == dtype
# print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
# print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
def quant(x): def quant(x):
...@@ -198,11 +313,6 @@ def quant_multi_chunk(x, dim, chunk_size=32): ...@@ -198,11 +313,6 @@ def quant_multi_chunk(x, dim, chunk_size=32):
return max1, x.to(torch.int8) return max1, x.to(torch.int8)
def quant_minmax(A):
minA = A.min()
maxA = A.max()
def mean(xx): def mean(xx):
return sum(xx) / float(len(xx)) return sum(xx) / float(len(xx))
...@@ -219,531 +329,618 @@ methods = { ...@@ -219,531 +329,618 @@ methods = {
} }
@pytest.mark.parametrize("dim1", [1024 * 2], ids=id_formatter("dim1")) class TestIGEMMFunctional:
@pytest.mark.parametrize("dim2", [1024 * 16], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim1", [1024 * 2], ids=id_formatter("dim1"))
@pytest.mark.parametrize("quant_methods", methods.values(), ids=methods.keys()) @pytest.mark.parametrize("dim2", [1024 * 16], ids=id_formatter("dim2"))
@pytest.mark.parametrize("batched", TRUE_FALSE, ids=id_formatter("batched")) @pytest.mark.parametrize("quant_methods", methods.values(), ids=methods.keys())
def test_approx_igemm(dim1, dim2, quant_methods, batched): @pytest.mark.parametrize("batched", TRUE_FALSE, ids=id_formatter("batched"))
dim1 = dim1 - (dim1 % 32) def test_approx_igemm(self, dim1, dim2, quant_methods, batched):
dim2 = dim2 - (dim2 % 32) dim1 = dim1 - (dim1 % 32)
errors = [] dim2 = dim2 - (dim2 % 32)
relerrors = [] errors = []
# print("") relerrors = []
for i in range(5): # print("")
if batched: for i in range(5):
A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda") if batched:
B = torch.normal(0, 0.5, size=(32, dim2 // 32, dim1), device="cuda") A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda")
maxA, Ac = quant_methods[0](A, 2) B = torch.normal(0, 0.5, size=(32, dim2 // 32, dim1), device="cuda")
maxB, Bc = quant_methods[1](B, 1) maxA, Ac = quant_methods[0](A, 2)
else: maxB, Bc = quant_methods[1](B, 1)
A = torch.normal(0, 0.5, size=(dim1, dim2), device="cuda") else:
B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda") A = torch.normal(0, 0.5, size=(dim1, dim2), device="cuda")
maxA, Ac = quant_methods[0](A, 1) B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda")
maxB, Bc = quant_methods[1](B, 0) maxA, Ac = quant_methods[0](A, 1)
torch.testing.assert_close(quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05) maxB, Bc = quant_methods[1](B, 0)
if batched: torch.testing.assert_close(quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05)
out2 = torch.bmm(A, B) if batched:
C = torch.bmm(Ac.float(), Bc.float()) out2 = torch.bmm(A, B)
else: C = torch.bmm(Ac.float(), Bc.float())
out2 = torch.mm(A, B) else:
C = F.igemm(Ac, Bc) out2 = torch.mm(A, B)
out = quant_methods[4](maxA, maxB, C) C = F.igemm(Ac, Bc)
std = out2.std() out = quant_methods[4](maxA, maxB, C)
out /= std std = out2.std()
out2 /= std out /= std
err = torch.abs(out - out2) out2 /= std
relerr = err / torch.abs(out2) err = torch.abs(out - out2)
errors.append(err.mean().item()) relerr = err / torch.abs(out2)
relerrors.append(relerr.mean().item()) errors.append(err.mean().item())
# print(mean(errors)) relerrors.append(relerr.mean().item())
# print(mean(relerrors)) # print(mean(errors))
# print(mean(relerrors))
def test_stable_embedding(): @pytest.mark.parametrize("hidden_dim", get_test_dims(32, 256, n=2), ids=id_formatter("hidden_dim"))
layer = bnb.nn.StableEmbedding(1024, 1024) @pytest.mark.parametrize("batch_dim", get_test_dims(16, 256, n=2), ids=id_formatter("batch_dim"))
layer.reset_parameters() @pytest.mark.parametrize("seq_dim", get_test_dims(16, 256, n=2), ids=id_formatter("seq_dim"))
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
def test_igemm(self, hidden_dim, batch_dim, transpose, seq_dim):
@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 256, n=2), ids=id_formatter("hidden_dim")) hidden_dim = hidden_dim - (hidden_dim % 32)
@pytest.mark.parametrize("batch_dim", get_test_dims(16, 256, n=2), ids=id_formatter("batch_dim")) batch_dim = batch_dim - (batch_dim % 16)
@pytest.mark.parametrize("seq_dim", get_test_dims(16, 256, n=2), ids=id_formatter("seq_dim")) seq_dim = seq_dim - (seq_dim % 16)
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose")) for i in range(k):
def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): shapeA = (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim)
hidden_dim = hidden_dim - (hidden_dim % 32) shapeB = (
batch_dim = batch_dim - (batch_dim % 16) (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4))
seq_dim = seq_dim - (seq_dim % 16) )
for i in range(k): A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
shapeA = (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim) B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
shapeB = (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4)) if not transpose[0] and not transpose[1]:
A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) out2 = torch.matmul(A.float(), B.float())
B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8) out = F.igemm(A, B)
if not transpose[0] and not transpose[1]: elif not transpose[0] and transpose[1]:
out2 = torch.matmul(A.float(), B.float()) out2 = torch.matmul(A.float(), B.t().float())
out = F.igemm(A, B) out = F.igemm(A, B.t())
elif not transpose[0] and transpose[1]: elif transpose[0] and not transpose[1]:
out2 = torch.matmul(A.float(), B.t().float()) out2 = torch.matmul(A.t().float(), B.float())
out = F.igemm(A, B.t()) out = F.igemm(A.t(), B)
elif transpose[0] and not transpose[1]: elif transpose[0] and transpose[1]:
out2 = torch.matmul(A.t().float(), B.float()) out2 = torch.matmul(A.t().float(), B.t().float())
out = F.igemm(A.t(), B) out = F.igemm(A.t(), B.t())
elif transpose[0] and transpose[1]:
out2 = torch.matmul(A.t().float(), B.t().float()) torch.testing.assert_close(out.float(), out2)
out = F.igemm(A.t(), B.t())
for i in range(k):
torch.testing.assert_close(out.float(), out2) shapeA = (batch_dim, seq_dim, hidden_dim)
shapeB = (
for i in range(k): (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4))
shapeA = (batch_dim, seq_dim, hidden_dim) )
shapeB = (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4)) A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8) if not transpose[0] and not transpose[1]:
if not transpose[0] and not transpose[1]: out2 = torch.matmul(A.float(), B.float())
out2 = torch.matmul(A.float(), B.float()) out = F.igemm(A, B)
out = F.igemm(A, B) elif not transpose[0] and transpose[1]:
elif not transpose[0] and transpose[1]: out2 = torch.matmul(A.float(), B.t().float())
out2 = torch.matmul(A.float(), B.t().float()) out = F.igemm(A, B.t())
out = F.igemm(A, B.t())
torch.testing.assert_close(out.float(), out2)
torch.testing.assert_close(out.float(), out2)
@pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=3), ids=id_formatter("seq_dim"))
@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=3), ids=id_formatter("hidden_dim"))
@pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=3), ids=id_formatter("seq_dim")) @pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=3), ids=id_formatter("batch_dim"))
@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=3), ids=id_formatter("hidden_dim")) def test_dim3_igemm(self, seq_dim, hidden_dim, batch_dim):
@pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=3), ids=id_formatter("batch_dim")) seq_dim = seq_dim - (seq_dim % 32)
def test_dim3_igemm(seq_dim, hidden_dim, batch_dim): hidden_dim = hidden_dim - (hidden_dim % 32)
seq_dim = seq_dim - (seq_dim % 32) batch_dim = batch_dim - (batch_dim % 2)
hidden_dim = hidden_dim - (hidden_dim % 32) for i in range(25):
batch_dim = batch_dim - (batch_dim % 2) A = torch.randint(-128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda").to(torch.int8)
for i in range(25): B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda").to(torch.int8)
A = torch.randint(-128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda").to(torch.int8) out2 = torch.einsum("bsi, bso->io", A.float(), B.float())
B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda").to(torch.int8) iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device)
out2 = torch.einsum("bsi, bso->io", A.float(), B.float()) out = F.igemm(A, B, out=iout)
iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device)
out = F.igemm(A, B, out=iout) torch.testing.assert_close(out.float(), out2)
torch.testing.assert_close(out.float(), out2) @pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=2), ids=id_formatter("seq_dim"))
@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=2), ids=id_formatter("hidden_dim"))
@pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=2), ids=id_formatter("batch_dim"))
@pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=2), ids=id_formatter("seq_dim")) @pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose"))
@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=2), ids=id_formatter("hidden_dim")) def test_minmax_igemm(self, seq_dim, hidden_dim, batch_dim, transpose):
@pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=2), ids=id_formatter("batch_dim")) def min_max(x):
@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose")) maxA = torch.amax(x, dim=2, keepdim=True)
def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose): minA = torch.amin(x, dim=2, keepdim=True)
def min_max(x): scale = (maxA - minA) / 2.0
maxA = torch.amax(x, dim=2, keepdim=True) return (127 * (x - minA - scale) / scale).to(torch.int8), minA, scale
minA = torch.amin(x, dim=2, keepdim=True)
scale = (maxA - minA) / 2.0 seq_dim = seq_dim - (seq_dim % 16)
return (127 * (x - minA - scale) / scale).to(torch.int8), minA, scale hidden_dim = hidden_dim - (hidden_dim % 16)
batch_dim = batch_dim - (batch_dim % 2)
seq_dim = seq_dim - (seq_dim % 16) errs = []
hidden_dim = hidden_dim - (hidden_dim % 16) relerrs = []
batch_dim = batch_dim - (batch_dim % 2) errs2 = []
errs = [] relerrs2 = []
relerrs = [] for i in range(k):
errs2 = [] A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda")
relerrs2 = [] if transpose:
for i in range(k): B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda")
A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda") else:
if transpose: B = torch.normal(0, 0.5, size=(hidden_dim, 256), device="cuda")
B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda") Ac, minA, scale = min_max(A)
else: if transpose:
B = torch.normal(0, 0.5, size=(hidden_dim, 256), device="cuda") maxB, Bc = quant_multi(B, dim=(1 if transpose else 0))
Ac, minA, scale = min_max(A) out = F.igemm(Ac, Bc.t())
if transpose: out2 = torch.matmul(A, B.t())
maxB, Bc = quant_multi(B, dim=(1 if transpose else 0)) offset = B.t().sum(0) * (minA + scale)
out = F.igemm(Ac, Bc.t()) out = out.float()
out2 = torch.matmul(A, B.t()) out = (out * maxB.t() * scale / (127 * 127)) + offset
offset = B.t().sum(0) * (minA + scale)
out = out.float() maxA, Ac = quant_multi(A, dim=2)
out = (out * maxB.t() * scale / (127 * 127)) + offset out3 = F.igemm(Ac, Bc.t())
out3 = mm_dequant(maxA, maxB.t(), out3)
maxA, Ac = quant_multi(A, dim=2) else:
out3 = F.igemm(Ac, Bc.t()) maxB, Bc = quant_multi(B, dim=0)
out3 = mm_dequant(maxA, maxB.t(), out3) offset = B.sum(0) * (minA + scale)
else: out = F.igemm(Ac, Bc)
maxB, Bc = quant_multi(B, dim=0) out2 = torch.matmul(A, B)
offset = B.sum(0) * (minA + scale) out = out.float()
out = F.igemm(Ac, Bc) out = (out * maxB * scale / (127 * 127)) + offset
out2 = torch.matmul(A, B)
out = out.float() maxA, Ac = quant_multi(A, dim=2)
out = (out * maxB * scale / (127 * 127)) + offset out3 = F.igemm(Ac, Bc)
out3 = mm_dequant(maxA, maxB, out3)
maxA, Ac = quant_multi(A, dim=2)
out3 = F.igemm(Ac, Bc) std = out2.std()
out3 = mm_dequant(maxA, maxB, out3) out2 /= std
out /= std
std = out2.std() out3 /= std
out2 /= std
out /= std err = torch.abs(out - out2)
out3 /= std relerr = err / (torch.abs(out2) + 1e-7)
err = torch.abs(out - out2) err2 = torch.abs(out3 - out2)
relerr = err / (torch.abs(out2) + 1e-7) relerr2 = err2 / (torch.abs(out2) + 1e-7)
err2 = torch.abs(out3 - out2) errs.append(err.mean().item())
relerr2 = err2 / (torch.abs(out2) + 1e-7) relerrs.append(relerr.mean().item())
errs2.append(err2.mean().item())
errs.append(err.mean().item()) relerrs2.append(relerr2.mean().item())
relerrs.append(relerr.mean().item()) # print(mean(errs))
errs2.append(err2.mean().item()) # print(mean(relerrs))
relerrs2.append(relerr2.mean().item()) # print(mean(errs2))
# print(mean(errs)) # print(mean(relerrs2))
# print(mean(relerrs)) assert mean(errs) < 0.015
# print(mean(errs2)) assert mean(relerrs) < 0.3
# print(mean(relerrs2))
assert mean(errs) < 0.015 @pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=2), ids=id_formatter("dim1"))
assert mean(relerrs) < 0.3 @pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=2), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=2), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 256, n=2), ids=id_formatter("dim4"))
@pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=2), ids=id_formatter("dim1")) @pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
@pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=2), ids=id_formatter("dim2")) def test_ibmm(self, dim1, dim2, dim3, dim4, transpose):
@pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=2), ids=id_formatter("dim3")) dim2 = dim2 - (dim2 % 16)
@pytest.mark.parametrize("dim4", get_test_dims(32, 256, n=2), ids=id_formatter("dim4")) dim3 = dim3 - (dim3 % 16)
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose")) dim4 = dim4 - (dim4 % 16)
def test_ibmm(dim1, dim2, dim3, dim4, transpose): for i in range(k):
dim2 = dim2 - (dim2 % 16) shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3)
dim3 = dim3 - (dim3 % 16) shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4)
dim4 = dim4 - (dim4 % 16) A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
for i in range(k): B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3)
shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4) if not transpose[0] and not transpose[1]:
A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) out2 = torch.bmm(A.float(), B.float())
B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8) out = F.igemm(A, B)
elif not transpose[0] and transpose[1]:
if not transpose[0] and not transpose[1]: out2 = torch.bmm(A.float(), B.permute([0, 2, 1]).float())
out2 = torch.bmm(A.float(), B.float()) out = F.igemm(A, B.permute([0, 2, 1]))
out = F.igemm(A, B) elif transpose[0] and not transpose[1]:
elif not transpose[0] and transpose[1]: out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float())
out2 = torch.bmm(A.float(), B.permute([0, 2, 1]).float()) out = F.igemm(A.permute([0, 2, 1]), B)
out = F.igemm(A, B.permute([0, 2, 1])) elif transpose[0] and transpose[1]:
elif transpose[0] and not transpose[1]: out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float())
out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float()) out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
out = F.igemm(A.permute([0, 2, 1]), B) torch.testing.assert_close(out.float(), out2.float())
elif transpose[0] and transpose[1]:
out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float())
out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1])) class TestLLMInt8Functional:
torch.testing.assert_close(out.float(), out2.float()) @pytest.mark.parametrize("dim1", [128], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [256], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [499, 512], ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim1", [128], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim4", [512], ids=id_formatter("dim4"))
@pytest.mark.parametrize("dim2", [256], ids=id_formatter("dim2")) @pytest.mark.parametrize("dims", (2, 3), ids=id_formatter("dims"))
@pytest.mark.parametrize("dim3", [499, 512], ids=id_formatter("dim3")) @pytest.mark.parametrize("ldb", (0,), ids=id_formatter("ldb"))
@pytest.mark.parametrize("dim4", [512], ids=id_formatter("dim4")) def test_int8_linear_matmul(self, dim1, dim2, dim3, dim4, dims, ldb):
@pytest.mark.parametrize("dims", (2, 3), ids=id_formatter("dims")) for i in range(k):
@pytest.mark.parametrize("ldb", (0,), ids=id_formatter("ldb")) if dims == 2:
def test_int8_linear_matmul(dim1, dim2, dim3, dim4, dims, ldb): A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(torch.int8)
for i in range(k): elif dims == 3:
if dims == 2: A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(torch.int8)
A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(torch.int8) B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8)
elif dims == 3: C1 = torch.matmul(A.float(), B.t().float())
A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(torch.int8)
B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8) C2 = F.int8_linear_matmul(A, B)
C1 = torch.matmul(A.float(), B.t().float()) torch.testing.assert_close(C1, C2.float())
C2 = F.int8_linear_matmul(A, B) @pytest.mark.parametrize("dim1", [32], ids=id_formatter("dim1"))
torch.testing.assert_close(C1, C2.float()) @pytest.mark.parametrize("dim2", [32], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", [32], ids=id_formatter("dim4"))
@pytest.mark.parametrize("dim1", [32], ids=id_formatter("dim1")) @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
@pytest.mark.parametrize("dim2", [32], ids=id_formatter("dim2")) def test_int8_linear_matmul_half(self, dim1, dim2, dim3, dim4, dims):
@pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3")) for i in range(k):
@pytest.mark.parametrize("dim4", [32], ids=id_formatter("dim4")) if dims == 2:
@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half()
def test_int8_linear_matmul_half(dim1, dim2, dim3, dim4, dims): elif dims == 3:
for i in range(k): A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device="cuda").half()
if dims == 2: B = torch.randn((dim4, dim3), device="cuda").half()
A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half() torch.nn.init.xavier_uniform_(B)
elif dims == 3: C1 = torch.matmul(A, B.t())
A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device="cuda").half()
B = torch.randn((dim4, dim3), device="cuda").half() A = A.view(-1, A.shape[-1])
torch.nn.init.xavier_uniform_(B)
C1 = torch.matmul(A, B.t()) CA, _, statsA, _, _ = F.int8_double_quant(A)
CB, statsB, _ = F.int8_vectorwise_quant(B)
A = A.view(-1, A.shape[-1]) output = F.int8_mm_dequant(F.int8_linear_matmul(CA, CB), statsA, statsB)
CA, _, statsA, _, _ = F.int8_double_quant(A) torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
CB, statsB, _ = F.int8_vectorwise_quant(B)
output = F.int8_mm_dequant(F.int8_linear_matmul(CA, CB), statsA, statsB) @pytest.mark.parametrize("dim1", (64, 256), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim4", (64, 1024), ids=id_formatter("dim4"))
torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05) @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
def test_dequant_mm(self, dim1, dim4, dims, has_bias):
@pytest.mark.parametrize("dim1", (64, 256), ids=id_formatter("dim1")) inner = 128
@pytest.mark.parametrize("dim4", (64, 1024), ids=id_formatter("dim4")) bias = None
@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
def test_dequant_mm(dim1, dim4, dims, has_bias):
inner = 128
bias = None
if has_bias:
bias = torch.randn(dim4, device="cuda", dtype=torch.float16)
for i in range(1):
A = torch.randn(dim1, inner, device="cuda")
B = torch.randn(dim4, inner, device="cuda")
C1 = torch.matmul(A.half(), B.t().half())
if has_bias:
C1 += bias
A1, maxA = F.vectorwise_quant(A, dim=1)
B1, maxB = F.vectorwise_quant(B, dim=1)
C2 = F.int8_linear_matmul(A1, B1)
C4 = F.vectorwise_mm_dequant(C2.float(), maxA, maxB.t())
if has_bias: if has_bias:
C4 += bias bias = torch.randn(dim4, device="cuda", dtype=torch.float16)
# TODO: is something wrong here? If so, the problem goes deeper for i in range(1):
# n = C1.numel() A = torch.randn(dim1, inner, device="cuda")
# p = 0.06 B = torch.randn(dim4, inner, device="cuda")
std = C1.std(0).view(1, -1) C1 = torch.matmul(A.half(), B.t().half())
C1 /= std if has_bias:
C4 /= std C1 += bias
# assert_all_approx_close(C1, C4, atol=0.02, rtol=0.1, count=int(n*0.06))
# assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}" A1, maxA = F.vectorwise_quant(A, dim=1)
B1, maxB = F.vectorwise_quant(B, dim=1)
C5 = F.int8_mm_dequant(C2, maxA, maxB, bias=bias)
C5 /= std C2 = F.int8_linear_matmul(A1, B1)
torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1)
n = C5.numel() C4 = F.vectorwise_mm_dequant(C2.float(), maxA, maxB.t())
assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01 * n)) if has_bias:
C4 += bias
@pytest.mark.parametrize("dim1", [1 * 1024], ids=id_formatter("dim1")) # TODO: is something wrong here? If so, the problem goes deeper
@pytest.mark.parametrize("dim2", [1 * 1024], ids=id_formatter("dim2")) # n = C1.numel()
@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) # p = 0.06
@pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("decomp")) std = C1.std(0).view(1, -1)
def test_colrow_absmax(dim1, dim2, dims, threshold): C1 /= std
for i in range(k): C4 /= std
A = torch.randn(dim1, dim2, device="cuda").half() # assert_all_approx_close(C1, C4, atol=0.02, rtol=0.1, count=int(n*0.06))
# assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}"
assert dims == 2
C5 = F.int8_mm_dequant(C2, maxA, maxB, bias=bias)
row_stats1, _ = torch.abs(A.float()).max(1) C5 /= std
col_stats1, _ = torch.abs(A.float()).max(0) torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1)
n = C5.numel()
if threshold > 0.0: assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01 * n))
A_truncated = A.clone()
A_truncated[torch.abs(A_truncated) >= threshold] = 0.0 @pytest.mark.parametrize("dim1", [1 * 1024], ids=id_formatter("dim1"))
row_stats1_trunc, _ = torch.abs(A_truncated.float()).max(1) @pytest.mark.parametrize("dim2", [1 * 1024], ids=id_formatter("dim2"))
col_stats1_trunc, _ = torch.abs(A_truncated.float()).max(0) @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
@pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("decomp"))
row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold) def test_colrow_absmax(self, dim1, dim2, dims, threshold):
for i in range(k):
nnz_rows1_counts = (torch.abs(A) >= threshold).sum(1).flatten() A = torch.randn(dim1, dim2, device="cuda").half()
nnz_block_ptr1 = torch.zeros(
nnz_rows1_counts.shape[0] + 1, assert dims == 2
dtype=nnz_rows1_counts.dtype,
device=nnz_rows1_counts.device, row_stats1, _ = torch.abs(A.float()).max(1)
col_stats1, _ = torch.abs(A.float()).max(0)
if threshold > 0.0:
A_truncated = A.clone()
A_truncated[torch.abs(A_truncated) >= threshold] = 0.0
row_stats1_trunc, _ = torch.abs(A_truncated.float()).max(1)
col_stats1_trunc, _ = torch.abs(A_truncated.float()).max(0)
row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold)
nnz_rows1_counts = (torch.abs(A) >= threshold).sum(1).flatten()
nnz_block_ptr1 = torch.zeros(
nnz_rows1_counts.shape[0] + 1,
dtype=nnz_rows1_counts.dtype,
device=nnz_rows1_counts.device,
)
nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0)
torch.testing.assert_close(col_stats1_trunc, col_stats2)
torch.testing.assert_close(row_stats1_trunc, row_stats2)
# torch.testing.assert_close(nnz_block_ptr1, nnz_block_ptr2)
else:
row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0)
assert nnz_block_ptr2 is None
torch.testing.assert_close(col_stats1, col_stats2)
torch.testing.assert_close(row_stats1, row_stats2)
@pytest.mark.parametrize("dim1", [2048, 4096], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [512, 1024], ids=id_formatter("dim2"))
def test_int8_double_quant(self, dim1, dim2):
for i in range(k):
A = torch.randn(dim1, dim2, device="cuda").half()
out_col1, Scol = F.vectorwise_quant(A, dim=0)
out_row1, Srow = F.vectorwise_quant(A, dim=1)
CA, CAt, statsA, statsAt, _ = F.int8_double_quant(A)
# max difference is 1 due to rounding differences
torch.testing.assert_close(CA, out_row1, atol=1, rtol=0)
torch.testing.assert_close(CAt, out_col1, atol=1, rtol=0)
n = CAt.numel()
num_not_close_rows = (torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
# allow for 1:500 error due to rounding differences
min_error = 1 / 500
if num_not_close_cols > (min_error * n):
print(
f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}"
)
assert False
if num_not_close_rows > (min_error * n):
print(
f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}"
)
assert False
torch.testing.assert_close(Srow.flatten().float(), statsA)
torch.testing.assert_close(Scol.flatten().float(), statsAt)
@pytest.mark.parametrize(
("dim1", "dim4", "inner"),
(
pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}")
for (dim1, dim4, inner) in zip(
(1, 8, 2048, 4096),
(2, 128, 2048, 4096),
(4, 256, 512, 4096),
) )
nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0) ),
)
torch.testing.assert_close(col_stats1_trunc, col_stats2) def test_integrated_int8_linear_matmul(self, dim1, dim4, inner):
torch.testing.assert_close(row_stats1_trunc, row_stats2) for i in range(k):
# torch.testing.assert_close(nnz_block_ptr1, nnz_block_ptr2) A = torch.randn(dim1, inner, device="cuda").half()
else: B = torch.randn(dim4, inner, device="cuda").half()
row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0)
assert nnz_block_ptr2 is None out1 = torch.matmul(A.half(), B.t().half())
torch.testing.assert_close(col_stats1, col_stats2)
torch.testing.assert_close(row_stats1, row_stats2) C1a, stats1a, _ = F.int8_vectorwise_quant(A)
C2a, stats2a, _ = F.int8_vectorwise_quant(B)
A1, maxA = F.vectorwise_quant(A, dim=1)
@pytest.mark.parametrize("dim1", [2048, 4096], ids=id_formatter("dim1")) B1, maxB = F.vectorwise_quant(B, dim=1)
@pytest.mark.parametrize("dim2", [512, 1024], ids=id_formatter("dim2"))
def test_int8_double_quant(dim1, dim2): torch.testing.assert_close(maxA.flatten().float(), stats1a)
for i in range(k): torch.testing.assert_close(maxB.flatten().float(), stats2a)
torch.testing.assert_close(C1a, A1, rtol=0, atol=1)
torch.testing.assert_close(C2a, B1, rtol=0, atol=1)
out2 = F.int8_linear_matmul(A1, B1)
C2 = F.int8_linear_matmul(A1, B1)
out3 = F.vectorwise_mm_dequant(C2.float(), maxA, maxB.t())
err1 = torch.abs(out1 - out2).mean().item()
err2 = torch.abs(out1 - out3).mean().item()
assert err2 <= err1 * 1.025
@pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2"))
def test_coo_double_quant(self, dim1, dim2):
threshold = 2.00
for i in range(k):
A = torch.randn(dim1, dim2, device="cuda").half()
idx = torch.abs(A) >= threshold
CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold)
if outlier_cols is not None:
A1 = A * idx
A2 = torch.zeros_like(A) + A1
torch.testing.assert_close(A1, A2)
A[:, outlier_cols] = 0
A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
torch.testing.assert_close(A, A2, rtol=0.05, atol=1.5e-2)
@pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2"))
def test_coo_int8_vectorwise_quant(self, dim1, dim2):
threshold = 3.00
for i in range(k):
A = torch.randn(dim1, dim2, device="cuda").half()
idx = torch.abs(A) >= threshold
CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold)
if outlier_cols is not None:
A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
A[:, outlier_cols] = 0
torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2)
class TestSpMMFunctional:
@pytest.mark.parametrize("dim1", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim2"))
@pytest.mark.parametrize("transposed_B", TRUE_FALSE, ids=id_formatter("transposed_B"))
def test_spmm_coo(self, dim1, dim2, transposed_B):
threshold = 1.5
dim3 = torch.randint(32, 128, size=(1,)).item()
# dim3 = 17
for i in range(k):
A = torch.randn(dim1, dim2).cuda().half()
if transposed_B:
B = torch.randn(dim3, dim2).cuda().half()
else:
B = torch.randn(dim2, dim3).cuda().half()
idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
A2 = A * idx
if transposed_B:
out2 = F.spmm_coo(cooA, B.t())
out1 = torch.matmul(A2, B.t())
else:
out2 = F.spmm_coo(cooA, B)
out1 = torch.matmul(A2, B)
assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30)
@pytest.mark.benchmark
def test_spmm_bench(self):
batch = 2
model = 1024 * 1
hidden = model * 4
seq = 1024
dim1 = batch * seq
dim2 = model
dim3 = hidden
threshold = 4
A = torch.randn(dim1, dim2, device="cuda").half() A = torch.randn(dim1, dim2, device="cuda").half()
out_col1, Scol = F.vectorwise_quant(A, dim=0) B = torch.randn(dim2, dim3, device="cuda").half()
out_row1, Srow = F.vectorwise_quant(A, dim=1) for i in range(10):
C1 = bnb.matmul(A, B.t())
CA, CAt, statsA, statsAt, _ = F.int8_double_quant(A)
# max difference is 1 due to rounding differences
torch.testing.assert_close(CA, out_row1, atol=1, rtol=0)
torch.testing.assert_close(CAt, out_col1, atol=1, rtol=0)
n = CAt.numel()
num_not_close_rows = (torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
# allow for 1:500 error due to rounding differences
min_error = 1 / 500
if num_not_close_cols > (min_error * n):
print(f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}")
assert False
if num_not_close_rows > (min_error * n):
print(f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}")
assert False
torch.testing.assert_close(Srow.flatten().float(), statsA)
torch.testing.assert_close(Scol.flatten().float(), statsAt)
@pytest.mark.parametrize(
("dim1", "dim4", "inner"),
(
pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}")
for (dim1, dim4, inner) in zip(
(1, 8, 2048, 4096),
(2, 128, 2048, 4096),
(4, 256, 512, 4096),
)
),
)
def test_integrated_int8_linear_matmul(dim1, dim4, inner):
for i in range(k):
A = torch.randn(dim1, inner, device="cuda").half()
B = torch.randn(dim4, inner, device="cuda").half()
out1 = torch.matmul(A.half(), B.t().half())
C1a, stats1a, _ = F.int8_vectorwise_quant(A)
C2a, stats2a, _ = F.int8_vectorwise_quant(B)
A1, maxA = F.vectorwise_quant(A, dim=1)
B1, maxB = F.vectorwise_quant(B, dim=1)
torch.testing.assert_close(maxA.flatten().float(), stats1a)
torch.testing.assert_close(maxB.flatten().float(), stats2a)
torch.testing.assert_close(C1a, A1, rtol=0, atol=1)
torch.testing.assert_close(C2a, B1, rtol=0, atol=1)
out2 = F.int8_linear_matmul(A1, B1)
C2 = F.int8_linear_matmul(A1, B1)
out3 = F.vectorwise_mm_dequant(C2.float(), maxA, maxB.t()) torch.cuda.synchronize()
t0 = time.time()
for i in range(k):
C1 = bnb.matmul(A, B.t())
torch.cuda.synchronize()
t8 = time.time() - t0
err1 = torch.abs(out1 - out2).mean().item() idx = torch.abs(A) >= threshold
err2 = torch.abs(out1 - out3).mean().item() nnz = (idx == 1).sum().item()
assert err2 <= err1 * 1.025 print(nnz / idx.numel())
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
for i in range(10):
out2 = F.spmm_coo(cooA, B)
@pytest.mark.parametrize( torch.cuda.synchronize()
("dim1", "dim4", "inner"), t0 = time.time()
( for i in range(k):
pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}") out2 = F.spmm_coo(cooA, B)
for (dim1, dim4, inner) in zip( torch.cuda.synchronize()
get_test_dims(1, 4 * 1024, n=6), tsp = time.time() - t0
get_test_dims(1, 4 * 1024, n=6), print(tsp, t8)
get_test_dims(1, 4 * 1024, n=6), print(tsp / t8)
)
), @pytest.mark.parametrize("dim1", [1 * 2048], ids=id_formatter("dim1"))
) @pytest.mark.parametrize("dim2", [12288], ids=id_formatter("dim2"))
@pytest.mark.skip("Row scale has some bugs for ampere") @pytest.mark.parametrize("dtype", [torch.float16], ids=describe_dtype)
def test_igemmlt_row_scale(dim1, dim4, inner): @pytest.mark.parametrize("out_func", ["zeros", "ones"], ids=id_formatter("out_func"))
formatB = F.get_special_format_str() def test_spmm_coo_very_sparse(self, dim1, dim2, dtype, out_func):
err1, err2, err3 = [], [], [] out_func = getattr(torch, out_func)
relerr1, relerr2 = [], []
scale = 1 threshold = 3.3
for i in range(k): # threshold = 2.8
A = torch.randn(dim1, inner, device="cuda").half() # threshold = 0.0
B = torch.randn(dim4, inner, device="cuda").half() A = torch.randn(dim1, dim2, device="cuda").half()
torch.nn.init.xavier_uniform_(B) if dtype == torch.float16:
C1 = torch.matmul(A, B.t()) B = torch.randn(dim2, dim2 * 4, device="cuda").half()
torch.nn.init.xavier_uniform_(B)
out1 = torch.matmul(A.half(), B.t().half())
C1a, C1b, stats1a, stats1b, coo_tensor = F.int8_double_quant(A)
CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
A2, SA = F.nvidia_transform(C1a, "col32")
B2, SB = F.nvidia_transform(CB, formatB)
A1, maxA = F.vectorwise_quant(A, dim=1)
c = 10.0 * inner * scale
row_scale = torch.ones_like(maxA) / c
outC32 = F.int8_linear_matmul(A2, B2, dtype=torch.int8, row_scale=row_scale)
# C3, S = F.nvidia_transform(outC32, "row", state=SC)
C3 = outC32
maxval = torch.abs(C3).max()
if maxval == 127:
scale = 1.5
else: else:
scale = maxval / 120 B = torch.randn(dim2, dim2 * 4, device="cuda").half()
out3 = C3 * maxA * absmaxB * c / (127 * 127) torch.nn.init.xavier_uniform_(B)
B, SB = F.vectorwise_quant(B, quant_type="linear")
C4 = torch.matmul(C1a.float(), CB.float().t()) # B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)
C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
B2, SB = F.nvidia_transform(C2a, formatB)
outC32 = F.int8_linear_matmul(A2, B2)
out2 = F.int8_mm_dequant(outC32, stats1a, stats2a)
CA, SA = F.vectorwise_quant(A, dim=1, quant_type="vector")
CB, SB = F.vectorwise_quant(B, dim=1, quant_type="linear")
C = torch.matmul(CA.float(), CB.t().float())
out4 = C * SA * SB / (127 * 127)
# out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127)
# print('='*80)
# print(out1)
# print(out2)
# print(out3)
print("")
idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
A2 = A * idx
out1 = torch.matmul(A2.half(), B.half())
out = out_func(out1.shape, dtype=torch.float16, device=out1.device)
out1 += out.clone()
out2 = F.spmm_coo_very_sparse(cooA, B, out=out)
# print(B)
# print(out1) # print(out1)
# print(out2) # print(out2)
# print(out3) p = 200 / (2048 * 12288 * 4)
err1.append(torch.abs(out1 - out2).mean().item()) n = out1.numel()
err2.append(torch.abs(out1 - out3).mean().item()) count = math.ceil(p * n)
err3.append(torch.abs(out1 - out4).mean().item()) std = out1.std()
out1 /= std
# assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10) out2 /= std
print("") assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count)
print(sum(err1) / len(err1)) # assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
print(sum(err2) / len(err2))
print(sum(err3) / len(err3)) idx_col = torch.randint(0, A2.shape[-1], size=(15,))
# torch.testing.assert_close(out1, out2.half(), rtol=0.05, atol=0.001)
@pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) # Bt = torch.randn(dim2*4, dim2, device='cuda').half()
def test_coo_double_quant(dim1, dim2): # torch.cuda.synchronize()
threshold = 2.00 # t0 = time.time()
for i in range(k): # print(A2.shape, B.shape)
A = torch.randn(dim1, dim2, device="cuda").half() # for i in range(100):
# #out3 = F.spmm_coo(cooA, Bt.t())
idx = torch.abs(A) >= threshold # #out2 = F.spmm_coo(cooA, B)
CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold) # #out2 = F.spmm_coo_very_sparse(cooA, B)
# #out1 = torch.matmul(A, Bt.t())
if outlier_cols is not None:
A1 = A * idx # torch.cuda.synchronize()
A2 = torch.zeros_like(A) + A1 # print(time.time() - t0)
torch.testing.assert_close(A1, A2)
@pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1"))
A[:, outlier_cols] = 0 @pytest.mark.parametrize("dim2", [256, 1024], ids=id_formatter("dim2"))
A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() @pytest.skip("No longer supported")
torch.testing.assert_close(A, A2, rtol=0.05, atol=1.5e-2) def test_integrated_sparse_decomp(self, dim1, dim2):
threshold = 3.0
for _ in range(k):
@pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) A = torch.randn(dim1, dim2).cuda().half()
@pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) w1 = torch.randn(dim1, dim2).cuda().half()
def test_coo_int8_vectorwise_quant(dim1, dim2): out1 = torch.matmul(A, w1.t())
threshold = 3.00
for i in range(k): Cw1, statsw1, _ = F.int8_vectorwise_quant(w1)
CA, statsA, _ = F.int8_vectorwise_quant(A)
out1_32 = F.int8_linear_matmul(CA, Cw1)
out2 = F.int8_mm_dequant(out1_32, statsA, statsw1)
# CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold)
CA, _, statsA, _, coo_tensor = F.double_quant(A, threshold=threshold)
out1_32 = F.int8_linear_matmul(CA, Cw1)
out3 = F.int8_mm_dequant(out1_32, statsA, statsw1)
assert coo_tensor is not None
out4 = F.spmm_coo(coo_tensor, w1.t())
# idx = torch.unique(coo_tensor._indices()[1]).long()
# out4 = torch.matmul(A, w1.t())
out5 = out3 + out4
err1 = torch.abs(out1 - out2).mean().item()
err2 = torch.abs(out1 - out5).mean().item()
assert err2 < err1
@pytest.mark.parametrize("dim1", [1 * 2048])
@pytest.mark.parametrize("dim2", [2048])
@pytest.mark.parametrize("dtype", [torch.int8])
def test_spmm_coo_dequant(self, dim1, dim2, dtype):
threshold = 6.0
# threshold = 2.8
# threshold = 0.0
A = torch.randn(dim1, dim2, device="cuda").half() A = torch.randn(dim1, dim2, device="cuda").half()
B = torch.empty(dim2, dim2 * 4, device="cuda", dtype=torch.float16)
torch.nn.init.xavier_uniform_(B)
Bt = B.t().contiguous()
idx = torch.abs(A) >= threshold CB, CBt, statsB, statsBt, coo_tensor = F.int8_double_quant(B)
CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold)
if outlier_cols is not None:
A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
A[:, outlier_cols] = 0
torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2)
rowidx = torch.randint(0, A.shape[-1], size=(15,))
@pytest.mark.parametrize("dim1", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim1")) A[:, rowidx] = 8.0
@pytest.mark.parametrize("dim2", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim2"))
@pytest.mark.parametrize("transposed_B", TRUE_FALSE, ids=id_formatter("transposed_B"))
def test_spmm_coo(dim1, dim2, transposed_B):
threshold = 1.5
dim3 = torch.randint(32, 128, size=(1,)).item()
# dim3 = 17
for i in range(k):
A = torch.randn(dim1, dim2).cuda().half()
if transposed_B:
B = torch.randn(dim3, dim2).cuda().half()
else:
B = torch.randn(dim2, dim3).cuda().half()
idx = torch.abs(A) >= threshold idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item() nnz = (idx == 1).sum().item()
...@@ -751,712 +948,381 @@ def test_spmm_coo(dim1, dim2, transposed_B): ...@@ -751,712 +948,381 @@ def test_spmm_coo(dim1, dim2, transposed_B):
values = A[idx] values = A[idx]
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
A2 = A * idx A2 = A * idx
if transposed_B:
out2 = F.spmm_coo(cooA, B.t())
out1 = torch.matmul(A2, B.t())
else:
out2 = F.spmm_coo(cooA, B)
out1 = torch.matmul(A2, B)
assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30)
@pytest.mark.benchmark
def test_spmm_bench():
batch = 2
model = 1024 * 1
hidden = model * 4
seq = 1024
dim1 = batch * seq
dim2 = model
dim3 = hidden
threshold = 4
A = torch.randn(dim1, dim2, device="cuda").half()
B = torch.randn(dim2, dim3, device="cuda").half()
for i in range(10):
C1 = bnb.matmul(A, B.t())
torch.cuda.synchronize()
t0 = time.time()
for i in range(k):
C1 = bnb.matmul(A, B.t())
torch.cuda.synchronize()
t8 = time.time() - t0
idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item()
print(nnz / idx.numel())
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
for i in range(10):
out2 = F.spmm_coo(cooA, B)
torch.cuda.synchronize()
t0 = time.time()
for i in range(k):
out2 = F.spmm_coo(cooA, B)
torch.cuda.synchronize()
tsp = time.time() - t0
print(tsp, t8)
print(tsp / t8)
@pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [256, 1024], ids=id_formatter("dim2"))
def test_integrated_sparse_decomp(dim1, dim2):
threshold = 3.0
for _ in range(k):
A = torch.randn(dim1, dim2).cuda().half()
w1 = torch.randn(dim1, dim2).cuda().half()
out1 = torch.matmul(A, w1.t())
Cw1, statsw1, _ = F.int8_vectorwise_quant(w1)
CA, statsA, _ = F.int8_vectorwise_quant(A)
out1_32 = F.int8_linear_matmul(CA, Cw1)
out2 = F.int8_mm_dequant(out1_32, statsA, statsw1)
# CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold)
CA, _, statsA, _, coo_tensor = F.double_quant(A, threshold=threshold)
out1_32 = F.int8_linear_matmul(CA, Cw1)
out3 = F.int8_mm_dequant(out1_32, statsA, statsw1)
assert coo_tensor is not None
out4 = F.spmm_coo(coo_tensor, w1.t())
# idx = torch.unique(coo_tensor._indices()[1]).long()
# out4 = torch.matmul(A, w1.t())
out5 = out3 + out4
err1 = torch.abs(out1 - out2).mean().item()
err2 = torch.abs(out1 - out5).mean().item()
assert err2 < err1
def test_matmuls():
a = torch.randn(256, 512).half().cuda()
b = torch.randn(256, 512).half().cuda()
c1 = torch.matmul(a, b.t())
c2 = bnb.matmul(a, b)
c3 = bnb.matmul_cublas(a, b.t())
err1 = torch.abs(c1 - c2).mean().item()
err2 = torch.abs(c1 - c3).mean().item()
assert err1 < 0.2
assert err2 < 0.2
print(err1, err2)
@pytest.mark.parametrize("dim1", [1 * 2048], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [12288], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dtype", [torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("out_func", ["zeros", "ones"], ids=id_formatter("out_func"))
def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
out_func = getattr(torch, out_func)
threshold = 3.3
# threshold = 2.8
# threshold = 0.0
A = torch.randn(dim1, dim2, device="cuda").half()
if dtype == torch.float16:
B = torch.randn(dim2, dim2 * 4, device="cuda").half()
torch.nn.init.xavier_uniform_(B)
else:
B = torch.randn(dim2, dim2 * 4, device="cuda").half()
torch.nn.init.xavier_uniform_(B)
B, SB = F.vectorwise_quant(B, quant_type="linear")
# B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)
print("")
idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
A2 = A * idx
out1 = torch.matmul(A2.half(), B.half())
out = out_func(out1.shape, dtype=torch.float16, device=out1.device)
out1 += out.clone()
out2 = F.spmm_coo_very_sparse(cooA, B, out=out)
# print(B)
# print(out1)
# print(out2)
p = 200 / (2048 * 12288 * 4)
n = out1.numel()
count = math.ceil(p * n)
std = out1.std()
out1 /= std
out2 /= std
assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count)
# assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
idx_col = torch.randint(0, A2.shape[-1], size=(15,))
# torch.testing.assert_close(out1, out2.half(), rtol=0.05, atol=0.001)
# Bt = torch.randn(dim2*4, dim2, device='cuda').half()
# torch.cuda.synchronize()
# t0 = time.time()
# print(A2.shape, B.shape)
# for i in range(100):
# #out3 = F.spmm_coo(cooA, Bt.t())
# #out2 = F.spmm_coo(cooA, B)
# #out2 = F.spmm_coo_very_sparse(cooA, B)
# #out1 = torch.matmul(A, Bt.t())
# torch.cuda.synchronize()
# print(time.time() - t0)
def test_coo2csr():
threshold = 1
A = torch.randn(128, 128).half().cuda()
idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
A2 = A * idx
csrA = F.coo2csr(cooA)
counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
assert counts.numel() == A.shape[0]
torch.testing.assert_close(counts.long(), (A2 != 0).sum(1))
idx = A2 != 0
torch.testing.assert_close(A2[idx], csrA.values)
def test_coo2csc():
threshold = 1
A = torch.randn(128, 128).half().cuda()
idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
A2 = A * idx
cscA = F.coo2csc(cooA)
counts = cscA.colptr[1:] - cscA.colptr[:-1]
assert counts.numel() == A.shape[1]
torch.testing.assert_close(counts.long(), (A2 != 0).sum(0))
# torch uses row-major -> use transpose to transfer to col-major
idx = A2.t() != 0
torch.testing.assert_close(A2.t()[idx], cscA.values)
@pytest.mark.parametrize("dim1", [1 * 2048])
@pytest.mark.parametrize("dim2", [2048])
@pytest.mark.parametrize("dtype", [torch.int8])
def test_spmm_coo_dequant(dim1, dim2, dtype):
threshold = 6.0
# threshold = 2.8
# threshold = 0.0
A = torch.randn(dim1, dim2, device="cuda").half()
B = torch.empty(dim2, dim2 * 4, device="cuda", dtype=torch.float16)
torch.nn.init.xavier_uniform_(B)
Bt = B.t().contiguous()
CB, CBt, statsB, statsBt, coo_tensor = F.int8_double_quant(B)
rowidx = torch.randint(0, A.shape[-1], size=(15,))
A[:, rowidx] = 8.0
idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
A2 = A * idx
out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
out1 = torch.matmul(A2, B.half())
out3 = F.spmm_coo_very_sparse(cooA, CBt.half())
out3 = out3 * statsBt.half() / 127
values, counts = torch.unique(cooA.rowidx, return_counts=True)
offset = counts.cumsum(0).int()
max_count, max_idx = torch.sort(counts, descending=True)
print(torch.median(max_count.float()))
torch.testing.assert_close(out2, out3, rtol=0.05, atol=0.001)
p = 200 / (2048 * 12288 * 4)
n = out1.numel()
count = math.ceil(p * n)
assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(100):
# out2 = F.spmm_coo_very_sparse(cooA, B)
# torch.cuda.synchronize()
# print('fp16', time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
out2 = F.spmm_coo(cooA, B)
torch.cuda.synchronize()
print("cusparse fp16", time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
out2 = F.spmm_coo_very_sparse(cooA, CBt)
torch.cuda.synchronize()
print("int8", time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
torch.cuda.synchronize() out1 = torch.matmul(A2, B.half())
print("int8+dequant", time.time() - t0) out3 = F.spmm_coo_very_sparse(cooA, CBt.half())
out3 = out3 * statsBt.half() / 127
torch.cuda.synchronize()
t0 = time.time() values, counts = torch.unique(cooA.rowidx, return_counts=True)
for i in range(100): offset = counts.cumsum(0).int()
out2 = torch.matmul(A, B) max_count, max_idx = torch.sort(counts, descending=True)
torch.cuda.synchronize() print(torch.median(max_count.float()))
print("matmul", time.time() - t0)
torch.testing.assert_close(out2, out3, rtol=0.05, atol=0.001)
torch.cuda.synchronize()
t0 = time.time() p = 200 / (2048 * 12288 * 4)
for i in range(100): n = out1.numel()
out1 = bnb.matmul(A, Bt) count = math.ceil(p * n)
out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count)
out = out1 + out2
torch.cuda.synchronize() # torch.cuda.synchronize()
print("sparse+ matmul", time.time() - t0) # t0 = time.time()
# for i in range(100):
torch.cuda.synchronize() # out2 = F.spmm_coo_very_sparse(cooA, B)
t0 = time.time() # torch.cuda.synchronize()
for i in range(100): # print('fp16', time.time() - t0)
out1 = bnb.matmul(A, Bt)
torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1) torch.cuda.synchronize()
torch.cuda.synchronize() t0 = time.time()
print("partial matmul", time.time() - t0) for i in range(100):
out2 = F.spmm_coo(cooA, B)
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() print("cusparse fp16", time.time() - t0)
for i in range(100):
out1 = bnb.matmul(A, Bt)
torch.cuda.synchronize()
print("partial matmul", time.time() - t0)
def test_zeropoint():
def quant_zp(x):
dtype = x.dtype
x = x.float()
dyna = x.max() - x.min()
if dyna == 0:
dyna = 1
qx = 254.0 / dyna
minx = x.min()
# zpx = torch.round(minx* qx)
# zpx = 127 - torch.round(x.max()* qx)
zpx = torch.round(x.min() * qx) - 127
x = (qx * x) + zpx
return x, qx, zpx
batch = 2
seq = 512
model = 1024
hidden = 4 * model
A = torch.randn(batch * seq, model, device="cuda").half() * 0.1
B = torch.randn(model, hidden, device="cuda").half() * 0.1
C0 = torch.matmul(A, B)
# A, SA = F.vectorwise_quant(A, quant_type='linear')
# B, SB = F.vectorwise_quant(B, quant_type='linear')
A = A.float()
B = B.float()
C1 = torch.matmul(A, B)
C3 = bnb.matmul(A.half(), B.t().contiguous().half())
zp = 1
# C2 = torch.matmul(A-zp, B)
# C2 += B.sum(0).view(1, -1)*zp
C2 = torch.matmul(A, B - zp)
C2 -= A.sum(1).view(-1, 1) * zp
ca, cqa, cza = quant_zp(A)
# print(ca.min(), ca.max())
# print((ca - cza).min(), (ca - cza).max())
zp = 1
scale = 2.0
C5 = torch.matmul((A * scale) - zp, B)
C5 += B.sum(0) * zp
C5 /= scale
CA, qa, zpa = quant_zp(A)
C4 = torch.matmul(CA, B)
C4 -= B.sum(0) * zpa
C4 /= qa
zpb = 1
zpa = 1
qa = 2
qb = 2
C6 = torch.matmul((A * qa) + zpa, (B * qb) + zpb)
C6 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
C6 -= zpa * zpb * A.shape[1]
C6 /= qa * qb
CA, qa, zpa = quant_zp(A)
CB, qb, zpb = quant_zp(B)
C7 = torch.matmul(CA, CB)
C7 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
C7 -= zpa * zpb * A.shape[1]
C7 /= qa * qb
# print("")
# print(C0.flatten()[:10])
# print(C1.flatten()[:10])
# print(C2.flatten()[:10])
# print(C3.flatten()[:10])
# print(C5.flatten()[:10])
# print(C6.flatten()[:10])
# print(C7.flatten()[:10])
err1 = torch.abs(C1 - C2).mean().item()
err2 = torch.abs(C1 - C3).mean().item()
err3 = torch.abs(C1 - C4).mean().item()
err4 = torch.abs(C1 - C5).mean().item()
err5 = torch.abs(C1 - C6).mean().item()
err6 = torch.abs(C1 - C7).mean().item()
print(err1, err2, err3, err4, err5, err6)
@pytest.mark.deprecated
def test_extract_outliers():
for i in range(k):
shapeA = (4096, 4096 * 4)
idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda()
# idx = torch.Tensor([0]).int().cuda()
A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
outliers1 = A[:, idx.long()]
CA, SA = F.transform(A, "col_turing")
outliers2 = F.extract_outliers(CA, SA, idx)
assert outliers2.shape[0] == shapeA[0] torch.cuda.synchronize()
assert outliers2.shape[1] == idx.numel() t0 = time.time()
for i in range(100):
out2 = F.spmm_coo_very_sparse(cooA, CBt)
torch.cuda.synchronize()
print("int8", time.time() - t0)
torch.testing.assert_close(outliers1, outliers2) torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
torch.cuda.synchronize()
print("int8+dequant", time.time() - t0)
CA, SA = F.transform(A, "col_ampere") torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
out2 = torch.matmul(A, B)
torch.cuda.synchronize()
print("matmul", time.time() - t0)
outliers2 = F.extract_outliers(CA, SA, idx) torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
out1 = bnb.matmul(A, Bt)
out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
out = out1 + out2
torch.cuda.synchronize()
print("sparse+ matmul", time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
out1 = bnb.matmul(A, Bt)
torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1)
torch.cuda.synchronize()
print("partial matmul", time.time() - t0)
assert outliers2.shape[0] == shapeA[0] torch.cuda.synchronize()
assert outliers2.shape[1] == idx.numel() t0 = time.time()
for i in range(100):
out1 = bnb.matmul(A, Bt)
torch.cuda.synchronize()
print("partial matmul", time.time() - t0)
torch.testing.assert_close(outliers1, outliers2)
class TestSparseTensorFunctional:
def test_coo2csr(self):
threshold = 1
A = torch.randn(128, 128).half().cuda()
idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
A2 = A * idx
csrA = F.coo2csr(cooA)
counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
assert counts.numel() == A.shape[0]
def test_blockwise_cpu_large(): torch.testing.assert_close(counts.long(), (A2 != 0).sum(1))
diffs = [] idx = A2 != 0
reldiffs = [] torch.testing.assert_close(A2[idx], csrA.values)
batch = 128
seq = 128
for hidden in [128]: # , 14336]:
for blocksize in [4096, 16384]:
for i in range(2):
A1 = torch.randn(batch, seq, hidden, device="cpu")
t0 = time.time()
C, S = F.quantize_blockwise(A1, blocksize=blocksize)
A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
print(time.time() - t0)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
assert diffs[-1] < 0.011
# print(sum(diffs)/len(diffs))
# print(sum(reldiffs)/len(reldiffs))
def test_coo2csc(self):
threshold = 1
A = torch.randn(128, 128).half().cuda()
idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
A2 = A * idx
cscA = F.coo2csc(cooA)
counts = cscA.colptr[1:] - cscA.colptr[:-1]
assert counts.numel() == A.shape[1]
def test_fp8_quant(): torch.testing.assert_close(counts.long(), (A2 != 0).sum(0))
for e_bits in range(1, 7): # torch uses row-major -> use transpose to transfer to col-major
p_bits = 7 - e_bits idx = A2.t() != 0
code = F.create_fp8_map(True, e_bits, p_bits).cuda() torch.testing.assert_close(A2.t()[idx], cscA.values)
abserr = []
relerr = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
# assert diff < 0.0075
# print(sum(abserr)/len(abserr))
# print(sum(relerr)/len(relerr))
abserr = []
relerr = []
for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
# assert diff < 0.0075
# print(sum(abserr)/len(abserr))
# print(sum(relerr)/len(relerr))
abserr = []
relerr = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
# assert diff < 0.0075
# print(3, sum(abserr)/len(abserr))
# print(3, sum(relerr)/len(relerr))
def test_few_bit_quant():
# print('')
for bits in range(2, 9):
# print('='*30, bits, '='*30)
for method in ["linear", "fp8", "dynamic", "quantile"]:
abserrs = []
relerrs = []
code = None
if method == "linear":
code = F.create_linear_map(True, total_bits=bits).cuda()
elif method == "fp8":
ebits = math.ceil(bits / 2)
pbits = bits - ebits - 1
code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
elif method == "dynamic":
code = F.create_dynamic_map(True, bits - 0, bits).cuda()
elif method == "quantile":
values = torch.randn(2048, 2048, device="cuda")
code = F.create_quantile_map(values, bits).cuda()
# for some data types we have no zero
# for some data types we have one zero
# for some data types we have two zeros
assert torch.unique(code).numel() in [2**bits, 2**bits - 1], f"bits: {bits}, method: {method}"
# print(method, (code==0).sum())
assert code.numel() == 256
for i in range(10):
values = torch.randn(1, 32, device="cuda")
values /= values.abs().max()
# values[values.abs() < 1e-6] += 1e-5
q1 = []
v1 = []
for v in values[0]:
idx = torch.abs(v - code).argmin()
q1.append(idx.item())
v1.append(code[idx].item())
q1 = torch.Tensor(q1).cuda()
v1 = torch.Tensor(v1).cuda()
q2, S2 = F.quantize_blockwise(values, code=code)
v2 = F.dequantize_blockwise(q2, S2)
idx = torch.isclose(q1.int(), q2.int())
err2 = torch.abs(v2 - values)
abserrs.append(err2.mean().item())
relerrs.append((err2 / (1e-10 + values).abs()).mean().item())
if idx.sum():
# some weird cases
err1 = torch.abs(v1 - values).mean()
# assert err2.mean() <= err1
else: class TestQuantize4BitFunctional:
torch.testing.assert_close(q1, q2) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
# print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs)) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
# assert False @pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096])
def test_4bit_quant(self, dtype, quant_type, blocksize):
A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
def test_kbit_quantile_estimation(): qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
for i in range(100): A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
data = torch.randn(1024, 1024, device="cuda")
for bits in range(2, 9):
p = np.linspace(1.3e-4, 1 - 1.3e-4, 2**bits)
val1 = torch.Tensor(norm.ppf(p)).cuda()
val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
err = torch.abs(val1 - val2).mean()
assert err < 0.038
for i in range(100):
data = torch.randn(1024, 1024, device="cuda")
for bits in range(2, 4):
total_values = 2**bits - 1
p = np.linspace(0, 1, 2 * total_values + 1)
idx = np.arange(1, 2 * total_values + 1, 2)
p = p[idx]
offset = 1 / (2 * total_values)
p = np.linspace(offset, 1 - offset, total_values)
val1 = torch.Tensor(norm.ppf(p)).cuda()
val2 = F.estimate_quantiles(data, num_quantiles=2**bits - 1)
err = torch.abs(val1 - val2).mean()
assert err < 0.035
@pytest.mark.benchmark
def test_bench_dequantization():
a = torch.rand(1024, 1024, device="cuda").half()
code = F.create_fp8_map(True, 3, 0, 4).cuda()
qa, SA = F.quantize_blockwise(a, code=code)
print(qa.max())
max_theoretical_mu = 1024 * 1024 * 2 / 1024**3 / 672 * 1000 * 1000
# print(max_theoretical_mu)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
qa, SA = F.quantize_blockwise(a)
torch.cuda.synchronize()
# print((time.time()-t0)/1e6)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096])
def test_4bit_quant(dtype, quant_type, blocksize):
vals = list(product([0, 1], repeat=4))
code = {}
for bits in vals:
result = 0
bias = 3
sign, e1, e2, p1 = bits
idx = sign * 8 + e1 * 4 + e2 * 2 + p1 * 1
sign = -1.0 if sign else 1.0
exp = e1 * 2 + e2 * 1
if exp == 0:
# sub-normal
if p1 == 0:
result = 0
else:
result = sign * 0.0625
else:
# normal
exp = 2 ** (-exp + bias + 1)
frac = 1.5 if p1 else 1.0
result = sign * exp * frac
code[idx] = result
A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
err = (A1 - A2).abs().float()
relerr = (err / (A1.abs().float() + 1e-8)).mean()
idx = err > 1.0
err = err.mean()
assert A2.dtype == dtype
# With larger block sizes, we can expect this to blow up.
# At blocksize>=1024, don't even bother looking at relerr.
if blocksize <= 64:
assert err.item() < 0.1
assert relerr.item() < 0.28
elif blocksize <= 256:
assert err.item() < 0.11
assert relerr.item() < 0.30
elif blocksize <= 512:
assert err.item() < 0.12
assert relerr.item() < 0.31
elif quant_type == "fp4":
# 1024 => 0.48, 2048 => 0.52, 4096 => 0.56
assert err.item() < 0.08 + math.log2(blocksize) * 4e-2
else:
# 1024 => 0.8, 2048 => 0.88, 4096 => 0.96
assert err.item() < math.log2(blocksize) * 8e-2
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
def test_4bit_compressed_stats(quant_type):
for blocksize in [128, 64]:
errs1 = []
errs2 = []
for i in range(10):
A1 = torch.randn(1024, 1024, device="cuda").half()
q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
q3, SA3 = F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type)
A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type)
A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type)
err = (A1 - A2).abs().float() err = (A1 - A2).abs().float()
relerr = (err / (A1.abs().float() + 1e-15)).mean() relerr = (err / (A1.abs().float() + 1e-8)).mean()
err = err.mean() err = err.mean()
errs1.append(err.item()) assert A2.dtype == dtype
assert err.item() < 0.11 # With larger block sizes, we can expect this to blow up.
# At blocksize>=1024, don't even bother looking at relerr.
if blocksize <= 64:
assert err.item() < 0.1
assert relerr.item() < 0.28 assert relerr.item() < 0.28
elif blocksize <= 256:
err = (A1 - A3).abs().float()
relerr = (err / (A1.abs().float() + 1e-15)).mean()
err = err.mean()
errs2.append(err.item())
assert err.item() < 0.11 assert err.item() < 0.11
assert relerr.item() < 0.28 assert relerr.item() < 0.30
elif blocksize <= 512:
assert err.item() < 0.12
assert relerr.item() < 0.31
elif quant_type == "fp4":
# 1024 => 0.48, 2048 => 0.52, 4096 => 0.56
assert err.item() < 0.08 + math.log2(blocksize) * 4e-2
else:
# 1024 => 0.8, 2048 => 0.88, 4096 => 0.96
assert err.item() < math.log2(blocksize) * 8e-2
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
def test_4bit_compressed_stats(self, quant_type):
for blocksize in [128, 64]:
errs1 = []
errs2 = []
for i in range(10):
A1 = torch.randn(1024, 1024, device="cuda").half()
q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
q3, SA3 = F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type)
A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type)
A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type)
err = (A1 - A2).abs().float()
relerr = (err / (A1.abs().float() + 1e-15)).mean()
err = err.mean()
errs1.append(err.item())
assert err.item() < 0.11
assert relerr.item() < 0.28
err = (A1 - A3).abs().float()
relerr = (err / (A1.abs().float() + 1e-15)).mean()
err = err.mean()
errs2.append(err.item())
assert err.item() < 0.11
assert relerr.item() < 0.28
# print(sum(errs1)/len(errs1), blocksize, quant_type)
# print(sum(errs2)/len(errs2), blocksize, quant_type)
# @pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
@pytest.mark.parametrize("quant_type", ["nf4"])
@pytest.mark.benchmark
def test_bench_4bit_dequant(self, quant_type):
blocksize = 256
a = torch.rand(1024 * 12 * 4, 1024 * 12, device="cuda").half()
qa, SA = F.quantize_4bit(a, blocksize=blocksize, quant_type=quant_type)
input_size = a.numel() / 2
output_size = a.numel() * 2
num_bytes = input_size + output_size
GB = num_bytes / 1e9
max_theoretical_s = GB / 768
# print(max_theoretical_s*1e6)
b = torch.randn(128, 1024 * 12, device="cuda").half()
iters = 100
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
# b.copy_(a)
torch.cuda.synchronize()
# print((time.time()-t0)/iters*1e6)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# torch.matmul(b, a.t())
# torch.cuda.synchronize()
# print((time.time()-t0)/iters*1e6)
@pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}")
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"])
@pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize(
"quant_storage",
[torch.uint8, torch.float16, torch.bfloat16, torch.float32],
ids=describe_dtype,
)
def test_gemv_4bit(self, dtype, storage_type, quant_storage, double_quant, kind):
for dim in [128, 256, 512, 1024]:
# for dim in [4*1024]:
# for dim in [1*16]:
errs1 = []
errs2 = []
errs3 = []
relerrs1 = []
relerrs2 = []
relerrs3 = []
max_errs1 = []
max_errs2 = []
max_errs3 = []
for i in range(100):
if kind == "fc1":
A = torch.randn(1, dim, dtype=dtype, device="cuda")
B = torch.randn(dim * 4, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
elif kind == "fc2":
A = torch.randn(1, 4 * dim, dtype=dtype, device="cuda")
B = torch.randn(dim, 4 * dim, dtype=dtype, device="cuda") / math.sqrt(dim)
elif kind == "attn":
A = torch.randn(1, dim, dtype=dtype, device="cuda")
B = torch.randn(dim, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
elif kind == "attn_packed":
A = torch.randn(1, dim, dtype=dtype, device="cuda")
B = torch.randn(dim * 3, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
qB, state = F.quantize_4bit(
B,
quant_type=storage_type,
compress_statistics=double_quant,
quant_storage=quant_storage,
)
C3 = torch.matmul(A, B.t())
C2 = F.gemv_4bit(A, qB.t(), state=state)
A.requires_grad = True
C1 = bnb.matmul_4bit(A, qB.t(), state)
err1 = (C1 - C2).abs().float()
err2 = (C3 - C2).abs().float()
err3 = (C3 - C1).abs().float()
mag1 = torch.abs(C1).float() + 1e-5
mag2 = torch.abs(C3).float() + 1e-5
mag3 = torch.abs(C3).float() + 1e-5
relerr1 = err1 / mag1
relerr2 = err2 / mag2
relerr3 = err3 / mag3
max_err1 = err1.max()
max_err2 = err2.max()
max_err3 = err3.max()
errs1.append(err1.mean().item())
errs2.append(err2.mean().item())
errs3.append(err3.mean().item())
relerrs1.append(relerr1.mean().item())
relerrs2.append(relerr2.mean().item())
relerrs3.append(relerr3.mean().item())
max_errs1.append(max_err1.item())
max_errs2.append(max_err2.item())
max_errs3.append(max_err3.item())
c = int(C1.numel() * 0.0014 * (dim / 256)) + 1
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=0, throw=False)
err1 = sum(errs1) / len(errs1) / math.sqrt(dim)
err2 = sum(errs2) / len(errs2) / math.sqrt(dim)
err3 = sum(errs3) / len(errs3) / math.sqrt(dim)
relerr1 = sum(relerrs1) / len(relerrs1) / math.sqrt(dim)
relerr2 = sum(relerrs2) / len(relerrs2) / math.sqrt(dim)
relerr3 = sum(relerrs3) / len(relerrs3) / math.sqrt(dim)
maxerr1 = sum(max_errs1) / len(max_errs1) / math.sqrt(dim)
maxerr2 = sum(max_errs2) / len(max_errs2) / math.sqrt(dim)
maxerr3 = sum(max_errs3) / len(max_errs3) / math.sqrt(dim)
absratio = err2 / err3
relratio = relerr2 / relerr3
maxratio = relerr2 / relerr3
# for debugging if the tests fails
#
# print('='*80)
# print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
# print(C1.flatten()[-20:])
# print(C2.flatten()[-20:])
# print(f'inference vs training abs: {err1}')
# print(f'inference vs training rel: {relerr1}')
# print(f'inference vs training max: {maxerr1}')
# print(f'inference vs training vs torch err ratio abs: {absratio}')
# print(f'inference vs training vs torch err ratio rel: {relratio}')
# print(f'inference vs training vs torch err ratio max: {maxratio}')
if dtype == torch.float16:
if dim <= 512:
assert err1 < 7e-5
assert relerr1 < 0.0008
else:
assert err1 < 6e-5
assert relerr1 < 2e-4
assert absratio < 1.005 and absratio > 0.995
assert relratio < 1.005 and relratio > 0.995
assert maxratio < 1.005 and maxratio > 0.995
elif dtype == torch.float32:
if dim <= 512:
assert err1 < 5e-8
assert relerr1 < 1e-6
assert maxerr1 < 1e-7
else:
assert err1 < 5e-8
assert relerr1 < 8e-6
assert maxerr1 < 1e-7
assert absratio < 1.005 and absratio > 0.995
assert relratio < 1.005 and relratio > 0.995
assert maxratio < 1.005 and maxratio > 0.995
elif dtype == torch.bfloat16:
if dim <= 512:
assert err1 < 6e-4
assert relerr1 < 0.007
assert maxerr1 < 0.015
else:
assert err1 < 2e-4
assert relerr1 < 0.002
assert maxerr1 < 0.0012
assert absratio < 1.005 and absratio > 0.995
assert relratio < 1.04 and relratio > 0.96
assert maxratio < 1.02 and maxratio > 0.98
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"])
def test_gemv_eye_4bit(self, storage_type, dtype, double_quant):
dims = 10
torch.random.manual_seed(np.random.randint(0, 412424242))
dims = get_test_dims(0, 8192, n=dims)
dims = [dim + (64 - (dim % 64)) for dim in dims]
# for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
for dim in dims:
A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device="cuda")
B = torch.eye(dim, dtype=dtype, device="cuda")
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
C3 = torch.matmul(A, B.t())
C2 = bnb.matmul_4bit(A, qB.t(), state)
A.requires_grad = True
C1 = bnb.matmul_4bit(A, qB.t(), state)
# print(sum(errs1)/len(errs1), blocksize, quant_type) torch.testing.assert_close(A, C3)
# print(sum(errs2)/len(errs2), blocksize, quant_type) torch.testing.assert_close(A, C1)
torch.testing.assert_close(A, C2)
# torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001)
# @pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) # torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080)
@pytest.mark.parametrize("quant_type", ["nf4"])
@pytest.mark.benchmark
def test_bench_4bit_dequant(quant_type):
blocksize = 256
a = torch.rand(1024 * 12 * 4, 1024 * 12, device="cuda").half()
qa, SA = F.quantize_4bit(a, blocksize=blocksize, quant_type=quant_type)
input_size = a.numel() / 2
output_size = a.numel() * 2
num_bytes = input_size + output_size
GB = num_bytes / 1e9
max_theoretical_s = GB / 768
# print(max_theoretical_s*1e6)
b = torch.randn(128, 1024 * 12, device="cuda").half()
iters = 100
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
# b.copy_(a)
torch.cuda.synchronize()
# print((time.time()-t0)/iters*1e6)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# torch.matmul(b, a.t())
# torch.cuda.synchronize()
# print((time.time()-t0)/iters*1e6)
def test_normal_map_tree(): def test_normal_map_tree():
...@@ -1474,146 +1340,6 @@ def test_normal_map_tree(): ...@@ -1474,146 +1340,6 @@ def test_normal_map_tree():
# print(pivots) # print(pivots)
@pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}")
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"])
@pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize(
"quant_storage",
[torch.uint8, torch.float16, torch.bfloat16, torch.float32],
ids=describe_dtype,
)
def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
for dim in [128, 256, 512, 1024]:
# for dim in [4*1024]:
# for dim in [1*16]:
errs1 = []
errs2 = []
errs3 = []
relerrs1 = []
relerrs2 = []
relerrs3 = []
max_errs1 = []
max_errs2 = []
max_errs3 = []
for i in range(100):
if kind == "fc1":
A = torch.randn(1, dim, dtype=dtype, device="cuda")
B = torch.randn(dim * 4, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
elif kind == "fc2":
A = torch.randn(1, 4 * dim, dtype=dtype, device="cuda")
B = torch.randn(dim, 4 * dim, dtype=dtype, device="cuda") / math.sqrt(dim)
elif kind == "attn":
A = torch.randn(1, dim, dtype=dtype, device="cuda")
B = torch.randn(dim, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
elif kind == "attn_packed":
A = torch.randn(1, dim, dtype=dtype, device="cuda")
B = torch.randn(dim * 3, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
qB, state = F.quantize_4bit(
B,
quant_type=storage_type,
compress_statistics=double_quant,
quant_storage=quant_storage,
)
C3 = torch.matmul(A, B.t())
C2 = F.gemv_4bit(A, qB.t(), state=state)
A.requires_grad = True
C1 = bnb.matmul_4bit(A, qB.t(), state)
err1 = (C1 - C2).abs().float()
err2 = (C3 - C2).abs().float()
err3 = (C3 - C1).abs().float()
mag1 = torch.abs(C1).float() + 1e-5
mag2 = torch.abs(C3).float() + 1e-5
mag3 = torch.abs(C3).float() + 1e-5
relerr1 = err1 / mag1
relerr2 = err2 / mag2
relerr3 = err3 / mag3
max_err1 = err1.max()
max_err2 = err2.max()
max_err3 = err3.max()
errs1.append(err1.mean().item())
errs2.append(err2.mean().item())
errs3.append(err3.mean().item())
relerrs1.append(relerr1.mean().item())
relerrs2.append(relerr2.mean().item())
relerrs3.append(relerr3.mean().item())
max_errs1.append(max_err1.item())
max_errs2.append(max_err2.item())
max_errs3.append(max_err3.item())
c = int(C1.numel() * 0.0014 * (dim / 256)) + 1
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=0, throw=False)
err1 = sum(errs1) / len(errs1) / math.sqrt(dim)
err2 = sum(errs2) / len(errs2) / math.sqrt(dim)
err3 = sum(errs3) / len(errs3) / math.sqrt(dim)
relerr1 = sum(relerrs1) / len(relerrs1) / math.sqrt(dim)
relerr2 = sum(relerrs2) / len(relerrs2) / math.sqrt(dim)
relerr3 = sum(relerrs3) / len(relerrs3) / math.sqrt(dim)
maxerr1 = sum(max_errs1) / len(max_errs1) / math.sqrt(dim)
maxerr2 = sum(max_errs2) / len(max_errs2) / math.sqrt(dim)
maxerr3 = sum(max_errs3) / len(max_errs3) / math.sqrt(dim)
absratio = err2 / err3
relratio = relerr2 / relerr3
maxratio = relerr2 / relerr3
# for debugging if the tests fails
#
# print('='*80)
# print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
# print(C1.flatten()[-20:])
# print(C2.flatten()[-20:])
# print(f'inference vs training abs: {err1}')
# print(f'inference vs training rel: {relerr1}')
# print(f'inference vs training max: {maxerr1}')
# print(f'inference vs training vs torch err ratio abs: {absratio}')
# print(f'inference vs training vs torch err ratio rel: {relratio}')
# print(f'inference vs training vs torch err ratio max: {maxratio}')
if dtype == torch.float16:
if dim <= 512:
assert err1 < 7e-5
assert relerr1 < 0.0008
else:
assert err1 < 6e-5
assert relerr1 < 2e-4
assert absratio < 1.005 and absratio > 0.995
assert relratio < 1.005 and relratio > 0.995
assert maxratio < 1.005 and maxratio > 0.995
elif dtype == torch.float32:
if dim <= 512:
assert err1 < 5e-8
assert relerr1 < 1e-6
assert maxerr1 < 1e-7
else:
assert err1 < 5e-8
assert relerr1 < 8e-6
assert maxerr1 < 1e-7
assert absratio < 1.005 and absratio > 0.995
assert relratio < 1.005 and relratio > 0.995
assert maxratio < 1.005 and maxratio > 0.995
elif dtype == torch.bfloat16:
if dim <= 512:
assert err1 < 6e-4
assert relerr1 < 0.007
assert maxerr1 < 0.015
else:
assert err1 < 2e-4
assert relerr1 < 0.002
assert maxerr1 < 0.0012
assert absratio < 1.005 and absratio > 0.995
assert relratio < 1.04 and relratio > 0.96
assert maxratio < 1.02 and maxratio > 0.98
@pytest.mark.skip("Row scale has some bugs for ampere") @pytest.mark.skip("Row scale has some bugs for ampere")
def test_managed(): def test_managed():
n = 32 * 10 n = 32 * 10
...@@ -1637,32 +1363,6 @@ def test_managed(): ...@@ -1637,32 +1363,6 @@ def test_managed():
assert (A == 17 * (2**3)).sum().item() == n * n assert (A == 17 * (2**3)).sum().item() == n * n
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"])
def test_gemv_eye_4bit(storage_type, dtype, double_quant):
dims = 10
torch.random.manual_seed(np.random.randint(0, 412424242))
dims = get_test_dims(0, 8192, n=dims)
dims = [dim + (64 - (dim % 64)) for dim in dims]
# for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
for dim in dims:
A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device="cuda")
B = torch.eye(dim, dtype=dtype, device="cuda")
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
C3 = torch.matmul(A, B.t())
C2 = bnb.matmul_4bit(A, qB.t(), state)
A.requires_grad = True
C1 = bnb.matmul_4bit(A, qB.t(), state)
torch.testing.assert_close(A, C3)
torch.testing.assert_close(A, C1)
torch.testing.assert_close(A, C2)
# torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001)
# torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080)
@pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=1), ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=1), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=1), ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=1), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=1), ids=id_formatter("dim3")) @pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=1), ids=id_formatter("dim3"))
...@@ -1676,169 +1376,3 @@ def test_vector_quant(dim1, dim2, dim3): ...@@ -1676,169 +1376,3 @@ def test_vector_quant(dim1, dim2, dim3):
A1 = F.vectorwise_dequant(qA, SA) A1 = F.vectorwise_dequant(qA, SA)
n = A1.numel() n = A1.numel()
assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n * 0.002)) assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n * 0.002))
@pytest.mark.deprecated
def test_quantile_quantization():
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
code = F.estimate_quantiles(A1)
C = F.quantize_no_absmax(A1, code)
A2 = F.dequantize_no_absmax(C, code)
diff = torch.abs(A1 - A2).mean().item()
assert diff < 0.0075
A1 = torch.rand(1024, 1024, device="cuda")
code = F.estimate_quantiles(A1)
C = F.quantize_no_absmax(A1, code)
A2 = F.dequantize_no_absmax(C, code)
diff = torch.abs(A1 - A2).mean().item()
torch.testing.assert_close(A1, A2, atol=5e-3, rtol=0)
assert diff < 0.001
@pytest.mark.deprecated
def test_dynamic_quantization():
diffs = []
reldiffs = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, S = F.quantize(A1)
A2 = F.dequantize(C, S)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
assert diff.mean().item() < 0.0135
print(sum(diffs) / len(diffs))
print(sum(reldiffs) / len(reldiffs))
for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda")
C, S = F.quantize(A1)
A2 = F.dequantize(C, S)
diff = torch.abs(A1 - A2).mean().item()
torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
assert diff < 0.004
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=["float", "half"])
@pytest.mark.deprecated
def test_percentile_clipping(gtype):
gnorm_vec1 = torch.zeros(100, device="cuda")
gnorm_vec2 = torch.zeros(100, device="cuda")
n = 4
step = 0
percentile = 5
for i in range(k):
step += 1
g = torch.randn(n, n, dtype=gtype, device="cuda")
gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile)
assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1
gnorm2 = torch.norm(g.float())
if step == 1:
gnorm_vec1[:] = gnorm2
else:
gnorm_vec1[step % 100] = gnorm2
vals, idx = torch.sort(gnorm_vec1)
clip1 = vals[percentile]
torch.testing.assert_close(gnorm_vec1, torch.sqrt(gnorm_vec2))
torch.testing.assert_close(clip1, clip2)
torch.testing.assert_close(gnorm1, gnorm2)
@pytest.mark.parametrize("dim1", get_test_dims(2, 1024, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(2, 1024, n=2), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [0], ids=id_formatter("dim3"))
@pytest.mark.parametrize("dims", [2], ids=id_formatter("dims"))
@pytest.mark.parametrize("dtype", [torch.int8], ids=describe_dtype)
@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA"))
@pytest.mark.parametrize("orderOut", ["col32", "col_turing", "col_ampere"], ids=id_formatter("orderOut"))
@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose"))
@pytest.mark.deprecated
def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
for i in range(k):
if dims == 2:
A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(dtype)
elif dims == 3:
A = torch.randint(10, 99, size=(dim1, dim2, dim3), device="cuda").to(dtype)
A.view(-1)[-1] = -1
if transpose:
At = A.t().contiguous()
out1, S1 = F.nvidia_transform(At, to_order=orderOut)
else:
out1, S1 = F.nvidia_transform(A, to_order=orderOut)
out2, S2 = F.transform(A, to_order=orderOut, transpose=transpose)
assert S1[0][0] == S2[0][0]
assert S1[0][1] == S2[0][1]
# print(out1)
# print(out2)
torch.testing.assert_close(out1, out2)
@pytest.mark.parametrize("dim1", get_test_dims(2, 256, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(2, 256, n=2), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(2, 256, n=2), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dtype", [torch.int8, torch.int32], ids=describe_dtype)
@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA"))
@pytest.mark.parametrize("orderOut", ["col", "row", "col32"], ids=id_formatter("orderOut"))
@pytest.mark.parametrize("transpose", [False], ids=id_formatter("transpose"))
@pytest.mark.parametrize("dims", [2, 3], ids=id_formatter("dims"))
@pytest.mark.deprecated
def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
if dims == 3 and orderOut != "col32":
return
if dtype == torch.int32 and orderOut != "col32":
return
try:
func = F.get_transform_func(dtype, orderA, orderOut, transpose)
except ValueError as ve:
pytest.skip(str(ve)) # skip if not supported
if dims == 2:
A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype)
elif dims == 3:
A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(dtype)
out, S = F.nvidia_transform(A, to_order=orderOut)
if orderOut == "row":
torch.testing.assert_close(A.flatten(), out.flatten())
elif orderOut == "col":
torch.testing.assert_close(A.t().flatten(), out.flatten())
elif orderOut == "col32":
if dims == 2:
n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
elif dims == 3:
n = A.shape[0] * A.shape[1] * (A.shape[2] + (32 - (A.shape[2] % 32)))
assert out.numel() == n
elif orderOut == "col_turing":
# 32 col 8 row tiles
n = (A.shape[0] + (8 - A.shape[0] % 8)) * (A.shape[1] + (32 - (A.shape[1] % 32)))
assert out.numel() == n
total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0)
for row in range(A.shape[0]):
for col in range(A.shape[1]):
i = row * A.shape[1]
j = col
coltile = (col // 32) + (1 if col % 32 != 0 else 0)
rowtile = ((row // 8) + (1 if row % 8 != 0 else 0)) * total_coltile
offset = 32 * 8 * (rowtile + coltile)
col2 = col % 32
row2 = (row % 8) * 32
assert A.flatten()[i + j] == A[row, col]
# assert A.flatten()[i+j] == out.flatten()[row2+col2]
# torch.testing.assert_close(A.flatten()[i+j], A[row, col])
# torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
if orderOut == "col32":
out2, S = F.nvidia_transform(out, from_order=orderOut, to_order="row", state=S)
torch.testing.assert_close(A, out2)
...@@ -8,8 +8,6 @@ import pytest ...@@ -8,8 +8,6 @@ import pytest
import torch import torch
import bitsandbytes as bnb import bitsandbytes as bnb
from bitsandbytes import functional as F
from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout
from bitsandbytes.nn.modules import Linear8bitLt from bitsandbytes.nn.modules import Linear8bitLt
from tests.helpers import ( from tests.helpers import (
TRUE_FALSE, TRUE_FALSE,
...@@ -18,28 +16,9 @@ from tests.helpers import ( ...@@ -18,28 +16,9 @@ from tests.helpers import (
torch_save_to_buffer, torch_save_to_buffer,
) )
# contributed by Alex Borzunov, see: # contributed by Alex Borzunov, see:
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py # https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
@pytest.mark.skipif(
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5),
reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs",
)
def test_layout_exact_match():
x = (torch.randn(14336 * 3, 14336) * 10).to(torch.int8).cuda()
for tile_size, order in ((8, 32), "col_turing"), ((32, 32), "col_ampere"):
transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device)
tile_indices = get_inverse_transform_indices(transform, tile_size)
cxb = transform(x)
torch.cuda.synchronize()
restored_x = undo_layout(cxb, tile_indices)
torch.cuda.synchronize()
assert restored_x.is_contiguous()
assert torch.all(torch.eq(restored_x, x))
def test_linear_no_igemmlt(): def test_linear_no_igemmlt():
linear = torch.nn.Linear(1024, 3072) linear = torch.nn.Linear(1024, 3072)
x = torch.randn(3, 1024, dtype=torch.half) x = torch.randn(3, 1024, dtype=torch.half)
...@@ -139,7 +118,7 @@ def test_linear_serialization( ...@@ -139,7 +118,7 @@ def test_linear_serialization(
if not has_fp16_weights: if not has_fp16_weights:
assert os.path.getsize(state_path_8bit) < 0.5 * os.path.getsize(state_path) assert os.path.getsize(state_path_8bit) < 0.5 * os.path.getsize(state_path)
new_state_dict = torch.load(state_path_8bit) new_state_dict = torch.load(state_path_8bit, weights_only=False)
new_linear_custom = Linear8bitLt( new_linear_custom = Linear8bitLt(
linear.in_features, linear.in_features,
......
from math import prod
import pytest
import torch
import bitsandbytes
from tests.helpers import TRUE_FALSE, id_formatter
class TestLLMInt8Ops:
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_int8_linear_matmul(self, device):
A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device)
B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device)
out = torch.ops.bitsandbytes.int8_linear_matmul.default(A, B)
assert out.shape == (10, 30)
assert out.dtype == torch.int32
assert out.device == A.device
torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.default, (A, B))
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_int8_linear_matmul_out(self, device):
A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device)
B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device)
out = torch.empty((10, 30), dtype=torch.int32, device=device)
torch.ops.bitsandbytes.int8_linear_matmul.out(A, B, out)
assert out.shape == (10, 30)
assert out.dtype == torch.int32
assert out.device == A.device
torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.out, (A, B, out))
@pytest.mark.parametrize("threshold", [0.0, 6.0])
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_int8_vectorwise_quant(self, threshold, device):
if device == "cpu":
pytest.skip("CPU implementation is not available")
A = torch.randn(10, 20, dtype=torch.float16, device=device)
A[1][0] = 1000.0
out_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant(A, threshold=threshold)
assert out_row.shape == (10, 20)
assert out_row.dtype == torch.int8
assert out_row.device == A.device
assert row_stats.shape == (10,)
assert row_stats.dtype == torch.float32
assert row_stats.device == A.device
if threshold > 0.0:
assert outlier_cols is not None
assert outlier_cols.dim() == 1
assert outlier_cols.shape[0] <= A.shape[1]
assert outlier_cols.device == A.device
else:
assert outlier_cols is None
torch.library.opcheck(torch.ops.bitsandbytes.int8_vectorwise_quant, (A,))
torch.library.opcheck(torch.ops.bitsandbytes.int8_vectorwise_quant, (A, threshold))
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_int8_mm_dequant(self, device):
A = torch.randint(-128, 127, (256, 256), dtype=torch.int32, device=device)
row_stats = torch.randn(256, dtype=torch.float32, device=device)
col_stats = torch.randn(256, dtype=torch.float32, device=device)
out = torch.ops.bitsandbytes.int8_mm_dequant(A, row_stats, col_stats)
assert out.shape == A.shape
assert out.dtype == torch.float16
assert out.device == A.device
torch.library.opcheck(torch.ops.bitsandbytes.int8_mm_dequant, (A, row_stats, col_stats))
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE)
def test_int8_scaled_mm(self, device, dtype, has_bias):
A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device)
B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device)
row_stats = torch.randn(10, dtype=torch.float32, device=device)
col_stats = torch.randn(30, dtype=torch.float32, device=device)
bias = torch.randn(30, dtype=dtype, device=device) if has_bias else None
out = torch.ops.bitsandbytes.int8_scaled_mm(A, B, row_stats, col_stats, bias=bias, dtype=dtype)
assert out.shape == (10, 30)
assert out.dtype == dtype
assert out.device == A.device
torch.library.opcheck(torch.ops.bitsandbytes.int8_scaled_mm, (A, B, row_stats, col_stats, bias, dtype))
class TestInt8BlockwiseQuantOps:
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
def test_quantize_blockwise(self, device, dtype, blocksize):
if device == "cpu" and dtype != torch.float32:
pytest.skip("CPU implementation is only available for float32")
code = bitsandbytes.functional.create_dynamic_map().to(device)
A = torch.randn(1024, 1024, dtype=dtype, device=device)
out, absmax = torch.ops.bitsandbytes.quantize_blockwise(A, code, blocksize)
assert out.shape == A.shape
assert out.dtype == torch.uint8
assert out.device == A.device
assert absmax.device == A.device
assert absmax.dtype == torch.float32
torch.library.opcheck(torch.ops.bitsandbytes.quantize_blockwise, (A, code, blocksize))
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
def test_dequantize_blockwise(self, device, dtype, blocksize):
if device == "cpu" and dtype != torch.float32:
pytest.skip("CPU implementation is only available for float32")
A = torch.randint(0, 255, (1024, 1024), dtype=torch.uint8, device=device)
code = bitsandbytes.functional.create_dynamic_map().to(device, dtype=torch.float32)
n = A.numel()
blocks = -(n // -blocksize)
absmax = torch.randn((blocks,), device=device, dtype=torch.float32)
out = torch.ops.bitsandbytes.dequantize_blockwise.default(A, absmax, code, blocksize, dtype)
assert out.shape == A.shape
assert out.dtype == dtype
assert out.device == A.device
torch.library.opcheck(torch.ops.bitsandbytes.dequantize_blockwise.default, (A, absmax, code, blocksize, dtype))
class Test4bitBlockwiseQuantOps:
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
if device == "cpu" and quant_type != "nf4":
pytest.skip("CPU implementation is only available for nf4")
A = torch.randn(1024, 1024, dtype=dtype, device=device)
out, absmax = torch.ops.bitsandbytes.quantize_4bit(A, blocksize, quant_type, storage_dtype)
assert out.device == A.device
assert out.dtype == storage_dtype
assert absmax.device == A.device
assert absmax.dtype == torch.float32
torch.library.opcheck(torch.ops.bitsandbytes.quantize_4bit, (A, blocksize, quant_type, storage_dtype))
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
if device == "cpu":
pytest.skip("CPU implementation is not available")
shape = (128, 128)
n = prod(shape)
blocks = -(n // -blocksize)
quantized_shape = ((n + 1) // (storage_dtype.itemsize * 2), 1)
A = (
torch.randint(0, 255, ((n + 1) // 2,), dtype=torch.uint8, device=device)
.view(storage_dtype)
.reshape(quantized_shape)
.contiguous()
)
absmax = torch.randn((blocks,), dtype=torch.float32, device=device)
out = torch.ops.bitsandbytes.dequantize_4bit.default(A, absmax, blocksize, quant_type, shape, dtype)
assert out.device == A.device
assert out.shape == shape
torch.library.opcheck(
torch.ops.bitsandbytes.dequantize_4bit.default, (A, absmax, blocksize, quant_type, shape, dtype)
)
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
if device == "cpu":
pytest.skip("CPU implementation is not available")
out_features = 1024
in_features = 256
A = torch.randn((1, 1, in_features), dtype=dtype, device=device)
B = torch.randn((out_features, in_features), dtype=dtype, device=A.device)
B_q, absmax = torch.ops.bitsandbytes.quantize_4bit(B, blocksize, quant_type, storage_dtype)
code = bitsandbytes.functional.get_4bit_type(quant_type, device=A.device, blocksize=blocksize)
out = torch.ops.bitsandbytes.gemv_4bit.default(A, B_q, B.shape, absmax, code, blocksize)
assert out.device == A.device
assert out.dtype == dtype
assert out.shape == (1, 1, out_features)
assert out.isreal().all()
torch.library.opcheck(torch.ops.bitsandbytes.gemv_4bit.default, (A, B_q, B.shape, absmax, code, blocksize))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment