Commit bfa0e332 authored by Titus von Koeller's avatar Titus von Koeller
Browse files

ran black and isort for coherent code formatting

parent 597a8521
import pytest
from itertools import product
import pytest
import torch
import bitsandbytes as bnb
from itertools import product
import bitsandbytes as bnb
n = 1
k = 25
dim1 = torch.randint(16,64, size=(n,)).tolist()
dim2 = torch.randint(32,96, size=(n,)).tolist()
dim3 = torch.randint(32,96, size=(n,)).tolist()
dim4 = torch.randint(32,96, size=(n,)).tolist()
dim1 = torch.randint(16, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32, 96, size=(n,)).tolist()
funcs = [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)]
str_funcs = ['bmm', 'matmul']
str_funcs = ["bmm", "matmul"]
req_grad = [(False, False), (True, False), (True, True), (False, True)]
req_grad_str = ['FF', 'TF', 'TT', 'FT']
req_grad_str = ["FF", "TF", "TT", "FT"]
transpose = [(False, False), (False, True), (True, True), (True, False)]
str_transpose = ['FF', 'FT', 'TT', 'TF']
str_transpose = ["FF", "FT", "TT", "TF"]
dtype = [torch.float32, torch.float16]
values = list(product(dim1,dim2,dim3,dim4,funcs, dtype, req_grad, transpose))
str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, str_transpose))
names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}'.format(*vals) for vals in str_values]
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names)
values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose))
str_values = list(
product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose)
)
names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}".format(
*vals
)
for vals in str_values
]
@pytest.mark.parametrize(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names
)
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16)
......@@ -32,9 +43,11 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if funcs[0] in [torch.mm, torch.matmul]:
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
A = torch.randn(size=dimA, device='cuda', requires_grad=req_grad[0])
B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1])
target = torch.randn(size=(dim2, dim4), device='cuda', requires_grad=req_grad[1])
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0])
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
target = torch.randn(
size=(dim2, dim4), device="cuda", requires_grad=req_grad[1]
)
torch.nn.init.xavier_uniform_(B)
if not transpose[0] and not transpose[1]:
......@@ -52,9 +65,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx==0).sum().item() < n*0.0175
assert (idx == 0).sum().item() < n * 0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
assert (idx==0).sum().item() < n*0.001
assert (idx == 0).sum().item() < n * 0.001
if any(req_grad):
out_bnb.data.copy_(out_torch)
......@@ -78,16 +91,22 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx==0).sum().item() < n*0.1
assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx==0).sum().item() < n*0.02
assert (idx == 0).sum().item() < n * 0.02
torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3)
# batched matrix multiply
if funcs[0] in [torch.bmm, torch.matmul]:
A = torch.randn(size=(dim1, dim2, dim3), device='cuda', requires_grad=req_grad[0])
B = torch.randn(size=(dim1, dim3, dim4), device='cuda', requires_grad=req_grad[1])
target = torch.randn(size=(dim1, dim2, dim4), device='cuda', requires_grad=req_grad[1])
A = torch.randn(
size=(dim1, dim2, dim3), device="cuda", requires_grad=req_grad[0]
)
B = torch.randn(
size=(dim1, dim3, dim4), device="cuda", requires_grad=req_grad[1]
)
target = torch.randn(
size=(dim1, dim2, dim4), device="cuda", requires_grad=req_grad[1]
)
torch.nn.init.xavier_uniform_(B)
out_torch = funcs[0](A, B)
......@@ -95,7 +114,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx==0).sum().item() < n*0.01
assert (idx == 0).sum().item() < n * 0.01
torch.testing.assert_allclose(out_bnb, out_torch, atol=0.027, rtol=0.2)
if any(req_grad):
......@@ -120,16 +139,20 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx==0).sum().item() < n*0.1
assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx==0).sum().item() < n*0.02
assert (idx == 0).sum().item() < n * 0.02
if funcs[0] in [torch.matmul]:
dim1 = dim1 - (dim1 % 16)
A = torch.randn(size=(dim1, dim2, dim3), device='cuda', requires_grad=req_grad[0])
A = torch.randn(
size=(dim1, dim2, dim3), device="cuda", requires_grad=req_grad[0]
)
dimB = (dim4, dim3) if transpose[1] else (dim3, dim4)
B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1])
target = torch.randn(size=(dim1, dim2, dim4), device='cuda', requires_grad=req_grad[1])
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
target = torch.randn(
size=(dim1, dim2, dim4), device="cuda", requires_grad=req_grad[1]
)
torch.nn.init.xavier_uniform_(B)
if transpose[1]:
......@@ -141,9 +164,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx==0).sum().item() < n*0.0175
assert (idx == 0).sum().item() < n * 0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
assert (idx==0).sum().item() < n*0.001
assert (idx == 0).sum().item() < n * 0.001
if any(req_grad):
out_bnb.data.copy_(out_torch)
......@@ -167,51 +190,96 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx==0).sum().item() < n*0.1
assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx==0).sum().item() < n*0.02
assert (idx == 0).sum().item() < n * 0.02
n = 1
k = 3
dim1 = torch.randint(16,64, size=(n,)).tolist()
dim2 = torch.randint(32,96, size=(n,)).tolist()
dim3 = torch.randint(32,96, size=(n,)).tolist()
dim4 = torch.randint(32,96, size=(n,)).tolist()
dim1 = torch.randint(16, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32, 96, size=(n,)).tolist()
#dim1 = (17,)
#dim2 = (7,)
#dim3 = (37,)
#dim4 = (23,)
# dim1 = (17,)
# dim2 = (7,)
# dim3 = (37,)
# dim4 = (23,)
decomp = [0.0, 6.0]
funcs = [(torch.matmul, bnb.matmul)]
str_funcs = ['matmul']
str_funcs = ["matmul"]
req_grad = [(False, False), (True, False), (True, True), (False, True)]
req_grad_str = ['FF', 'TF', 'TT', 'FT']
req_grad_str = ["FF", "TF", "TT", "FT"]
transpose = [(False, True), (False, False)]
str_transpose = ['NT', 'NN']
str_transpose = ["NT", "NN"]
dtype = [torch.float16]
has_fp16_weights = [True, False]
values = list(product(dim1,dim2,dim3,dim4,funcs, dtype, req_grad, transpose, decomp, has_fp16_weights))
str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, str_transpose, decomp, has_fp16_weights))
names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}'.format(*vals) for vals in str_values]
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights", values, ids=names)
def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights):
values = list(
product(
dim1,
dim2,
dim3,
dim4,
funcs,
dtype,
req_grad,
transpose,
decomp,
has_fp16_weights,
)
)
str_values = list(
product(
dim1,
dim2,
dim3,
dim4,
str_funcs,
dtype,
req_grad_str,
str_transpose,
decomp,
has_fp16_weights,
)
)
names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}".format(
*vals
)
for vals in str_values
]
@pytest.mark.parametrize(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights",
values,
ids=names,
)
def test_matmullt(
dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights
):
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
outlier_dim = torch.randint(0, dimA[1], size=(dimA[1]//8,), device='cuda')
outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda")
for i in range(k):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
A = torch.randn(size=dimA, device='cuda', requires_grad=req_grad[0], dtype=dtype)
A = torch.randn(
size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype
)
if decomp == 6.0:
with torch.no_grad():
A[:, outlier_dim] = 6.0
B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1], dtype=dtype)
target = torch.randn(size=(dim2, dim4), device='cuda', requires_grad=req_grad[1], dtype=dtype)
B = torch.randn(
size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype
)
target = torch.randn(
size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype
)
torch.nn.init.xavier_uniform_(B)
B2 = B.clone()
......@@ -219,8 +287,15 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
state.threshold = decomp
state.has_fp16_weights = has_fp16_weights
if not has_fp16_weights:
if not transpose[0] and not transpose[1]: B2 = B2.t().contiguous()
state.CB, CBt, state.SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B2)
if not transpose[0] and not transpose[1]:
B2 = B2.t().contiguous()
(
state.CB,
CBt,
state.SCB,
SCBt,
coo_tensorB,
) = bnb.functional.double_quant(B2)
B2 = state.CB
if not transpose[0] and transpose[1]:
......@@ -231,12 +306,12 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
out_bnb = funcs[1](A, B2.t(), state=state)
n = out_bnb.numel()
err = torch.abs(out_bnb-out_torch).mean().item()
#print(f'abs error {err:.4f}')
err = torch.abs(out_bnb - out_torch).mean().item()
# print(f'abs error {err:.4f}')
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx==0).sum().item() < n*0.0175
assert (idx == 0).sum().item() < n * 0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
assert (idx==0).sum().item() < n*0.001
assert (idx == 0).sum().item() < n * 0.001
if has_fp16_weights:
if any(req_grad):
......@@ -263,8 +338,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
assert torch.abs(gradB1).sum() > 0.0
assert torch.abs(gradB2).sum() > 0.0
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx==0).sum().item() < n*0.1
assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx==0).sum().item() < n*0.02
assert (idx == 0).sum().item() < n * 0.02
torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3)
import pytest
import os
from typing import List, NamedTuple
import pytest
from typing import List
from bitsandbytes.cuda_setup import (CUDA_RUNTIME_LIB, evaluate_cuda_setup,
get_cuda_runtime_lib_path, tokenize_paths)
from bitsandbytes.cuda_setup import (
CUDA_RUNTIME_LIB,
get_cuda_runtime_lib_path,
evaluate_cuda_setup,
tokenize_paths,
)
class InputAndExpectedOutput(NamedTuple):
input: str
output: str
HAPPY_PATH__LD_LIB_TEST_PATHS: List[tuple[str,str]] = [
HAPPY_PATH__LD_LIB_TEST_PATHS: List[InputAndExpectedOutput] = [
(f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"),
(f":some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"),
(f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}:", f"dir/with/{CUDA_RUNTIME_LIB}"),
(f"some/other/dir::dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"),
(f"dir/with/{CUDA_RUNTIME_LIB}:some/other/dir", f"dir/with/{CUDA_RUNTIME_LIB}"),
(f"dir/with/{CUDA_RUNTIME_LIB}:other/dir/libcuda.so", f"dir/with/{CUDA_RUNTIME_LIB}"),
(
f"dir/with/{CUDA_RUNTIME_LIB}:other/dir/libcuda.so",
f"dir/with/{CUDA_RUNTIME_LIB}",
),
]
@pytest.mark.parametrize(
"test_input, expected",
HAPPY_PATH__LD_LIB_TEST_PATHS
)
@pytest.fixture(params=HAPPY_PATH__LD_LIB_TEST_PATHS)
def happy_path_path_string(tmpdir, request):
for path in tokenize_paths(request.param):
test_dir.mkdir()
if CUDA_RUNTIME_LIB in path:
(test_input / CUDA_RUNTIME_LIB).touch()
@pytest.mark.parametrize("test_input, expected", HAPPY_PATH__LD_LIB_TEST_PATHS)
def test_get_cuda_runtime_lib_path__happy_path(
tmp_path, test_input: str, expected: str
tmp_path, test_input: str, expected: str
):
for path in tokenize_paths(test_input):
assert False == tmp_path / test_input
test_dir.mkdir()
(test_input / CUDA_RUNTIME_LIB).touch()
path.mkdir()
(path / CUDA_RUNTIME_LIB).touch()
assert get_cuda_runtime_lib_path(test_input) == expected
......@@ -47,40 +55,33 @@ def test_get_cuda_runtime_lib_path__unhappy_path(tmp_path, test_input: str):
(test_input / CUDA_RUNTIME_LIB).touch()
with pytest.raises(FileNotFoundError) as err_info:
get_cuda_runtime_lib_path(test_input)
assert all(
match in err_info
for match in {"duplicate", CUDA_RUNTIME_LIB}
)
assert all(match in err_info for match in {"duplicate", CUDA_RUNTIME_LIB})
def test_get_cuda_runtime_lib_path__non_existent_dir(capsys, tmp_path):
existent_dir = tmp_path / 'a/b'
existent_dir = tmp_path / "a/b"
existent_dir.mkdir()
non_existent_dir = tmp_path / 'c/d' # non-existent dir
non_existent_dir = tmp_path / "c/d" # non-existent dir
test_input = ":".join([str(existent_dir), str(non_existent_dir)])
get_cuda_runtime_lib_path(test_input)
std_err = capsys.readouterr().err
assert all(
match in std_err
for match in {"WARNING", "non-existent"}
)
assert all(match in std_err for match in {"WARNING", "non-existent"})
def test_full_system():
## this only tests the cuda version and not compute capability
ld_path = os.environ['LD_LIBRARY_PATH']
paths = ld_path.split(':')
version = ''
ld_path = os.environ["LD_LIBRARY_PATH"]
paths = ld_path.split(":")
version = ""
for p in paths:
if 'cuda' in p:
idx = p.rfind('cuda-')
version = p[idx+5:idx+5+4].replace('/', '')
if "cuda" in p:
idx = p.rfind("cuda-")
version = p[idx + 5 : idx + 5 + 4].replace("/", "")
version = float(version)
break
binary_name = evaluate_cuda_setup()
binary_name = binary_name.replace('libbitsandbytes_cuda', '')
assert binary_name.startswith(str(version).replace('.', ''))
binary_name = binary_name.replace("libbitsandbytes_cuda", "")
assert binary_name.startswith(str(version).replace(".", ""))
import pytest
import math
import random
import time
import torch
import bitsandbytes as bnb
import einops
from itertools import product
import einops
import pytest
import torch
import bitsandbytes as bnb
from bitsandbytes import functional as F
torch.set_printoptions(precision=4, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000)
torch.set_printoptions(
precision=4, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
)
k = 20
def assert_all_approx_close(a, b, rtol, atol, count):
idx = torch.isclose(a, b, rtol, atol)
sumval = (idx==0).sum().item()
sumval = (idx == 0).sum().item()
if sumval > count:
print(f'Too many values not close: assert {sumval} < {count}')
print(f"Too many values not close: assert {sumval} < {count}")
torch.testing.assert_allclose(a, b, rtol, atol)
class FFN(torch.nn.Module):
def __init__(self, input_features, hidden_size, bias=True):
super(FFN, self).__init__()
......@@ -35,13 +39,14 @@ class FFN(torch.nn.Module):
x = self.fc2(x)
return x
class Timer(object):
def __init__(self):
self.starts = {}
self.ends = {}
self.agg = {}
def tick(self, name='default'):
def tick(self, name="default"):
if name not in self.starts:
self.starts[name] = torch.cuda.Event(enable_timing=True)
self.ends[name] = torch.cuda.Event(enable_timing=True)
......@@ -49,66 +54,70 @@ class Timer(object):
else:
ms = self.tock(name, evict=True, print_ms=False)
def tock(self, name='default', evict=True, print_ms=True):
def tock(self, name="default", evict=True, print_ms=True):
if name in self.ends:
self.ends[name].record()
torch.cuda.synchronize()
ms = self.starts[name].elapsed_time(self.ends[name])
if name not in self.agg: self.agg[name] = 0.0
if name not in self.agg:
self.agg[name] = 0.0
self.agg[name] += ms
if evict:
self.starts.pop(name)
self.ends.pop(name)
if print_ms and name in self.agg:
print('{0} took: {1:.5f}s'.format(name, self.agg[name]/1000.0))
print("{0} took: {1:.5f}s".format(name, self.agg[name] / 1000.0))
return self.agg[name]
def reset(self):
self.starts = {}
self.starts = {}
self.ends = {}
self.agg = {}
print('Resetting benchmark data')
print("Resetting benchmark data")
def setup():
pass
def teardown():
pass
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['float', 'half'])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["float", "half"])
def test_estimate_quantiles(dtype):
A = torch.rand(1024, 1024, device='cuda')
A = torch.rand(1024, 1024, device="cuda")
A = A.to(dtype)
code = F.estimate_quantiles(A)
percs = torch.linspace(1/512, 511/512, 256, device=A.device)
percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device)
torch.testing.assert_allclose(percs, code, atol=1e-3, rtol=1e-2)
A = torch.randn(1024, 1024, device='cuda')
A = torch.randn(1024, 1024, device="cuda")
A = A.to(dtype)
code = F.estimate_quantiles(A)
quantiles = torch.quantile(A.float(), percs)
diff = torch.abs(code-quantiles)
diff = torch.abs(code - quantiles)
assert (diff > 5e-02).sum().item() == 0
def test_quantile_quantization():
for i in range(100):
A1 = torch.randn(1024, 1024, device='cuda')
A1 = torch.randn(1024, 1024, device="cuda")
code = F.estimate_quantiles(A1)
C = F.quantize_no_absmax(A1, code)
A2 = F.dequantize_no_absmax(C, code)
diff = torch.abs(A1-A2).mean().item()
diff = torch.abs(A1 - A2).mean().item()
assert diff < 0.0075
A1 = torch.rand(1024, 1024, device='cuda')
A1 = torch.rand(1024, 1024, device="cuda")
code = F.estimate_quantiles(A1)
C = F.quantize_no_absmax(A1, code)
A2 = F.dequantize_no_absmax(C, code)
diff = torch.abs(A1-A2).mean().item()
diff = torch.abs(A1 - A2).mean().item()
torch.testing.assert_allclose(A1, A2, atol=5e-3, rtol=0)
assert diff < 0.001
......@@ -117,22 +126,22 @@ def test_dynamic_quantization():
diffs = []
reldiffs = []
for i in range(100):
A1 = torch.randn(1024, 1024, device='cuda')
A1 = torch.randn(1024, 1024, device="cuda")
C, S = F.quantize(A1)
A2 = F.dequantize(C, S)
diff = torch.abs(A1-A2)
reldiff = diff/torch.abs(A1+1e-8)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
assert diff.mean().item() < 0.0135
#print(sum(diffs)/len(diffs))
#print(sum(reldiffs)/len(reldiffs))
# print(sum(diffs)/len(diffs))
# print(sum(reldiffs)/len(reldiffs))
for i in range(100):
A1 = torch.rand(1024, 1024, device='cuda')
A1 = torch.rand(1024, 1024, device="cuda")
C, S = F.quantize(A1)
A2 = F.dequantize(C, S)
diff = torch.abs(A1-A2).mean().item()
diff = torch.abs(A1 - A2).mean().item()
torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
assert diff < 0.004
......@@ -141,56 +150,60 @@ def test_dynamic_blockwise_quantization():
diffs = []
reldiffs = []
for i in range(100):
A1 = torch.randn(1024, 1024, device='cuda')
A1 = torch.randn(1024, 1024, device="cuda")
C, S = F.quantize_blockwise(A1)
A2 = F.dequantize_blockwise(C, S)
diff = torch.abs(A1-A2)
reldiff = diff/torch.abs(A1+1e-8)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
assert diffs[-1] < 0.011
#print(sum(diffs)/len(diffs))
#print(sum(reldiffs)/len(reldiffs))
# print(sum(diffs)/len(diffs))
# print(sum(reldiffs)/len(reldiffs))
diffs = []
for i in range(100):
A1 = torch.rand(1024, 1024, device='cuda')
A1 = torch.rand(1024, 1024, device="cuda")
C, S = F.quantize_blockwise(A1)
A2 = F.dequantize_blockwise(C, S)
diff = torch.abs(A1-A2).mean().item()
diff = torch.abs(A1 - A2).mean().item()
assert diff < 0.0033
diffs.append(diff)
torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
#print(sum(diffs)/len(diffs))
# print(sum(diffs)/len(diffs))
def test_dynamic_blockwise_stochastic_quantization():
diffs = []
reldiffs = []
rand = torch.rand(1024).cuda()
for i in range(100):
A1 = torch.randn(1024, 1024, device='cuda')
A1 = torch.randn(1024, 1024, device="cuda")
C1, S1 = F.quantize_blockwise(A1, rand=rand)
C2, S2 = F.quantize_blockwise(A1)
# a maximunm distance of quantized values of 1
torch.testing.assert_allclose(C1, C2, atol=1, rtol=0)
fraction_smaller = (C1<C2).float().sum()/C1.numel()
fraction_larger = (C1>C2).float().sum()/C1.numel()
torch.testing.assert_allclose(fraction_larger, fraction_smaller, atol=0.01, rtol=0)
fraction_smaller = (C1 < C2).float().sum() / C1.numel()
fraction_larger = (C1 > C2).float().sum() / C1.numel()
torch.testing.assert_allclose(
fraction_larger, fraction_smaller, atol=0.01, rtol=0
)
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=['float', 'half'])
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=["float", "half"])
def test_percentile_clipping(gtype):
gnorm_vec1 = torch.zeros(100, device='cuda')
gnorm_vec2 = torch.zeros(100, device='cuda')
gnorm_vec1 = torch.zeros(100, device="cuda")
gnorm_vec2 = torch.zeros(100, device="cuda")
n = 4
step = 0
percentile=5
percentile = 5
for i in range(k):
step += 1
g = torch.randn(n, n, dtype=gtype, device='cuda')
gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile)
assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2/gnorm1
g = torch.randn(n, n, dtype=gtype, device="cuda")
gnorm1, clip2, gnorm_scale = F.percentile_clipping(
g, gnorm_vec2, step, percentile=percentile
)
assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1
gnorm2 = torch.norm(g.float())
if step == 1:
......@@ -208,74 +221,89 @@ def test_percentile_clipping(gtype):
def quant(x):
max1 = torch.abs(x).max()
x = torch.round(x/max1*127)
x = torch.round(x / max1 * 127)
return max1, x.to(torch.int8)
def dequant(c, maxC):
return c.float()*(maxC/127)
return c.float() * (maxC / 127)
def mm_dequant(maxA, maxB, C):
return C.float()*(maxA/127)*(maxB/127)
return C.float() * (maxA / 127) * (maxB / 127)
def quant_multi(x, dim):
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
max1[max1==0] = 1.0
x = torch.round(x/max1*127)
max1[max1 == 0] = 1.0
x = torch.round(x / max1 * 127)
return max1, x.to(torch.int8)
def quant_multi_chunk(x, dim, chunk_size=32):
if dim==1:
x_chunked = einops.rearrange(x, '(c a) b -> c a b', c=chunk_size)
max1 = torch.amax(torch.abs(x_chunked), dim=dim+1, keepdim=True)
if dim == 1:
x_chunked = einops.rearrange(x, "(c a) b -> c a b", c=chunk_size)
max1 = torch.amax(torch.abs(x_chunked), dim=dim + 1, keepdim=True)
max1 = torch.tile(max1, (1, 1, x.shape[1]))
max1 = max1.view(x.shape)
elif dim==0:
x_chunked = einops.rearrange(x, 'a (b c) -> a b c', c=chunk_size)
elif dim == 0:
x_chunked = einops.rearrange(x, "a (b c) -> a b c", c=chunk_size)
max1 = torch.amax(torch.abs(x_chunked), dim=dim, keepdim=True)
max1 = torch.tile(max1, (x.shape[0], 1, 1))
max1 = max1.view(x.shape)
max1[max1==0] = 1.0
x = torch.round(x/max1*127)
max1[max1 == 0] = 1.0
x = torch.round(x / max1 * 127)
return max1, x.to(torch.int8)
def quant_minmax(A):
minA = A.min()
maxA = A.max()
def mean(xx):
return sum(xx)/float(len(xx))
return sum(xx) / float(len(xx))
#dim1 = torch.randint(1,1024*4, size=(4,)).tolist()
#dim2 = torch.randint(1,1024*4, size=(4,)).tolist()
dim1 = [1024*2]
dim2 = [1024*16]
methods = [(lambda x, dim: quant(x), lambda x, dim: quant(x), dequant, dequant, mm_dequant)]
# dim1 = torch.randint(1,1024*4, size=(4,)).tolist()
# dim2 = torch.randint(1,1024*4, size=(4,)).tolist()
dim1 = [1024 * 2]
dim2 = [1024 * 16]
methods = [
(lambda x, dim: quant(x), lambda x, dim: quant(x), dequant, dequant, mm_dequant)
]
methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant))
#methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
method_names = ['linear', 'vectorwise']
# methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
method_names = ["linear", "vectorwise"]
batched = [False, True]
values = list(product(dim1,dim2, methods, batched))
values_names = list(product(dim1,dim2, method_names, batched))
names = ['dim1_{0}_dim2_{1}_quant_{2}_batched_{3}'.format(*vals) for vals in values_names]
values = list(product(dim1, dim2, methods, batched))
values_names = list(product(dim1, dim2, method_names, batched))
names = [
"dim1_{0}_dim2_{1}_quant_{2}_batched_{3}".format(*vals) for vals in values_names
]
@pytest.mark.parametrize("dim1, dim2, quant_methods, batched", values, ids=names)
def test_approx_igemm(dim1, dim2, quant_methods, batched):
dim1 = dim1 - (dim1 % 32)
dim2 = dim2 - (dim2 % 32)
errors = []
relerrors = []
print('')
print("")
for i in range(5):
if batched:
A = torch.normal(0, 0.5, size=(32, dim1, dim2//32), device='cuda')
B = torch.normal(0, 0.5, size=(32, dim2//32, dim1), device='cuda')
A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda")
B = torch.normal(0, 0.5, size=(32, dim2 // 32, dim1), device="cuda")
maxA, Ac = quant_methods[0](A, 2)
maxB, Bc = quant_methods[1](B, 1)
else:
A = torch.normal(0, 0.5, size=(dim1, dim2), device='cuda')
B = torch.normal(0, 0.5, size=(dim2, dim1), device='cuda')
A = torch.normal(0, 0.5, size=(dim1, dim2), device="cuda")
B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda")
maxA, Ac = quant_methods[0](A, 1)
maxB, Bc = quant_methods[1](B, 0)
torch.testing.assert_allclose(quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05)
torch.testing.assert_allclose(
quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05
)
if batched:
out2 = torch.bmm(A, B)
C = torch.bmm(Ac.float(), Bc.float())
......@@ -284,43 +312,49 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
C = F.igemm(Ac, Bc)
out = quant_methods[4](maxA, maxB, C)
std = out2.std()
out/= std
out2/= std
err = torch.abs(out-out2)
relerr = err/torch.abs(out2)
out /= std
out2 /= std
err = torch.abs(out - out2)
relerr = err / torch.abs(out2)
errors.append(err.mean().item())
relerrors.append(relerr.mean().item())
print(mean(errors))
print(mean(relerrors))
def test_stable_embedding():
layer = bnb.nn.StableEmbedding(1024, 1024)
layer.reset_parameters()
n = 2
hidden_dim = torch.randint(32,256, size=(n,)).tolist()
batch_dim = torch.randint(16,256, size=(n,)).tolist()
seq_dim = torch.randint(16,256, size=(n,)).tolist()
hidden_dim = torch.randint(32, 256, size=(n,)).tolist()
batch_dim = torch.randint(16, 256, size=(n,)).tolist()
seq_dim = torch.randint(16, 256, size=(n,)).tolist()
transpose = [(False, False), (False, True), (True, False), (True, True)]
values = list(product(hidden_dim,batch_dim, transpose, seq_dim))
names = ['hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}'.format(*vals) for vals in values]
values = list(product(hidden_dim, batch_dim, transpose, seq_dim))
names = [
"hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}".format(*vals)
for vals in values
]
@pytest.mark.parametrize("hidden_dim, batch_dim, transpose, seq_dim", values, ids=names)
def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
hidden_dim = hidden_dim - (hidden_dim % 32)
batch_dim = batch_dim - (batch_dim % 16)
seq_dim = seq_dim - (seq_dim % 16)
for i in range(k):
shapeA = (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim)
shapeB = ((32*random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32*random.randint(1, 4)))
A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8)
B = torch.randint(-128, 127, size=shapeB, device='cuda').to(torch.int8)
shapeA = (
(batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim)
)
shapeB = (
(32 * random.randint(1, 4), hidden_dim)
if transpose[1]
else (hidden_dim, 32 * random.randint(1, 4))
)
A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
if not transpose[0] and not transpose[1]:
out2 = torch.matmul(A.float(), B.float())
out = F.igemm(A, B)
......@@ -338,9 +372,13 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
for i in range(k):
shapeA = (batch_dim, seq_dim, hidden_dim)
shapeB = ((32*random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32*random.randint(1, 4)))
A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8)
B = torch.randint(-128, 127, size=shapeB, device='cuda').to(torch.int8)
shapeB = (
(32 * random.randint(1, 4), hidden_dim)
if transpose[1]
else (hidden_dim, 32 * random.randint(1, 4))
)
A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
if not transpose[0] and not transpose[1]:
out2 = torch.matmul(A.float(), B.float())
out = F.igemm(A, B)
......@@ -352,40 +390,51 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
n = 3
seq_dim = torch.randint(32,512, size=(n,)).tolist()
hidden_dim = torch.randint(32,1024*4, size=(n,)).tolist()
batch_dim = torch.randint(2,16, size=(n,)).tolist()
values = list(product(seq_dim,hidden_dim,batch_dim))
names = ['seq_dim{0}_hidden_dim{1}_batch_dim{2}'.format(*vals) for vals in values]
seq_dim = torch.randint(32, 512, size=(n,)).tolist()
hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
batch_dim = torch.randint(2, 16, size=(n,)).tolist()
values = list(product(seq_dim, hidden_dim, batch_dim))
names = ["seq_dim{0}_hidden_dim{1}_batch_dim{2}".format(*vals) for vals in values]
@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names)
def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
seq_dim = seq_dim - (seq_dim % 32)
hidden_dim = hidden_dim - (hidden_dim % 32)
batch_dim = batch_dim - (batch_dim % 2)
for i in range(25):
A = torch.randint(-128, 127, size=(batch_dim, seq_dim, hidden_dim), device='cuda').to(torch.int8)
B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device='cuda').to(torch.int8)
out2 = torch.einsum('bsi, bso->io', A.float(), B.float())
A = torch.randint(
-128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
).to(torch.int8)
B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda").to(
torch.int8
)
out2 = torch.einsum("bsi, bso->io", A.float(), B.float())
iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device)
out = F.igemm(A, B, out=iout)
torch.testing.assert_allclose(out.float(), out2)
n = 2
seq_dim = torch.randint(32,512, size=(n,)).tolist()
hidden_dim = torch.randint(32,1024*4, size=(n,)).tolist()
batch_dim = torch.randint(2,16, size=(n,)).tolist()
seq_dim = torch.randint(32, 512, size=(n,)).tolist()
hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
batch_dim = torch.randint(2, 16, size=(n,)).tolist()
transpose = [False, True]
values = list(product(seq_dim,hidden_dim,batch_dim, transpose))
names = ['seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}'.format(*vals) for vals in values]
values = list(product(seq_dim, hidden_dim, batch_dim, transpose))
names = [
"seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}".format(*vals)
for vals in values
]
@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim, transpose", values, ids=names)
def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
def min_max(x):
maxA = torch.amax(x, dim=2, keepdim=True)
minA = torch.amin(x, dim=2, keepdim=True)
scale = (maxA-minA)/2.0
return (127*(x-minA-scale)/scale).to(torch.int8), minA, scale
scale = (maxA - minA) / 2.0
return (127 * (x - minA - scale) / scale).to(torch.int8), minA, scale
seq_dim = seq_dim - (seq_dim % 16)
hidden_dim = hidden_dim - (hidden_dim % 16)
......@@ -395,30 +444,30 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
errs2 = []
relerrs2 = []
for i in range(k):
A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device='cuda')
A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda")
if transpose:
B = torch.normal(0, 0.5, size=(256, hidden_dim), device='cuda')
B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda")
else:
B = torch.normal(0, 0.5, size=(hidden_dim, 256), device='cuda')
B = torch.normal(0, 0.5, size=(hidden_dim, 256), device="cuda")
Ac, minA, scale = min_max(A)
if transpose:
maxB, Bc = quant_multi(B, dim=(1 if transpose else 0))
out = F.igemm(Ac, Bc.t())
out2 = torch.matmul(A,B.t())
offset = B.t().sum(0)*(minA+scale)
out2 = torch.matmul(A, B.t())
offset = B.t().sum(0) * (minA + scale)
out = out.float()
out = (out*maxB.t()*scale/(127*127))+offset
out = (out * maxB.t() * scale / (127 * 127)) + offset
maxA, Ac = quant_multi(A, dim=2)
out3 = F.igemm(Ac, Bc.t())
out3 = mm_dequant(maxA, maxB.t(), out3)
else:
maxB, Bc = quant_multi(B, dim=0)
offset = B.sum(0)*(minA+scale)
offset = B.sum(0) * (minA + scale)
out = F.igemm(Ac, Bc)
out2 = torch.matmul(A,B)
out2 = torch.matmul(A, B)
out = out.float()
out = (out*maxB*scale/(127*127))+offset
out = (out * maxB * scale / (127 * 127)) + offset
maxA, Ac = quant_multi(A, dim=2)
out3 = F.igemm(Ac, Bc)
......@@ -429,31 +478,36 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
out /= std
out3 /= std
err = torch.abs(out-out2)
relerr = err/(torch.abs(out2)+1e-7)
err = torch.abs(out - out2)
relerr = err / (torch.abs(out2) + 1e-7)
err2 = torch.abs(out3-out2)
relerr2 = err2/(torch.abs(out2)+1e-7)
err2 = torch.abs(out3 - out2)
relerr2 = err2 / (torch.abs(out2) + 1e-7)
errs.append(err.mean().item())
relerrs.append(relerr.mean().item())
errs2.append(err2.mean().item())
relerrs2.append(relerr2.mean().item())
#print(mean(errs))
#print(mean(relerrs))
#print(mean(errs2))
#print(mean(relerrs2))
# print(mean(errs))
# print(mean(relerrs))
# print(mean(errs2))
# print(mean(relerrs2))
assert mean(errs) < 0.015
assert mean(relerrs) < 0.3
n = 2
dim1 = torch.randint(1,64, size=(n,)).tolist()
dim2 = torch.randint(32,128, size=(n,)).tolist()
dim3 = torch.randint(32,256, size=(n,)).tolist()
dim4 = torch.randint(32,256, size=(n,)).tolist()
dim1 = torch.randint(1, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 128, size=(n,)).tolist()
dim3 = torch.randint(32, 256, size=(n,)).tolist()
dim4 = torch.randint(32, 256, size=(n,)).tolist()
transpose = [(False, False), (True, False), (False, True), (True, True)]
values = list(product(dim1,dim2,dim3,dim4,transpose))
names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}'.format(*vals) for vals in values]
values = list(product(dim1, dim2, dim3, dim4, transpose))
names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}".format(*vals) for vals in values
]
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, transpose", values, ids=names)
def test_ibmm(dim1, dim2, dim3, dim4, transpose):
dim2 = dim2 - (dim2 % 16)
......@@ -462,8 +516,8 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
for i in range(k):
shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3)
shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4)
A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8)
B = torch.randint(-128, 127, size=shapeB, device='cuda').to(torch.int8)
A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
if not transpose[0] and not transpose[1]:
out2 = torch.bmm(A.float(), B.float())
......@@ -479,146 +533,174 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
torch.testing.assert_allclose(out.float(), out2.float())
n = 1
dim1 = torch.randint(1,64, size=(n,)).tolist()
dim2 = torch.randint(32,128, size=(n,)).tolist()
dim3 = torch.randint(32,256, size=(n,)).tolist()
values = list(product(dim1,dim2,dim3))
names = ['dim1_{0}_dim2_{1}_dim3_{2}'.format(*vals) for vals in values]
dim1 = torch.randint(1, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 128, size=(n,)).tolist()
dim3 = torch.randint(32, 256, size=(n,)).tolist()
values = list(product(dim1, dim2, dim3))
names = ["dim1_{0}_dim2_{1}_dim3_{2}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names)
def test_vector_quant(dim1, dim2, dim3):
dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16)
for i in range(k):
A = torch.randn(size=(dim2, dim3), device='cuda')
A = torch.randn(size=(dim2, dim3), device="cuda")
qA, SA = F.vectorwise_quant(A, dim=0)
A1 = F.vectorwise_dequant(qA, SA)
torch.testing.assert_allclose(A1, A, atol=0.01, rtol=0.1)
n = 2
dim1 = torch.randint(2,256, size=(n,)).tolist()
dim2 = torch.randint(2,256, size=(n,)).tolist()
dim3 = torch.randint(2,256, size=(n,)).tolist()
#dim1, dim2 = (256,), (256,)
dim1 = torch.randint(2, 256, size=(n,)).tolist()
dim2 = torch.randint(2, 256, size=(n,)).tolist()
dim3 = torch.randint(2, 256, size=(n,)).tolist()
# dim1, dim2 = (256,), (256,)
dtype = [torch.int8, torch.int32]
a_order = ['row']
out_order = ['col', 'row', 'col32']
a_order = ["row"]
out_order = ["col", "row", "col32"]
transpose = [False]
dims = [2, 3]
values = list(product(dim1,dim2,dim3, dims,dtype, a_order, out_order, transpose))
values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose))
names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}".format(
*vals
)
for vals in values
]
names = ['dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names)
@pytest.mark.parametrize(
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names
)
def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
if dims == 3 and out_order != 'col32': return
if dtype == torch.int32 and out_order != 'col32': return
if dims == 3 and out_order != "col32":
return
if dtype == torch.int32 and out_order != "col32":
return
func = F.get_transform_func(dtype, orderA, orderOut, transpose)
if dims == 2:
A = torch.randint(-128, 127, size=(dim1, dim2), device='cuda').to(dtype)
A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype)
elif dims == 3:
A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(dtype)
A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(dtype)
out, S = F.nvidia_transform(A, to_order=orderOut)
if orderOut == 'row':
if orderOut == "row":
torch.testing.assert_allclose(A.flatten(), out.flatten())
elif orderOut == 'col':
elif orderOut == "col":
torch.testing.assert_allclose(A.t().flatten(), out.flatten())
elif orderOut == 'col32':
elif orderOut == "col32":
if dims == 2:
n = A.shape[0]*(A.shape[1] + (32 - (A.shape[1]%32)))
n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
elif dims == 3:
n = A.shape[0]*A.shape[1]*(A.shape[2] + (32 - (A.shape[2]%32)))
n = A.shape[0] * A.shape[1] * (A.shape[2] + (32 - (A.shape[2] % 32)))
assert out.numel() == n
elif orderOut == 'col_turing':
elif orderOut == "col_turing":
# 32 col 8 row tiles
n = (A.shape[0]+(8- A.shape[0]%8))*(A.shape[1] + (32 - (A.shape[1]%32)))
n = (A.shape[0] + (8 - A.shape[0] % 8)) * (
A.shape[1] + (32 - (A.shape[1] % 32))
)
assert out.numel() == n
total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0)
for row in range(A.shape[0]):
for col in range(A.shape[1]):
i = row*A.shape[1]
i = row * A.shape[1]
j = col
coltile = (col // 32) + (1 if col % 32 != 0 else 0)
rowtile = ((row // 8) + (1 if row % 8 != 0 else 0))*total_coltile
offset = 32*8*(rowtile+coltile)
rowtile = ((row // 8) + (1 if row % 8 != 0 else 0)) * total_coltile
offset = 32 * 8 * (rowtile + coltile)
col2 = col % 32
row2 = (row%8)*32
row2 = (row % 8) * 32
assert A.flatten()[i + j] == A[row, col]
# assert A.flatten()[i+j] == out.flatten()[row2+col2]
# torch.testing.assert_allclose(A.flatten()[i+j], A[row, col])
# torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
assert A.flatten()[i+j] == A[row, col]
#assert A.flatten()[i+j] == out.flatten()[row2+col2]
#torch.testing.assert_allclose(A.flatten()[i+j], A[row, col])
#torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
if orderOut == 'col32':
out2, S = F.nvidia_transform(out, from_order=orderOut, to_order='row', state=S)
if orderOut == "col32":
out2, S = F.nvidia_transform(out, from_order=orderOut, to_order="row", state=S)
torch.testing.assert_allclose(A, out2)
n = 1
dim1 = torch.randint(1,256, size=(n,)).tolist()
dim2 = torch.randint(32,512, size=(n,)).tolist()
dim3 = torch.randint(32,1024, size=(n,)).tolist()
dim4 = torch.randint(32,1024, size=(n,)).tolist()
dim1 = torch.randint(1, 256, size=(n,)).tolist()
dim2 = torch.randint(32, 512, size=(n,)).tolist()
dim3 = torch.randint(32, 1024, size=(n,)).tolist()
dim4 = torch.randint(32, 1024, size=(n,)).tolist()
#dim1 = [2]
#dim2 = [2]
#dim3 = [2]
#dim4 = [2]
# dim1 = [2]
# dim2 = [2]
# dim3 = [2]
# dim4 = [2]
dims = (2,3)
dims = (2, 3)
ldb = [0]
#ldb = list(range(256, 1*1024, 256))
values = list(product(dim1,dim2,dim3,dim4,dims, ldb))
names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}'.format(*vals) for vals in values]
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims, ldb))
names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}".format(*vals)
for vals in values
]
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims, ldb", values, ids=names)
def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
for i in range(k):
if dims == 2:
A = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8)
A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(
torch.int8
)
elif dims == 3:
A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
B = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(
torch.int8
)
B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8)
C1 = torch.matmul(A.float(), B.t().float())
A2, SA = F.transform(A, 'col32')
B2, SB = F.transform(B, 'col_turing')
A2, SA = F.transform(A, "col32")
B2, SB = F.transform(B, "col_turing")
C2, SC = F.igemmlt(A2, B2, SA, SB)
C3, S = F.nvidia_transform(C2, 'row', state=SC)
C3, S = F.nvidia_transform(C2, "row", state=SC)
torch.testing.assert_allclose(C1, C3.float())
# transpose
B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(torch.int8)
C1 = torch.matmul(A.float(), B.float())
B2t, SBt = F.transform(B, 'col_turing', transpose=True)
B2t, SBt = F.transform(B, "col_turing", transpose=True)
C2, SC = F.igemmlt(A2, B2t, SA, SBt)
C3, S = F.nvidia_transform(C2, 'row', state=SC)
C3, S = F.nvidia_transform(C2, "row", state=SC)
torch.testing.assert_allclose(C1, C3.float())
dim1 = [32]
dim2 = [32]
dim3 = [32]
dim4 = [32]
dims = (2,)
#ldb = list(range(256, 1*1024, 256))
values = list(product(dim1,dim2,dim3,dim4,dims))
names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}'.format(*vals) for vals in values]
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims))
names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}".format(*vals) for vals in values
]
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims", values, ids=names)
def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
formatB = F.get_special_format_str()
for i in range(k):
if dims == 2:
A = torch.normal(0, 0.5, size=(dim1, dim3), device='cuda').half()
A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half()
elif dims == 3:
A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device='cuda').half()
B = torch.randn((dim4, dim3), device='cuda').half()
A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device="cuda").half()
B = torch.randn((dim4, dim3), device="cuda").half()
torch.nn.init.xavier_uniform_(B)
C1 = torch.matmul(A, B.t())
C2 = bnb.matmul(A, B.t())
......@@ -627,50 +709,56 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)
C32A, SA = F.transform(CA, 'col32')
C32A, SA = F.transform(CA, "col32")
CxB, SB = F.transform(CB, to_order=formatB)
out1_32, Sout1_32 = F.igemmlt(C32A, CxB, SA, SB)
output = F.mm_dequant(out1_32, Sout1_32, statsAt, statsBt)
#print('')
#print(output.flatten()[:10])
#print(C1.flatten()[:10])
#print(C2.flatten()[:10])
# print('')
# print(output.flatten()[:10])
# print(C1.flatten()[:10])
# print(C2.flatten()[:10])
#torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
# torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
# transpose
#B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
#C1 = torch.matmul(A.float(), B.float())
# B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
# C1 = torch.matmul(A.float(), B.float())
# B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
# C2, SC = F.igemmlt(A2, B2t, SA, SBt)
# C3, S = F.transform(C2, 'row', state=SC)
# torch.testing.assert_allclose(C1, C3.float())
#B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
#C2, SC = F.igemmlt(A2, B2t, SA, SBt)
#C3, S = F.transform(C2, 'row', state=SC)
#torch.testing.assert_allclose(C1, C3.float())
batch_size = 2
seqdim = 512
#values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)]
values = [(batch_size, seqdim, 4*1024, 3*4*1024),(batch_size, seqdim, 5120, 3*5120),(batch_size, seqdim, 12*1024, 4*12*1024)]
# values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)]
values = [
(batch_size, seqdim, 4 * 1024, 3 * 4 * 1024),
(batch_size, seqdim, 5120, 3 * 5120),
(batch_size, seqdim, 12 * 1024, 4 * 12 * 1024),
]
# values = list(product(batch, seq, model, hidden))
names = ["batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values]
#values = list(product(batch, seq, model, hidden))
names = ['batch_{0}_seq_{1}_model_{2}_hidden_{3}'.format(*vals) for vals in values]
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_8bit_training(batch, seq, model, hidden):
formatB = F.get_special_format_str()
A = torch.randn(batch, seq, model, device='cuda').half()
grad = torch.randn(batch, seq, model, device='cuda').half()
w1 = torch.randint(-128, 127, size=(hidden, model), device='cuda').half()
w2 = torch.randint(-128, 127, size=(model, hidden), device='cuda').half()
print('')
A = torch.randn(batch, seq, model, device="cuda").half()
grad = torch.randn(batch, seq, model, device="cuda").half()
w1 = torch.randint(-128, 127, size=(hidden, model), device="cuda").half()
w2 = torch.randint(-128, 127, size=(model, hidden), device="cuda").half()
print("")
#torch.cuda.synchronize()
# torch.cuda.synchronize()
## warmup
#for i in range(100):
# for i in range(100):
# torch.matmul(A, w1.t())
#torch.cuda.synchronize()
# torch.cuda.synchronize()
dtype = torch.int8
A = A.view(-1, A.shape[-1]).contiguous()
......@@ -679,77 +767,77 @@ def test_bench_8bit_training(batch, seq, model, hidden):
t0 = time.time()
for i in range(k):
out1 = torch.matmul(A, w1.t()) # fc1
#out2 = torch.matmul(out1, w2.t())# fc2
out1 = torch.matmul(A, w1.t()) # fc1
# out2 = torch.matmul(out1, w2.t())# fc2
#d1 = torch.matmul(grad, w2) # delta1
#d2 = torch.matmul(d1, w1) # delta2
# d1 = torch.matmul(grad, w2) # delta1
# d2 = torch.matmul(d1, w1) # delta2
#grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2
#grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1
# grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2
# grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1
torch.cuda.synchronize()
t16 = time.time() - t0
print(t16)
#torch.cuda.empty_cache()
# torch.cuda.empty_cache()
#Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
#Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
# Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
# Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
#CTw1, Sw1 = F.transform2(Cw1, formatB)
#CTw2, Sw2 = F.transform2(Cw2, formatB)
#CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
#CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
# CTw1, Sw1 = F.transform2(Cw1, formatB)
# CTw2, Sw2 = F.transform2(Cw2, formatB)
# CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
# CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
#CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
#C32A, SA = F.transform2(CA, 'col32')
# CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
# C32A, SA = F.transform2(CA, 'col32')
## fc1
#out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
# out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t)
## fc2
#Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
#C32out1, Sout1 = F.transform2(Cout1, 'col32')
#out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
# Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
# C32out1, Sout1 = F.transform2(Cout1, 'col32')
# out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t)
## delta1
#Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
#C32grad, Sgrad = F.transform2(Cgrad, 'col32')
# Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
# C32grad, Sgrad = F.transform2(Cgrad, 'col32')
##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2)
## delta2
#Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
#C32d1, Sd1 = F.transform2(Cd1, 'col32')
# Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
# C32d1, Sd1 = F.transform2(Cd1, 'col32')
##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1)
## grad1
#C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
#CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
# C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
# CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad)
## grad2
#C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
#CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
# C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
# CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1)
#Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
# Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
#Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
#Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
# Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
# Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
#CTw1, Sw1 = F.transform2(Cw1, formatB)
#CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
#CTw2, Sw2 = F.transform2(Cw2, formatB)
#CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(k):
# CTw1, Sw1 = F.transform2(Cw1, formatB)
# CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
# CTw2, Sw2 = F.transform2(Cw2, formatB)
# CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(k):
# #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
# #CTw1, Sw1 = F.transform2(Cw1, formatB)
# #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
......@@ -802,74 +890,76 @@ def test_bench_8bit_training(batch, seq, model, hidden):
# #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
# #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t)
#torch.cuda.synchronize()
#t8 = time.time() - t0
#print(t8)
# torch.cuda.synchronize()
# t8 = time.time() - t0
# print(t8)
n = 2
dim1 = torch.randint(64,256, size=(n,)).tolist()
dim4 = torch.randint(64,1024, size=(n,)).tolist()
dim1 = torch.randint(64, 256, size=(n,)).tolist()
dim4 = torch.randint(64, 1024, size=(n,)).tolist()
#dim1 = [2*1024]
#dim4 = [2*1024]
# dim1 = [2*1024]
# dim4 = [2*1024]
#dim1 = [4]
#dim4 = [4]
# dim1 = [4]
# dim4 = [4]
dims = (2,)
#ldb = list(range(256, 1*1024, 256))
formatB = ['col_turing', 'col_ampere']
values = list(product(dim1,dim4,dims, formatB))
names = ['dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}'.format(*vals) for vals in values]
# ldb = list(range(256, 1*1024, 256))
formatB = ["col_turing", "col_ampere"]
values = list(product(dim1, dim4, dims, formatB))
names = ["dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, dims, formatB", values, ids=names)
def test_dequant_mm(dim1, dim4, dims, formatB):
inner = torch.randint(1, 128, size=(1,)).item()
formatB = F.get_special_format_str()
for i in range(k):
A = torch.randn(dim1, inner, device='cuda')
B = torch.randn(dim4, inner, device='cuda')
A = torch.randn(dim1, inner, device="cuda")
B = torch.randn(dim4, inner, device="cuda")
C1 = torch.matmul(A.half(), B.t().half())
A1, maxA = F.vectorwise_quant(A, dim=1)
B1, maxB = F.vectorwise_quant(B, dim=1)
A2, SA = F.nvidia_transform(A1, 'col32')
A2, SA = F.nvidia_transform(A1, "col32")
B2, SB = F.nvidia_transform(B1, formatB)
C2, SC = F.igemmlt(A2, B2, SA, SB)
C3, S = F.nvidia_transform(C2, 'row', state=SC)
C3, S = F.nvidia_transform(C2, "row", state=SC)
C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())
count = (torch.isclose(C1, C4, atol=0.01, rtol=0.1) == 0).sum().item()
n = C1.numel()
p = 0.06
assert count/n < p, f'error in more than {p} of elements: {count}/{n}={count/n}'
assert (
count / n < p
), f"error in more than {p} of elements: {count}/{n}={count/n}"
C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten())
torch.testing.assert_allclose(C5, C4)
#print(C2)
# print(C2)
n = 2
dim1 = [1*1024]
dim2 = [1*1024]
#dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
#dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1 = [1 * 1024]
dim2 = [1 * 1024]
# dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dims = (2,)
#ldb = list(range(256, 1*1024, 256))
values = list(product(dim1,dim2,dims))
names = ['dim1_{0}_dim2_{1}_dims_{2}'.format(*vals) for vals in values]
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dims))
names = ["dim1_{0}_dim2_{1}_dims_{2}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dims", values, ids=names)
def test_colrow_absmax(dim1, dim2, dims):
for i in range(k):
threshold = 3.0
A = torch.randn(dim1, dim2, device='cuda').half()
A = torch.randn(dim1, dim2, device="cuda").half()
A_truncated = A.clone()
A_truncated[torch.abs(A_truncated) >= 3.0] = 0.0
if dims == 2:
......@@ -880,11 +970,22 @@ def test_colrow_absmax(dim1, dim2, dims):
else:
assert False
row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold)
A_blocked = einops.rearrange(torch.abs(A), '(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size', row_tiles=16, block_size=64*4)
nnz_rows1_counts = (torch.abs(A_blocked)>=threshold).sum(3).flatten()
nnz_block_ptr1 = torch.zeros(nnz_rows1_counts.shape[0]+1, dtype=nnz_rows1_counts.dtype, device=nnz_rows1_counts.device)
row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
A, threshold=threshold
)
A_blocked = einops.rearrange(
torch.abs(A),
"(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size",
row_tiles=16,
block_size=64 * 4,
)
nnz_rows1_counts = (torch.abs(A_blocked) >= threshold).sum(3).flatten()
nnz_block_ptr1 = torch.zeros(
nnz_rows1_counts.shape[0] + 1,
dtype=nnz_rows1_counts.dtype,
device=nnz_rows1_counts.device,
)
nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0)
torch.testing.assert_allclose(col_stats1_trunc, col_stats2)
......@@ -898,19 +999,20 @@ def test_colrow_absmax(dim1, dim2, dims):
assert nnz_block_ptr2 is None
n = 2
#dim1 = [8*1024]
#dim2 = [4*1024]
dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
# dim1 = [8*1024]
# dim2 = [4*1024]
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]
values = list(product(dim1,dim2))
names = ['dim1_{0}_dim2_{1}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_double_quant(dim1, dim2):
for i in range(k):
A = torch.randn(dim1, dim2, device='cuda').half()
A = torch.randn(dim1, dim2, device="cuda").half()
out_col1, Scol = F.vectorwise_quant(A, dim=0)
out_row1, Srow = F.vectorwise_quant(A, dim=1)
......@@ -920,18 +1022,21 @@ def test_double_quant(dim1, dim2):
torch.testing.assert_allclose(CA, out_row1, atol=1, rtol=0)
torch.testing.assert_allclose(CAt, out_col1, atol=1, rtol=0)
n = CAt.numel()
num_not_close_rows = (torch.isclose(CA, out_row1, atol=1)==0).sum().item()
num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1)==0).sum().item()
num_not_close_rows = (torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
# allow for 1:500 error due to rounding differences
min_error = 1/500
if num_not_close_cols > (min_error*n):
print(f'Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}')
min_error = 1 / 500
if num_not_close_cols > (min_error * n):
print(
f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}"
)
assert False
if num_not_close_rows > (min_error*n):
print(f'Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}')
if num_not_close_rows > (min_error * n):
print(
f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}"
)
assert False
torch.testing.assert_allclose(Srow.flatten(), statsA)
......@@ -939,21 +1044,23 @@ def test_double_quant(dim1, dim2):
n = 4
dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
dim4 = torch.randint(1,4*1024, size=(n,)).tolist()
inner = torch.randint(1,4*1024, size=(n,)).tolist()
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim1 = [6]
dim4 = [4]
inner = [8]
values = list(zip(dim1, dim4, inner))
names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
def test_integrated_igemmlt(dim1, dim4, inner):
for i in range(k):
A = torch.randn(dim1, inner, device='cuda').half()
B = torch.randn(dim4, inner, device='cuda').half()
A = torch.randn(dim1, inner, device="cuda").half()
B = torch.randn(dim4, inner, device="cuda").half()
out1 = torch.matmul(A.half(), B.t().half())
......@@ -967,30 +1074,32 @@ def test_integrated_igemmlt(dim1, dim4, inner):
torch.testing.assert_allclose(C1a, A1, rtol=0, atol=1)
torch.testing.assert_allclose(C2a, B1, rtol=0, atol=1)
A2, SA = F.nvidia_transform(C1a, 'col32')
B2, SB = F.nvidia_transform(C2a, 'col_turing')
A2, SA = F.nvidia_transform(C1a, "col32")
B2, SB = F.nvidia_transform(C2a, "col_turing")
outC32, SC = F.igemmlt(A2, B2, SA, SB)
out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)
A2, SA = F.nvidia_transform(A1, 'col32')
B2, SB = F.nvidia_transform(B1, 'col_turing')
A2, SA = F.nvidia_transform(A1, "col32")
B2, SB = F.nvidia_transform(B1, "col_turing")
C2, SC = F.igemmlt(A2, B2, SA, SB)
C3, S = F.nvidia_transform(C2, 'row', state=SC)
C3, S = F.nvidia_transform(C2, "row", state=SC)
out3 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())
err1 = torch.abs(out1-out2).mean().item()
err2 = torch.abs(out1-out3).mean().item()
assert err2 <= err1*1.01
err1 = torch.abs(out1 - out2).mean().item()
err2 = torch.abs(out1 - out3).mean().item()
assert err2 <= err1 * 1.01
n = 6
dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
dim4 = torch.randint(1,4*1024, size=(n,)).tolist()
inner = torch.randint(1,4*1024, size=(n,)).tolist()
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
values = list(zip(dim1, dim4, inner))
names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
@pytest.mark.skip("Row scale has some bugs for ampere")
def test_igemmlt_row_scale(dim1, dim4, inner):
......@@ -999,79 +1108,79 @@ def test_igemmlt_row_scale(dim1, dim4, inner):
relerr1, relerr2 = [], []
scale = 1
for i in range(k):
A = torch.randn(dim1, inner, device='cuda').half()
B = torch.randn(dim4, inner, device='cuda').half()
A = torch.randn(dim1, inner, device="cuda").half()
B = torch.randn(dim4, inner, device="cuda").half()
torch.nn.init.xavier_uniform_(B)
C1 = torch.matmul(A, B.t())
out1 = torch.matmul(A.half(), B.t().half())
C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
CB, absmaxB = F.vectorwise_quant(B, quant_type='linear')
A2, SA = F.nvidia_transform(C1a, 'col32')
CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
A2, SA = F.nvidia_transform(C1a, "col32")
B2, SB = F.nvidia_transform(CB, formatB)
A1, maxA = F.vectorwise_quant(A, dim=1)
c = 10.0*inner*scale
row_scale = torch.ones_like(maxA)/c
c = 10.0 * inner * scale
row_scale = torch.ones_like(maxA) / c
outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale)
C3, S = F.nvidia_transform(outC32, 'row', state=SC)
C3, S = F.nvidia_transform(outC32, "row", state=SC)
maxval = torch.abs(C3).max()
if maxval == 127:
scale = 1.5
else:
scale = maxval/120
out3 = C3*maxA*absmaxB*c/(127*127)
scale = maxval / 120
out3 = C3 * maxA * absmaxB * c / (127 * 127)
C4 = torch.matmul(C1a.float(), CB.float().t())
C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
B2, SB = F.nvidia_transform(C2a, formatB)
outC32, SC = F.igemmlt(A2, B2, SA, SB)
out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)
CA, SA = F.vectorwise_quant(A, dim=1, quant_type='vector')
CB, SB = F.vectorwise_quant(B, dim=1, quant_type='linear')
CA, SA = F.vectorwise_quant(A, dim=1, quant_type="vector")
CB, SB = F.vectorwise_quant(B, dim=1, quant_type="linear")
C = torch.matmul(CA.float(), CB.t().float())
out4 = C*SA*SB/(127*127)
#out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127)
out4 = C * SA * SB / (127 * 127)
# out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127)
#print('='*80)
#print(out1)
#print(out2)
#print(out3)
# print('='*80)
# print(out1)
# print(out2)
# print(out3)
#print(out1)
#print(out2)
#print(out3)
err1.append(torch.abs(out1-out2).mean().item())
err2.append(torch.abs(out1-out3).mean().item())
err3.append(torch.abs(out1-out4).mean().item())
# print(out1)
# print(out2)
# print(out3)
err1.append(torch.abs(out1 - out2).mean().item())
err2.append(torch.abs(out1 - out3).mean().item())
err3.append(torch.abs(out1 - out4).mean().item())
#assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10)
print('')
print(sum(err1)/len(err1))
print(sum(err2)/len(err2))
print(sum(err3)/len(err3))
# assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10)
print("")
print(sum(err1) / len(err1))
print(sum(err2) / len(err2))
print(sum(err3) / len(err3))
dim1 = [1024, 2048]
inner = [12288*4, 4096*4]
inner = [12288 * 4, 4096 * 4]
dim4 = [12288, 4096]
values = list(zip(dim1, dim4, inner))
names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
@pytest.mark.skip("Row scale has some bugs for ampere")
def test_row_scale_bench(dim1, dim4, inner):
err1, err2, err3 = [], [], []
relerr1, relerr2 = [], []
scale = 1
A = torch.randn(dim1, inner, device='cuda').half()
B = torch.randn(dim4, inner, device='cuda').half()
A = torch.randn(dim1, inner, device="cuda").half()
B = torch.randn(dim4, inner, device="cuda").half()
torch.nn.init.xavier_uniform_(B)
# warmpup
for i in range(k):
......@@ -1082,23 +1191,22 @@ def test_row_scale_bench(dim1, dim4, inner):
for i in range(k):
C1 = torch.matmul(A, B.t())
torch.cuda.synchronize()
print('16', time.time()-t0)
print("16", time.time() - t0)
C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
CB, absmaxB = F.vectorwise_quant(B, quant_type='linear')
A2, SA = F.nvidia_transform(C1a, 'col32')
CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
A2, SA = F.nvidia_transform(C1a, "col32")
B2, SB = F.nvidia_transform(CB, formatB)
A1, maxA = F.vectorwise_quant(A, dim=1)
c = 10.0*inner*scale
row_scale = maxA/c
c = 10.0 * inner * scale
row_scale = maxA / c
torch.cuda.synchronize()
t0 = time.time()
for i in range(k):
outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale)
torch.cuda.synchronize()
print('row-wise', time.time()-t0)
print("row-wise", time.time() - t0)
C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
B2, SB = F.nvidia_transform(C2a, formatB)
......@@ -1107,32 +1215,39 @@ def test_row_scale_bench(dim1, dim4, inner):
for i in range(k):
outC32, SC = F.igemmlt(A2, B2, SA, SB)
torch.cuda.synchronize()
print('vector-wise', time.time()-t0)
print("vector-wise", time.time() - t0)
n = 2
dim1 = torch.randint(2,1024, size=(n,)).tolist()
dim2 = torch.randint(2,1024, size=(n,)).tolist()
#dim1 = [8*1024]
#dim2 = [4*1024]
dim1 = torch.randint(2, 1024, size=(n,)).tolist()
dim2 = torch.randint(2, 1024, size=(n,)).tolist()
# dim1 = [8*1024]
# dim2 = [4*1024]
dim3 = [0]
dtype = [torch.int8]
a_order = ['row']
out_order = ['col32', 'col_turing', 'col_ampere']
a_order = ["row"]
out_order = ["col32", "col_turing", "col_ampere"]
transpose = [False, True]
dims = [2]
values = list(product(dim1,dim2,dim3, dims,dtype, a_order, out_order, transpose))
names = ['dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names)
values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose))
names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}".format(
*vals
)
for vals in values
]
@pytest.mark.parametrize(
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names
)
def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
for i in range(k):
if dims == 2:
A = torch.randint(10, 99, size=(dim1, dim2), device='cuda').to(dtype)
A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(dtype)
elif dims == 3:
A = torch.randint(10, 99, size=(dim1, dim2, dim3), device='cuda').to(dtype)
A = torch.randint(10, 99, size=(dim1, dim2, dim3), device="cuda").to(dtype)
A.view(-1)[-1] = -1
if transpose:
......@@ -1144,53 +1259,55 @@ def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
assert S1[0][0] == S2[0][0]
assert S1[0][1] == S2[0][1]
#print(out1)
#print(out2)
# print(out1)
# print(out2)
torch.testing.assert_allclose(out1, out2)
n = 2
#dim1 = torch.randint(2,1024, size=(n,)).tolist()
#dim2 = torch.randint(2,1024, size=(n,)).tolist()
# dim1 = torch.randint(2,1024, size=(n,)).tolist()
# dim2 = torch.randint(2,1024, size=(n,)).tolist()
dim1 = [1]
dim2 = [33]
dtype = [torch.int8]
#a_order = ['col_turing', 'col_ampere']
a_order = ['col_turing']
out_order = ['row']
values = list(product(dim1,dim2,dtype, a_order, out_order))
names = ['dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}'.format(*vals) for vals in values]
# a_order = ['col_turing', 'col_ampere']
a_order = ["col_turing"]
out_order = ["row"]
values = list(product(dim1, dim2, dtype, a_order, out_order))
names = [
"dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}".format(*vals)
for vals in values
]
@pytest.mark.parametrize("dim1, dim2, dtype, orderA, orderOut", values, ids=names)
def test_transform_to_row(dim1, dim2, dtype, orderA, orderOut):
for i in range(1):
A = torch.randint(-127, 127, size=(dim1, dim2), device='cuda').to(dtype)
A = torch.randint(-127, 127, size=(dim1, dim2), device="cuda").to(dtype)
out2, S2 = F.transform(A, to_order=orderA)
A2, S3 = F.transform(out2, from_order=orderA, to_order='row', state=S2)
A2, S3 = F.transform(out2, from_order=orderA, to_order="row", state=S2)
assert A2.shape[0] == A.shape[0]
assert A2.shape[1] == A.shape[1]
print('')
print("")
print(A)
print(out2)
print(A2)
#torch.testing.assert_allclose(A, A2)
# torch.testing.assert_allclose(A, A2)
def test_overflow():
formatB = F.get_special_format_str()
print(formatB)
for i in range(2):
a = torch.arange(5, 15).cuda().to(torch.int8).view(-1,1 )
b = torch.arange(5, 15).cuda().to(torch.int8).view(-1,1 )
a = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
b = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
Ca, Sa = F.nvidia_transform(a, 'col32')
Ca, Sa = F.nvidia_transform(a, "col32")
Cb, Sb = F.nvidia_transform(b, formatB)
c = F.igemmlt(Ca, Cb, Sa, Sb, dtype=torch.int8)
......@@ -1198,46 +1315,51 @@ def test_overflow():
n = 2
dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
#dim1 = [4]
#dim2 = [5]
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
# dim1 = [4]
# dim2 = [5]
values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]
values = list(product(dim1,dim2))
names = ['dim1_{0}_dim2_{1}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_coo_double_quant(dim1, dim2):
threshold = 3.00
for i in range(k):
A = torch.randn(dim1, dim2, device='cuda').half()
A = torch.randn(dim1, dim2, device="cuda").half()
idx = (torch.abs(A) >= threshold)
idx = torch.abs(A) >= threshold
CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold)
if coo_tensor is not None:
A1 = A*idx
A1 = A * idx
A2 = torch.zeros_like(A)
A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values
torch.testing.assert_allclose(A1, A2)
A1 = A*(idx==0)
A2 = (CA.float()*statsA.unsqueeze(1)/127).half()
torch.testing.assert_allclose(A*(idx==0), A2, rtol=0.05, atol=1.5e-2)
A1 = A * (idx == 0)
A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
torch.testing.assert_allclose(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2)
n = 2
dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
dim2 = torch.randint(1,1*1024, size=(n,)).tolist()
#dim1 = [7]
#dim2 = [11]
dim1 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
# dim1 = [7]
# dim2 = [11]
transposed_B = [False, True]
values = list(product(dim1,dim2, transposed_B))
names = ['dim1_{0}_dim2_{1}_transposed_B_{2}'.format(*vals) for vals in values]
values = list(product(dim1, dim2, transposed_B))
names = ["dim1_{0}_dim2_{1}_transposed_B_{2}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names)
def test_spmm_coo(dim1, dim2, transposed_B):
threshold = 1.5
dim3 = torch.randint(32, 128, size=(1,)).item()
#dim3 = 17
# dim3 = 17
for i in range(k):
A = torch.randn(dim1, dim2).cuda().half()
if transposed_B:
......@@ -1249,8 +1371,10 @@ def test_spmm_coo(dim1, dim2, transposed_B):
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
A2 = A*idx
cooA = F.COOSparseTensor(
A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
)
A2 = A * idx
if transposed_B:
out2 = F.spmm_coo(cooA, B.t())
......@@ -1262,18 +1386,17 @@ def test_spmm_coo(dim1, dim2, transposed_B):
assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30)
def test_spmm_bench():
batch = 2
model = 1024*1
hidden = model*4
model = 1024 * 1
hidden = model * 4
seq = 1024
dim1 = batch*seq
dim1 = batch * seq
dim2 = model
dim3 = hidden
threshold = 4
A = torch.randn(dim1, dim2, device='cuda').half()
B = torch.randn(dim2, dim3, device='cuda').half()
A = torch.randn(dim1, dim2, device="cuda").half()
B = torch.randn(dim2, dim3, device="cuda").half()
for i in range(10):
C1 = bnb.matmul(A, B)
......@@ -1282,14 +1405,16 @@ def test_spmm_bench():
for i in range(k):
C1 = bnb.matmul(A, B)
torch.cuda.synchronize()
t8 = time.time()-t0
t8 = time.time() - t0
idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item()
print(nnz/idx.numel())
print(nnz / idx.numel())
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
cooA = F.COOSparseTensor(
A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
)
for i in range(10):
out2 = F.spmm_coo(cooA, B)
......@@ -1299,20 +1424,22 @@ def test_spmm_bench():
for i in range(k):
out2 = F.spmm_coo(cooA, B)
torch.cuda.synchronize()
tsp = time.time()-t0
tsp = time.time() - t0
print(tsp, t8)
print(tsp/t8)
print(tsp / t8)
n = 2
dim1 = torch.randint(256,1*1024, size=(n,)).tolist()
dim2 = torch.randint(256,1*1024, size=(n,)).tolist()
values = list(product(dim1,dim2))
names = ['dim1_{0}_dim2_{1}'.format(*vals) for vals in values]
dim1 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
dim2 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_integrated_sparse_decomp(dim1, dim2):
threshold = 3.0
formatB = 'col_turing'
formatB = "col_turing"
for i in range(k):
A = torch.randn(dim1, dim2).cuda().half()
w1 = torch.randn(dim1, dim2).cuda().half()
......@@ -1322,13 +1449,13 @@ def test_integrated_sparse_decomp(dim1, dim2):
CTw1, Sw1 = F.transform(Cw1, formatB)
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
C32A, SA = F.transform(CA, 'col32')
C32A, SA = F.transform(CA, "col32")
out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold)
C32A, SA = F.transform(CA, 'col32')
C32A, SA = F.transform(CA, "col32")
out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
out3 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
......@@ -1338,8 +1465,8 @@ def test_integrated_sparse_decomp(dim1, dim2):
out4 = F.spmm_coo(coo_tensor, w1.t())
out5 = out3 + out4
err1 = torch.abs(out1-out2).mean().item()
err2 = torch.abs(out1-out5).mean().item()
err1 = torch.abs(out1 - out2).mean().item()
err2 = torch.abs(out1 - out5).mean().item()
assert err2 < err1
......@@ -1350,91 +1477,95 @@ def test_matmuls():
c2 = bnb.matmul(a, b)
c3 = bnb.matmul(a, b)
err1 = torch.abs(c1-c2).mean().item()
err2 = torch.abs(c1-c3).mean().item()
err1 = torch.abs(c1 - c2).mean().item()
err2 = torch.abs(c1 - c3).mean().item()
assert err1 < 0.2
assert err2 < 0.2
n = 2
#dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
#dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1 = [1*2048]
# dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1 = [1 * 2048]
dim2 = [12288]
#dim1 = [32]
#dim2 = [32]
#dtype = [torch.float16, torch.int8]
# dim1 = [32]
# dim2 = [32]
# dtype = [torch.float16, torch.int8]
dtype = [torch.float16]
out_function = ['zeros', 'ones']
values = list(product(dim1,dim2, dtype, out_function))
names = ['dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}'.format(*vals) for vals in values]
out_function = ["zeros", "ones"]
values = list(product(dim1, dim2, dtype, out_function))
names = ["dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names)
def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
out_func = getattr(torch, out_func)
threshold = 3.3
#threshold = 2.8
#threshold = 0.0
A = torch.randn(dim1, dim2, device='cuda').half()
# threshold = 2.8
# threshold = 0.0
A = torch.randn(dim1, dim2, device="cuda").half()
if dtype == torch.float16:
B = torch.randn(dim2, dim2*4, device='cuda').half()
B = torch.randn(dim2, dim2 * 4, device="cuda").half()
torch.nn.init.xavier_uniform_(B)
else:
B = torch.randn(dim2, dim2*4, device='cuda').half()
B = torch.randn(dim2, dim2 * 4, device="cuda").half()
torch.nn.init.xavier_uniform_(B)
B, SB = F.vectorwise_quant(B, quant_type='linear')
#B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)
B, SB = F.vectorwise_quant(B, quant_type="linear")
# B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)
print('')
print("")
idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
A2 = A*idx
cooA = F.COOSparseTensor(
A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
)
A2 = A * idx
out1 = torch.matmul(A2.half(), B.half())
out = out_func(out1.shape, dtype=torch.float16, device=out1.device)
out1 += out.clone()
out2 = F.spmm_coo_very_sparse(cooA, B, out=out)
#print(B)
#print(out1)
#print(out2)
p = 200/(2048*12288*4)
# print(B)
# print(out1)
# print(out2)
p = 200 / (2048 * 12288 * 4)
n = out1.numel()
count = math.ceil(p*n)
count = math.ceil(p * n)
std = out1.std()
out1 /= std
out2 /= std
assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count)
#assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
# assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
idx_col = torch.randint(0, A2.shape[-1], size=(15,))
#torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001)
# torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001)
#Bt = torch.randn(dim2*4, dim2, device='cuda').half()
#torch.cuda.synchronize()
#t0 = time.time()
#print(A2.shape, B.shape)
#for i in range(100):
# Bt = torch.randn(dim2*4, dim2, device='cuda').half()
# torch.cuda.synchronize()
# t0 = time.time()
# print(A2.shape, B.shape)
# for i in range(100):
# #out3 = F.spmm_coo(cooA, Bt.t())
# #out2 = F.spmm_coo(cooA, B)
# #out2 = F.spmm_coo_very_sparse(cooA, B)
# #out1 = torch.matmul(A, Bt.t())
#torch.cuda.synchronize()
#print(time.time() - t0)
# torch.cuda.synchronize()
# print(time.time() - t0)
def test_layout():
a1 = torch.rand(16, 64, device='cuda', dtype=torch.float16)
a1 = torch.arange(16* 64, device='cuda').reshape(16, 64).byte()
a2, s2 = F.transform(a1, 'col_turing')
a1 = torch.rand(16, 64, device="cuda", dtype=torch.float16)
a1 = torch.arange(16 * 64, device="cuda").reshape(16, 64).byte()
a2, s2 = F.transform(a1, "col_turing")
print(a2.shape)
print(a1.flatten()[8*64:8*64+32])
print(a1.flatten()[8 * 64 : 8 * 64 + 32])
for i in range(4):
print(a2.flatten()[i*8*32:i*8*32+32], 0)
print(a2.flatten()[i * 8 * 32 : i * 8 * 32 + 32], 0)
def test_coo2csr():
......@@ -1444,14 +1575,16 @@ def test_coo2csr():
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
A2 = A*idx
cooA = F.COOSparseTensor(
A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
)
A2 = A * idx
csrA = F.coo2csr(cooA)
counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
assert counts.numel() == A.shape[0]
torch.testing.assert_allclose(counts, (A2!=0).sum(1))
idx = (A2!=0)
torch.testing.assert_allclose(counts, (A2 != 0).sum(1))
idx = A2 != 0
torch.testing.assert_allclose(A2[idx], csrA.values)
......@@ -1462,41 +1595,43 @@ def test_coo2csc():
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
A2 = A*idx
cooA = F.COOSparseTensor(
A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
)
A2 = A * idx
cscA = F.coo2csc(cooA)
counts = cscA.colptr[1:] - cscA.colptr[:-1]
assert counts.numel() == A.shape[1]
torch.testing.assert_allclose(counts, (A2!=0).sum(0))
torch.testing.assert_allclose(counts, (A2 != 0).sum(0))
# torch uses row-major -> use transpose to transfer to col-major
idx = (A2.t()!=0)
idx = A2.t() != 0
torch.testing.assert_allclose(A2.t()[idx], cscA.values)
n = 2
#dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
#dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1 = [1*2048]
#dim2 = [12288]
# dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1 = [1 * 2048]
# dim2 = [12288]
dim2 = [2048]
#dim1 = [2]
#dim2 = [2]
# dim1 = [2]
# dim2 = [2]
dtype = [torch.int8]
values = list(product(dim1,dim2, dtype))
names = ['dim1_{0}_dim2_{1}_dtype_{2}'.format(*vals) for vals in values]
values = list(product(dim1, dim2, dtype))
names = ["dim1_{0}_dim2_{1}_dtype_{2}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names)
def test_spmm_coo_dequant(dim1, dim2, dtype):
threshold = 6.0
#threshold = 2.8
#threshold = 0.0
A = torch.randn(dim1, dim2, device='cuda').half()
B = torch.empty(dim2, dim2*4, device='cuda', dtype=torch.float16)
# threshold = 2.8
# threshold = 0.0
A = torch.randn(dim1, dim2, device="cuda").half()
B = torch.empty(dim2, dim2 * 4, device="cuda", dtype=torch.float16)
torch.nn.init.xavier_uniform_(B)
Bt = B.t().contiguous()
CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)
rowidx = torch.randint(0, A.shape[-1], size=(15,))
......@@ -1507,12 +1642,14 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
A2 = A*idx
cooA = F.COOSparseTensor(
A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
)
A2 = A * idx
out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
out1 = torch.matmul(A2, B.half())
out3 = F.spmm_coo_very_sparse(cooA, CBt.half())
out3 = out3*statsBt.half()/127
out3 = out3 * statsBt.half() / 127
values, counts = torch.unique(cooA.rowidx, return_counts=True)
offset = counts.cumsum(0).int()
......@@ -1521,56 +1658,54 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
torch.testing.assert_allclose(out2, out3, rtol=0.05, atol=0.001)
p = 200/(2048*12288*4)
p = 200 / (2048 * 12288 * 4)
n = out1.numel()
count = math.ceil(p*n)
count = math.ceil(p * n)
assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count)
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(100):
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(100):
# out2 = F.spmm_coo_very_sparse(cooA, B)
#torch.cuda.synchronize()
#print('fp16', time.time() - t0)
# torch.cuda.synchronize()
# print('fp16', time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
out2 = F.spmm_coo(cooA, B)
out2 = F.spmm_coo(cooA, B)
torch.cuda.synchronize()
print('cusparse fp16', time.time() - t0)
print("cusparse fp16", time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
out2 = F.spmm_coo_very_sparse(cooA, CBt)
out2 = F.spmm_coo_very_sparse(cooA, CBt)
torch.cuda.synchronize()
print('int8', time.time() - t0)
print("int8", time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
torch.cuda.synchronize()
print('int8+dequant', time.time() - t0)
print("int8+dequant", time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
out2 = torch.matmul(A, B)
out2 = torch.matmul(A, B)
torch.cuda.synchronize()
print('matmul', time.time() - t0)
print("matmul", time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
out1 = bnb.matmul(A, Bt)
out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
out = out1+out2
out = out1 + out2
torch.cuda.synchronize()
print('sparse+ matmul', time.time() - t0)
print("sparse+ matmul", time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
......@@ -1578,33 +1713,36 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
out1 = bnb.matmul(A, Bt)
torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1)
torch.cuda.synchronize()
print('partial matmul', time.time() - t0)
print("partial matmul", time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
out1 = bnb.matmul(A, Bt)
torch.cuda.synchronize()
print('partial matmul', time.time() - t0)
print("partial matmul", time.time() - t0)
batch_size = 1
seqdim = 2048
values = []
values.append((batch_size, seqdim, 768, 4*768))
#values.append((batch_size, seqdim, 1024, 4*1024))
#values.append((batch_size, seqdim, 1536, 4*1536))
#values.append((batch_size, seqdim, 2048, 4*2048))
#values.append((batch_size, seqdim, 2560, 4*2560))
#values.append((batch_size, seqdim, 4096, 4*4096))
#values.append((batch_size, seqdim, 5140, 4*5140))
#values.append((batch_size, seqdim, 12288, 4*12288))
names = ['batch_{0}_seq_{1}_model_{2}_hidden_{3}'.format(*vals) for vals in values]
values.append((batch_size, seqdim, 768, 4 * 768))
# values.append((batch_size, seqdim, 1024, 4*1024))
# values.append((batch_size, seqdim, 1536, 4*1536))
# values.append((batch_size, seqdim, 2048, 4*2048))
# values.append((batch_size, seqdim, 2560, 4*2560))
# values.append((batch_size, seqdim, 4096, 4*4096))
# values.append((batch_size, seqdim, 5140, 4*5140))
# values.append((batch_size, seqdim, 12288, 4*12288))
names = ["batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values]
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_matmul(batch, seq, model, hidden):
formatB = F.get_special_format_str()
A = torch.randn(batch, seq, model, device='cuda').half()
B = torch.empty(hidden, model, dtype=torch.float16, device='cuda')
A = torch.randn(batch, seq, model, device="cuda").half()
B = torch.empty(hidden, model, dtype=torch.float16, device="cuda")
torch.nn.init.xavier_uniform_(B)
linear8bit = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
......@@ -1613,31 +1751,37 @@ def test_bench_matmul(batch, seq, model, hidden):
outliers = torch.randint(0, model, size=(5,)).cuda()
A[:, :, outliers] = 8.0
linearMixedBit = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
linearMixedBit = (
bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
)
linearMixedBit.eval()
# warmup
for i in range(100):
torch.matmul(A, B.t())
torch.cuda.synchronize()
print('')
print("")
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
torch.matmul(A, B.t())
torch.cuda.synchronize()
print(f'pytorch: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
print(
f"pytorch: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
bnb.matmul(A, B)
torch.cuda.synchronize()
print(f'bnb lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
print(
f"bnb lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
C32A, SA = F.transform(CA, 'col32')
C32A, SA = F.transform(CA, "col32")
CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
CxB, SB = F.transform(CB, to_order=formatB)
torch.cuda.synchronize()
......@@ -1645,7 +1789,9 @@ def test_bench_matmul(batch, seq, model, hidden):
for i in range(100):
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
torch.cuda.synchronize()
print(f'igemmlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
print(
f"igemmlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
BA, statsB = F.vectorwise_quant(B, dim=1)
CxB, SB = F.nvidia_transform(CB, to_order=formatB)
......@@ -1654,26 +1800,30 @@ def test_bench_matmul(batch, seq, model, hidden):
for i in range(100):
A2 = A.view(-1, A.shape[-1]).contiguous()
CA, statsA = F.vectorwise_quant(A2, dim=1)
C32A, SA = F.nvidia_transform(CA, 'col32')
C32A, SA = F.nvidia_transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
Cout, Sout = F.nvidia_transform(out32, 'row', state=Sout32)
Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
torch.cuda.synchronize()
print(f'vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
print(
f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
BA, statsB = F.vectorwise_quant(B, dim=1, quant_type='linear')
BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
CxB, SB = F.nvidia_transform(CB, to_order=formatB)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
A2 = A.view(-1, A.shape[-1]).contiguous()
CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type='linear')
C32A, SA = F.nvidia_transform(CA, 'col32')
CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
C32A, SA = F.nvidia_transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
Cout, Sout = F.nvidia_transform(out32, 'row', state=Sout32)
out = Cout*statsB*statsA*(1.0/(127*127))
Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
out = Cout * statsB * statsA * (1.0 / (127 * 127))
torch.cuda.synchronize()
print(f'linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
print(
f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
linear8bit(A)
torch.cuda.synchronize()
......@@ -1681,8 +1831,9 @@ def test_bench_matmul(batch, seq, model, hidden):
for i in range(100):
linear8bit(A)
torch.cuda.synchronize()
print(f'bnb linear8bitlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
print(
f"bnb linear8bitlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
linearMixedBit(A)
torch.cuda.synchronize()
......@@ -1690,65 +1841,66 @@ def test_bench_matmul(batch, seq, model, hidden):
for i in range(100):
linearMixedBit(A)
torch.cuda.synchronize()
print(f'bnb linear8bitlt with threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
print(
f"bnb linear8bitlt with threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
def test_zeropoint():
def min_max(x):
maxA = torch.amax(x, dim=1, keepdim=True)
minA = torch.amin(x, dim=1, keepdim=True)
midpoint = (maxA-minA)/2.0
dyna = 252/(maxA-minA)
#dyna *= 0.98
x = dyna*x
x = x - torch.round((dyna*(minA+midpoint)))
midpoint = (maxA - minA) / 2.0
dyna = 252 / (maxA - minA)
# dyna *= 0.98
x = dyna * x
x = x - torch.round((dyna * (minA + midpoint)))
return x.to(torch.int8), minA, midpoint, dyna
batch = 2
seq = 2
model = 4
hidden = 2*model
#batch = 4
#seq = 2048
#model = 1024
#hidden = 8*model
A = torch.randn(batch*seq, model, device='cuda').half()-0.4
B = torch.nn.Parameter(torch.randn(model, hidden, device='cuda').half())
#A[0] = 0
#B[:, 0] = 0
#A = A*(A>0)
#A[0, 0] = 0
#A[0, 0] = 6.0
hidden = 2 * model
# batch = 4
# seq = 2048
# model = 1024
# hidden = 8*model
A = torch.randn(batch * seq, model, device="cuda").half() - 0.4
B = torch.nn.Parameter(torch.randn(model, hidden, device="cuda").half())
# A[0] = 0
# B[:, 0] = 0
# A = A*(A>0)
# A[0, 0] = 0
# A[0, 0] = 6.0
Ac, minA, midpoint, dyna = min_max(A)
#print(Ac[0, 0], 'zero')
#print(Ac, Ac.min(), Ac.max())
Bc, maxB = F.vectorwise_quant(B, quant_type='linear')
# print(Ac[0, 0], 'zero')
# print(Ac, Ac.min(), Ac.max())
Bc, maxB = F.vectorwise_quant(B, quant_type="linear")
out = F.igemm(Ac, Bc)
out2 = torch.matmul(A,B)
offset = B.sum(0)*torch.round(dyna*(minA+midpoint))/dyna
out2 = torch.matmul(A, B)
offset = B.sum(0) * torch.round(dyna * (minA + midpoint)) / dyna
out = out.float()
#print(out.shape, maxB.shape, scale.shape, offset.shape)
norm1 = maxB/127
C4 = (out/dyna)*norm1+offset
# print(out.shape, maxB.shape, scale.shape, offset.shape)
norm1 = maxB / 127
C4 = (out / dyna) * norm1 + offset
B1 = torch.nn.Parameter(B.clone())
B2 = torch.nn.Parameter(B.clone())
B3 = torch.nn.Parameter(B.clone())
B4 = torch.nn.Parameter(B.clone())
C1 = torch.matmul(A, B1)
C2 = bnb.matmul_cublas(A, B2, None, 'linear')
C3 = bnb.matmul_cublas(A, B3, None, 'zeropoint')
C4 = bnb.matmul_cublas(A, B4, None, 'vector-zeropoint')
C2 = bnb.matmul_cublas(A, B2, None, "linear")
C3 = bnb.matmul_cublas(A, B3, None, "zeropoint")
C4 = bnb.matmul_cublas(A, B4, None, "vector-zeropoint")
err1 = torch.abs(C1-C2).mean().item()
err2 = torch.abs(C1-C3).mean().item()
err3 = torch.abs(C1-C4).mean().item()
err1 = torch.abs(C1 - C2).mean().item()
err2 = torch.abs(C1 - C3).mean().item()
err3 = torch.abs(C1 - C4).mean().item()
print(err1, err2, err3)
#assert err1 > err2
# assert err1 > err2
loss1 = C1.mean()
loss2 = C2.mean()
......@@ -1765,40 +1917,38 @@ def test_zeropoint():
print(B2.grad)
print(B3.grad)
print(B4.grad)
err1 = torch.abs(B1.grad-B2.grad).mean().item()
err2 = torch.abs(B1.grad-B3.grad).mean().item()
err3 = torch.abs(B1.grad-B4.grad).mean().item()
err1 = torch.abs(B1.grad - B2.grad).mean().item()
err2 = torch.abs(B1.grad - B3.grad).mean().item()
err3 = torch.abs(B1.grad - B4.grad).mean().item()
print(err1, err2, err3)
def test_zp():
def quant_zp(x):
dtype = x.dtype
x = x.float()
dyna = x.max() - x.min()
if dyna == 0: dyna = 1
qx = 254./dyna
if dyna == 0:
dyna = 1
qx = 254.0 / dyna
minx = x.min()
#zpx = torch.round(minx* qx)
#zpx = 127 - torch.round(x.max()* qx)
zpx = torch.round(x.min()* qx) - 127
x = (qx*x) + zpx
# zpx = torch.round(minx* qx)
# zpx = 127 - torch.round(x.max()* qx)
zpx = torch.round(x.min() * qx) - 127
x = (qx * x) + zpx
return x, qx, zpx
batch = 2
seq = 512
model = 1024
hidden = 4*model
A = torch.randn(batch*seq, model, device='cuda').half()*0.1
B = torch.randn(model, hidden, device='cuda').half()*0.1
hidden = 4 * model
A = torch.randn(batch * seq, model, device="cuda").half() * 0.1
B = torch.randn(model, hidden, device="cuda").half() * 0.1
C0 = torch.matmul(A, B)
#A, SA = F.vectorwise_quant(A, quant_type='linear')
#B, SB = F.vectorwise_quant(B, quant_type='linear')
# A, SA = F.vectorwise_quant(A, quant_type='linear')
# B, SB = F.vectorwise_quant(B, quant_type='linear')
A = A.float()
B = B.float()
......@@ -1806,69 +1956,68 @@ def test_zp():
C3 = bnb.matmul(A.half(), B.t().contiguous().half())
zp = 1
#C2 = torch.matmul(A-zp, B)
#C2 += B.sum(0).view(1, -1)*zp
C2 = torch.matmul(A, B-zp)
C2 -= A.sum(1).view(-1, 1)*zp
# C2 = torch.matmul(A-zp, B)
# C2 += B.sum(0).view(1, -1)*zp
C2 = torch.matmul(A, B - zp)
C2 -= A.sum(1).view(-1, 1) * zp
ca, cqa, cza = quant_zp(A)
print(ca.min(), ca.max())
print((ca-cza).min(), (ca-cza).max())
print((ca - cza).min(), (ca - cza).max())
zp = 1
scale = 2.0
C5 = torch.matmul((A*scale)-zp, B)
C5 += B.sum(0)*zp
C5 = torch.matmul((A * scale) - zp, B)
C5 += B.sum(0) * zp
C5 /= scale
CA, qa, zpa = quant_zp(A)
C4 = torch.matmul(CA, B)
C4 -= B.sum(0)*zpa
C4 -= B.sum(0) * zpa
C4 /= qa
zpb = 1
zpa = 1
qa = 2
qb = 2
C6 = torch.matmul((A*qa)+zpa, (B*qb)+zpb)
C6 -= (qb*B.sum(0).view(1, -1)*zpa) + (qa*A.sum(1).view(-1, 1)*zpb)
C6 -= zpa*zpb*A.shape[1]
C6 /= qa*qb
C6 = torch.matmul((A * qa) + zpa, (B * qb) + zpb)
C6 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
C6 -= zpa * zpb * A.shape[1]
C6 /= qa * qb
CA, qa, zpa = quant_zp(A)
CB, qb, zpb = quant_zp(B)
C7 = torch.matmul(CA, CB)
C7 -= (qb*B.sum(0).view(1, -1)*zpa) + (qa*A.sum(1).view(-1, 1)*zpb)
C7 -= zpa*zpb*A.shape[1]
C7 /= qa*qb
C7 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
C7 -= zpa * zpb * A.shape[1]
C7 /= qa * qb
print('')
#print(C0.flatten()[:10])
print("")
# print(C0.flatten()[:10])
print(C1.flatten()[:10])
print(C2.flatten()[:10])
print(C3.flatten()[:10])
print(C5.flatten()[:10])
print(C6.flatten()[:10])
print(C7.flatten()[:10])
err1 = torch.abs(C1-C2).mean().item()
err2 = torch.abs(C1-C3).mean().item()
err3 = torch.abs(C1-C4).mean().item()
err4 = torch.abs(C1-C5).mean().item()
err5 = torch.abs(C1-C6).mean().item()
err6 = torch.abs(C1-C7).mean().item()
err1 = torch.abs(C1 - C2).mean().item()
err2 = torch.abs(C1 - C3).mean().item()
err3 = torch.abs(C1 - C4).mean().item()
err4 = torch.abs(C1 - C5).mean().item()
err5 = torch.abs(C1 - C6).mean().item()
err6 = torch.abs(C1 - C7).mean().item()
print(err1, err2, err3, err4, err5, err6)
def test_extract_outliers():
for i in range(k):
shapeA = (4096, 4096*4)
shapeA = (4096, 4096 * 4)
idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda()
#idx = torch.Tensor([0]).int().cuda()
A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8)
# idx = torch.Tensor([0]).int().cuda()
A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
outliers1 = A[:, idx.long()]
CA, SA = F.transform(A, 'col_turing')
CA, SA = F.transform(A, "col_turing")
outliers2 = F.extract_outliers(CA, SA, idx)
......@@ -1877,7 +2026,7 @@ def test_extract_outliers():
torch.testing.assert_allclose(outliers1, outliers2)
CA, SA = F.transform(A, 'col_ampere')
CA, SA = F.transform(A, "col_ampere")
outliers2 = F.extract_outliers(CA, SA, idx)
......
from itertools import product
import pytest
import torch
from itertools import product
from torch import nn
import bitsandbytes as bnb
class MockArgs(object):
def __init__(self, initial_data):
for key in initial_data:
setattr(self, key, initial_data[key])
class MLP8bit(torch.nn.Module):
def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0):
super(MLP8bit, self).__init__()
self.fc1 = bnb.nn.Linear8bitLt(dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold)
self.fc2 = bnb.nn.Linear8bitLt(dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold)
self.fc1 = bnb.nn.Linear8bitLt(
dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold
)
self.fc2 = bnb.nn.Linear8bitLt(
dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold
)
def forward(self, x):
x = self.fc1(x)
......@@ -25,108 +31,114 @@ class MLP8bit(torch.nn.Module):
def get_args():
args = MockArgs([])
args.quant_type = 'vector'
args.use_8bit_training = 'full'
args.quant_type = "vector"
args.use_8bit_training = "full"
args.clip_freq = 9999
return args
def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
idx = torch.isclose(a, b, rtol, atol)
sumval = (idx==0).sum().item()
sumval = (idx == 0).sum().item()
if sumval > count:
print(f'Too many values not close: assert {sumval} < {count}')
print(f"Too many values not close: assert {sumval} < {count}")
torch.testing.assert_allclose(a, b, rtol, atol)
class LinearFunction(torch.autograd.Function):
class LinearFunction(torch.autograd.Function):
@staticmethod
def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0):
round_func = LinearFunction.round_stoachastic if stochastic else torch.round
norm = math.sqrt(math.pi)/math.sqrt(2.0)
#std = torch.abs(x).mean()*norm
norm = math.sqrt(math.pi) / math.sqrt(2.0)
# std = torch.abs(x).mean()*norm
std = torch.std(x)
max1 = std*trim_value
x = x/max1*127
max1 = std * trim_value
x = x / max1 * 127
x = round_func(x)
x[x > 127] = 127
x[x < -127] = -127
x = x/127*max1
x = x / 127 * max1
return x
def quant(x, quant_type, dim=1):
if quant_type == 'linear':
if quant_type == "linear":
max1 = torch.abs(x).max().float()
xq = torch.round(x/max1*127).to(torch.int8)
xq = torch.round(x / max1 * 127).to(torch.int8)
return xq, max1
elif quant_type == 'vector':
elif quant_type == "vector":
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
xq = torch.round(x/max1*127).to(torch.int8)
xq = torch.round(x / max1 * 127).to(torch.int8)
return xq, max1
elif quant_type == 'min-max':
elif quant_type == "min-max":
maxA = torch.amax(x, dim=dim, keepdim=True).float()
minA = torch.amin(x, dim=dim, keepdim=True).float()
scale = (maxA-minA)/2.0
xq = torch.round(127*(x-minA-scale)/scale).to(torch.int8)
scale = (maxA - minA) / 2.0
xq = torch.round(127 * (x - minA - scale) / scale).to(torch.int8)
return xq, (minA.float(), scale.float())
else: return None
else:
return None
def dequant(xq, S1, S2, dtype, quant_type):
if quant_type == 'linear':
norm = S1*S2/(127*127)
if quant_type == "linear":
norm = S1 * S2 / (127 * 127)
# double cast needed to prevent overflows
return (xq.float()*norm).to(dtype)
elif quant_type == 'vector':
return (xq.float() * norm).to(dtype)
elif quant_type == "vector":
x = xq.float()
if len(xq.shape) == 2 and len(S1.shape) == 3: S1 = S1.squeeze(0)
if len(xq.shape) == 2 and len(S2.shape) == 3: S2 = S2.squeeze(0)
#print(x.shape, S1.shape, S2.shape)
if len(xq.shape) == 2 and len(S1.shape) == 3:
S1 = S1.squeeze(0)
if len(xq.shape) == 2 and len(S2.shape) == 3:
S2 = S2.squeeze(0)
# print(x.shape, S1.shape, S2.shape)
if len(S1.shape) == 2:
x *= S1.t()/127
x *= S1.t() / 127
else:
x *= S1/127
x *= S2/127
x *= S1 / 127
x *= S2 / 127
return x.to(dtype)
else: return None
else:
return None
def dequant_min_max(xq, A, B, SA, SB, dtype):
offset = B.float().t().sum(0)*(SA[0]+SA[1])
offset = B.float().t().sum(0) * (SA[0] + SA[1])
x = xq.float()
if len(xq.shape) == 2 and len(SB.shape) == 3: SB = SB.squeeze(0)
if len(xq.shape) == 2 and len(SA.shape) == 3: SA = SA.squeeze(0)
if len(xq.shape) == 2 and len(SB.shape) == 3:
SB = SB.squeeze(0)
if len(xq.shape) == 2 and len(SA.shape) == 3:
SA = SA.squeeze(0)
if len(SB.shape) == 2:
x *= SB.t()/127
x *= SB.t() / 127
else:
x *= SB/127
x *= SA[1]/127
x +=offset
x *= SB / 127
x *= SA[1] / 127
x += offset
return x.to(dtype)
def get_8bit_linear(x, stochastic=False):
round_func = LinearFunction.round_stoachastic if stochastic else torch.round
max1 = torch.abs(x).max()
x = x/max1*127
x = round_func(x)/127*max1
#x = torch.round(x)/128*max1
x = x / max1 * 127
x = round_func(x) / 127 * max1
# x = torch.round(x)/128*max1
return x
@staticmethod
def get_8bit_vector_wise(x, dim, stochastic=False):
round_func = LinearFunction.round_stoachastic if stochastic else torch.round
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
max1[max1==0] = 1.0
x = (x*127)/max1
x = round_func(x)/127*max1
max1[max1 == 0] = 1.0
x = (x * 127) / max1
x = round_func(x) / 127 * max1
return x
@staticmethod
def round_stoachastic(x):
sign = torch.sign(x)
absx = torch.abs(x)
decimal = absx-torch.floor(absx)
decimal = absx - torch.floor(absx)
rdm = torch.rand_like(decimal)
return sign*(torch.floor(absx)+(rdm < decimal).to(x.dtype))
return sign * (torch.floor(absx) + (rdm < decimal).to(x.dtype))
@staticmethod
def fake_8bit_storage(w, exponent_bits):
......@@ -140,10 +152,10 @@ class LinearFunction(torch.autograd.Function):
@staticmethod
def fake_8bit_storage_quantile(w, args):
code = bnb.functional.estimate_quantiles(w.data, offset=args.offset)
#C = bnb.functional.quantize_no_absmax(code, w)
#out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
#print(out)
#out = out.half()
# C = bnb.functional.quantize_no_absmax(code, w)
# out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
# print(out)
# out = out.half()
code /= torch.max(torch.abs(code))
absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
out = bnb.functional.dequantize_blockwise(absmax, C, code)
......@@ -162,7 +174,7 @@ class LinearFunction(torch.autograd.Function):
@staticmethod
def fake_8bit_storage_with_max(w, topk=8):
blocked_w = einops.rearrange(w.flatten(), '(h b) -> h b', b=256)
blocked_w = einops.rearrange(w.flatten(), "(h b) -> h b", b=256)
max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True)
idx = idx[:, :topk]
max_val = max_val[:, :topk]
......@@ -191,22 +203,21 @@ class LinearFunction(torch.autograd.Function):
w.copy_(unblocked_w)
return unblocked_w
@staticmethod
def forward(ctx, x, weight, bias=None, args=None):
if args.use_8bit_training != 'off':
if args.use_8bit_training != "off":
weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1)
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2)
outputq = bnb.functional.igemm(x8, weight8.t())
output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type)
#if torch.rand(1) < 0.01:
#output32 = torch.matmul(x, weight.t())
#err = torch.abs(output-output32).float()
#relerr = err/(torch.abs(output32).float()+1e-8)
#print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
# if torch.rand(1) < 0.01:
# output32 = torch.matmul(x, weight.t())
# err = torch.abs(output-output32).float()
# relerr = err/(torch.abs(output32).float()+1e-8)
# print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
else:
#output = torch.matmul(x, weight.t())
output = torch.einsum('bsi,oi->bso', x, weight)
# output = torch.matmul(x, weight.t())
output = torch.einsum("bsi,oi->bso", x, weight)
ctx.save_for_backward(x, weight, bias)
ctx.args = args
......@@ -221,37 +232,49 @@ class LinearFunction(torch.autograd.Function):
args = ctx.args
stochastic = False
grad_input = grad_weight = grad_bias = None
if bias is not None and ctx.needs_input_grad[2]: grad_bias = grad_output.sum(0)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
# weight and x are already 8bit
# -> transform grad_output to 8-bit
if args.use_8bit_training == 'forward+wgrad':
grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
if args.use_8bit_training == "forward+wgrad":
grad_output8, S1 = LinearFunction.quant(
grad_output, args.quant_type, dim=[0, 1]
)
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
grad_weight8 = bnb.functional.igemm(grad_output8, x8)
grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)
grad_weight = LinearFunction.dequant(
grad_weight8, S1, S2, grad_output.dtype, args.quant_type
)
#grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
# grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
grad_input = grad_output.matmul(weight)
elif args.use_8bit_training == 'full':
grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
elif args.use_8bit_training == "full":
grad_output8, S1 = LinearFunction.quant(
grad_output, args.quant_type, dim=[0, 1]
)
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
grad_weight8 = torch.zeros_like(weight, dtype=torch.int32)
bnb.functional.igemm(grad_output8, x8, out=grad_weight8)
grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)
grad_weight = LinearFunction.dequant(
grad_weight8, S1, S2, grad_output.dtype, args.quant_type
)
grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2)
weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0)
grad_input8 = bnb.functional.igemm(grad_output8, weight8)
grad_input = LinearFunction.dequant(grad_input8, S1, S3, grad_output.dtype, args.quant_type)
grad_input = LinearFunction.dequant(
grad_input8, S1, S3, grad_output.dtype, args.quant_type
)
else:
grad_input = grad_output.matmul(weight)
grad_weight = torch.einsum('bsi,bso->oi', x, grad_output)
grad_weight = torch.einsum("bsi,bso->oi", x, grad_output)
return grad_input, grad_weight, grad_bias, None
class Linear8bit(nn.Module):
def __init__(self, input_features, output_features, bias=True, args=None):
super(Linear8bit, self).__init__()
......@@ -263,7 +286,7 @@ class Linear8bit(nn.Module):
if bias:
self.bias = nn.Parameter(torch.empty(output_features))
else:
self.register_parameter('bias', None)
self.register_parameter("bias", None)
torch.nn.init.xavier_uniform_(self.weight)
if self.bias is not None:
......@@ -275,12 +298,11 @@ class Linear8bit(nn.Module):
return LinearFunction.apply(x, self.weight, self.bias, self.args)
def test_linear8bit():
l0 = torch.nn.Linear(32, 64).cuda().half()
l1 = bnb.nn.Linear8bit(32,64, args=get_args()).cuda().half()
l1 = bnb.nn.Linear8bit(32, 64, args=get_args()).cuda().half()
l2 = Linear8bit(32, 64, args=get_args()).cuda().half()
l3 = bnb.nn.Linear8bitLt(32,64).cuda().half()
l3 = bnb.nn.Linear8bitLt(32, 64).cuda().half()
l0.weight.data = l2.weight.data.clone()
l0.bias.data = l2.bias.data.clone()
......@@ -292,8 +314,8 @@ def test_linear8bit():
l3.bias.data = l2.bias.data.clone()
for i in range(100):
b1 = torch.randn(16, 8, 32, device='cuda').half()
t = torch.randn(16, 8, 64, device='cuda').half()
b1 = torch.randn(16, 8, 32, device="cuda").half()
t = torch.randn(16, 8, 64, device="cuda").half()
b2 = b1.clone()
b3 = b1.clone()
b0 = b1.clone()
......@@ -318,16 +340,20 @@ def test_linear8bit():
assert_all_approx_close(l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2)
assert_all_approx_close(l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2)
assert_all_approx_close(l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2)
assert_all_approx_close(l3.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2)
assert_all_approx_close(
l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2
)
assert_all_approx_close(
l3.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2
)
err1 = torch.abs(l0.weight.grad-l1.weight.grad).mean().item()
err2 = torch.abs(l0.weight.grad-l2.weight.grad).mean().item()
err3 = torch.abs(l0.weight.grad-l3.weight.grad).mean().item()
err1 = torch.abs(l0.weight.grad - l1.weight.grad).mean().item()
err2 = torch.abs(l0.weight.grad - l2.weight.grad).mean().item()
err3 = torch.abs(l0.weight.grad - l3.weight.grad).mean().item()
assert err1*0.8 < err2
assert err2*0.8 < err3
assert err3*0.8 < err1
assert err1 * 0.8 < err2
assert err2 * 0.8 < err3
assert err3 * 0.8 < err1
l0.weight.grad = None
l1.weight.grad = None
......@@ -341,23 +367,28 @@ def test_linear8bit():
threshold = [0.0, 3.0]
values = threshold
names = ['threshold_{0}'.format(vals) for vals in values]
names = ["threshold_{0}".format(vals) for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names)
def test_linear8bitlt_inference(threshold):
l1 = bnb.nn.Linear8bitLt(32,64, threshold=threshold).cuda().half()
assert l1.weight.device.type == 'cuda'
l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold).cuda().half()
assert l1.weight.device.type == "cuda"
assert l1.weight.dtype == torch.float16
l1.eval()
for i in range(100):
b1 = torch.randn(16, 8, 32, device='cuda').half()
b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = l1(b1)
if i == 1:
assert l1.state.CxB is not None
def test_linear8bitlt_accumulated_gradient():
l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32,32).cuda().half() for i in range(2)])
l2 = torch.nn.Sequential(*[torch.nn.Linear(32,32).cuda().half() for i in range(2)])
l1 = torch.nn.Sequential(
*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)]
)
l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)])
l2[0].weight = torch.nn.Parameter(l1[0].weight.clone())
l2[0].bias = torch.nn.Parameter(l1[0].bias.clone())
l2[1].weight = torch.nn.Parameter(l1[1].weight.clone())
......@@ -367,9 +398,8 @@ def test_linear8bitlt_accumulated_gradient():
acc_steps = 10
for i in range(10):
b1 = torch.randn(16, 8, 32, device='cuda').half()
b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = l1(b1)
o2 = l2(b1)
loss1 = o1.mean()
......@@ -385,8 +415,12 @@ def test_linear8bitlt_accumulated_gradient():
opt1.zero_grad(True)
opt2.step()
opt2.zero_grad(True)
assert_all_approx_close(l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2)
assert_all_approx_close(l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2)
assert_all_approx_close(
l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2
)
assert_all_approx_close(
l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2
)
# we do this copy because otherwise we have small divergences over time that add up
l1[0].weight.data.copy_(l2[0].weight.data)
l1[1].weight.data.copy_(l2[1].weight.data)
......@@ -397,15 +431,21 @@ def test_linear8bitlt_accumulated_gradient():
threshold = [0.0, 2.0]
values = threshold
names = ['threshold_{0}'.format(vals) for vals in values]
names = ["threshold_{0}".format(vals) for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names)
def test_linear8bitlt_no_fp16_weights(threshold):
l1 = bnb.nn.Linear8bitLt(32,64, threshold=threshold, has_fp16_weights=False).cuda().half()
l1 = (
bnb.nn.Linear8bitLt(32, 64, threshold=threshold, has_fp16_weights=False)
.cuda()
.half()
)
assert l1.weight.dtype == torch.int8
l1.eval()
for i in range(100):
b1 = torch.randn(16, 8, 32, device='cuda').half()
b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = l1(b1)
assert o1.dtype == torch.float16
......@@ -414,57 +454,70 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert mlp.fc2.weight.dtype == torch.int8
for i in range(100):
b1 = torch.randn(16, 8, 32, device='cuda').half()
b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = mlp(b1)
assert o1.dtype == torch.float16
if threshold > 0: assert mlp.fc1.state.idx is not None
if threshold > 0: assert mlp.fc2.state.idx is not None
if threshold > 0:
assert mlp.fc1.state.idx is not None
if threshold > 0:
assert mlp.fc2.state.idx is not None
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half()
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
for i in range(100):
b1 = torch.randn(16, 8, 32, device='cuda').half()
b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = mlp(b1)
assert o1.dtype == torch.float16
if threshold > 0: assert mlp.fc1.state.idx is not None
if threshold > 0: assert mlp.fc2.state.idx is not None
if threshold > 0:
assert mlp.fc1.state.idx is not None
if threshold > 0:
assert mlp.fc2.state.idx is not None
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda()
for i in range(100):
b1 = torch.randn(16, 8, 32, device='cuda').half()
b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = mlp(b1)
assert o1.dtype == torch.float16
if threshold > 0: assert mlp.fc1.state.idx is not None
if threshold > 0: assert mlp.fc2.state.idx is not None
if threshold > 0:
assert mlp.fc1.state.idx is not None
if threshold > 0:
assert mlp.fc2.state.idx is not None
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to('cuda')
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to("cuda")
for i in range(100):
b1 = torch.randn(16, 8, 32, device='cuda').half()
b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = mlp(b1)
assert o1.dtype == torch.float16
if threshold > 0: assert mlp.fc1.state.idx is not None
if threshold > 0: assert mlp.fc2.state.idx is not None
if threshold > 0:
assert mlp.fc1.state.idx is not None
if threshold > 0:
assert mlp.fc2.state.idx is not None
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
assert mlp.fc1.weight.device.type == 'cuda'
assert mlp.fc2.weight.device.type == 'cuda'
assert mlp.fc1.weight.device.type == "cuda"
assert mlp.fc2.weight.device.type == "cuda"
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).to(torch.float16).to('cuda')
mlp = (
MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
.to(torch.float16)
.to("cuda")
)
for i in range(100):
b1 = torch.randn(16, 8, 32, device='cuda').half()
b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = mlp(b1)
assert o1.dtype == torch.float16
if threshold > 0: assert mlp.fc1.state.idx is not None
if threshold > 0: assert mlp.fc2.state.idx is not None
if threshold > 0:
assert mlp.fc1.state.idx is not None
if threshold > 0:
assert mlp.fc2.state.idx is not None
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
assert mlp.fc1.weight.device.type == 'cuda'
assert mlp.fc2.weight.device.type == 'cuda'
assert mlp.fc1.weight.device.type == "cuda"
assert mlp.fc2.weight.device.type == "cuda"
import ctypes
import os
import time
import shutil
import time
import uuid
from itertools import product
from os.path import join
import pytest
import ctypes
import torch
import bitsandbytes as bnb
import bitsandbytes.functional as F
from os.path import join
from itertools import product
#import apex
# import apex
k = 20
def get_temp_dir():
path = '/tmp/autoswap/{0}'.format(str(uuid.uuid4()))
path = "/tmp/autoswap/{0}".format(str(uuid.uuid4()))
os.makedirs(path, exist_ok=True)
return path
def rm_path(path):
shutil.rmtree(path)
str2optimizers = {}
str2optimizers['adam_pytorch'] = (None, torch.optim.Adam, bnb.optim.Adam)
#str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
#str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
str2optimizers['momentum_pytorch'] = (None, lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), bnb.optim.Adam)
#str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam)
#str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam)
str2optimizers['adam'] = (torch.optim.Adam, bnb.optim.Adam)
#str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers['momentum'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False))
str2optimizers['lars'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9))
#str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB)
str2optimizers['rmsprop'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False))
str2optimizers['adam8bit'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False))
str2optimizers['momentum8bit'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False))
str2optimizers['rmsprop8bit'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False))
#str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit)
str2optimizers['lars8bit'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9))
str2optimizers['adam8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
str2optimizers['momentum8bit_blockwise'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True))
str2optimizers['rmsprop8bit_blockwise'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True))
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
str2optimizers["momentum_pytorch"] = (
None,
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
bnb.optim.Adam,
)
# str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam)
# str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam)
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers["momentum"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["lars"] = (
lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9),
)
# str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB)
str2optimizers["rmsprop"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["adam8bit"] = (
torch.optim.Adam,
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False),
)
str2optimizers["momentum8bit"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["rmsprop8bit"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False),
)
# str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit)
str2optimizers["lars8bit"] = (
lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9),
)
str2optimizers["adam8bit_blockwise"] = (
torch.optim.Adam,
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True),
)
str2optimizers["momentum8bit_blockwise"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
)
str2optimizers["rmsprop8bit_blockwise"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True),
)
str2statenames = {}
str2statenames['adam'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
str2statenames['momentum'] = [('momentum_buffer', 'state1')]
str2statenames['lars'] = [('momentum_buffer', 'state1')]
str2statenames['lamb'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
str2statenames['rmsprop'] = [('square_avg', 'state1')]
str2statenames['adam8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')]
str2statenames['lamb8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')]
str2statenames['adam8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')]
str2statenames['momentum8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')]
str2statenames['momentum8bit_blockwise'] = [('momentum_buffer', 'state1', 'qmap1', 'absmax1')]
str2statenames['lars8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')]
str2statenames['rmsprop8bit'] = [('square_avg', 'state1', 'qmap1', 'max1')]
str2statenames['rmsprop8bit_blockwise'] = [('square_avg', 'state1', 'qmap1', 'absmax1')]
str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["momentum"] = [("momentum_buffer", "state1")]
str2statenames["lars"] = [("momentum_buffer", "state1")]
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["rmsprop"] = [("square_avg", "state1")]
str2statenames["adam8bit"] = [
("exp_avg", "state1", "qmap1", "max1"),
("exp_avg_sq", "state2", "qmap2", "max2"),
]
str2statenames["lamb8bit"] = [
("exp_avg", "state1", "qmap1", "max1"),
("exp_avg_sq", "state2", "qmap2", "max2"),
]
str2statenames["adam8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1"),
("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
str2statenames["momentum8bit_blockwise"] = [
("momentum_buffer", "state1", "qmap1", "absmax1")
]
str2statenames["lars8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")]
dim1 = [1024]
dim2 = [32, 1024, 4097, 1]
gtype = [torch.float32, torch.float16]
optimizer_names = ['adam', 'momentum', 'rmsprop', 'lars', 'lamb']
values = list(product(dim1,dim2, gtype, optimizer_names))
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
optimizer_names = ["adam", "momentum", "rmsprop", "lars", "lamb"]
values = list(product(dim1, dim2, gtype, optimizer_names))
names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer32bit(dim1, dim2, gtype, optim_name):
if dim1 == 1 and dim2 == 1: return
p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1
if dim1 == 1 and dim2 == 1:
return
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
p2 = p1.clone()
p1 = p1.float()
torch_optimizer = str2optimizers[optim_name][0]([p1])
bnb_optimizer = str2optimizers[optim_name][1]([p2])
......@@ -84,9 +135,8 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
else:
atol, rtol = 1e-4, 1e-3
for i in range(k):
g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01
g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
p1.grad = g.clone().float()
p2.grad = g.clone()
......@@ -94,21 +144,31 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
torch_optimizer.step()
for name1, name2 in str2statenames[optim_name]:
torch.testing.assert_allclose(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2], atol=atol, rtol=rtol)
torch.testing.assert_allclose(
torch_optimizer.state[p1][name1],
bnb_optimizer.state[p2][name2],
atol=atol,
rtol=rtol,
)
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)
if i % (k//5) == 0 and i > 0:
if i % (k // 5) == 0 and i > 0:
path = get_temp_dir()
torch.save(bnb_optimizer.state_dict(),join(path, 'opt.pt'))
torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
del bnb_optimizer
bnb_optimizer = None
bnb_optimizer = str2optimizers[optim_name][1]([p2])
bnb_optimizer.load_state_dict(torch.load(join(path, 'opt.pt')))
bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
rm_path(path)
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)
for name1, name2 in str2statenames[optim_name]:
torch.testing.assert_allclose(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2], atol=atol, rtol=rtol)
torch.testing.assert_allclose(
torch_optimizer.state[p1][name1],
bnb_optimizer.state[p2][name2],
atol=atol,
rtol=rtol,
)
if gtype == torch.float16:
# the adam buffers should also be close because they are 32-bit
......@@ -118,20 +178,24 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
p1.data = p1.data.half().float()
p2.copy_(p1.data)
torch.testing.assert_allclose(p1.half(), p2)
if optim_name in ['lars', 'lamb']:
assert bnb_optimizer.state[p2]['unorm_vec'] > 0.0
if optim_name in ["lars", "lamb"]:
assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0
dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16]
values = list(product(dim1,dim2, gtype))
names = ['dim1_{0}_dim2_{1}_gtype_{2}'.format(*vals) for vals in values]
values = list(product(dim1, dim2, gtype))
names = ["dim1_{0}_dim2_{1}_gtype_{2}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names)
def test_global_config(dim1, dim2, gtype):
if dim1 == 1 and dim2 == 1: return
p1 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1
p2 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1
p3 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1
if dim1 == 1 and dim2 == 1:
return
p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
p2 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
p3 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
mask = torch.rand_like(p2) < 0.1
beta1 = 0.9
beta2 = 0.999
......@@ -139,7 +203,7 @@ def test_global_config(dim1, dim2, gtype):
eps = 1e-8
bnb.optim.GlobalOptimManager.get_instance().initialize()
bnb.optim.GlobalOptimManager.get_instance().override_config(p3, 'optim_bits', 8)
bnb.optim.GlobalOptimManager.get_instance().override_config(p3, "optim_bits", 8)
bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3])
p1 = p1.cuda()
......@@ -154,30 +218,41 @@ def test_global_config(dim1, dim2, gtype):
atol, rtol = 1e-4, 1e-3
for i in range(50):
g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
g2 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
g3 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
g2 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
g3 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
p1.grad = g1
p2.grad = g2
p3.grad = g3
adam2.step()
assert adam2.state[p3]['state1'].dtype == torch.uint8
assert adam2.state[p3]['state2'].dtype == torch.uint8
assert adam2.state[p3]["state1"].dtype == torch.uint8
assert adam2.state[p3]["state2"].dtype == torch.uint8
dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16]
optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise']
values = list(product(dim1,dim2, gtype, optimizer_names))
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
optimizer_names = [
"adam8bit",
"momentum8bit",
"rmsprop8bit",
"adam8bit_blockwise",
"lamb8bit",
"lars8bit",
"momentum8bit_blockwise",
"rmsprop8bit_blockwise",
]
values = list(product(dim1, dim2, gtype, optimizer_names))
names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer8bit(dim1, dim2, gtype, optim_name):
if dim1 == 1 and dim2 == 1: return
p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1
if dim1 == 1 and dim2 == 1:
return
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
p2 = p1.clone()
p1 = p1.float()
blocksize = 2048
......@@ -197,7 +272,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
relerrors = []
for i in range(50):
g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01
g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
p1.grad = g.clone().float()
p2.grad = g.clone()
......@@ -208,17 +283,31 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
dequant_states = []
for name1, name2, qmap, max_val in str2statenames[optim_name]:
#print(bnb_optimizer.state[p2][max_val], name1)
if 'blockwise' in optim_name:
s1 = F.dequantize_blockwise(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2], blocksize=blocksize)
# print(bnb_optimizer.state[p2][max_val], name1)
if "blockwise" in optim_name:
s1 = F.dequantize_blockwise(
code=bnb_optimizer.state[p2][qmap],
absmax=bnb_optimizer.state[p2][max_val],
A=bnb_optimizer.state[p2][name2],
blocksize=blocksize,
)
else:
s1 = F.dequantize(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2])
num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol)==0
s1 = F.dequantize(
code=bnb_optimizer.state[p2][qmap],
absmax=bnb_optimizer.state[p2][max_val],
A=bnb_optimizer.state[p2][name2],
)
num_not_close = (
torch.isclose(
torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol
)
== 0
)
assert num_not_close.sum().item() < 20
dequant_states.append(s1.clone())
err = torch.abs(p1-p2)
relerr = err/torch.abs(p1)
err = torch.abs(p1 - p2)
relerr = err / torch.abs(p1)
assert err.mean() < 0.0001
assert relerr.mean() < 0.001
......@@ -226,28 +315,44 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
relerrors.append(relerr.mean().item())
if i % 10 == 0 and i > 0:
for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
for (name1, name2, qmap, max_val), s in zip(
str2statenames[optim_name], dequant_states
):
s1cpy = s.clone()
raws1cpy = bnb_optimizer.state[p2][name2].clone()
qmap1 = bnb_optimizer.state[p2][qmap].clone()
path = get_temp_dir()
torch.save(bnb_optimizer.state_dict(),join(path, 'opt.pt'))
torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
del bnb_optimizer
bnb_optimizer = None
bnb_optimizer = str2optimizers[optim_name][1]([p2])
bnb_optimizer.load_state_dict(torch.load(join(path, 'opt.pt')))
bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
rm_path(path)
torch.testing.assert_allclose(raws1cpy, bnb_optimizer.state[p2][name2])
torch.testing.assert_allclose(qmap1, bnb_optimizer.state[p2][qmap])
if 'blockwise' in optim_name:
s1 = F.dequantize_blockwise(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2], blocksize=blocksize)
if "blockwise" in optim_name:
s1 = F.dequantize_blockwise(
code=bnb_optimizer.state[p2][qmap],
absmax=bnb_optimizer.state[p2][max_val],
A=bnb_optimizer.state[p2][name2],
blocksize=blocksize,
)
else:
s1 = F.dequantize(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2])
s1 = F.dequantize(
code=bnb_optimizer.state[p2][qmap],
absmax=bnb_optimizer.state[p2][max_val],
A=bnb_optimizer.state[p2][name2],
)
torch.testing.assert_allclose(s1cpy, s1)
num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol)==0
num_not_close = (
torch.isclose(
torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol
)
== 0
)
assert num_not_close.sum().item() < 20
torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol)
......@@ -256,24 +361,28 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
p1.data = p1.data.to(gtype).float()
p2.copy_(p1.data)
torch.testing.assert_allclose(p1.to(gtype), p2)
for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
for (name1, name2, qmap, max_val), s in zip(
str2statenames[optim_name], dequant_states
):
torch_optimizer.state[p1][name1].copy_(s.data)
#print(sum(errors)/len(errors))
#print(sum(relerrors)/len(relerrors))
# print(sum(errors)/len(errors))
# print(sum(relerrors)/len(relerrors))
dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32]
optim_bits = [32, 8]
values = list(product(dim1,dim2, gtype, optim_bits))
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}'.format(*vals) for vals in values]
values = list(product(dim1, dim2, gtype, optim_bits))
names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names)
def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
if dim1 == 1 and dim2 == 1: return
p1 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1
if dim1 == 1 and dim2 == 1:
return
p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
beta1 = 0.9
beta2 = 0.999
lr = 0.001
......@@ -281,19 +390,23 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
p1 = p1.cuda()
p2 = p1.clone()
adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits)
adam2 = bnb.optim.Adam([p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5)
adam2 = bnb.optim.Adam(
[p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5
)
gnorm_vec = torch.zeros(100).cuda()
step = 0
for i in range(50):
step += 1
g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + (0.01*i)
g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (0.01 * i)
g2 = g1.clone()
p2.grad = g2
current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(g1, gnorm_vec, step, 5)
g1 = (g1.float()*gnorm_scale).to(gtype)
current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(
g1, gnorm_vec, step, 5
)
g1 = (g1.float() * gnorm_scale).to(gtype)
p1.grad = g1
adam1.step()
......@@ -302,47 +415,69 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
# gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state
if optim_bits == 32:
torch.testing.assert_allclose(p1, p2)
torch.testing.assert_allclose(adam1.state[p1]['state1'], adam2.state[p2]['state1'], atol=5e-5, rtol=1e-4)
torch.testing.assert_allclose(adam1.state[p1]['state2'], adam2.state[p2]['state2'], atol=5e-5, rtol=1e-4)
torch.testing.assert_allclose(
adam1.state[p1]["state1"],
adam2.state[p2]["state1"],
atol=5e-5,
rtol=1e-4,
)
torch.testing.assert_allclose(
adam1.state[p1]["state2"],
adam2.state[p2]["state2"],
atol=5e-5,
rtol=1e-4,
)
elif optim_bits == 8:
torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3)
torch.testing.assert_allclose(adam1.state[p1]['state1'], adam2.state[p2]['state1'], atol=2, rtol=1e-3)
torch.testing.assert_allclose(adam1.state[p1]['state2'], adam2.state[p2]['state2'], atol=2, rtol=1e-3)
adam1.state[p1]['state1'].copy_(adam2.state[p2]['state1'])
adam1.state[p1]['state2'].copy_(adam2.state[p2]['state2'])
torch.testing.assert_allclose(
adam1.state[p1]["state1"], adam2.state[p2]["state1"], atol=2, rtol=1e-3
)
torch.testing.assert_allclose(
adam1.state[p1]["state2"], adam2.state[p2]["state2"], atol=2, rtol=1e-3
)
adam1.state[p1]["state1"].copy_(adam2.state[p2]["state1"])
adam1.state[p1]["state2"].copy_(adam2.state[p2]["state2"])
if i % 10 == 0 and i > 0:
path = get_temp_dir()
torch.save(adam2.state_dict(),join(path, 'opt.pt'))
torch.save(adam2.state_dict(), join(path, "opt.pt"))
del adam2
adam2 = None
adam2 = bnb.optim.Adam([p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5)
adam2.load_state_dict(torch.load(join(path, 'opt.pt')))
adam2 = bnb.optim.Adam(
[p2],
lr,
(beta1, beta2),
eps,
optim_bits=optim_bits,
percentile_clipping=5,
)
adam2.load_state_dict(torch.load(join(path, "opt.pt")))
dim1 = [4096]
dim2 = [4096]
gtype = [torch.float32, torch.float16]
#optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit']
#optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch']
#optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
#optimizer_names = ['lamb_apex', 'lamb8bit']
#optimizer_names = ['lars_apex', 'lars8bit']
optimizer_names = ['adam8bit_blockwise']
values = list(product(dim1,dim2, gtype, optimizer_names))
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
# optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit']
# optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch']
# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
# optimizer_names = ['lamb_apex', 'lamb8bit']
# optimizer_names = ['lars_apex', 'lars8bit']
optimizer_names = ["adam8bit_blockwise"]
values = list(product(dim1, dim2, gtype, optimizer_names))
names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
if dim1 == 1 and dim2 == 1: return
p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1
if dim1 == 1 and dim2 == 1:
return
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
bnb_optimizer = str2optimizers[optim_name][1]([p1])
g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01
g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
p1.grad = g
for i in range(k):
if i == k//5:
if i == k // 5:
# 100 iterations for burn-in
torch.cuda.synchronize()
t0 = time.time()
......@@ -350,10 +485,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
bnb_optimizer.step()
torch.cuda.synchronize()
s = time.time()-t0
print('')
params = (k-k//5)*dim1*dim2
print(optim_name, gtype, s/params)
#assert s < 3.9
s = time.time() - t0
print("")
params = (k - k // 5) * dim1 * dim2
print(optim_name, gtype, s / params)
# assert s < 3.9
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment