Unverified Commit e82f72b3 authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

PyTorch Custom Operator Integration (#1544)



* Sketch out first custom op registration

* Add note

* Initial int8 op registration

* Cleanup some deprecated functions.

* Int8 ops updates; tests

* Implement 4bit quant/dequant ops

* Fix nested quant

* cleanup

* Test improvements

* Clean up and improve tests

* Add higher level custom op for int8 matmul + dequant + bias

* Add gemv 4bit custom op

* Cleanup

* Implement out kwarg overloads for custom ops

* Update PyTorch minimum to 2.1

* Deprecation updates

* Deprecation updates

* Cleanup; rename int8_linear_dequant -> int8_scaled_mm

* Bump min pytorch to 2.2

* cleanup

* Test reorganization

* Remove deprecated supports_igemmlt

* More cleanup

* Cleanup obsolete C++/CUDA code

* Cleanup

* Create 'default' backend for fallback op implementations; initial CPU nf4 work

* Stub out for multi-platform

* Fix serialization tests for torch>=2.6.0

* Add example for torch.compile e2e inference

* Test update

---------
Co-authored-by: default avatarTitus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com>
parent f0735f95
import torch
import torch._dynamo
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
# torch._dynamo.config.suppress_errors = True
torch.set_float32_matmul_precision("high")
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
# torch._dynamo.config.capture_dynamic_output_shape_ops = True
model_id = "google/gemma-2-2b-it"
# model_id = "Qwen/Qwen2.5-7B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.bfloat16,
)
input_text = "Write me a poem about Machine Learning."
input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)
# model.forward = torch.compile(model.forward, fullgraph=True)
model = torch.compile(model)
outputs = model.generate(**input_ids, max_new_tokens=32)
print(tokenizer.decode(outputs[0]))
......@@ -42,7 +42,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Artificial Intelligence"
]
dependencies = [
"torch>=2.0,<3",
"torch>=2.2,<3",
"numpy>=1.17"
]
......
......@@ -22,7 +22,7 @@ def torch_save_to_buffer(obj):
def torch_load_from_buffer(buffer):
buffer.seek(0)
obj = torch.load(buffer)
obj = torch.load(buffer, weights_only=False)
buffer.seek(0)
return obj
......@@ -36,6 +36,8 @@ def format_with_label(label: str, value: Any) -> str:
formatted = "T" if value else "F"
elif isinstance(value, (list, tuple)) and all(isinstance(v, bool) for v in value):
formatted = "".join("T" if b else "F" for b in value)
elif isinstance(value, torch.dtype):
formatted = describe_dtype(value)
else:
formatted = str(value)
return f"{label}={formatted}"
......
from typing import Tuple
import pytest
import torch
import bitsandbytes as bnb
from tests.helpers import (
BOOLEAN_TRIPLES,
BOOLEAN_TUPLES,
TRUE_FALSE,
describe_dtype,
get_test_dims,
......@@ -16,189 +13,6 @@ from tests.helpers import (
TRANSPOSE_VALS = [(False, True), (False, False)]
@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(32, 96, n=1), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
@pytest.mark.parametrize(
"funcs",
[(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)],
ids=["func=bmm", "func=matmul"],
)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("req_grad", BOOLEAN_TUPLES, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
@pytest.mark.deprecated
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool], transpose: Tuple[bool, bool]):
if dim2 > 0:
dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16)
dim4 = dim4 - (dim4 % 16)
for i in range(25):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0])
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1])
torch.nn.init.xavier_uniform_(B)
if not transpose[0] and not transpose[1]:
out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B)
elif not transpose[0] and transpose[1]:
out_torch = funcs[0](A, B.t())
out_bnb = funcs[1](A, B.t())
elif transpose[0] and not transpose[1]:
out_torch = funcs[0](A.t(), B)
out_bnb = funcs[1](A.t(), B)
elif transpose[0] and transpose[1]:
out_torch = funcs[0](A.t(), B.t())
out_bnb = funcs[1](A.t(), B.t())
n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx == 0).sum().item() < n * 0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
assert (idx == 0).sum().item() < n * 0.001
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None
if req_grad[0]:
torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02
torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3)
# batched matrix multiply
if funcs[0] in [torch.bmm, torch.matmul]:
A = torch.randn(
size=(dim1, dim2, dim3),
device="cuda",
requires_grad=req_grad[0],
)
B = torch.randn(
size=(dim1, dim3, dim4),
device="cuda",
requires_grad=req_grad[1],
)
target = torch.randn(
size=(dim1, dim2, dim4),
device="cuda",
requires_grad=req_grad[1],
)
torch.nn.init.xavier_uniform_(B)
out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B)
n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx == 0).sum().item() < n * 0.01
torch.testing.assert_close(out_bnb, out_torch, atol=0.027, rtol=0.2)
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None
if req_grad[0]:
torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02
if funcs[0] in [torch.matmul]:
dim1 = dim1 - (dim1 % 16)
A = torch.randn(
size=(dim1, dim2, dim3),
device="cuda",
requires_grad=req_grad[0],
)
dimB = (dim4, dim3) if transpose[1] else (dim3, dim4)
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
target = torch.randn(
size=(dim1, dim2, dim4),
device="cuda",
requires_grad=req_grad[1],
)
torch.nn.init.xavier_uniform_(B)
if transpose[1]:
out_torch = funcs[0](A, B.t())
out_bnb = funcs[1](A, B.t())
else:
out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B)
n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx == 0).sum().item() < n * 0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
assert (idx == 0).sum().item() < n * 0.001
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None
if req_grad[0]:
torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02
@pytest.mark.parametrize("dim1", [40], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3"))
......
import numpy as np
import pytest
from scipy.stats import norm
import torch
from bitsandbytes import functional as F
@pytest.mark.deprecated
def test_kbit_quantile_estimation():
for i in range(100):
data = torch.randn(1024, 1024, device="cuda")
for bits in range(2, 9):
p = np.linspace(1.3e-4, 1 - 1.3e-4, 2**bits)
val1 = torch.Tensor(norm.ppf(p)).cuda()
val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
err = torch.abs(val1 - val2).mean()
assert err < 0.038
for i in range(100):
data = torch.randn(1024, 1024, device="cuda")
for bits in range(2, 4):
total_values = 2**bits - 1
p = np.linspace(0, 1, 2 * total_values + 1)
idx = np.arange(1, 2 * total_values + 1, 2)
p = p[idx]
offset = 1 / (2 * total_values)
p = np.linspace(offset, 1 - offset, total_values)
val1 = torch.Tensor(norm.ppf(p)).cuda()
val2 = F.estimate_quantiles(data, num_quantiles=2**bits - 1)
err = torch.abs(val1 - val2).mean()
assert err < 0.035
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["float", "half"])
@pytest.mark.deprecated
def test_estimate_quantiles(dtype):
A = torch.rand(1024, 1024, device="cuda")
A = A.to(dtype)
code = F.estimate_quantiles(A)
percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device)
torch.testing.assert_close(percs, code, atol=1e-3, rtol=1e-2)
A = torch.randn(1024, 1024, device="cuda")
A = A.to(dtype)
code = F.estimate_quantiles(A)
quantiles = torch.quantile(A.float(), percs)
diff = torch.abs(code - quantiles)
assert (diff > 5e-02).sum().item() == 0
@pytest.mark.deprecated
def test_quantile_quantization():
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
code = F.estimate_quantiles(A1)
C = F.quantize_no_absmax(A1, code)
A2 = F.dequantize_no_absmax(C, code)
diff = torch.abs(A1 - A2).mean().item()
assert diff < 0.0075
A1 = torch.rand(1024, 1024, device="cuda")
code = F.estimate_quantiles(A1)
C = F.quantize_no_absmax(A1, code)
A2 = F.dequantize_no_absmax(C, code)
diff = torch.abs(A1 - A2).mean().item()
torch.testing.assert_close(A1, A2, atol=5e-3, rtol=0)
assert diff < 0.001
@pytest.mark.deprecated
def test_dynamic_quantization():
diffs = []
reldiffs = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, S = F.quantize(A1)
A2 = F.dequantize(C, S)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
assert diff.mean().item() < 0.0135
print(sum(diffs) / len(diffs))
print(sum(reldiffs) / len(reldiffs))
for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda")
C, S = F.quantize(A1)
A2 = F.dequantize(C, S)
diff = torch.abs(A1 - A2).mean().item()
torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
assert diff < 0.004
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=["float", "half"])
@pytest.mark.deprecated
def test_percentile_clipping(gtype):
gnorm_vec1 = torch.zeros(100, device="cuda")
gnorm_vec2 = torch.zeros(100, device="cuda")
n = 4
step = 0
percentile = 5
for i in range(20):
step += 1
g = torch.randn(n, n, dtype=gtype, device="cuda")
gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile)
assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1
gnorm2 = torch.norm(g.float())
if step == 1:
gnorm_vec1[:] = gnorm2
else:
gnorm_vec1[step % 100] = gnorm2
vals, idx = torch.sort(gnorm_vec1)
clip1 = vals[percentile]
torch.testing.assert_close(gnorm_vec1, torch.sqrt(gnorm_vec2))
torch.testing.assert_close(clip1, clip2)
torch.testing.assert_close(gnorm1, gnorm2)
from itertools import product
import math
import random
import time
......@@ -6,7 +5,6 @@ import time
import einops
import numpy as np
import pytest
from scipy.stats import norm
import torch
import bitsandbytes as bnb
......@@ -88,37 +86,12 @@ class Timer:
print("Resetting benchmark data")
def setup():
pass
def teardown():
pass
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["float", "half"])
def test_estimate_quantiles(dtype):
A = torch.rand(1024, 1024, device="cuda")
A = A.to(dtype)
code = F.estimate_quantiles(A)
percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device)
torch.testing.assert_close(percs, code, atol=1e-3, rtol=1e-2)
A = torch.randn(1024, 1024, device="cuda")
A = A.to(dtype)
code = F.estimate_quantiles(A)
quantiles = torch.quantile(A.float(), percs)
diff = torch.abs(code - quantiles)
assert (diff > 5e-02).sum().item() == 0
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested"))
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
@pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed"))
def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
class Test8BitBlockwiseQuantizeFunctional:
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested"))
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
@pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed"))
def test_dynamic_blockwise_quantization(self, dtype, nested, blocksize, signed):
diffs = []
reldiffs = []
for i in range(100):
......@@ -160,6 +133,148 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
# print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
# print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
def test_blockwise_cpu_large(self):
diffs = []
reldiffs = []
batch = 128
seq = 128
for hidden in [128]: # , 14336]:
for blocksize in [4096, 16384]:
for i in range(2):
A1 = torch.randn(batch, seq, hidden, device="cpu")
t0 = time.time()
C, S = F.quantize_blockwise(A1, blocksize=blocksize)
A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
print(time.time() - t0)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
assert diffs[-1] < 0.011
# print(sum(diffs)/len(diffs))
# print(sum(reldiffs)/len(reldiffs))
@pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits"))
@pytest.mark.parametrize("method", ["linear", "fp8", "dynamic", "quantile"])
def test_few_bit_quant(self, bits, method):
abserrs = []
relerrs = []
code = None
if method == "linear":
code = F.create_linear_map(True, total_bits=bits).cuda()
elif method == "fp8":
ebits = math.ceil(bits / 2)
pbits = bits - ebits - 1
code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
elif method == "dynamic":
code = F.create_dynamic_map(True, bits - 0, bits).cuda()
elif method == "quantile":
values = torch.randn(2048, 2048, device="cuda")
code = F.create_quantile_map(values, bits).cuda()
# for some data types we have no zero
# for some data types we have one zero
# for some data types we have two zeros
assert torch.unique(code).numel() in [2**bits, 2**bits - 1], f"bits: {bits}, method: {method}"
# print(method, (code==0).sum())
assert code.numel() == 256
for i in range(10):
values = torch.randn(1, 32, device="cuda")
values /= values.abs().max()
# values[values.abs() < 1e-6] += 1e-5
q1 = []
v1 = []
for v in values[0]:
idx = torch.abs(v - code).argmin()
q1.append(idx.item())
v1.append(code[idx].item())
q1 = torch.Tensor(q1).cuda()
v1 = torch.Tensor(v1).cuda()
q2, S2 = F.quantize_blockwise(values, code=code)
v2 = F.dequantize_blockwise(q2, S2)
idx = torch.isclose(q1.int(), q2.int())
err2 = torch.abs(v2 - values)
abserrs.append(err2.mean().item())
relerrs.append((err2 / (1e-10 + values).abs()).mean().item())
if idx.sum():
# some weird cases
err1 = torch.abs(v1 - values).mean()
# assert err2.mean() <= err1
else:
torch.testing.assert_close(q1, q2)
def test_fp8_quant(self):
for e_bits in range(1, 7):
p_bits = 7 - e_bits
code = F.create_fp8_map(True, e_bits, p_bits).cuda()
abserr = []
relerr = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
# assert diff < 0.0075
# print(sum(abserr)/len(abserr))
# print(sum(relerr)/len(relerr))
abserr = []
relerr = []
for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
# assert diff < 0.0075
# print(sum(abserr)/len(abserr))
# print(sum(relerr)/len(relerr))
abserr = []
relerr = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
# assert diff < 0.0075
# print(3, sum(abserr)/len(abserr))
# print(3, sum(relerr)/len(relerr))
@pytest.mark.benchmark
def test_bench_dequantization(self):
a = torch.rand(1024, 1024, device="cuda").half()
code = F.create_fp8_map(True, 3, 0, 4).cuda()
qa, SA = F.quantize_blockwise(a, code=code)
print(qa.max())
max_theoretical_mu = 1024 * 1024 * 2 / 1024**3 / 672 * 1000 * 1000
# print(max_theoretical_mu)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
qa, SA = F.quantize_blockwise(a)
torch.cuda.synchronize()
# print((time.time()-t0)/1e6)
def test_stable_embedding():
layer = bnb.nn.StableEmbedding(1024, 1024)
layer.reset_parameters()
def quant(x):
max1 = torch.abs(x).max()
......@@ -198,11 +313,6 @@ def quant_multi_chunk(x, dim, chunk_size=32):
return max1, x.to(torch.int8)
def quant_minmax(A):
minA = A.min()
maxA = A.max()
def mean(xx):
return sum(xx) / float(len(xx))
......@@ -219,11 +329,12 @@ methods = {
}
@pytest.mark.parametrize("dim1", [1024 * 2], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1024 * 16], ids=id_formatter("dim2"))
@pytest.mark.parametrize("quant_methods", methods.values(), ids=methods.keys())
@pytest.mark.parametrize("batched", TRUE_FALSE, ids=id_formatter("batched"))
def test_approx_igemm(dim1, dim2, quant_methods, batched):
class TestIGEMMFunctional:
@pytest.mark.parametrize("dim1", [1024 * 2], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1024 * 16], ids=id_formatter("dim2"))
@pytest.mark.parametrize("quant_methods", methods.values(), ids=methods.keys())
@pytest.mark.parametrize("batched", TRUE_FALSE, ids=id_formatter("batched"))
def test_approx_igemm(self, dim1, dim2, quant_methods, batched):
dim1 = dim1 - (dim1 % 32)
dim2 = dim2 - (dim2 % 32)
errors = []
......@@ -258,23 +369,19 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
# print(mean(errors))
# print(mean(relerrors))
def test_stable_embedding():
layer = bnb.nn.StableEmbedding(1024, 1024)
layer.reset_parameters()
@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 256, n=2), ids=id_formatter("hidden_dim"))
@pytest.mark.parametrize("batch_dim", get_test_dims(16, 256, n=2), ids=id_formatter("batch_dim"))
@pytest.mark.parametrize("seq_dim", get_test_dims(16, 256, n=2), ids=id_formatter("seq_dim"))
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 256, n=2), ids=id_formatter("hidden_dim"))
@pytest.mark.parametrize("batch_dim", get_test_dims(16, 256, n=2), ids=id_formatter("batch_dim"))
@pytest.mark.parametrize("seq_dim", get_test_dims(16, 256, n=2), ids=id_formatter("seq_dim"))
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
def test_igemm(self, hidden_dim, batch_dim, transpose, seq_dim):
hidden_dim = hidden_dim - (hidden_dim % 32)
batch_dim = batch_dim - (batch_dim % 16)
seq_dim = seq_dim - (seq_dim % 16)
for i in range(k):
shapeA = (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim)
shapeB = (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4))
shapeB = (
(32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4))
)
A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
if not transpose[0] and not transpose[1]:
......@@ -294,7 +401,9 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
for i in range(k):
shapeA = (batch_dim, seq_dim, hidden_dim)
shapeB = (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4))
shapeB = (
(32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4))
)
A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
if not transpose[0] and not transpose[1]:
......@@ -306,11 +415,10 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
torch.testing.assert_close(out.float(), out2)
@pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=3), ids=id_formatter("seq_dim"))
@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=3), ids=id_formatter("hidden_dim"))
@pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=3), ids=id_formatter("batch_dim"))
def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
@pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=3), ids=id_formatter("seq_dim"))
@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=3), ids=id_formatter("hidden_dim"))
@pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=3), ids=id_formatter("batch_dim"))
def test_dim3_igemm(self, seq_dim, hidden_dim, batch_dim):
seq_dim = seq_dim - (seq_dim % 32)
hidden_dim = hidden_dim - (hidden_dim % 32)
batch_dim = batch_dim - (batch_dim % 2)
......@@ -323,12 +431,11 @@ def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
torch.testing.assert_close(out.float(), out2)
@pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=2), ids=id_formatter("seq_dim"))
@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=2), ids=id_formatter("hidden_dim"))
@pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=2), ids=id_formatter("batch_dim"))
@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose"))
def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
@pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=2), ids=id_formatter("seq_dim"))
@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=2), ids=id_formatter("hidden_dim"))
@pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=2), ids=id_formatter("batch_dim"))
@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose"))
def test_minmax_igemm(self, seq_dim, hidden_dim, batch_dim, transpose):
def min_max(x):
maxA = torch.amax(x, dim=2, keepdim=True)
minA = torch.amin(x, dim=2, keepdim=True)
......@@ -394,13 +501,12 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
assert mean(errs) < 0.015
assert mean(relerrs) < 0.3
@pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=2), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=2), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 256, n=2), ids=id_formatter("dim4"))
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
def test_ibmm(dim1, dim2, dim3, dim4, transpose):
@pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=2), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=2), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 256, n=2), ids=id_formatter("dim4"))
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
def test_ibmm(self, dim1, dim2, dim3, dim4, transpose):
dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16)
dim4 = dim4 - (dim4 % 16)
......@@ -425,13 +531,14 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
torch.testing.assert_close(out.float(), out2.float())
@pytest.mark.parametrize("dim1", [128], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [256], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [499, 512], ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", [512], ids=id_formatter("dim4"))
@pytest.mark.parametrize("dims", (2, 3), ids=id_formatter("dims"))
@pytest.mark.parametrize("ldb", (0,), ids=id_formatter("ldb"))
def test_int8_linear_matmul(dim1, dim2, dim3, dim4, dims, ldb):
class TestLLMInt8Functional:
@pytest.mark.parametrize("dim1", [128], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [256], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [499, 512], ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", [512], ids=id_formatter("dim4"))
@pytest.mark.parametrize("dims", (2, 3), ids=id_formatter("dims"))
@pytest.mark.parametrize("ldb", (0,), ids=id_formatter("ldb"))
def test_int8_linear_matmul(self, dim1, dim2, dim3, dim4, dims, ldb):
for i in range(k):
if dims == 2:
A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(torch.int8)
......@@ -443,13 +550,12 @@ def test_int8_linear_matmul(dim1, dim2, dim3, dim4, dims, ldb):
C2 = F.int8_linear_matmul(A, B)
torch.testing.assert_close(C1, C2.float())
@pytest.mark.parametrize("dim1", [32], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", [32], ids=id_formatter("dim4"))
@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
def test_int8_linear_matmul_half(dim1, dim2, dim3, dim4, dims):
@pytest.mark.parametrize("dim1", [32], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", [32], ids=id_formatter("dim4"))
@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
def test_int8_linear_matmul_half(self, dim1, dim2, dim3, dim4, dims):
for i in range(k):
if dims == 2:
A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half()
......@@ -467,12 +573,11 @@ def test_int8_linear_matmul_half(dim1, dim2, dim3, dim4, dims):
torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
@pytest.mark.parametrize("dim1", (64, 256), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim4", (64, 1024), ids=id_formatter("dim4"))
@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
def test_dequant_mm(dim1, dim4, dims, has_bias):
@pytest.mark.parametrize("dim1", (64, 256), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim4", (64, 1024), ids=id_formatter("dim4"))
@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
def test_dequant_mm(self, dim1, dim4, dims, has_bias):
inner = 128
bias = None
if has_bias:
......@@ -509,12 +614,11 @@ def test_dequant_mm(dim1, dim4, dims, has_bias):
n = C5.numel()
assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01 * n))
@pytest.mark.parametrize("dim1", [1 * 1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1 * 1024], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
@pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("decomp"))
def test_colrow_absmax(dim1, dim2, dims, threshold):
@pytest.mark.parametrize("dim1", [1 * 1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1 * 1024], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
@pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("decomp"))
def test_colrow_absmax(self, dim1, dim2, dims, threshold):
for i in range(k):
A = torch.randn(dim1, dim2, device="cuda").half()
......@@ -548,10 +652,9 @@ def test_colrow_absmax(dim1, dim2, dims, threshold):
torch.testing.assert_close(col_stats1, col_stats2)
torch.testing.assert_close(row_stats1, row_stats2)
@pytest.mark.parametrize("dim1", [2048, 4096], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [512, 1024], ids=id_formatter("dim2"))
def test_int8_double_quant(dim1, dim2):
@pytest.mark.parametrize("dim1", [2048, 4096], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [512, 1024], ids=id_formatter("dim2"))
def test_int8_double_quant(self, dim1, dim2):
for i in range(k):
A = torch.randn(dim1, dim2, device="cuda").half()
out_col1, Scol = F.vectorwise_quant(A, dim=0)
......@@ -570,17 +673,20 @@ def test_int8_double_quant(dim1, dim2):
# allow for 1:500 error due to rounding differences
min_error = 1 / 500
if num_not_close_cols > (min_error * n):
print(f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}")
print(
f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}"
)
assert False
if num_not_close_rows > (min_error * n):
print(f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}")
print(
f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}"
)
assert False
torch.testing.assert_close(Srow.flatten().float(), statsA)
torch.testing.assert_close(Scol.flatten().float(), statsAt)
@pytest.mark.parametrize(
@pytest.mark.parametrize(
("dim1", "dim4", "inner"),
(
pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}")
......@@ -590,8 +696,8 @@ def test_int8_double_quant(dim1, dim2):
(4, 256, 512, 4096),
)
),
)
def test_integrated_int8_linear_matmul(dim1, dim4, inner):
)
def test_integrated_int8_linear_matmul(self, dim1, dim4, inner):
for i in range(k):
A = torch.randn(dim1, inner, device="cuda").half()
B = torch.randn(dim4, inner, device="cuda").half()
......@@ -618,86 +724,9 @@ def test_integrated_int8_linear_matmul(dim1, dim4, inner):
err2 = torch.abs(out1 - out3).mean().item()
assert err2 <= err1 * 1.025
@pytest.mark.parametrize(
("dim1", "dim4", "inner"),
(
pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}")
for (dim1, dim4, inner) in zip(
get_test_dims(1, 4 * 1024, n=6),
get_test_dims(1, 4 * 1024, n=6),
get_test_dims(1, 4 * 1024, n=6),
)
),
)
@pytest.mark.skip("Row scale has some bugs for ampere")
def test_igemmlt_row_scale(dim1, dim4, inner):
formatB = F.get_special_format_str()
err1, err2, err3 = [], [], []
relerr1, relerr2 = [], []
scale = 1
for i in range(k):
A = torch.randn(dim1, inner, device="cuda").half()
B = torch.randn(dim4, inner, device="cuda").half()
torch.nn.init.xavier_uniform_(B)
C1 = torch.matmul(A, B.t())
out1 = torch.matmul(A.half(), B.t().half())
C1a, C1b, stats1a, stats1b, coo_tensor = F.int8_double_quant(A)
CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
A2, SA = F.nvidia_transform(C1a, "col32")
B2, SB = F.nvidia_transform(CB, formatB)
A1, maxA = F.vectorwise_quant(A, dim=1)
c = 10.0 * inner * scale
row_scale = torch.ones_like(maxA) / c
outC32 = F.int8_linear_matmul(A2, B2, dtype=torch.int8, row_scale=row_scale)
# C3, S = F.nvidia_transform(outC32, "row", state=SC)
C3 = outC32
maxval = torch.abs(C3).max()
if maxval == 127:
scale = 1.5
else:
scale = maxval / 120
out3 = C3 * maxA * absmaxB * c / (127 * 127)
C4 = torch.matmul(C1a.float(), CB.float().t())
C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
B2, SB = F.nvidia_transform(C2a, formatB)
outC32 = F.int8_linear_matmul(A2, B2)
out2 = F.int8_mm_dequant(outC32, stats1a, stats2a)
CA, SA = F.vectorwise_quant(A, dim=1, quant_type="vector")
CB, SB = F.vectorwise_quant(B, dim=1, quant_type="linear")
C = torch.matmul(CA.float(), CB.t().float())
out4 = C * SA * SB / (127 * 127)
# out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127)
# print('='*80)
# print(out1)
# print(out2)
# print(out3)
# print(out1)
# print(out2)
# print(out3)
err1.append(torch.abs(out1 - out2).mean().item())
err2.append(torch.abs(out1 - out3).mean().item())
err3.append(torch.abs(out1 - out4).mean().item())
# assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10)
print("")
print(sum(err1) / len(err1))
print(sum(err2) / len(err2))
print(sum(err3) / len(err3))
@pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2"))
def test_coo_double_quant(dim1, dim2):
@pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2"))
def test_coo_double_quant(self, dim1, dim2):
threshold = 2.00
for i in range(k):
A = torch.randn(dim1, dim2, device="cuda").half()
......@@ -714,10 +743,9 @@ def test_coo_double_quant(dim1, dim2):
A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
torch.testing.assert_close(A, A2, rtol=0.05, atol=1.5e-2)
@pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2"))
def test_coo_int8_vectorwise_quant(dim1, dim2):
@pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2"))
def test_coo_int8_vectorwise_quant(self, dim1, dim2):
threshold = 3.00
for i in range(k):
A = torch.randn(dim1, dim2, device="cuda").half()
......@@ -731,10 +759,11 @@ def test_coo_int8_vectorwise_quant(dim1, dim2):
torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2)
@pytest.mark.parametrize("dim1", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim2"))
@pytest.mark.parametrize("transposed_B", TRUE_FALSE, ids=id_formatter("transposed_B"))
def test_spmm_coo(dim1, dim2, transposed_B):
class TestSpMMFunctional:
@pytest.mark.parametrize("dim1", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim2"))
@pytest.mark.parametrize("transposed_B", TRUE_FALSE, ids=id_formatter("transposed_B"))
def test_spmm_coo(self, dim1, dim2, transposed_B):
threshold = 1.5
dim3 = torch.randint(32, 128, size=(1,)).item()
# dim3 = 17
......@@ -761,9 +790,8 @@ def test_spmm_coo(dim1, dim2, transposed_B):
assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30)
@pytest.mark.benchmark
def test_spmm_bench():
@pytest.mark.benchmark
def test_spmm_bench(self):
batch = 2
model = 1024 * 1
hidden = model * 4
......@@ -803,59 +831,11 @@ def test_spmm_bench():
print(tsp, t8)
print(tsp / t8)
@pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [256, 1024], ids=id_formatter("dim2"))
def test_integrated_sparse_decomp(dim1, dim2):
threshold = 3.0
for _ in range(k):
A = torch.randn(dim1, dim2).cuda().half()
w1 = torch.randn(dim1, dim2).cuda().half()
out1 = torch.matmul(A, w1.t())
Cw1, statsw1, _ = F.int8_vectorwise_quant(w1)
CA, statsA, _ = F.int8_vectorwise_quant(A)
out1_32 = F.int8_linear_matmul(CA, Cw1)
out2 = F.int8_mm_dequant(out1_32, statsA, statsw1)
# CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold)
CA, _, statsA, _, coo_tensor = F.double_quant(A, threshold=threshold)
out1_32 = F.int8_linear_matmul(CA, Cw1)
out3 = F.int8_mm_dequant(out1_32, statsA, statsw1)
assert coo_tensor is not None
out4 = F.spmm_coo(coo_tensor, w1.t())
# idx = torch.unique(coo_tensor._indices()[1]).long()
# out4 = torch.matmul(A, w1.t())
out5 = out3 + out4
err1 = torch.abs(out1 - out2).mean().item()
err2 = torch.abs(out1 - out5).mean().item()
assert err2 < err1
def test_matmuls():
a = torch.randn(256, 512).half().cuda()
b = torch.randn(256, 512).half().cuda()
c1 = torch.matmul(a, b.t())
c2 = bnb.matmul(a, b)
c3 = bnb.matmul_cublas(a, b.t())
err1 = torch.abs(c1 - c2).mean().item()
err2 = torch.abs(c1 - c3).mean().item()
assert err1 < 0.2
assert err2 < 0.2
print(err1, err2)
@pytest.mark.parametrize("dim1", [1 * 2048], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [12288], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dtype", [torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("out_func", ["zeros", "ones"], ids=id_formatter("out_func"))
def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
@pytest.mark.parametrize("dim1", [1 * 2048], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [12288], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dtype", [torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("out_func", ["zeros", "ones"], ids=id_formatter("out_func"))
def test_spmm_coo_very_sparse(self, dim1, dim2, dtype, out_func):
out_func = getattr(torch, out_func)
threshold = 3.3
......@@ -911,48 +891,43 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
# torch.cuda.synchronize()
# print(time.time() - t0)
@pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [256, 1024], ids=id_formatter("dim2"))
@pytest.skip("No longer supported")
def test_integrated_sparse_decomp(self, dim1, dim2):
threshold = 3.0
for _ in range(k):
A = torch.randn(dim1, dim2).cuda().half()
w1 = torch.randn(dim1, dim2).cuda().half()
out1 = torch.matmul(A, w1.t())
def test_coo2csr():
threshold = 1
A = torch.randn(128, 128).half().cuda()
idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
A2 = A * idx
csrA = F.coo2csr(cooA)
counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
assert counts.numel() == A.shape[0]
Cw1, statsw1, _ = F.int8_vectorwise_quant(w1)
CA, statsA, _ = F.int8_vectorwise_quant(A)
torch.testing.assert_close(counts.long(), (A2 != 0).sum(1))
idx = A2 != 0
torch.testing.assert_close(A2[idx], csrA.values)
out1_32 = F.int8_linear_matmul(CA, Cw1)
out2 = F.int8_mm_dequant(out1_32, statsA, statsw1)
# CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold)
CA, _, statsA, _, coo_tensor = F.double_quant(A, threshold=threshold)
def test_coo2csc():
threshold = 1
A = torch.randn(128, 128).half().cuda()
idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
A2 = A * idx
cscA = F.coo2csc(cooA)
counts = cscA.colptr[1:] - cscA.colptr[:-1]
assert counts.numel() == A.shape[1]
out1_32 = F.int8_linear_matmul(CA, Cw1)
out3 = F.int8_mm_dequant(out1_32, statsA, statsw1)
torch.testing.assert_close(counts.long(), (A2 != 0).sum(0))
# torch uses row-major -> use transpose to transfer to col-major
idx = A2.t() != 0
torch.testing.assert_close(A2.t()[idx], cscA.values)
assert coo_tensor is not None
out4 = F.spmm_coo(coo_tensor, w1.t())
# idx = torch.unique(coo_tensor._indices()[1]).long()
# out4 = torch.matmul(A, w1.t())
out5 = out3 + out4
@pytest.mark.parametrize("dim1", [1 * 2048])
@pytest.mark.parametrize("dim2", [2048])
@pytest.mark.parametrize("dtype", [torch.int8])
def test_spmm_coo_dequant(dim1, dim2, dtype):
err1 = torch.abs(out1 - out2).mean().item()
err2 = torch.abs(out1 - out5).mean().item()
assert err2 < err1
@pytest.mark.parametrize("dim1", [1 * 2048])
@pytest.mark.parametrize("dim2", [2048])
@pytest.mark.parametrize("dtype", [torch.int8])
def test_spmm_coo_dequant(self, dim1, dim2, dtype):
threshold = 6.0
# threshold = 2.8
# threshold = 0.0
......@@ -1050,325 +1025,54 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
print("partial matmul", time.time() - t0)
def test_zeropoint():
def quant_zp(x):
dtype = x.dtype
x = x.float()
dyna = x.max() - x.min()
if dyna == 0:
dyna = 1
qx = 254.0 / dyna
minx = x.min()
# zpx = torch.round(minx* qx)
# zpx = 127 - torch.round(x.max()* qx)
zpx = torch.round(x.min() * qx) - 127
x = (qx * x) + zpx
return x, qx, zpx
batch = 2
seq = 512
model = 1024
hidden = 4 * model
A = torch.randn(batch * seq, model, device="cuda").half() * 0.1
B = torch.randn(model, hidden, device="cuda").half() * 0.1
C0 = torch.matmul(A, B)
# A, SA = F.vectorwise_quant(A, quant_type='linear')
# B, SB = F.vectorwise_quant(B, quant_type='linear')
A = A.float()
B = B.float()
C1 = torch.matmul(A, B)
C3 = bnb.matmul(A.half(), B.t().contiguous().half())
zp = 1
# C2 = torch.matmul(A-zp, B)
# C2 += B.sum(0).view(1, -1)*zp
C2 = torch.matmul(A, B - zp)
C2 -= A.sum(1).view(-1, 1) * zp
ca, cqa, cza = quant_zp(A)
# print(ca.min(), ca.max())
# print((ca - cza).min(), (ca - cza).max())
zp = 1
scale = 2.0
C5 = torch.matmul((A * scale) - zp, B)
C5 += B.sum(0) * zp
C5 /= scale
CA, qa, zpa = quant_zp(A)
C4 = torch.matmul(CA, B)
C4 -= B.sum(0) * zpa
C4 /= qa
zpb = 1
zpa = 1
qa = 2
qb = 2
C6 = torch.matmul((A * qa) + zpa, (B * qb) + zpb)
C6 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
C6 -= zpa * zpb * A.shape[1]
C6 /= qa * qb
CA, qa, zpa = quant_zp(A)
CB, qb, zpb = quant_zp(B)
C7 = torch.matmul(CA, CB)
C7 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
C7 -= zpa * zpb * A.shape[1]
C7 /= qa * qb
# print("")
# print(C0.flatten()[:10])
# print(C1.flatten()[:10])
# print(C2.flatten()[:10])
# print(C3.flatten()[:10])
# print(C5.flatten()[:10])
# print(C6.flatten()[:10])
# print(C7.flatten()[:10])
err1 = torch.abs(C1 - C2).mean().item()
err2 = torch.abs(C1 - C3).mean().item()
err3 = torch.abs(C1 - C4).mean().item()
err4 = torch.abs(C1 - C5).mean().item()
err5 = torch.abs(C1 - C6).mean().item()
err6 = torch.abs(C1 - C7).mean().item()
print(err1, err2, err3, err4, err5, err6)
@pytest.mark.deprecated
def test_extract_outliers():
for i in range(k):
shapeA = (4096, 4096 * 4)
idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda()
# idx = torch.Tensor([0]).int().cuda()
A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
outliers1 = A[:, idx.long()]
CA, SA = F.transform(A, "col_turing")
outliers2 = F.extract_outliers(CA, SA, idx)
assert outliers2.shape[0] == shapeA[0]
assert outliers2.shape[1] == idx.numel()
torch.testing.assert_close(outliers1, outliers2)
CA, SA = F.transform(A, "col_ampere")
outliers2 = F.extract_outliers(CA, SA, idx)
assert outliers2.shape[0] == shapeA[0]
assert outliers2.shape[1] == idx.numel()
torch.testing.assert_close(outliers1, outliers2)
def test_blockwise_cpu_large():
diffs = []
reldiffs = []
batch = 128
seq = 128
for hidden in [128]: # , 14336]:
for blocksize in [4096, 16384]:
for i in range(2):
A1 = torch.randn(batch, seq, hidden, device="cpu")
t0 = time.time()
C, S = F.quantize_blockwise(A1, blocksize=blocksize)
A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
print(time.time() - t0)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
assert diffs[-1] < 0.011
# print(sum(diffs)/len(diffs))
# print(sum(reldiffs)/len(reldiffs))
def test_fp8_quant():
for e_bits in range(1, 7):
p_bits = 7 - e_bits
code = F.create_fp8_map(True, e_bits, p_bits).cuda()
abserr = []
relerr = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
# assert diff < 0.0075
# print(sum(abserr)/len(abserr))
# print(sum(relerr)/len(relerr))
abserr = []
relerr = []
for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
# assert diff < 0.0075
# print(sum(abserr)/len(abserr))
# print(sum(relerr)/len(relerr))
abserr = []
relerr = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
# assert diff < 0.0075
# print(3, sum(abserr)/len(abserr))
# print(3, sum(relerr)/len(relerr))
def test_few_bit_quant():
# print('')
for bits in range(2, 9):
# print('='*30, bits, '='*30)
for method in ["linear", "fp8", "dynamic", "quantile"]:
abserrs = []
relerrs = []
code = None
if method == "linear":
code = F.create_linear_map(True, total_bits=bits).cuda()
elif method == "fp8":
ebits = math.ceil(bits / 2)
pbits = bits - ebits - 1
code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
elif method == "dynamic":
code = F.create_dynamic_map(True, bits - 0, bits).cuda()
elif method == "quantile":
values = torch.randn(2048, 2048, device="cuda")
code = F.create_quantile_map(values, bits).cuda()
# for some data types we have no zero
# for some data types we have one zero
# for some data types we have two zeros
assert torch.unique(code).numel() in [2**bits, 2**bits - 1], f"bits: {bits}, method: {method}"
# print(method, (code==0).sum())
assert code.numel() == 256
for i in range(10):
values = torch.randn(1, 32, device="cuda")
values /= values.abs().max()
# values[values.abs() < 1e-6] += 1e-5
q1 = []
v1 = []
for v in values[0]:
idx = torch.abs(v - code).argmin()
q1.append(idx.item())
v1.append(code[idx].item())
q1 = torch.Tensor(q1).cuda()
v1 = torch.Tensor(v1).cuda()
q2, S2 = F.quantize_blockwise(values, code=code)
v2 = F.dequantize_blockwise(q2, S2)
idx = torch.isclose(q1.int(), q2.int())
err2 = torch.abs(v2 - values)
abserrs.append(err2.mean().item())
relerrs.append((err2 / (1e-10 + values).abs()).mean().item())
if idx.sum():
# some weird cases
err1 = torch.abs(v1 - values).mean()
# assert err2.mean() <= err1
else:
torch.testing.assert_close(q1, q2)
# print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
# assert False
def test_kbit_quantile_estimation():
for i in range(100):
data = torch.randn(1024, 1024, device="cuda")
for bits in range(2, 9):
p = np.linspace(1.3e-4, 1 - 1.3e-4, 2**bits)
val1 = torch.Tensor(norm.ppf(p)).cuda()
val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
err = torch.abs(val1 - val2).mean()
assert err < 0.038
for i in range(100):
data = torch.randn(1024, 1024, device="cuda")
for bits in range(2, 4):
total_values = 2**bits - 1
p = np.linspace(0, 1, 2 * total_values + 1)
idx = np.arange(1, 2 * total_values + 1, 2)
p = p[idx]
offset = 1 / (2 * total_values)
p = np.linspace(offset, 1 - offset, total_values)
val1 = torch.Tensor(norm.ppf(p)).cuda()
val2 = F.estimate_quantiles(data, num_quantiles=2**bits - 1)
err = torch.abs(val1 - val2).mean()
assert err < 0.035
@pytest.mark.benchmark
def test_bench_dequantization():
a = torch.rand(1024, 1024, device="cuda").half()
code = F.create_fp8_map(True, 3, 0, 4).cuda()
qa, SA = F.quantize_blockwise(a, code=code)
print(qa.max())
class TestSparseTensorFunctional:
def test_coo2csr(self):
threshold = 1
A = torch.randn(128, 128).half().cuda()
idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
A2 = A * idx
csrA = F.coo2csr(cooA)
counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
assert counts.numel() == A.shape[0]
max_theoretical_mu = 1024 * 1024 * 2 / 1024**3 / 672 * 1000 * 1000
# print(max_theoretical_mu)
torch.testing.assert_close(counts.long(), (A2 != 0).sum(1))
idx = A2 != 0
torch.testing.assert_close(A2[idx], csrA.values)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
qa, SA = F.quantize_blockwise(a)
torch.cuda.synchronize()
# print((time.time()-t0)/1e6)
def test_coo2csc(self):
threshold = 1
A = torch.randn(128, 128).half().cuda()
idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
A2 = A * idx
cscA = F.coo2csc(cooA)
counts = cscA.colptr[1:] - cscA.colptr[:-1]
assert counts.numel() == A.shape[1]
torch.testing.assert_close(counts.long(), (A2 != 0).sum(0))
# torch uses row-major -> use transpose to transfer to col-major
idx = A2.t() != 0
torch.testing.assert_close(A2.t()[idx], cscA.values)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096])
def test_4bit_quant(dtype, quant_type, blocksize):
vals = list(product([0, 1], repeat=4))
code = {}
for bits in vals:
result = 0
bias = 3
sign, e1, e2, p1 = bits
idx = sign * 8 + e1 * 4 + e2 * 2 + p1 * 1
sign = -1.0 if sign else 1.0
exp = e1 * 2 + e2 * 1
if exp == 0:
# sub-normal
if p1 == 0:
result = 0
else:
result = sign * 0.0625
else:
# normal
exp = 2 ** (-exp + bias + 1)
frac = 1.5 if p1 else 1.0
result = sign * exp * frac
code[idx] = result
class TestQuantize4BitFunctional:
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096])
def test_4bit_quant(self, dtype, quant_type, blocksize):
A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
err = (A1 - A2).abs().float()
relerr = (err / (A1.abs().float() + 1e-8)).mean()
idx = err > 1.0
err = err.mean()
assert A2.dtype == dtype
......@@ -1391,9 +1095,8 @@ def test_4bit_quant(dtype, quant_type, blocksize):
# 1024 => 0.8, 2048 => 0.88, 4096 => 0.96
assert err.item() < math.log2(blocksize) * 8e-2
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
def test_4bit_compressed_stats(quant_type):
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
def test_4bit_compressed_stats(self, quant_type):
for blocksize in [128, 64]:
errs1 = []
errs2 = []
......@@ -1425,11 +1128,10 @@ def test_4bit_compressed_stats(quant_type):
# print(sum(errs1)/len(errs1), blocksize, quant_type)
# print(sum(errs2)/len(errs2), blocksize, quant_type)
# @pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
@pytest.mark.parametrize("quant_type", ["nf4"])
@pytest.mark.benchmark
def test_bench_4bit_dequant(quant_type):
# @pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
@pytest.mark.parametrize("quant_type", ["nf4"])
@pytest.mark.benchmark
def test_bench_4bit_dequant(self, quant_type):
blocksize = 256
a = torch.rand(1024 * 12 * 4, 1024 * 12, device="cuda").half()
qa, SA = F.quantize_4bit(a, blocksize=blocksize, quant_type=quant_type)
......@@ -1458,32 +1160,16 @@ def test_bench_4bit_dequant(quant_type):
# torch.cuda.synchronize()
# print((time.time()-t0)/iters*1e6)
def test_normal_map_tree():
code = F.create_normal_map()
values = code[:8].tolist() + code[-8:].tolist()
num_pivots = 1
# print(values)
while num_pivots < 16:
idx = list(range(16 // num_pivots // 2, 16, 16 // num_pivots))
# print(idx)
num_pivots *= 2
pivots = []
for i in idx:
pivots.append((values[i - 1] + values[i]) / 2)
# print(pivots)
@pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}")
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"])
@pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize(
@pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}")
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"])
@pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize(
"quant_storage",
[torch.uint8, torch.float16, torch.bfloat16, torch.float32],
ids=describe_dtype,
)
def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
)
def test_gemv_4bit(self, dtype, storage_type, quant_storage, double_quant, kind):
for dim in [128, 256, 512, 1024]:
# for dim in [4*1024]:
# for dim in [1*16]:
......@@ -1613,6 +1299,46 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
assert relratio < 1.04 and relratio > 0.96
assert maxratio < 1.02 and maxratio > 0.98
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"])
def test_gemv_eye_4bit(self, storage_type, dtype, double_quant):
dims = 10
torch.random.manual_seed(np.random.randint(0, 412424242))
dims = get_test_dims(0, 8192, n=dims)
dims = [dim + (64 - (dim % 64)) for dim in dims]
# for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
for dim in dims:
A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device="cuda")
B = torch.eye(dim, dtype=dtype, device="cuda")
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
C3 = torch.matmul(A, B.t())
C2 = bnb.matmul_4bit(A, qB.t(), state)
A.requires_grad = True
C1 = bnb.matmul_4bit(A, qB.t(), state)
torch.testing.assert_close(A, C3)
torch.testing.assert_close(A, C1)
torch.testing.assert_close(A, C2)
# torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001)
# torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080)
def test_normal_map_tree():
code = F.create_normal_map()
values = code[:8].tolist() + code[-8:].tolist()
num_pivots = 1
# print(values)
while num_pivots < 16:
idx = list(range(16 // num_pivots // 2, 16, 16 // num_pivots))
# print(idx)
num_pivots *= 2
pivots = []
for i in idx:
pivots.append((values[i - 1] + values[i]) / 2)
# print(pivots)
@pytest.mark.skip("Row scale has some bugs for ampere")
def test_managed():
......@@ -1637,32 +1363,6 @@ def test_managed():
assert (A == 17 * (2**3)).sum().item() == n * n
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"])
def test_gemv_eye_4bit(storage_type, dtype, double_quant):
dims = 10
torch.random.manual_seed(np.random.randint(0, 412424242))
dims = get_test_dims(0, 8192, n=dims)
dims = [dim + (64 - (dim % 64)) for dim in dims]
# for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
for dim in dims:
A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device="cuda")
B = torch.eye(dim, dtype=dtype, device="cuda")
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
C3 = torch.matmul(A, B.t())
C2 = bnb.matmul_4bit(A, qB.t(), state)
A.requires_grad = True
C1 = bnb.matmul_4bit(A, qB.t(), state)
torch.testing.assert_close(A, C3)
torch.testing.assert_close(A, C1)
torch.testing.assert_close(A, C2)
# torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001)
# torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080)
@pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=1), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=1), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=1), ids=id_formatter("dim3"))
......@@ -1676,169 +1376,3 @@ def test_vector_quant(dim1, dim2, dim3):
A1 = F.vectorwise_dequant(qA, SA)
n = A1.numel()
assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n * 0.002))
@pytest.mark.deprecated
def test_quantile_quantization():
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
code = F.estimate_quantiles(A1)
C = F.quantize_no_absmax(A1, code)
A2 = F.dequantize_no_absmax(C, code)
diff = torch.abs(A1 - A2).mean().item()
assert diff < 0.0075
A1 = torch.rand(1024, 1024, device="cuda")
code = F.estimate_quantiles(A1)
C = F.quantize_no_absmax(A1, code)
A2 = F.dequantize_no_absmax(C, code)
diff = torch.abs(A1 - A2).mean().item()
torch.testing.assert_close(A1, A2, atol=5e-3, rtol=0)
assert diff < 0.001
@pytest.mark.deprecated
def test_dynamic_quantization():
diffs = []
reldiffs = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, S = F.quantize(A1)
A2 = F.dequantize(C, S)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
assert diff.mean().item() < 0.0135
print(sum(diffs) / len(diffs))
print(sum(reldiffs) / len(reldiffs))
for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda")
C, S = F.quantize(A1)
A2 = F.dequantize(C, S)
diff = torch.abs(A1 - A2).mean().item()
torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
assert diff < 0.004
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=["float", "half"])
@pytest.mark.deprecated
def test_percentile_clipping(gtype):
gnorm_vec1 = torch.zeros(100, device="cuda")
gnorm_vec2 = torch.zeros(100, device="cuda")
n = 4
step = 0
percentile = 5
for i in range(k):
step += 1
g = torch.randn(n, n, dtype=gtype, device="cuda")
gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile)
assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1
gnorm2 = torch.norm(g.float())
if step == 1:
gnorm_vec1[:] = gnorm2
else:
gnorm_vec1[step % 100] = gnorm2
vals, idx = torch.sort(gnorm_vec1)
clip1 = vals[percentile]
torch.testing.assert_close(gnorm_vec1, torch.sqrt(gnorm_vec2))
torch.testing.assert_close(clip1, clip2)
torch.testing.assert_close(gnorm1, gnorm2)
@pytest.mark.parametrize("dim1", get_test_dims(2, 1024, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(2, 1024, n=2), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [0], ids=id_formatter("dim3"))
@pytest.mark.parametrize("dims", [2], ids=id_formatter("dims"))
@pytest.mark.parametrize("dtype", [torch.int8], ids=describe_dtype)
@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA"))
@pytest.mark.parametrize("orderOut", ["col32", "col_turing", "col_ampere"], ids=id_formatter("orderOut"))
@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose"))
@pytest.mark.deprecated
def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
for i in range(k):
if dims == 2:
A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(dtype)
elif dims == 3:
A = torch.randint(10, 99, size=(dim1, dim2, dim3), device="cuda").to(dtype)
A.view(-1)[-1] = -1
if transpose:
At = A.t().contiguous()
out1, S1 = F.nvidia_transform(At, to_order=orderOut)
else:
out1, S1 = F.nvidia_transform(A, to_order=orderOut)
out2, S2 = F.transform(A, to_order=orderOut, transpose=transpose)
assert S1[0][0] == S2[0][0]
assert S1[0][1] == S2[0][1]
# print(out1)
# print(out2)
torch.testing.assert_close(out1, out2)
@pytest.mark.parametrize("dim1", get_test_dims(2, 256, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(2, 256, n=2), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(2, 256, n=2), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dtype", [torch.int8, torch.int32], ids=describe_dtype)
@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA"))
@pytest.mark.parametrize("orderOut", ["col", "row", "col32"], ids=id_formatter("orderOut"))
@pytest.mark.parametrize("transpose", [False], ids=id_formatter("transpose"))
@pytest.mark.parametrize("dims", [2, 3], ids=id_formatter("dims"))
@pytest.mark.deprecated
def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
if dims == 3 and orderOut != "col32":
return
if dtype == torch.int32 and orderOut != "col32":
return
try:
func = F.get_transform_func(dtype, orderA, orderOut, transpose)
except ValueError as ve:
pytest.skip(str(ve)) # skip if not supported
if dims == 2:
A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype)
elif dims == 3:
A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(dtype)
out, S = F.nvidia_transform(A, to_order=orderOut)
if orderOut == "row":
torch.testing.assert_close(A.flatten(), out.flatten())
elif orderOut == "col":
torch.testing.assert_close(A.t().flatten(), out.flatten())
elif orderOut == "col32":
if dims == 2:
n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
elif dims == 3:
n = A.shape[0] * A.shape[1] * (A.shape[2] + (32 - (A.shape[2] % 32)))
assert out.numel() == n
elif orderOut == "col_turing":
# 32 col 8 row tiles
n = (A.shape[0] + (8 - A.shape[0] % 8)) * (A.shape[1] + (32 - (A.shape[1] % 32)))
assert out.numel() == n
total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0)
for row in range(A.shape[0]):
for col in range(A.shape[1]):
i = row * A.shape[1]
j = col
coltile = (col // 32) + (1 if col % 32 != 0 else 0)
rowtile = ((row // 8) + (1 if row % 8 != 0 else 0)) * total_coltile
offset = 32 * 8 * (rowtile + coltile)
col2 = col % 32
row2 = (row % 8) * 32
assert A.flatten()[i + j] == A[row, col]
# assert A.flatten()[i+j] == out.flatten()[row2+col2]
# torch.testing.assert_close(A.flatten()[i+j], A[row, col])
# torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
if orderOut == "col32":
out2, S = F.nvidia_transform(out, from_order=orderOut, to_order="row", state=S)
torch.testing.assert_close(A, out2)
......@@ -8,8 +8,6 @@ import pytest
import torch
import bitsandbytes as bnb
from bitsandbytes import functional as F
from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout
from bitsandbytes.nn.modules import Linear8bitLt
from tests.helpers import (
TRUE_FALSE,
......@@ -18,28 +16,9 @@ from tests.helpers import (
torch_save_to_buffer,
)
# contributed by Alex Borzunov, see:
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
@pytest.mark.skipif(
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5),
reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs",
)
def test_layout_exact_match():
x = (torch.randn(14336 * 3, 14336) * 10).to(torch.int8).cuda()
for tile_size, order in ((8, 32), "col_turing"), ((32, 32), "col_ampere"):
transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device)
tile_indices = get_inverse_transform_indices(transform, tile_size)
cxb = transform(x)
torch.cuda.synchronize()
restored_x = undo_layout(cxb, tile_indices)
torch.cuda.synchronize()
assert restored_x.is_contiguous()
assert torch.all(torch.eq(restored_x, x))
def test_linear_no_igemmlt():
linear = torch.nn.Linear(1024, 3072)
x = torch.randn(3, 1024, dtype=torch.half)
......@@ -139,7 +118,7 @@ def test_linear_serialization(
if not has_fp16_weights:
assert os.path.getsize(state_path_8bit) < 0.5 * os.path.getsize(state_path)
new_state_dict = torch.load(state_path_8bit)
new_state_dict = torch.load(state_path_8bit, weights_only=False)
new_linear_custom = Linear8bitLt(
linear.in_features,
......
from math import prod
import pytest
import torch
import bitsandbytes
from tests.helpers import TRUE_FALSE, id_formatter
class TestLLMInt8Ops:
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_int8_linear_matmul(self, device):
A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device)
B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device)
out = torch.ops.bitsandbytes.int8_linear_matmul.default(A, B)
assert out.shape == (10, 30)
assert out.dtype == torch.int32
assert out.device == A.device
torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.default, (A, B))
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_int8_linear_matmul_out(self, device):
A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device)
B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device)
out = torch.empty((10, 30), dtype=torch.int32, device=device)
torch.ops.bitsandbytes.int8_linear_matmul.out(A, B, out)
assert out.shape == (10, 30)
assert out.dtype == torch.int32
assert out.device == A.device
torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.out, (A, B, out))
@pytest.mark.parametrize("threshold", [0.0, 6.0])
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_int8_vectorwise_quant(self, threshold, device):
if device == "cpu":
pytest.skip("CPU implementation is not available")
A = torch.randn(10, 20, dtype=torch.float16, device=device)
A[1][0] = 1000.0
out_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant(A, threshold=threshold)
assert out_row.shape == (10, 20)
assert out_row.dtype == torch.int8
assert out_row.device == A.device
assert row_stats.shape == (10,)
assert row_stats.dtype == torch.float32
assert row_stats.device == A.device
if threshold > 0.0:
assert outlier_cols is not None
assert outlier_cols.dim() == 1
assert outlier_cols.shape[0] <= A.shape[1]
assert outlier_cols.device == A.device
else:
assert outlier_cols is None
torch.library.opcheck(torch.ops.bitsandbytes.int8_vectorwise_quant, (A,))
torch.library.opcheck(torch.ops.bitsandbytes.int8_vectorwise_quant, (A, threshold))
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_int8_mm_dequant(self, device):
A = torch.randint(-128, 127, (256, 256), dtype=torch.int32, device=device)
row_stats = torch.randn(256, dtype=torch.float32, device=device)
col_stats = torch.randn(256, dtype=torch.float32, device=device)
out = torch.ops.bitsandbytes.int8_mm_dequant(A, row_stats, col_stats)
assert out.shape == A.shape
assert out.dtype == torch.float16
assert out.device == A.device
torch.library.opcheck(torch.ops.bitsandbytes.int8_mm_dequant, (A, row_stats, col_stats))
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE)
def test_int8_scaled_mm(self, device, dtype, has_bias):
A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device)
B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device)
row_stats = torch.randn(10, dtype=torch.float32, device=device)
col_stats = torch.randn(30, dtype=torch.float32, device=device)
bias = torch.randn(30, dtype=dtype, device=device) if has_bias else None
out = torch.ops.bitsandbytes.int8_scaled_mm(A, B, row_stats, col_stats, bias=bias, dtype=dtype)
assert out.shape == (10, 30)
assert out.dtype == dtype
assert out.device == A.device
torch.library.opcheck(torch.ops.bitsandbytes.int8_scaled_mm, (A, B, row_stats, col_stats, bias, dtype))
class TestInt8BlockwiseQuantOps:
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
def test_quantize_blockwise(self, device, dtype, blocksize):
if device == "cpu" and dtype != torch.float32:
pytest.skip("CPU implementation is only available for float32")
code = bitsandbytes.functional.create_dynamic_map().to(device)
A = torch.randn(1024, 1024, dtype=dtype, device=device)
out, absmax = torch.ops.bitsandbytes.quantize_blockwise(A, code, blocksize)
assert out.shape == A.shape
assert out.dtype == torch.uint8
assert out.device == A.device
assert absmax.device == A.device
assert absmax.dtype == torch.float32
torch.library.opcheck(torch.ops.bitsandbytes.quantize_blockwise, (A, code, blocksize))
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
def test_dequantize_blockwise(self, device, dtype, blocksize):
if device == "cpu" and dtype != torch.float32:
pytest.skip("CPU implementation is only available for float32")
A = torch.randint(0, 255, (1024, 1024), dtype=torch.uint8, device=device)
code = bitsandbytes.functional.create_dynamic_map().to(device, dtype=torch.float32)
n = A.numel()
blocks = -(n // -blocksize)
absmax = torch.randn((blocks,), device=device, dtype=torch.float32)
out = torch.ops.bitsandbytes.dequantize_blockwise.default(A, absmax, code, blocksize, dtype)
assert out.shape == A.shape
assert out.dtype == dtype
assert out.device == A.device
torch.library.opcheck(torch.ops.bitsandbytes.dequantize_blockwise.default, (A, absmax, code, blocksize, dtype))
class Test4bitBlockwiseQuantOps:
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
if device == "cpu" and quant_type != "nf4":
pytest.skip("CPU implementation is only available for nf4")
A = torch.randn(1024, 1024, dtype=dtype, device=device)
out, absmax = torch.ops.bitsandbytes.quantize_4bit(A, blocksize, quant_type, storage_dtype)
assert out.device == A.device
assert out.dtype == storage_dtype
assert absmax.device == A.device
assert absmax.dtype == torch.float32
torch.library.opcheck(torch.ops.bitsandbytes.quantize_4bit, (A, blocksize, quant_type, storage_dtype))
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
if device == "cpu":
pytest.skip("CPU implementation is not available")
shape = (128, 128)
n = prod(shape)
blocks = -(n // -blocksize)
quantized_shape = ((n + 1) // (storage_dtype.itemsize * 2), 1)
A = (
torch.randint(0, 255, ((n + 1) // 2,), dtype=torch.uint8, device=device)
.view(storage_dtype)
.reshape(quantized_shape)
.contiguous()
)
absmax = torch.randn((blocks,), dtype=torch.float32, device=device)
out = torch.ops.bitsandbytes.dequantize_4bit.default(A, absmax, blocksize, quant_type, shape, dtype)
assert out.device == A.device
assert out.shape == shape
torch.library.opcheck(
torch.ops.bitsandbytes.dequantize_4bit.default, (A, absmax, blocksize, quant_type, shape, dtype)
)
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
if device == "cpu":
pytest.skip("CPU implementation is not available")
out_features = 1024
in_features = 256
A = torch.randn((1, 1, in_features), dtype=dtype, device=device)
B = torch.randn((out_features, in_features), dtype=dtype, device=A.device)
B_q, absmax = torch.ops.bitsandbytes.quantize_4bit(B, blocksize, quant_type, storage_dtype)
code = bitsandbytes.functional.get_4bit_type(quant_type, device=A.device, blocksize=blocksize)
out = torch.ops.bitsandbytes.gemv_4bit.default(A, B_q, B.shape, absmax, code, blocksize)
assert out.device == A.device
assert out.dtype == dtype
assert out.shape == (1, 1, out_features)
assert out.isreal().all()
torch.library.opcheck(torch.ops.bitsandbytes.gemv_4bit.default, (A, B_q, B.shape, absmax, code, blocksize))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment