Commit bfa0e332 authored by Titus von Koeller's avatar Titus von Koeller
Browse files

ran black and isort for coherent code formatting

parent 597a8521
import pytest from itertools import product
import pytest
import torch import torch
import bitsandbytes as bnb
from itertools import product import bitsandbytes as bnb
n = 1 n = 1
k = 25 k = 25
dim1 = torch.randint(16,64, size=(n,)).tolist() dim1 = torch.randint(16, 64, size=(n,)).tolist()
dim2 = torch.randint(32,96, size=(n,)).tolist() dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32,96, size=(n,)).tolist() dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32,96, size=(n,)).tolist() dim4 = torch.randint(32, 96, size=(n,)).tolist()
funcs = [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)] funcs = [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)]
str_funcs = ['bmm', 'matmul'] str_funcs = ["bmm", "matmul"]
req_grad = [(False, False), (True, False), (True, True), (False, True)] req_grad = [(False, False), (True, False), (True, True), (False, True)]
req_grad_str = ['FF', 'TF', 'TT', 'FT'] req_grad_str = ["FF", "TF", "TT", "FT"]
transpose = [(False, False), (False, True), (True, True), (True, False)] transpose = [(False, False), (False, True), (True, True), (True, False)]
str_transpose = ['FF', 'FT', 'TT', 'TF'] str_transpose = ["FF", "FT", "TT", "TF"]
dtype = [torch.float32, torch.float16] dtype = [torch.float32, torch.float16]
values = list(product(dim1,dim2,dim3,dim4,funcs, dtype, req_grad, transpose)) values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose))
str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, str_transpose)) str_values = list(
names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}'.format(*vals) for vals in str_values] product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose)
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names) )
names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}".format(
*vals
)
for vals in str_values
]
@pytest.mark.parametrize(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names
)
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
dim2 = dim2 - (dim2 % 16) dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16) dim3 = dim3 - (dim3 % 16)
...@@ -32,9 +43,11 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): ...@@ -32,9 +43,11 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if funcs[0] in [torch.mm, torch.matmul]: if funcs[0] in [torch.mm, torch.matmul]:
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
A = torch.randn(size=dimA, device='cuda', requires_grad=req_grad[0]) A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0])
B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1]) B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
target = torch.randn(size=(dim2, dim4), device='cuda', requires_grad=req_grad[1]) target = torch.randn(
size=(dim2, dim4), device="cuda", requires_grad=req_grad[1]
)
torch.nn.init.xavier_uniform_(B) torch.nn.init.xavier_uniform_(B)
if not transpose[0] and not transpose[1]: if not transpose[0] and not transpose[1]:
...@@ -52,9 +65,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): ...@@ -52,9 +65,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
n = out_bnb.numel() n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx==0).sum().item() < n*0.0175 assert (idx == 0).sum().item() < n * 0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2) idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
assert (idx==0).sum().item() < n*0.001 assert (idx == 0).sum().item() < n * 0.001
if any(req_grad): if any(req_grad):
out_bnb.data.copy_(out_torch) out_bnb.data.copy_(out_torch)
...@@ -78,16 +91,22 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): ...@@ -78,16 +91,22 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if req_grad[1]: if req_grad[1]:
n = gradB1.numel() n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx==0).sum().item() < n*0.1 assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx==0).sum().item() < n*0.02 assert (idx == 0).sum().item() < n * 0.02
torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3) torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3)
# batched matrix multiply # batched matrix multiply
if funcs[0] in [torch.bmm, torch.matmul]: if funcs[0] in [torch.bmm, torch.matmul]:
A = torch.randn(size=(dim1, dim2, dim3), device='cuda', requires_grad=req_grad[0]) A = torch.randn(
B = torch.randn(size=(dim1, dim3, dim4), device='cuda', requires_grad=req_grad[1]) size=(dim1, dim2, dim3), device="cuda", requires_grad=req_grad[0]
target = torch.randn(size=(dim1, dim2, dim4), device='cuda', requires_grad=req_grad[1]) )
B = torch.randn(
size=(dim1, dim3, dim4), device="cuda", requires_grad=req_grad[1]
)
target = torch.randn(
size=(dim1, dim2, dim4), device="cuda", requires_grad=req_grad[1]
)
torch.nn.init.xavier_uniform_(B) torch.nn.init.xavier_uniform_(B)
out_torch = funcs[0](A, B) out_torch = funcs[0](A, B)
...@@ -95,7 +114,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): ...@@ -95,7 +114,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
n = out_bnb.numel() n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx==0).sum().item() < n*0.01 assert (idx == 0).sum().item() < n * 0.01
torch.testing.assert_allclose(out_bnb, out_torch, atol=0.027, rtol=0.2) torch.testing.assert_allclose(out_bnb, out_torch, atol=0.027, rtol=0.2)
if any(req_grad): if any(req_grad):
...@@ -120,16 +139,20 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): ...@@ -120,16 +139,20 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if req_grad[1]: if req_grad[1]:
n = gradB1.numel() n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx==0).sum().item() < n*0.1 assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx==0).sum().item() < n*0.02 assert (idx == 0).sum().item() < n * 0.02
if funcs[0] in [torch.matmul]: if funcs[0] in [torch.matmul]:
dim1 = dim1 - (dim1 % 16) dim1 = dim1 - (dim1 % 16)
A = torch.randn(size=(dim1, dim2, dim3), device='cuda', requires_grad=req_grad[0]) A = torch.randn(
size=(dim1, dim2, dim3), device="cuda", requires_grad=req_grad[0]
)
dimB = (dim4, dim3) if transpose[1] else (dim3, dim4) dimB = (dim4, dim3) if transpose[1] else (dim3, dim4)
B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1]) B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
target = torch.randn(size=(dim1, dim2, dim4), device='cuda', requires_grad=req_grad[1]) target = torch.randn(
size=(dim1, dim2, dim4), device="cuda", requires_grad=req_grad[1]
)
torch.nn.init.xavier_uniform_(B) torch.nn.init.xavier_uniform_(B)
if transpose[1]: if transpose[1]:
...@@ -141,9 +164,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): ...@@ -141,9 +164,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
n = out_bnb.numel() n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx==0).sum().item() < n*0.0175 assert (idx == 0).sum().item() < n * 0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2) idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
assert (idx==0).sum().item() < n*0.001 assert (idx == 0).sum().item() < n * 0.001
if any(req_grad): if any(req_grad):
out_bnb.data.copy_(out_torch) out_bnb.data.copy_(out_torch)
...@@ -167,51 +190,96 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): ...@@ -167,51 +190,96 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if req_grad[1]: if req_grad[1]:
n = gradB1.numel() n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx==0).sum().item() < n*0.1 assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx==0).sum().item() < n*0.02 assert (idx == 0).sum().item() < n * 0.02
n = 1 n = 1
k = 3 k = 3
dim1 = torch.randint(16,64, size=(n,)).tolist() dim1 = torch.randint(16, 64, size=(n,)).tolist()
dim2 = torch.randint(32,96, size=(n,)).tolist() dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32,96, size=(n,)).tolist() dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32,96, size=(n,)).tolist() dim4 = torch.randint(32, 96, size=(n,)).tolist()
#dim1 = (17,) # dim1 = (17,)
#dim2 = (7,) # dim2 = (7,)
#dim3 = (37,) # dim3 = (37,)
#dim4 = (23,) # dim4 = (23,)
decomp = [0.0, 6.0] decomp = [0.0, 6.0]
funcs = [(torch.matmul, bnb.matmul)] funcs = [(torch.matmul, bnb.matmul)]
str_funcs = ['matmul'] str_funcs = ["matmul"]
req_grad = [(False, False), (True, False), (True, True), (False, True)] req_grad = [(False, False), (True, False), (True, True), (False, True)]
req_grad_str = ['FF', 'TF', 'TT', 'FT'] req_grad_str = ["FF", "TF", "TT", "FT"]
transpose = [(False, True), (False, False)] transpose = [(False, True), (False, False)]
str_transpose = ['NT', 'NN'] str_transpose = ["NT", "NN"]
dtype = [torch.float16] dtype = [torch.float16]
has_fp16_weights = [True, False] has_fp16_weights = [True, False]
values = list(product(dim1,dim2,dim3,dim4,funcs, dtype, req_grad, transpose, decomp, has_fp16_weights)) values = list(
str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, str_transpose, decomp, has_fp16_weights)) product(
names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}'.format(*vals) for vals in str_values] dim1,
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights", values, ids=names) dim2,
def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights): dim3,
dim4,
funcs,
dtype,
req_grad,
transpose,
decomp,
has_fp16_weights,
)
)
str_values = list(
product(
dim1,
dim2,
dim3,
dim4,
str_funcs,
dtype,
req_grad_str,
str_transpose,
decomp,
has_fp16_weights,
)
)
names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}".format(
*vals
)
for vals in str_values
]
@pytest.mark.parametrize(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights",
values,
ids=names,
)
def test_matmullt(
dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights
):
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
outlier_dim = torch.randint(0, dimA[1], size=(dimA[1]//8,), device='cuda') outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda")
for i in range(k): for i in range(k):
# normal multiply # normal multiply
if funcs[0] in [torch.mm, torch.matmul]: if funcs[0] in [torch.mm, torch.matmul]:
A = torch.randn(size=dimA, device='cuda', requires_grad=req_grad[0], dtype=dtype) A = torch.randn(
size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype
)
if decomp == 6.0: if decomp == 6.0:
with torch.no_grad(): with torch.no_grad():
A[:, outlier_dim] = 6.0 A[:, outlier_dim] = 6.0
B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1], dtype=dtype) B = torch.randn(
target = torch.randn(size=(dim2, dim4), device='cuda', requires_grad=req_grad[1], dtype=dtype) size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype
)
target = torch.randn(
size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype
)
torch.nn.init.xavier_uniform_(B) torch.nn.init.xavier_uniform_(B)
B2 = B.clone() B2 = B.clone()
...@@ -219,8 +287,15 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec ...@@ -219,8 +287,15 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
state.threshold = decomp state.threshold = decomp
state.has_fp16_weights = has_fp16_weights state.has_fp16_weights = has_fp16_weights
if not has_fp16_weights: if not has_fp16_weights:
if not transpose[0] and not transpose[1]: B2 = B2.t().contiguous() if not transpose[0] and not transpose[1]:
state.CB, CBt, state.SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B2) B2 = B2.t().contiguous()
(
state.CB,
CBt,
state.SCB,
SCBt,
coo_tensorB,
) = bnb.functional.double_quant(B2)
B2 = state.CB B2 = state.CB
if not transpose[0] and transpose[1]: if not transpose[0] and transpose[1]:
...@@ -231,12 +306,12 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec ...@@ -231,12 +306,12 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
out_bnb = funcs[1](A, B2.t(), state=state) out_bnb = funcs[1](A, B2.t(), state=state)
n = out_bnb.numel() n = out_bnb.numel()
err = torch.abs(out_bnb-out_torch).mean().item() err = torch.abs(out_bnb - out_torch).mean().item()
#print(f'abs error {err:.4f}') # print(f'abs error {err:.4f}')
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx==0).sum().item() < n*0.0175 assert (idx == 0).sum().item() < n * 0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2) idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
assert (idx==0).sum().item() < n*0.001 assert (idx == 0).sum().item() < n * 0.001
if has_fp16_weights: if has_fp16_weights:
if any(req_grad): if any(req_grad):
...@@ -263,8 +338,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec ...@@ -263,8 +338,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
assert torch.abs(gradB1).sum() > 0.0 assert torch.abs(gradB1).sum() > 0.0
assert torch.abs(gradB2).sum() > 0.0 assert torch.abs(gradB2).sum() > 0.0
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx==0).sum().item() < n*0.1 assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx==0).sum().item() < n*0.02 assert (idx == 0).sum().item() < n * 0.02
torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3) torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3)
import pytest
import os import os
from typing import List, NamedTuple
import pytest
from typing import List from bitsandbytes.cuda_setup import (CUDA_RUNTIME_LIB, evaluate_cuda_setup,
get_cuda_runtime_lib_path, tokenize_paths)
from bitsandbytes.cuda_setup import (
CUDA_RUNTIME_LIB,
get_cuda_runtime_lib_path,
evaluate_cuda_setup,
tokenize_paths,
)
class InputAndExpectedOutput(NamedTuple):
input: str
output: str
HAPPY_PATH__LD_LIB_TEST_PATHS: List[tuple[str,str]] = [
HAPPY_PATH__LD_LIB_TEST_PATHS: List[InputAndExpectedOutput] = [
(f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"), (f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"),
(f":some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"), (f":some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"),
(f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}:", f"dir/with/{CUDA_RUNTIME_LIB}"), (f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}:", f"dir/with/{CUDA_RUNTIME_LIB}"),
(f"some/other/dir::dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"), (f"some/other/dir::dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"),
(f"dir/with/{CUDA_RUNTIME_LIB}:some/other/dir", f"dir/with/{CUDA_RUNTIME_LIB}"), (f"dir/with/{CUDA_RUNTIME_LIB}:some/other/dir", f"dir/with/{CUDA_RUNTIME_LIB}"),
(f"dir/with/{CUDA_RUNTIME_LIB}:other/dir/libcuda.so", f"dir/with/{CUDA_RUNTIME_LIB}"), (
f"dir/with/{CUDA_RUNTIME_LIB}:other/dir/libcuda.so",
f"dir/with/{CUDA_RUNTIME_LIB}",
),
] ]
@pytest.mark.parametrize( @pytest.fixture(params=HAPPY_PATH__LD_LIB_TEST_PATHS)
"test_input, expected", def happy_path_path_string(tmpdir, request):
HAPPY_PATH__LD_LIB_TEST_PATHS for path in tokenize_paths(request.param):
) test_dir.mkdir()
if CUDA_RUNTIME_LIB in path:
(test_input / CUDA_RUNTIME_LIB).touch()
@pytest.mark.parametrize("test_input, expected", HAPPY_PATH__LD_LIB_TEST_PATHS)
def test_get_cuda_runtime_lib_path__happy_path( def test_get_cuda_runtime_lib_path__happy_path(
tmp_path, test_input: str, expected: str tmp_path, test_input: str, expected: str
): ):
for path in tokenize_paths(test_input): for path in tokenize_paths(test_input):
assert False == tmp_path / test_input path.mkdir()
test_dir.mkdir() (path / CUDA_RUNTIME_LIB).touch()
(test_input / CUDA_RUNTIME_LIB).touch()
assert get_cuda_runtime_lib_path(test_input) == expected assert get_cuda_runtime_lib_path(test_input) == expected
...@@ -47,40 +55,33 @@ def test_get_cuda_runtime_lib_path__unhappy_path(tmp_path, test_input: str): ...@@ -47,40 +55,33 @@ def test_get_cuda_runtime_lib_path__unhappy_path(tmp_path, test_input: str):
(test_input / CUDA_RUNTIME_LIB).touch() (test_input / CUDA_RUNTIME_LIB).touch()
with pytest.raises(FileNotFoundError) as err_info: with pytest.raises(FileNotFoundError) as err_info:
get_cuda_runtime_lib_path(test_input) get_cuda_runtime_lib_path(test_input)
assert all( assert all(match in err_info for match in {"duplicate", CUDA_RUNTIME_LIB})
match in err_info
for match in {"duplicate", CUDA_RUNTIME_LIB}
)
def test_get_cuda_runtime_lib_path__non_existent_dir(capsys, tmp_path): def test_get_cuda_runtime_lib_path__non_existent_dir(capsys, tmp_path):
existent_dir = tmp_path / 'a/b' existent_dir = tmp_path / "a/b"
existent_dir.mkdir() existent_dir.mkdir()
non_existent_dir = tmp_path / 'c/d' # non-existent dir non_existent_dir = tmp_path / "c/d" # non-existent dir
test_input = ":".join([str(existent_dir), str(non_existent_dir)]) test_input = ":".join([str(existent_dir), str(non_existent_dir)])
get_cuda_runtime_lib_path(test_input) get_cuda_runtime_lib_path(test_input)
std_err = capsys.readouterr().err std_err = capsys.readouterr().err
assert all( assert all(match in std_err for match in {"WARNING", "non-existent"})
match in std_err
for match in {"WARNING", "non-existent"}
)
def test_full_system(): def test_full_system():
## this only tests the cuda version and not compute capability ## this only tests the cuda version and not compute capability
ld_path = os.environ['LD_LIBRARY_PATH'] ld_path = os.environ["LD_LIBRARY_PATH"]
paths = ld_path.split(':') paths = ld_path.split(":")
version = '' version = ""
for p in paths: for p in paths:
if 'cuda' in p: if "cuda" in p:
idx = p.rfind('cuda-') idx = p.rfind("cuda-")
version = p[idx+5:idx+5+4].replace('/', '') version = p[idx + 5 : idx + 5 + 4].replace("/", "")
version = float(version) version = float(version)
break break
binary_name = evaluate_cuda_setup() binary_name = evaluate_cuda_setup()
binary_name = binary_name.replace('libbitsandbytes_cuda', '') binary_name = binary_name.replace("libbitsandbytes_cuda", "")
assert binary_name.startswith(str(version).replace('.', '')) assert binary_name.startswith(str(version).replace(".", ""))
import pytest
import math import math
import random import random
import time import time
import torch
import bitsandbytes as bnb
import einops
from itertools import product from itertools import product
import einops
import pytest
import torch
import bitsandbytes as bnb
from bitsandbytes import functional as F from bitsandbytes import functional as F
torch.set_printoptions(precision=4, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000) torch.set_printoptions(
precision=4, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
)
k = 20 k = 20
def assert_all_approx_close(a, b, rtol, atol, count): def assert_all_approx_close(a, b, rtol, atol, count):
idx = torch.isclose(a, b, rtol, atol) idx = torch.isclose(a, b, rtol, atol)
sumval = (idx==0).sum().item() sumval = (idx == 0).sum().item()
if sumval > count: if sumval > count:
print(f'Too many values not close: assert {sumval} < {count}') print(f"Too many values not close: assert {sumval} < {count}")
torch.testing.assert_allclose(a, b, rtol, atol) torch.testing.assert_allclose(a, b, rtol, atol)
class FFN(torch.nn.Module): class FFN(torch.nn.Module):
def __init__(self, input_features, hidden_size, bias=True): def __init__(self, input_features, hidden_size, bias=True):
super(FFN, self).__init__() super(FFN, self).__init__()
...@@ -35,13 +39,14 @@ class FFN(torch.nn.Module): ...@@ -35,13 +39,14 @@ class FFN(torch.nn.Module):
x = self.fc2(x) x = self.fc2(x)
return x return x
class Timer(object): class Timer(object):
def __init__(self): def __init__(self):
self.starts = {} self.starts = {}
self.ends = {} self.ends = {}
self.agg = {} self.agg = {}
def tick(self, name='default'): def tick(self, name="default"):
if name not in self.starts: if name not in self.starts:
self.starts[name] = torch.cuda.Event(enable_timing=True) self.starts[name] = torch.cuda.Event(enable_timing=True)
self.ends[name] = torch.cuda.Event(enable_timing=True) self.ends[name] = torch.cuda.Event(enable_timing=True)
...@@ -49,66 +54,70 @@ class Timer(object): ...@@ -49,66 +54,70 @@ class Timer(object):
else: else:
ms = self.tock(name, evict=True, print_ms=False) ms = self.tock(name, evict=True, print_ms=False)
def tock(self, name='default', evict=True, print_ms=True): def tock(self, name="default", evict=True, print_ms=True):
if name in self.ends: if name in self.ends:
self.ends[name].record() self.ends[name].record()
torch.cuda.synchronize() torch.cuda.synchronize()
ms = self.starts[name].elapsed_time(self.ends[name]) ms = self.starts[name].elapsed_time(self.ends[name])
if name not in self.agg: self.agg[name] = 0.0 if name not in self.agg:
self.agg[name] = 0.0
self.agg[name] += ms self.agg[name] += ms
if evict: if evict:
self.starts.pop(name) self.starts.pop(name)
self.ends.pop(name) self.ends.pop(name)
if print_ms and name in self.agg: if print_ms and name in self.agg:
print('{0} took: {1:.5f}s'.format(name, self.agg[name]/1000.0)) print("{0} took: {1:.5f}s".format(name, self.agg[name] / 1000.0))
return self.agg[name] return self.agg[name]
def reset(self): def reset(self):
self.starts = {} self.starts = {}
self.ends = {} self.ends = {}
self.agg = {} self.agg = {}
print('Resetting benchmark data') print("Resetting benchmark data")
def setup(): def setup():
pass pass
def teardown(): def teardown():
pass pass
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['float', 'half'])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["float", "half"])
def test_estimate_quantiles(dtype): def test_estimate_quantiles(dtype):
A = torch.rand(1024, 1024, device='cuda') A = torch.rand(1024, 1024, device="cuda")
A = A.to(dtype) A = A.to(dtype)
code = F.estimate_quantiles(A) code = F.estimate_quantiles(A)
percs = torch.linspace(1/512, 511/512, 256, device=A.device) percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device)
torch.testing.assert_allclose(percs, code, atol=1e-3, rtol=1e-2) torch.testing.assert_allclose(percs, code, atol=1e-3, rtol=1e-2)
A = torch.randn(1024, 1024, device='cuda') A = torch.randn(1024, 1024, device="cuda")
A = A.to(dtype) A = A.to(dtype)
code = F.estimate_quantiles(A) code = F.estimate_quantiles(A)
quantiles = torch.quantile(A.float(), percs) quantiles = torch.quantile(A.float(), percs)
diff = torch.abs(code-quantiles) diff = torch.abs(code - quantiles)
assert (diff > 5e-02).sum().item() == 0 assert (diff > 5e-02).sum().item() == 0
def test_quantile_quantization(): def test_quantile_quantization():
for i in range(100): for i in range(100):
A1 = torch.randn(1024, 1024, device='cuda') A1 = torch.randn(1024, 1024, device="cuda")
code = F.estimate_quantiles(A1) code = F.estimate_quantiles(A1)
C = F.quantize_no_absmax(A1, code) C = F.quantize_no_absmax(A1, code)
A2 = F.dequantize_no_absmax(C, code) A2 = F.dequantize_no_absmax(C, code)
diff = torch.abs(A1-A2).mean().item() diff = torch.abs(A1 - A2).mean().item()
assert diff < 0.0075 assert diff < 0.0075
A1 = torch.rand(1024, 1024, device='cuda') A1 = torch.rand(1024, 1024, device="cuda")
code = F.estimate_quantiles(A1) code = F.estimate_quantiles(A1)
C = F.quantize_no_absmax(A1, code) C = F.quantize_no_absmax(A1, code)
A2 = F.dequantize_no_absmax(C, code) A2 = F.dequantize_no_absmax(C, code)
diff = torch.abs(A1-A2).mean().item() diff = torch.abs(A1 - A2).mean().item()
torch.testing.assert_allclose(A1, A2, atol=5e-3, rtol=0) torch.testing.assert_allclose(A1, A2, atol=5e-3, rtol=0)
assert diff < 0.001 assert diff < 0.001
...@@ -117,22 +126,22 @@ def test_dynamic_quantization(): ...@@ -117,22 +126,22 @@ def test_dynamic_quantization():
diffs = [] diffs = []
reldiffs = [] reldiffs = []
for i in range(100): for i in range(100):
A1 = torch.randn(1024, 1024, device='cuda') A1 = torch.randn(1024, 1024, device="cuda")
C, S = F.quantize(A1) C, S = F.quantize(A1)
A2 = F.dequantize(C, S) A2 = F.dequantize(C, S)
diff = torch.abs(A1-A2) diff = torch.abs(A1 - A2)
reldiff = diff/torch.abs(A1+1e-8) reldiff = diff / torch.abs(A1 + 1e-8)
diffs.append(diff.mean().item()) diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item()) reldiffs.append(reldiff.mean().item())
assert diff.mean().item() < 0.0135 assert diff.mean().item() < 0.0135
#print(sum(diffs)/len(diffs)) # print(sum(diffs)/len(diffs))
#print(sum(reldiffs)/len(reldiffs)) # print(sum(reldiffs)/len(reldiffs))
for i in range(100): for i in range(100):
A1 = torch.rand(1024, 1024, device='cuda') A1 = torch.rand(1024, 1024, device="cuda")
C, S = F.quantize(A1) C, S = F.quantize(A1)
A2 = F.dequantize(C, S) A2 = F.dequantize(C, S)
diff = torch.abs(A1-A2).mean().item() diff = torch.abs(A1 - A2).mean().item()
torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0) torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
assert diff < 0.004 assert diff < 0.004
...@@ -141,56 +150,60 @@ def test_dynamic_blockwise_quantization(): ...@@ -141,56 +150,60 @@ def test_dynamic_blockwise_quantization():
diffs = [] diffs = []
reldiffs = [] reldiffs = []
for i in range(100): for i in range(100):
A1 = torch.randn(1024, 1024, device='cuda') A1 = torch.randn(1024, 1024, device="cuda")
C, S = F.quantize_blockwise(A1) C, S = F.quantize_blockwise(A1)
A2 = F.dequantize_blockwise(C, S) A2 = F.dequantize_blockwise(C, S)
diff = torch.abs(A1-A2) diff = torch.abs(A1 - A2)
reldiff = diff/torch.abs(A1+1e-8) reldiff = diff / torch.abs(A1 + 1e-8)
diffs.append(diff.mean().item()) diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item()) reldiffs.append(reldiff.mean().item())
assert diffs[-1] < 0.011 assert diffs[-1] < 0.011
#print(sum(diffs)/len(diffs)) # print(sum(diffs)/len(diffs))
#print(sum(reldiffs)/len(reldiffs)) # print(sum(reldiffs)/len(reldiffs))
diffs = [] diffs = []
for i in range(100): for i in range(100):
A1 = torch.rand(1024, 1024, device='cuda') A1 = torch.rand(1024, 1024, device="cuda")
C, S = F.quantize_blockwise(A1) C, S = F.quantize_blockwise(A1)
A2 = F.dequantize_blockwise(C, S) A2 = F.dequantize_blockwise(C, S)
diff = torch.abs(A1-A2).mean().item() diff = torch.abs(A1 - A2).mean().item()
assert diff < 0.0033 assert diff < 0.0033
diffs.append(diff) diffs.append(diff)
torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0) torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
#print(sum(diffs)/len(diffs)) # print(sum(diffs)/len(diffs))
def test_dynamic_blockwise_stochastic_quantization(): def test_dynamic_blockwise_stochastic_quantization():
diffs = [] diffs = []
reldiffs = [] reldiffs = []
rand = torch.rand(1024).cuda() rand = torch.rand(1024).cuda()
for i in range(100): for i in range(100):
A1 = torch.randn(1024, 1024, device='cuda') A1 = torch.randn(1024, 1024, device="cuda")
C1, S1 = F.quantize_blockwise(A1, rand=rand) C1, S1 = F.quantize_blockwise(A1, rand=rand)
C2, S2 = F.quantize_blockwise(A1) C2, S2 = F.quantize_blockwise(A1)
# a maximunm distance of quantized values of 1 # a maximunm distance of quantized values of 1
torch.testing.assert_allclose(C1, C2, atol=1, rtol=0) torch.testing.assert_allclose(C1, C2, atol=1, rtol=0)
fraction_smaller = (C1<C2).float().sum()/C1.numel() fraction_smaller = (C1 < C2).float().sum() / C1.numel()
fraction_larger = (C1>C2).float().sum()/C1.numel() fraction_larger = (C1 > C2).float().sum() / C1.numel()
torch.testing.assert_allclose(fraction_larger, fraction_smaller, atol=0.01, rtol=0) torch.testing.assert_allclose(
fraction_larger, fraction_smaller, atol=0.01, rtol=0
)
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=["float", "half"])
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=['float', 'half'])
def test_percentile_clipping(gtype): def test_percentile_clipping(gtype):
gnorm_vec1 = torch.zeros(100, device='cuda') gnorm_vec1 = torch.zeros(100, device="cuda")
gnorm_vec2 = torch.zeros(100, device='cuda') gnorm_vec2 = torch.zeros(100, device="cuda")
n = 4 n = 4
step = 0 step = 0
percentile=5 percentile = 5
for i in range(k): for i in range(k):
step += 1 step += 1
g = torch.randn(n, n, dtype=gtype, device='cuda') g = torch.randn(n, n, dtype=gtype, device="cuda")
gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile) gnorm1, clip2, gnorm_scale = F.percentile_clipping(
assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2/gnorm1 g, gnorm_vec2, step, percentile=percentile
)
assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1
gnorm2 = torch.norm(g.float()) gnorm2 = torch.norm(g.float())
if step == 1: if step == 1:
...@@ -208,74 +221,89 @@ def test_percentile_clipping(gtype): ...@@ -208,74 +221,89 @@ def test_percentile_clipping(gtype):
def quant(x): def quant(x):
max1 = torch.abs(x).max() max1 = torch.abs(x).max()
x = torch.round(x/max1*127) x = torch.round(x / max1 * 127)
return max1, x.to(torch.int8) return max1, x.to(torch.int8)
def dequant(c, maxC): def dequant(c, maxC):
return c.float()*(maxC/127) return c.float() * (maxC / 127)
def mm_dequant(maxA, maxB, C): def mm_dequant(maxA, maxB, C):
return C.float()*(maxA/127)*(maxB/127) return C.float() * (maxA / 127) * (maxB / 127)
def quant_multi(x, dim): def quant_multi(x, dim):
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
max1[max1==0] = 1.0 max1[max1 == 0] = 1.0
x = torch.round(x/max1*127) x = torch.round(x / max1 * 127)
return max1, x.to(torch.int8) return max1, x.to(torch.int8)
def quant_multi_chunk(x, dim, chunk_size=32): def quant_multi_chunk(x, dim, chunk_size=32):
if dim==1: if dim == 1:
x_chunked = einops.rearrange(x, '(c a) b -> c a b', c=chunk_size) x_chunked = einops.rearrange(x, "(c a) b -> c a b", c=chunk_size)
max1 = torch.amax(torch.abs(x_chunked), dim=dim+1, keepdim=True) max1 = torch.amax(torch.abs(x_chunked), dim=dim + 1, keepdim=True)
max1 = torch.tile(max1, (1, 1, x.shape[1])) max1 = torch.tile(max1, (1, 1, x.shape[1]))
max1 = max1.view(x.shape) max1 = max1.view(x.shape)
elif dim==0: elif dim == 0:
x_chunked = einops.rearrange(x, 'a (b c) -> a b c', c=chunk_size) x_chunked = einops.rearrange(x, "a (b c) -> a b c", c=chunk_size)
max1 = torch.amax(torch.abs(x_chunked), dim=dim, keepdim=True) max1 = torch.amax(torch.abs(x_chunked), dim=dim, keepdim=True)
max1 = torch.tile(max1, (x.shape[0], 1, 1)) max1 = torch.tile(max1, (x.shape[0], 1, 1))
max1 = max1.view(x.shape) max1 = max1.view(x.shape)
max1[max1==0] = 1.0 max1[max1 == 0] = 1.0
x = torch.round(x/max1*127) x = torch.round(x / max1 * 127)
return max1, x.to(torch.int8) return max1, x.to(torch.int8)
def quant_minmax(A): def quant_minmax(A):
minA = A.min() minA = A.min()
maxA = A.max() maxA = A.max()
def mean(xx): def mean(xx):
return sum(xx)/float(len(xx)) return sum(xx) / float(len(xx))
#dim1 = torch.randint(1,1024*4, size=(4,)).tolist() # dim1 = torch.randint(1,1024*4, size=(4,)).tolist()
#dim2 = torch.randint(1,1024*4, size=(4,)).tolist() # dim2 = torch.randint(1,1024*4, size=(4,)).tolist()
dim1 = [1024*2] dim1 = [1024 * 2]
dim2 = [1024*16] dim2 = [1024 * 16]
methods = [(lambda x, dim: quant(x), lambda x, dim: quant(x), dequant, dequant, mm_dequant)] methods = [
(lambda x, dim: quant(x), lambda x, dim: quant(x), dequant, dequant, mm_dequant)
]
methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant)) methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant))
#methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant)) # methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
method_names = ['linear', 'vectorwise'] method_names = ["linear", "vectorwise"]
batched = [False, True] batched = [False, True]
values = list(product(dim1,dim2, methods, batched)) values = list(product(dim1, dim2, methods, batched))
values_names = list(product(dim1,dim2, method_names, batched)) values_names = list(product(dim1, dim2, method_names, batched))
names = ['dim1_{0}_dim2_{1}_quant_{2}_batched_{3}'.format(*vals) for vals in values_names] names = [
"dim1_{0}_dim2_{1}_quant_{2}_batched_{3}".format(*vals) for vals in values_names
]
@pytest.mark.parametrize("dim1, dim2, quant_methods, batched", values, ids=names) @pytest.mark.parametrize("dim1, dim2, quant_methods, batched", values, ids=names)
def test_approx_igemm(dim1, dim2, quant_methods, batched): def test_approx_igemm(dim1, dim2, quant_methods, batched):
dim1 = dim1 - (dim1 % 32) dim1 = dim1 - (dim1 % 32)
dim2 = dim2 - (dim2 % 32) dim2 = dim2 - (dim2 % 32)
errors = [] errors = []
relerrors = [] relerrors = []
print('') print("")
for i in range(5): for i in range(5):
if batched: if batched:
A = torch.normal(0, 0.5, size=(32, dim1, dim2//32), device='cuda') A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda")
B = torch.normal(0, 0.5, size=(32, dim2//32, dim1), device='cuda') B = torch.normal(0, 0.5, size=(32, dim2 // 32, dim1), device="cuda")
maxA, Ac = quant_methods[0](A, 2) maxA, Ac = quant_methods[0](A, 2)
maxB, Bc = quant_methods[1](B, 1) maxB, Bc = quant_methods[1](B, 1)
else: else:
A = torch.normal(0, 0.5, size=(dim1, dim2), device='cuda') A = torch.normal(0, 0.5, size=(dim1, dim2), device="cuda")
B = torch.normal(0, 0.5, size=(dim2, dim1), device='cuda') B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda")
maxA, Ac = quant_methods[0](A, 1) maxA, Ac = quant_methods[0](A, 1)
maxB, Bc = quant_methods[1](B, 0) maxB, Bc = quant_methods[1](B, 0)
torch.testing.assert_allclose(quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05) torch.testing.assert_allclose(
quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05
)
if batched: if batched:
out2 = torch.bmm(A, B) out2 = torch.bmm(A, B)
C = torch.bmm(Ac.float(), Bc.float()) C = torch.bmm(Ac.float(), Bc.float())
...@@ -284,43 +312,49 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched): ...@@ -284,43 +312,49 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
C = F.igemm(Ac, Bc) C = F.igemm(Ac, Bc)
out = quant_methods[4](maxA, maxB, C) out = quant_methods[4](maxA, maxB, C)
std = out2.std() std = out2.std()
out/= std out /= std
out2/= std out2 /= std
err = torch.abs(out-out2) err = torch.abs(out - out2)
relerr = err/torch.abs(out2) relerr = err / torch.abs(out2)
errors.append(err.mean().item()) errors.append(err.mean().item())
relerrors.append(relerr.mean().item()) relerrors.append(relerr.mean().item())
print(mean(errors)) print(mean(errors))
print(mean(relerrors)) print(mean(relerrors))
def test_stable_embedding(): def test_stable_embedding():
layer = bnb.nn.StableEmbedding(1024, 1024) layer = bnb.nn.StableEmbedding(1024, 1024)
layer.reset_parameters() layer.reset_parameters()
n = 2 n = 2
hidden_dim = torch.randint(32,256, size=(n,)).tolist() hidden_dim = torch.randint(32, 256, size=(n,)).tolist()
batch_dim = torch.randint(16,256, size=(n,)).tolist() batch_dim = torch.randint(16, 256, size=(n,)).tolist()
seq_dim = torch.randint(16,256, size=(n,)).tolist() seq_dim = torch.randint(16, 256, size=(n,)).tolist()
transpose = [(False, False), (False, True), (True, False), (True, True)] transpose = [(False, False), (False, True), (True, False), (True, True)]
values = list(product(hidden_dim,batch_dim, transpose, seq_dim)) values = list(product(hidden_dim, batch_dim, transpose, seq_dim))
names = ['hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}'.format(*vals) for vals in values] names = [
"hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}".format(*vals)
for vals in values
]
@pytest.mark.parametrize("hidden_dim, batch_dim, transpose, seq_dim", values, ids=names) @pytest.mark.parametrize("hidden_dim, batch_dim, transpose, seq_dim", values, ids=names)
def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
hidden_dim = hidden_dim - (hidden_dim % 32) hidden_dim = hidden_dim - (hidden_dim % 32)
batch_dim = batch_dim - (batch_dim % 16) batch_dim = batch_dim - (batch_dim % 16)
seq_dim = seq_dim - (seq_dim % 16) seq_dim = seq_dim - (seq_dim % 16)
for i in range(k): for i in range(k):
shapeA = (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim) shapeA = (
shapeB = ((32*random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32*random.randint(1, 4))) (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim)
A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8) )
B = torch.randint(-128, 127, size=shapeB, device='cuda').to(torch.int8) shapeB = (
(32 * random.randint(1, 4), hidden_dim)
if transpose[1]
else (hidden_dim, 32 * random.randint(1, 4))
)
A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
if not transpose[0] and not transpose[1]: if not transpose[0] and not transpose[1]:
out2 = torch.matmul(A.float(), B.float()) out2 = torch.matmul(A.float(), B.float())
out = F.igemm(A, B) out = F.igemm(A, B)
...@@ -338,9 +372,13 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): ...@@ -338,9 +372,13 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
for i in range(k): for i in range(k):
shapeA = (batch_dim, seq_dim, hidden_dim) shapeA = (batch_dim, seq_dim, hidden_dim)
shapeB = ((32*random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32*random.randint(1, 4))) shapeB = (
A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8) (32 * random.randint(1, 4), hidden_dim)
B = torch.randint(-128, 127, size=shapeB, device='cuda').to(torch.int8) if transpose[1]
else (hidden_dim, 32 * random.randint(1, 4))
)
A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
if not transpose[0] and not transpose[1]: if not transpose[0] and not transpose[1]:
out2 = torch.matmul(A.float(), B.float()) out2 = torch.matmul(A.float(), B.float())
out = F.igemm(A, B) out = F.igemm(A, B)
...@@ -352,40 +390,51 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): ...@@ -352,40 +390,51 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
n = 3 n = 3
seq_dim = torch.randint(32,512, size=(n,)).tolist() seq_dim = torch.randint(32, 512, size=(n,)).tolist()
hidden_dim = torch.randint(32,1024*4, size=(n,)).tolist() hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
batch_dim = torch.randint(2,16, size=(n,)).tolist() batch_dim = torch.randint(2, 16, size=(n,)).tolist()
values = list(product(seq_dim,hidden_dim,batch_dim)) values = list(product(seq_dim, hidden_dim, batch_dim))
names = ['seq_dim{0}_hidden_dim{1}_batch_dim{2}'.format(*vals) for vals in values] names = ["seq_dim{0}_hidden_dim{1}_batch_dim{2}".format(*vals) for vals in values]
@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names) @pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names)
def test_dim3_igemm(seq_dim, hidden_dim, batch_dim): def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
seq_dim = seq_dim - (seq_dim % 32) seq_dim = seq_dim - (seq_dim % 32)
hidden_dim = hidden_dim - (hidden_dim % 32) hidden_dim = hidden_dim - (hidden_dim % 32)
batch_dim = batch_dim - (batch_dim % 2) batch_dim = batch_dim - (batch_dim % 2)
for i in range(25): for i in range(25):
A = torch.randint(-128, 127, size=(batch_dim, seq_dim, hidden_dim), device='cuda').to(torch.int8) A = torch.randint(
B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device='cuda').to(torch.int8) -128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
out2 = torch.einsum('bsi, bso->io', A.float(), B.float()) ).to(torch.int8)
B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda").to(
torch.int8
)
out2 = torch.einsum("bsi, bso->io", A.float(), B.float())
iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device) iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device)
out = F.igemm(A, B, out=iout) out = F.igemm(A, B, out=iout)
torch.testing.assert_allclose(out.float(), out2) torch.testing.assert_allclose(out.float(), out2)
n = 2 n = 2
seq_dim = torch.randint(32,512, size=(n,)).tolist() seq_dim = torch.randint(32, 512, size=(n,)).tolist()
hidden_dim = torch.randint(32,1024*4, size=(n,)).tolist() hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
batch_dim = torch.randint(2,16, size=(n,)).tolist() batch_dim = torch.randint(2, 16, size=(n,)).tolist()
transpose = [False, True] transpose = [False, True]
values = list(product(seq_dim,hidden_dim,batch_dim, transpose)) values = list(product(seq_dim, hidden_dim, batch_dim, transpose))
names = ['seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}'.format(*vals) for vals in values] names = [
"seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}".format(*vals)
for vals in values
]
@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim, transpose", values, ids=names) @pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim, transpose", values, ids=names)
def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose): def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
def min_max(x): def min_max(x):
maxA = torch.amax(x, dim=2, keepdim=True) maxA = torch.amax(x, dim=2, keepdim=True)
minA = torch.amin(x, dim=2, keepdim=True) minA = torch.amin(x, dim=2, keepdim=True)
scale = (maxA-minA)/2.0 scale = (maxA - minA) / 2.0
return (127*(x-minA-scale)/scale).to(torch.int8), minA, scale return (127 * (x - minA - scale) / scale).to(torch.int8), minA, scale
seq_dim = seq_dim - (seq_dim % 16) seq_dim = seq_dim - (seq_dim % 16)
hidden_dim = hidden_dim - (hidden_dim % 16) hidden_dim = hidden_dim - (hidden_dim % 16)
...@@ -395,30 +444,30 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose): ...@@ -395,30 +444,30 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
errs2 = [] errs2 = []
relerrs2 = [] relerrs2 = []
for i in range(k): for i in range(k):
A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device='cuda') A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda")
if transpose: if transpose:
B = torch.normal(0, 0.5, size=(256, hidden_dim), device='cuda') B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda")
else: else:
B = torch.normal(0, 0.5, size=(hidden_dim, 256), device='cuda') B = torch.normal(0, 0.5, size=(hidden_dim, 256), device="cuda")
Ac, minA, scale = min_max(A) Ac, minA, scale = min_max(A)
if transpose: if transpose:
maxB, Bc = quant_multi(B, dim=(1 if transpose else 0)) maxB, Bc = quant_multi(B, dim=(1 if transpose else 0))
out = F.igemm(Ac, Bc.t()) out = F.igemm(Ac, Bc.t())
out2 = torch.matmul(A,B.t()) out2 = torch.matmul(A, B.t())
offset = B.t().sum(0)*(minA+scale) offset = B.t().sum(0) * (minA + scale)
out = out.float() out = out.float()
out = (out*maxB.t()*scale/(127*127))+offset out = (out * maxB.t() * scale / (127 * 127)) + offset
maxA, Ac = quant_multi(A, dim=2) maxA, Ac = quant_multi(A, dim=2)
out3 = F.igemm(Ac, Bc.t()) out3 = F.igemm(Ac, Bc.t())
out3 = mm_dequant(maxA, maxB.t(), out3) out3 = mm_dequant(maxA, maxB.t(), out3)
else: else:
maxB, Bc = quant_multi(B, dim=0) maxB, Bc = quant_multi(B, dim=0)
offset = B.sum(0)*(minA+scale) offset = B.sum(0) * (minA + scale)
out = F.igemm(Ac, Bc) out = F.igemm(Ac, Bc)
out2 = torch.matmul(A,B) out2 = torch.matmul(A, B)
out = out.float() out = out.float()
out = (out*maxB*scale/(127*127))+offset out = (out * maxB * scale / (127 * 127)) + offset
maxA, Ac = quant_multi(A, dim=2) maxA, Ac = quant_multi(A, dim=2)
out3 = F.igemm(Ac, Bc) out3 = F.igemm(Ac, Bc)
...@@ -429,31 +478,36 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose): ...@@ -429,31 +478,36 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
out /= std out /= std
out3 /= std out3 /= std
err = torch.abs(out-out2) err = torch.abs(out - out2)
relerr = err/(torch.abs(out2)+1e-7) relerr = err / (torch.abs(out2) + 1e-7)
err2 = torch.abs(out3-out2) err2 = torch.abs(out3 - out2)
relerr2 = err2/(torch.abs(out2)+1e-7) relerr2 = err2 / (torch.abs(out2) + 1e-7)
errs.append(err.mean().item()) errs.append(err.mean().item())
relerrs.append(relerr.mean().item()) relerrs.append(relerr.mean().item())
errs2.append(err2.mean().item()) errs2.append(err2.mean().item())
relerrs2.append(relerr2.mean().item()) relerrs2.append(relerr2.mean().item())
#print(mean(errs)) # print(mean(errs))
#print(mean(relerrs)) # print(mean(relerrs))
#print(mean(errs2)) # print(mean(errs2))
#print(mean(relerrs2)) # print(mean(relerrs2))
assert mean(errs) < 0.015 assert mean(errs) < 0.015
assert mean(relerrs) < 0.3 assert mean(relerrs) < 0.3
n = 2 n = 2
dim1 = torch.randint(1,64, size=(n,)).tolist() dim1 = torch.randint(1, 64, size=(n,)).tolist()
dim2 = torch.randint(32,128, size=(n,)).tolist() dim2 = torch.randint(32, 128, size=(n,)).tolist()
dim3 = torch.randint(32,256, size=(n,)).tolist() dim3 = torch.randint(32, 256, size=(n,)).tolist()
dim4 = torch.randint(32,256, size=(n,)).tolist() dim4 = torch.randint(32, 256, size=(n,)).tolist()
transpose = [(False, False), (True, False), (False, True), (True, True)] transpose = [(False, False), (True, False), (False, True), (True, True)]
values = list(product(dim1,dim2,dim3,dim4,transpose)) values = list(product(dim1, dim2, dim3, dim4, transpose))
names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}'.format(*vals) for vals in values] names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}".format(*vals) for vals in values
]
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, transpose", values, ids=names) @pytest.mark.parametrize("dim1, dim2, dim3, dim4, transpose", values, ids=names)
def test_ibmm(dim1, dim2, dim3, dim4, transpose): def test_ibmm(dim1, dim2, dim3, dim4, transpose):
dim2 = dim2 - (dim2 % 16) dim2 = dim2 - (dim2 % 16)
...@@ -462,8 +516,8 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose): ...@@ -462,8 +516,8 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
for i in range(k): for i in range(k):
shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3) shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3)
shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4) shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4)
A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8) A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
B = torch.randint(-128, 127, size=shapeB, device='cuda').to(torch.int8) B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
if not transpose[0] and not transpose[1]: if not transpose[0] and not transpose[1]:
out2 = torch.bmm(A.float(), B.float()) out2 = torch.bmm(A.float(), B.float())
...@@ -479,146 +533,174 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose): ...@@ -479,146 +533,174 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1])) out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
torch.testing.assert_allclose(out.float(), out2.float()) torch.testing.assert_allclose(out.float(), out2.float())
n = 1 n = 1
dim1 = torch.randint(1,64, size=(n,)).tolist() dim1 = torch.randint(1, 64, size=(n,)).tolist()
dim2 = torch.randint(32,128, size=(n,)).tolist() dim2 = torch.randint(32, 128, size=(n,)).tolist()
dim3 = torch.randint(32,256, size=(n,)).tolist() dim3 = torch.randint(32, 256, size=(n,)).tolist()
values = list(product(dim1,dim2,dim3)) values = list(product(dim1, dim2, dim3))
names = ['dim1_{0}_dim2_{1}_dim3_{2}'.format(*vals) for vals in values] names = ["dim1_{0}_dim2_{1}_dim3_{2}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names) @pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names)
def test_vector_quant(dim1, dim2, dim3): def test_vector_quant(dim1, dim2, dim3):
dim2 = dim2 - (dim2 % 16) dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16) dim3 = dim3 - (dim3 % 16)
for i in range(k): for i in range(k):
A = torch.randn(size=(dim2, dim3), device='cuda') A = torch.randn(size=(dim2, dim3), device="cuda")
qA, SA = F.vectorwise_quant(A, dim=0) qA, SA = F.vectorwise_quant(A, dim=0)
A1 = F.vectorwise_dequant(qA, SA) A1 = F.vectorwise_dequant(qA, SA)
torch.testing.assert_allclose(A1, A, atol=0.01, rtol=0.1) torch.testing.assert_allclose(A1, A, atol=0.01, rtol=0.1)
n = 2 n = 2
dim1 = torch.randint(2,256, size=(n,)).tolist() dim1 = torch.randint(2, 256, size=(n,)).tolist()
dim2 = torch.randint(2,256, size=(n,)).tolist() dim2 = torch.randint(2, 256, size=(n,)).tolist()
dim3 = torch.randint(2,256, size=(n,)).tolist() dim3 = torch.randint(2, 256, size=(n,)).tolist()
#dim1, dim2 = (256,), (256,) # dim1, dim2 = (256,), (256,)
dtype = [torch.int8, torch.int32] dtype = [torch.int8, torch.int32]
a_order = ['row'] a_order = ["row"]
out_order = ['col', 'row', 'col32'] out_order = ["col", "row", "col32"]
transpose = [False] transpose = [False]
dims = [2, 3] dims = [2, 3]
values = list(product(dim1,dim2,dim3, dims,dtype, a_order, out_order, transpose)) values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose))
names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}".format(
*vals
)
for vals in values
]
names = ['dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}'.format(*vals) for vals in values] @pytest.mark.parametrize(
@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names) "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names
)
def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
if dims == 3 and out_order != 'col32': return if dims == 3 and out_order != "col32":
if dtype == torch.int32 and out_order != 'col32': return return
if dtype == torch.int32 and out_order != "col32":
return
func = F.get_transform_func(dtype, orderA, orderOut, transpose) func = F.get_transform_func(dtype, orderA, orderOut, transpose)
if dims == 2: if dims == 2:
A = torch.randint(-128, 127, size=(dim1, dim2), device='cuda').to(dtype) A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype)
elif dims == 3: elif dims == 3:
A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(dtype) A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(dtype)
out, S = F.nvidia_transform(A, to_order=orderOut) out, S = F.nvidia_transform(A, to_order=orderOut)
if orderOut == 'row': if orderOut == "row":
torch.testing.assert_allclose(A.flatten(), out.flatten()) torch.testing.assert_allclose(A.flatten(), out.flatten())
elif orderOut == 'col': elif orderOut == "col":
torch.testing.assert_allclose(A.t().flatten(), out.flatten()) torch.testing.assert_allclose(A.t().flatten(), out.flatten())
elif orderOut == 'col32': elif orderOut == "col32":
if dims == 2: if dims == 2:
n = A.shape[0]*(A.shape[1] + (32 - (A.shape[1]%32))) n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
elif dims == 3: elif dims == 3:
n = A.shape[0]*A.shape[1]*(A.shape[2] + (32 - (A.shape[2]%32))) n = A.shape[0] * A.shape[1] * (A.shape[2] + (32 - (A.shape[2] % 32)))
assert out.numel() == n assert out.numel() == n
elif orderOut == 'col_turing': elif orderOut == "col_turing":
# 32 col 8 row tiles # 32 col 8 row tiles
n = (A.shape[0]+(8- A.shape[0]%8))*(A.shape[1] + (32 - (A.shape[1]%32))) n = (A.shape[0] + (8 - A.shape[0] % 8)) * (
A.shape[1] + (32 - (A.shape[1] % 32))
)
assert out.numel() == n assert out.numel() == n
total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0) total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0)
for row in range(A.shape[0]): for row in range(A.shape[0]):
for col in range(A.shape[1]): for col in range(A.shape[1]):
i = row*A.shape[1] i = row * A.shape[1]
j = col j = col
coltile = (col // 32) + (1 if col % 32 != 0 else 0) coltile = (col // 32) + (1 if col % 32 != 0 else 0)
rowtile = ((row // 8) + (1 if row % 8 != 0 else 0))*total_coltile rowtile = ((row // 8) + (1 if row % 8 != 0 else 0)) * total_coltile
offset = 32*8*(rowtile+coltile) offset = 32 * 8 * (rowtile + coltile)
col2 = col % 32 col2 = col % 32
row2 = (row%8)*32 row2 = (row % 8) * 32
assert A.flatten()[i + j] == A[row, col]
# assert A.flatten()[i+j] == out.flatten()[row2+col2]
# torch.testing.assert_allclose(A.flatten()[i+j], A[row, col])
# torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
assert A.flatten()[i+j] == A[row, col] if orderOut == "col32":
#assert A.flatten()[i+j] == out.flatten()[row2+col2] out2, S = F.nvidia_transform(out, from_order=orderOut, to_order="row", state=S)
#torch.testing.assert_allclose(A.flatten()[i+j], A[row, col])
#torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
if orderOut == 'col32':
out2, S = F.nvidia_transform(out, from_order=orderOut, to_order='row', state=S)
torch.testing.assert_allclose(A, out2) torch.testing.assert_allclose(A, out2)
n = 1 n = 1
dim1 = torch.randint(1,256, size=(n,)).tolist() dim1 = torch.randint(1, 256, size=(n,)).tolist()
dim2 = torch.randint(32,512, size=(n,)).tolist() dim2 = torch.randint(32, 512, size=(n,)).tolist()
dim3 = torch.randint(32,1024, size=(n,)).tolist() dim3 = torch.randint(32, 1024, size=(n,)).tolist()
dim4 = torch.randint(32,1024, size=(n,)).tolist() dim4 = torch.randint(32, 1024, size=(n,)).tolist()
#dim1 = [2] # dim1 = [2]
#dim2 = [2] # dim2 = [2]
#dim3 = [2] # dim3 = [2]
#dim4 = [2] # dim4 = [2]
dims = (2,3) dims = (2, 3)
ldb = [0] ldb = [0]
#ldb = list(range(256, 1*1024, 256)) # ldb = list(range(256, 1*1024, 256))
values = list(product(dim1,dim2,dim3,dim4,dims, ldb)) values = list(product(dim1, dim2, dim3, dim4, dims, ldb))
names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}'.format(*vals) for vals in values] names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}".format(*vals)
for vals in values
]
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims, ldb", values, ids=names) @pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims, ldb", values, ids=names)
def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
for i in range(k): for i in range(k):
if dims == 2: if dims == 2:
A = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8) A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(
torch.int8
)
elif dims == 3: elif dims == 3:
A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8) A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(
B = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8) torch.int8
)
B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8)
C1 = torch.matmul(A.float(), B.t().float()) C1 = torch.matmul(A.float(), B.t().float())
A2, SA = F.transform(A, 'col32') A2, SA = F.transform(A, "col32")
B2, SB = F.transform(B, 'col_turing') B2, SB = F.transform(B, "col_turing")
C2, SC = F.igemmlt(A2, B2, SA, SB) C2, SC = F.igemmlt(A2, B2, SA, SB)
C3, S = F.nvidia_transform(C2, 'row', state=SC) C3, S = F.nvidia_transform(C2, "row", state=SC)
torch.testing.assert_allclose(C1, C3.float()) torch.testing.assert_allclose(C1, C3.float())
# transpose # transpose
B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8) B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(torch.int8)
C1 = torch.matmul(A.float(), B.float()) C1 = torch.matmul(A.float(), B.float())
B2t, SBt = F.transform(B, 'col_turing', transpose=True) B2t, SBt = F.transform(B, "col_turing", transpose=True)
C2, SC = F.igemmlt(A2, B2t, SA, SBt) C2, SC = F.igemmlt(A2, B2t, SA, SBt)
C3, S = F.nvidia_transform(C2, 'row', state=SC) C3, S = F.nvidia_transform(C2, "row", state=SC)
torch.testing.assert_allclose(C1, C3.float()) torch.testing.assert_allclose(C1, C3.float())
dim1 = [32] dim1 = [32]
dim2 = [32] dim2 = [32]
dim3 = [32] dim3 = [32]
dim4 = [32] dim4 = [32]
dims = (2,) dims = (2,)
#ldb = list(range(256, 1*1024, 256)) # ldb = list(range(256, 1*1024, 256))
values = list(product(dim1,dim2,dim3,dim4,dims)) values = list(product(dim1, dim2, dim3, dim4, dims))
names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}'.format(*vals) for vals in values] names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}".format(*vals) for vals in values
]
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims", values, ids=names) @pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims", values, ids=names)
def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
formatB = F.get_special_format_str() formatB = F.get_special_format_str()
for i in range(k): for i in range(k):
if dims == 2: if dims == 2:
A = torch.normal(0, 0.5, size=(dim1, dim3), device='cuda').half() A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half()
elif dims == 3: elif dims == 3:
A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device='cuda').half() A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device="cuda").half()
B = torch.randn((dim4, dim3), device='cuda').half() B = torch.randn((dim4, dim3), device="cuda").half()
torch.nn.init.xavier_uniform_(B) torch.nn.init.xavier_uniform_(B)
C1 = torch.matmul(A, B.t()) C1 = torch.matmul(A, B.t())
C2 = bnb.matmul(A, B.t()) C2 = bnb.matmul(A, B.t())
...@@ -627,50 +709,56 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): ...@@ -627,50 +709,56 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B) CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)
C32A, SA = F.transform(CA, 'col32') C32A, SA = F.transform(CA, "col32")
CxB, SB = F.transform(CB, to_order=formatB) CxB, SB = F.transform(CB, to_order=formatB)
out1_32, Sout1_32 = F.igemmlt(C32A, CxB, SA, SB) out1_32, Sout1_32 = F.igemmlt(C32A, CxB, SA, SB)
output = F.mm_dequant(out1_32, Sout1_32, statsAt, statsBt) output = F.mm_dequant(out1_32, Sout1_32, statsAt, statsBt)
#print('') # print('')
#print(output.flatten()[:10]) # print(output.flatten()[:10])
#print(C1.flatten()[:10]) # print(C1.flatten()[:10])
#print(C2.flatten()[:10]) # print(C2.flatten()[:10])
# torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
#torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
# transpose # transpose
#B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8) # B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
#C1 = torch.matmul(A.float(), B.float()) # C1 = torch.matmul(A.float(), B.float())
# B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
# C2, SC = F.igemmlt(A2, B2t, SA, SBt)
# C3, S = F.transform(C2, 'row', state=SC)
# torch.testing.assert_allclose(C1, C3.float())
#B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
#C2, SC = F.igemmlt(A2, B2t, SA, SBt)
#C3, S = F.transform(C2, 'row', state=SC)
#torch.testing.assert_allclose(C1, C3.float())
batch_size = 2 batch_size = 2
seqdim = 512 seqdim = 512
#values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)] # values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)]
values = [(batch_size, seqdim, 4*1024, 3*4*1024),(batch_size, seqdim, 5120, 3*5120),(batch_size, seqdim, 12*1024, 4*12*1024)] values = [
(batch_size, seqdim, 4 * 1024, 3 * 4 * 1024),
(batch_size, seqdim, 5120, 3 * 5120),
(batch_size, seqdim, 12 * 1024, 4 * 12 * 1024),
]
# values = list(product(batch, seq, model, hidden))
names = ["batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values]
#values = list(product(batch, seq, model, hidden))
names = ['batch_{0}_seq_{1}_model_{2}_hidden_{3}'.format(*vals) for vals in values]
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) @pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_8bit_training(batch, seq, model, hidden): def test_bench_8bit_training(batch, seq, model, hidden):
formatB = F.get_special_format_str() formatB = F.get_special_format_str()
A = torch.randn(batch, seq, model, device='cuda').half() A = torch.randn(batch, seq, model, device="cuda").half()
grad = torch.randn(batch, seq, model, device='cuda').half() grad = torch.randn(batch, seq, model, device="cuda").half()
w1 = torch.randint(-128, 127, size=(hidden, model), device='cuda').half() w1 = torch.randint(-128, 127, size=(hidden, model), device="cuda").half()
w2 = torch.randint(-128, 127, size=(model, hidden), device='cuda').half() w2 = torch.randint(-128, 127, size=(model, hidden), device="cuda").half()
print('') print("")
#torch.cuda.synchronize() # torch.cuda.synchronize()
## warmup ## warmup
#for i in range(100): # for i in range(100):
# torch.matmul(A, w1.t()) # torch.matmul(A, w1.t())
#torch.cuda.synchronize() # torch.cuda.synchronize()
dtype = torch.int8 dtype = torch.int8
A = A.view(-1, A.shape[-1]).contiguous() A = A.view(-1, A.shape[-1]).contiguous()
...@@ -679,77 +767,77 @@ def test_bench_8bit_training(batch, seq, model, hidden): ...@@ -679,77 +767,77 @@ def test_bench_8bit_training(batch, seq, model, hidden):
t0 = time.time() t0 = time.time()
for i in range(k): for i in range(k):
out1 = torch.matmul(A, w1.t()) # fc1 out1 = torch.matmul(A, w1.t()) # fc1
#out2 = torch.matmul(out1, w2.t())# fc2 # out2 = torch.matmul(out1, w2.t())# fc2
#d1 = torch.matmul(grad, w2) # delta1 # d1 = torch.matmul(grad, w2) # delta1
#d2 = torch.matmul(d1, w1) # delta2 # d2 = torch.matmul(d1, w1) # delta2
#grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2 # grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2
#grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1 # grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1
torch.cuda.synchronize() torch.cuda.synchronize()
t16 = time.time() - t0 t16 = time.time() - t0
print(t16) print(t16)
#torch.cuda.empty_cache() # torch.cuda.empty_cache()
#Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
#Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2) # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
#CTw1, Sw1 = F.transform2(Cw1, formatB) # CTw1, Sw1 = F.transform2(Cw1, formatB)
#CTw2, Sw2 = F.transform2(Cw2, formatB) # CTw2, Sw2 = F.transform2(Cw2, formatB)
#CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True) # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
#CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True) # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
#CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) # CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
#C32A, SA = F.transform2(CA, 'col32') # C32A, SA = F.transform2(CA, 'col32')
## fc1 ## fc1
#out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype) # out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t) ##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t)
## fc2 ## fc2
#Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1) # Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
#C32out1, Sout1 = F.transform2(Cout1, 'col32') # C32out1, Sout1 = F.transform2(Cout1, 'col32')
#out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype) # out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t) ##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t)
## delta1 ## delta1
#Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad) # Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
#C32grad, Sgrad = F.transform2(Cgrad, 'col32') # C32grad, Sgrad = F.transform2(Cgrad, 'col32')
##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype) ##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2) ##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2)
## delta2 ## delta2
#Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1) # Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
#C32d1, Sd1 = F.transform2(Cd1, 'col32') # C32d1, Sd1 = F.transform2(Cd1, 'col32')
##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype) ##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1) ##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1)
## grad1 ## grad1
#C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True) # C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
#CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True) # CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype) ##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad) ##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad)
## grad2 ## grad2
#C32At, SAt = F.transform2(CAt, 'col32', transpose=True) # C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
#CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True) # CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype) ##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1) ##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1)
#Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2) # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
#Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
#Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2) # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
#CTw1, Sw1 = F.transform2(Cw1, formatB) # CTw1, Sw1 = F.transform2(Cw1, formatB)
#CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True) # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
#CTw2, Sw2 = F.transform2(Cw2, formatB) # CTw2, Sw2 = F.transform2(Cw2, formatB)
#CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True) # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
#torch.cuda.synchronize() # torch.cuda.synchronize()
#t0 = time.time() # t0 = time.time()
#for i in range(k): # for i in range(k):
# #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) # #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
# #CTw1, Sw1 = F.transform2(Cw1, formatB) # #CTw1, Sw1 = F.transform2(Cw1, formatB)
# #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) # #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
...@@ -802,74 +890,76 @@ def test_bench_8bit_training(batch, seq, model, hidden): ...@@ -802,74 +890,76 @@ def test_bench_8bit_training(batch, seq, model, hidden):
# #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype) # #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
# #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t) # #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t)
#torch.cuda.synchronize() # torch.cuda.synchronize()
#t8 = time.time() - t0 # t8 = time.time() - t0
#print(t8) # print(t8)
n = 2 n = 2
dim1 = torch.randint(64,256, size=(n,)).tolist() dim1 = torch.randint(64, 256, size=(n,)).tolist()
dim4 = torch.randint(64,1024, size=(n,)).tolist() dim4 = torch.randint(64, 1024, size=(n,)).tolist()
#dim1 = [2*1024] # dim1 = [2*1024]
#dim4 = [2*1024] # dim4 = [2*1024]
#dim1 = [4] # dim1 = [4]
#dim4 = [4] # dim4 = [4]
dims = (2,) dims = (2,)
#ldb = list(range(256, 1*1024, 256)) # ldb = list(range(256, 1*1024, 256))
formatB = ['col_turing', 'col_ampere'] formatB = ["col_turing", "col_ampere"]
values = list(product(dim1,dim4,dims, formatB)) values = list(product(dim1, dim4, dims, formatB))
names = ['dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}'.format(*vals) for vals in values] names = ["dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, dims, formatB", values, ids=names) @pytest.mark.parametrize("dim1, dim4, dims, formatB", values, ids=names)
def test_dequant_mm(dim1, dim4, dims, formatB): def test_dequant_mm(dim1, dim4, dims, formatB):
inner = torch.randint(1, 128, size=(1,)).item() inner = torch.randint(1, 128, size=(1,)).item()
formatB = F.get_special_format_str() formatB = F.get_special_format_str()
for i in range(k): for i in range(k):
A = torch.randn(dim1, inner, device='cuda') A = torch.randn(dim1, inner, device="cuda")
B = torch.randn(dim4, inner, device='cuda') B = torch.randn(dim4, inner, device="cuda")
C1 = torch.matmul(A.half(), B.t().half()) C1 = torch.matmul(A.half(), B.t().half())
A1, maxA = F.vectorwise_quant(A, dim=1) A1, maxA = F.vectorwise_quant(A, dim=1)
B1, maxB = F.vectorwise_quant(B, dim=1) B1, maxB = F.vectorwise_quant(B, dim=1)
A2, SA = F.nvidia_transform(A1, 'col32') A2, SA = F.nvidia_transform(A1, "col32")
B2, SB = F.nvidia_transform(B1, formatB) B2, SB = F.nvidia_transform(B1, formatB)
C2, SC = F.igemmlt(A2, B2, SA, SB) C2, SC = F.igemmlt(A2, B2, SA, SB)
C3, S = F.nvidia_transform(C2, 'row', state=SC) C3, S = F.nvidia_transform(C2, "row", state=SC)
C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t()) C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())
count = (torch.isclose(C1, C4, atol=0.01, rtol=0.1) == 0).sum().item() count = (torch.isclose(C1, C4, atol=0.01, rtol=0.1) == 0).sum().item()
n = C1.numel() n = C1.numel()
p = 0.06 p = 0.06
assert count/n < p, f'error in more than {p} of elements: {count}/{n}={count/n}' assert (
count / n < p
), f"error in more than {p} of elements: {count}/{n}={count/n}"
C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten()) C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten())
torch.testing.assert_allclose(C5, C4) torch.testing.assert_allclose(C5, C4)
#print(C2) # print(C2)
n = 2 n = 2
dim1 = [1*1024] dim1 = [1 * 1024]
dim2 = [1*1024] dim2 = [1 * 1024]
#dim1 = torch.randint(1,4*1024, size=(n,)).tolist() # dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
#dim2 = torch.randint(1,4*1024, size=(n,)).tolist() # dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dims = (2,) dims = (2,)
#ldb = list(range(256, 1*1024, 256)) # ldb = list(range(256, 1*1024, 256))
values = list(product(dim1,dim2,dims)) values = list(product(dim1, dim2, dims))
names = ['dim1_{0}_dim2_{1}_dims_{2}'.format(*vals) for vals in values] names = ["dim1_{0}_dim2_{1}_dims_{2}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dims", values, ids=names) @pytest.mark.parametrize("dim1, dim2, dims", values, ids=names)
def test_colrow_absmax(dim1, dim2, dims): def test_colrow_absmax(dim1, dim2, dims):
for i in range(k): for i in range(k):
threshold = 3.0 threshold = 3.0
A = torch.randn(dim1, dim2, device='cuda').half() A = torch.randn(dim1, dim2, device="cuda").half()
A_truncated = A.clone() A_truncated = A.clone()
A_truncated[torch.abs(A_truncated) >= 3.0] = 0.0 A_truncated[torch.abs(A_truncated) >= 3.0] = 0.0
if dims == 2: if dims == 2:
...@@ -880,11 +970,22 @@ def test_colrow_absmax(dim1, dim2, dims): ...@@ -880,11 +970,22 @@ def test_colrow_absmax(dim1, dim2, dims):
else: else:
assert False assert False
row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold) row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
A, threshold=threshold
A_blocked = einops.rearrange(torch.abs(A), '(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size', row_tiles=16, block_size=64*4) )
nnz_rows1_counts = (torch.abs(A_blocked)>=threshold).sum(3).flatten()
nnz_block_ptr1 = torch.zeros(nnz_rows1_counts.shape[0]+1, dtype=nnz_rows1_counts.dtype, device=nnz_rows1_counts.device) A_blocked = einops.rearrange(
torch.abs(A),
"(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size",
row_tiles=16,
block_size=64 * 4,
)
nnz_rows1_counts = (torch.abs(A_blocked) >= threshold).sum(3).flatten()
nnz_block_ptr1 = torch.zeros(
nnz_rows1_counts.shape[0] + 1,
dtype=nnz_rows1_counts.dtype,
device=nnz_rows1_counts.device,
)
nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0) nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0)
torch.testing.assert_allclose(col_stats1_trunc, col_stats2) torch.testing.assert_allclose(col_stats1_trunc, col_stats2)
...@@ -898,19 +999,20 @@ def test_colrow_absmax(dim1, dim2, dims): ...@@ -898,19 +999,20 @@ def test_colrow_absmax(dim1, dim2, dims):
assert nnz_block_ptr2 is None assert nnz_block_ptr2 is None
n = 2 n = 2
#dim1 = [8*1024] # dim1 = [8*1024]
#dim2 = [4*1024] # dim2 = [4*1024]
dim1 = torch.randint(1,4*1024, size=(n,)).tolist() dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1,4*1024, size=(n,)).tolist() dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]
values = list(product(dim1,dim2))
names = ['dim1_{0}_dim2_{1}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names) @pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_double_quant(dim1, dim2): def test_double_quant(dim1, dim2):
for i in range(k): for i in range(k):
A = torch.randn(dim1, dim2, device='cuda').half() A = torch.randn(dim1, dim2, device="cuda").half()
out_col1, Scol = F.vectorwise_quant(A, dim=0) out_col1, Scol = F.vectorwise_quant(A, dim=0)
out_row1, Srow = F.vectorwise_quant(A, dim=1) out_row1, Srow = F.vectorwise_quant(A, dim=1)
...@@ -920,18 +1022,21 @@ def test_double_quant(dim1, dim2): ...@@ -920,18 +1022,21 @@ def test_double_quant(dim1, dim2):
torch.testing.assert_allclose(CA, out_row1, atol=1, rtol=0) torch.testing.assert_allclose(CA, out_row1, atol=1, rtol=0)
torch.testing.assert_allclose(CAt, out_col1, atol=1, rtol=0) torch.testing.assert_allclose(CAt, out_col1, atol=1, rtol=0)
n = CAt.numel() n = CAt.numel()
num_not_close_rows = (torch.isclose(CA, out_row1, atol=1)==0).sum().item() num_not_close_rows = (torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1)==0).sum().item() num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
# allow for 1:500 error due to rounding differences # allow for 1:500 error due to rounding differences
min_error = 1/500 min_error = 1 / 500
if num_not_close_cols > (min_error*n): if num_not_close_cols > (min_error * n):
print(f'Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}') print(
f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}"
)
assert False assert False
if num_not_close_rows > (min_error*n): if num_not_close_rows > (min_error * n):
print(f'Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}') print(
f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}"
)
assert False assert False
torch.testing.assert_allclose(Srow.flatten(), statsA) torch.testing.assert_allclose(Srow.flatten(), statsA)
...@@ -939,21 +1044,23 @@ def test_double_quant(dim1, dim2): ...@@ -939,21 +1044,23 @@ def test_double_quant(dim1, dim2):
n = 4 n = 4
dim1 = torch.randint(1,4*1024, size=(n,)).tolist() dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim4 = torch.randint(1,4*1024, size=(n,)).tolist() dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1,4*1024, size=(n,)).tolist() inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim1 = [6] dim1 = [6]
dim4 = [4] dim4 = [4]
inner = [8] inner = [8]
values = list(zip(dim1, dim4, inner)) values = list(zip(dim1, dim4, inner))
names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values] names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) @pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
def test_integrated_igemmlt(dim1, dim4, inner): def test_integrated_igemmlt(dim1, dim4, inner):
for i in range(k): for i in range(k):
A = torch.randn(dim1, inner, device='cuda').half() A = torch.randn(dim1, inner, device="cuda").half()
B = torch.randn(dim4, inner, device='cuda').half() B = torch.randn(dim4, inner, device="cuda").half()
out1 = torch.matmul(A.half(), B.t().half()) out1 = torch.matmul(A.half(), B.t().half())
...@@ -967,30 +1074,32 @@ def test_integrated_igemmlt(dim1, dim4, inner): ...@@ -967,30 +1074,32 @@ def test_integrated_igemmlt(dim1, dim4, inner):
torch.testing.assert_allclose(C1a, A1, rtol=0, atol=1) torch.testing.assert_allclose(C1a, A1, rtol=0, atol=1)
torch.testing.assert_allclose(C2a, B1, rtol=0, atol=1) torch.testing.assert_allclose(C2a, B1, rtol=0, atol=1)
A2, SA = F.nvidia_transform(C1a, 'col32') A2, SA = F.nvidia_transform(C1a, "col32")
B2, SB = F.nvidia_transform(C2a, 'col_turing') B2, SB = F.nvidia_transform(C2a, "col_turing")
outC32, SC = F.igemmlt(A2, B2, SA, SB) outC32, SC = F.igemmlt(A2, B2, SA, SB)
out2 = F.mm_dequant(outC32, SC, stats1a, stats2a) out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)
A2, SA = F.nvidia_transform(A1, 'col32') A2, SA = F.nvidia_transform(A1, "col32")
B2, SB = F.nvidia_transform(B1, 'col_turing') B2, SB = F.nvidia_transform(B1, "col_turing")
C2, SC = F.igemmlt(A2, B2, SA, SB) C2, SC = F.igemmlt(A2, B2, SA, SB)
C3, S = F.nvidia_transform(C2, 'row', state=SC) C3, S = F.nvidia_transform(C2, "row", state=SC)
out3 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t()) out3 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())
err1 = torch.abs(out1-out2).mean().item() err1 = torch.abs(out1 - out2).mean().item()
err2 = torch.abs(out1-out3).mean().item() err2 = torch.abs(out1 - out3).mean().item()
assert err2 <= err1*1.01 assert err2 <= err1 * 1.01
n = 6 n = 6
dim1 = torch.randint(1,4*1024, size=(n,)).tolist() dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim4 = torch.randint(1,4*1024, size=(n,)).tolist() dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1,4*1024, size=(n,)).tolist() inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
values = list(zip(dim1, dim4, inner)) values = list(zip(dim1, dim4, inner))
names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values] names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) @pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
@pytest.mark.skip("Row scale has some bugs for ampere") @pytest.mark.skip("Row scale has some bugs for ampere")
def test_igemmlt_row_scale(dim1, dim4, inner): def test_igemmlt_row_scale(dim1, dim4, inner):
...@@ -999,79 +1108,79 @@ def test_igemmlt_row_scale(dim1, dim4, inner): ...@@ -999,79 +1108,79 @@ def test_igemmlt_row_scale(dim1, dim4, inner):
relerr1, relerr2 = [], [] relerr1, relerr2 = [], []
scale = 1 scale = 1
for i in range(k): for i in range(k):
A = torch.randn(dim1, inner, device='cuda').half() A = torch.randn(dim1, inner, device="cuda").half()
B = torch.randn(dim4, inner, device='cuda').half() B = torch.randn(dim4, inner, device="cuda").half()
torch.nn.init.xavier_uniform_(B) torch.nn.init.xavier_uniform_(B)
C1 = torch.matmul(A, B.t()) C1 = torch.matmul(A, B.t())
out1 = torch.matmul(A.half(), B.t().half()) out1 = torch.matmul(A.half(), B.t().half())
C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A) C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
CB, absmaxB = F.vectorwise_quant(B, quant_type='linear') CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
A2, SA = F.nvidia_transform(C1a, 'col32') A2, SA = F.nvidia_transform(C1a, "col32")
B2, SB = F.nvidia_transform(CB, formatB) B2, SB = F.nvidia_transform(CB, formatB)
A1, maxA = F.vectorwise_quant(A, dim=1) A1, maxA = F.vectorwise_quant(A, dim=1)
c = 10.0*inner*scale c = 10.0 * inner * scale
row_scale = torch.ones_like(maxA)/c row_scale = torch.ones_like(maxA) / c
outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale) outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale)
C3, S = F.nvidia_transform(outC32, 'row', state=SC) C3, S = F.nvidia_transform(outC32, "row", state=SC)
maxval = torch.abs(C3).max() maxval = torch.abs(C3).max()
if maxval == 127: if maxval == 127:
scale = 1.5 scale = 1.5
else: else:
scale = maxval/120 scale = maxval / 120
out3 = C3*maxA*absmaxB*c/(127*127) out3 = C3 * maxA * absmaxB * c / (127 * 127)
C4 = torch.matmul(C1a.float(), CB.float().t()) C4 = torch.matmul(C1a.float(), CB.float().t())
C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B) C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
B2, SB = F.nvidia_transform(C2a, formatB) B2, SB = F.nvidia_transform(C2a, formatB)
outC32, SC = F.igemmlt(A2, B2, SA, SB) outC32, SC = F.igemmlt(A2, B2, SA, SB)
out2 = F.mm_dequant(outC32, SC, stats1a, stats2a) out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)
CA, SA = F.vectorwise_quant(A, dim=1, quant_type='vector') CA, SA = F.vectorwise_quant(A, dim=1, quant_type="vector")
CB, SB = F.vectorwise_quant(B, dim=1, quant_type='linear') CB, SB = F.vectorwise_quant(B, dim=1, quant_type="linear")
C = torch.matmul(CA.float(), CB.t().float()) C = torch.matmul(CA.float(), CB.t().float())
out4 = C*SA*SB/(127*127) out4 = C * SA * SB / (127 * 127)
#out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127) # out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127)
#print('='*80) # print('='*80)
#print(out1) # print(out1)
#print(out2) # print(out2)
#print(out3) # print(out3)
#print(out1) # print(out1)
#print(out2) # print(out2)
#print(out3) # print(out3)
err1.append(torch.abs(out1-out2).mean().item()) err1.append(torch.abs(out1 - out2).mean().item())
err2.append(torch.abs(out1-out3).mean().item()) err2.append(torch.abs(out1 - out3).mean().item())
err3.append(torch.abs(out1-out4).mean().item()) err3.append(torch.abs(out1 - out4).mean().item())
#assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10) # assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10)
print('') print("")
print(sum(err1)/len(err1)) print(sum(err1) / len(err1))
print(sum(err2)/len(err2)) print(sum(err2) / len(err2))
print(sum(err3)/len(err3)) print(sum(err3) / len(err3))
dim1 = [1024, 2048] dim1 = [1024, 2048]
inner = [12288*4, 4096*4] inner = [12288 * 4, 4096 * 4]
dim4 = [12288, 4096] dim4 = [12288, 4096]
values = list(zip(dim1, dim4, inner)) values = list(zip(dim1, dim4, inner))
names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values] names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) @pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
@pytest.mark.skip("Row scale has some bugs for ampere") @pytest.mark.skip("Row scale has some bugs for ampere")
def test_row_scale_bench(dim1, dim4, inner): def test_row_scale_bench(dim1, dim4, inner):
err1, err2, err3 = [], [], [] err1, err2, err3 = [], [], []
relerr1, relerr2 = [], [] relerr1, relerr2 = [], []
scale = 1 scale = 1
A = torch.randn(dim1, inner, device='cuda').half() A = torch.randn(dim1, inner, device="cuda").half()
B = torch.randn(dim4, inner, device='cuda').half() B = torch.randn(dim4, inner, device="cuda").half()
torch.nn.init.xavier_uniform_(B) torch.nn.init.xavier_uniform_(B)
# warmpup # warmpup
for i in range(k): for i in range(k):
...@@ -1082,23 +1191,22 @@ def test_row_scale_bench(dim1, dim4, inner): ...@@ -1082,23 +1191,22 @@ def test_row_scale_bench(dim1, dim4, inner):
for i in range(k): for i in range(k):
C1 = torch.matmul(A, B.t()) C1 = torch.matmul(A, B.t())
torch.cuda.synchronize() torch.cuda.synchronize()
print('16', time.time()-t0) print("16", time.time() - t0)
C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A) C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
CB, absmaxB = F.vectorwise_quant(B, quant_type='linear') CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
A2, SA = F.nvidia_transform(C1a, 'col32') A2, SA = F.nvidia_transform(C1a, "col32")
B2, SB = F.nvidia_transform(CB, formatB) B2, SB = F.nvidia_transform(CB, formatB)
A1, maxA = F.vectorwise_quant(A, dim=1) A1, maxA = F.vectorwise_quant(A, dim=1)
c = 10.0*inner*scale c = 10.0 * inner * scale
row_scale = maxA/c row_scale = maxA / c
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(k): for i in range(k):
outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale) outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale)
torch.cuda.synchronize() torch.cuda.synchronize()
print('row-wise', time.time()-t0) print("row-wise", time.time() - t0)
C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B) C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
B2, SB = F.nvidia_transform(C2a, formatB) B2, SB = F.nvidia_transform(C2a, formatB)
...@@ -1107,32 +1215,39 @@ def test_row_scale_bench(dim1, dim4, inner): ...@@ -1107,32 +1215,39 @@ def test_row_scale_bench(dim1, dim4, inner):
for i in range(k): for i in range(k):
outC32, SC = F.igemmlt(A2, B2, SA, SB) outC32, SC = F.igemmlt(A2, B2, SA, SB)
torch.cuda.synchronize() torch.cuda.synchronize()
print('vector-wise', time.time()-t0) print("vector-wise", time.time() - t0)
n = 2 n = 2
dim1 = torch.randint(2,1024, size=(n,)).tolist() dim1 = torch.randint(2, 1024, size=(n,)).tolist()
dim2 = torch.randint(2,1024, size=(n,)).tolist() dim2 = torch.randint(2, 1024, size=(n,)).tolist()
#dim1 = [8*1024] # dim1 = [8*1024]
#dim2 = [4*1024] # dim2 = [4*1024]
dim3 = [0] dim3 = [0]
dtype = [torch.int8] dtype = [torch.int8]
a_order = ['row'] a_order = ["row"]
out_order = ['col32', 'col_turing', 'col_ampere'] out_order = ["col32", "col_turing", "col_ampere"]
transpose = [False, True] transpose = [False, True]
dims = [2] dims = [2]
values = list(product(dim1,dim2,dim3, dims,dtype, a_order, out_order, transpose)) values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose))
names = ['dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}'.format(*vals) for vals in values] names = [
@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names) "dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}".format(
*vals
)
for vals in values
]
@pytest.mark.parametrize(
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names
)
def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
for i in range(k): for i in range(k):
if dims == 2: if dims == 2:
A = torch.randint(10, 99, size=(dim1, dim2), device='cuda').to(dtype) A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(dtype)
elif dims == 3: elif dims == 3:
A = torch.randint(10, 99, size=(dim1, dim2, dim3), device='cuda').to(dtype) A = torch.randint(10, 99, size=(dim1, dim2, dim3), device="cuda").to(dtype)
A.view(-1)[-1] = -1 A.view(-1)[-1] = -1
if transpose: if transpose:
...@@ -1144,53 +1259,55 @@ def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): ...@@ -1144,53 +1259,55 @@ def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
assert S1[0][0] == S2[0][0] assert S1[0][0] == S2[0][0]
assert S1[0][1] == S2[0][1] assert S1[0][1] == S2[0][1]
#print(out1) # print(out1)
#print(out2) # print(out2)
torch.testing.assert_allclose(out1, out2) torch.testing.assert_allclose(out1, out2)
n = 2 n = 2
#dim1 = torch.randint(2,1024, size=(n,)).tolist() # dim1 = torch.randint(2,1024, size=(n,)).tolist()
#dim2 = torch.randint(2,1024, size=(n,)).tolist() # dim2 = torch.randint(2,1024, size=(n,)).tolist()
dim1 = [1] dim1 = [1]
dim2 = [33] dim2 = [33]
dtype = [torch.int8] dtype = [torch.int8]
#a_order = ['col_turing', 'col_ampere'] # a_order = ['col_turing', 'col_ampere']
a_order = ['col_turing'] a_order = ["col_turing"]
out_order = ['row'] out_order = ["row"]
values = list(product(dim1,dim2,dtype, a_order, out_order)) values = list(product(dim1, dim2, dtype, a_order, out_order))
names = ['dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}'.format(*vals) for vals in values] names = [
"dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}".format(*vals)
for vals in values
]
@pytest.mark.parametrize("dim1, dim2, dtype, orderA, orderOut", values, ids=names) @pytest.mark.parametrize("dim1, dim2, dtype, orderA, orderOut", values, ids=names)
def test_transform_to_row(dim1, dim2, dtype, orderA, orderOut): def test_transform_to_row(dim1, dim2, dtype, orderA, orderOut):
for i in range(1): for i in range(1):
A = torch.randint(-127, 127, size=(dim1, dim2), device='cuda').to(dtype) A = torch.randint(-127, 127, size=(dim1, dim2), device="cuda").to(dtype)
out2, S2 = F.transform(A, to_order=orderA) out2, S2 = F.transform(A, to_order=orderA)
A2, S3 = F.transform(out2, from_order=orderA, to_order='row', state=S2) A2, S3 = F.transform(out2, from_order=orderA, to_order="row", state=S2)
assert A2.shape[0] == A.shape[0] assert A2.shape[0] == A.shape[0]
assert A2.shape[1] == A.shape[1] assert A2.shape[1] == A.shape[1]
print("")
print('')
print(A) print(A)
print(out2) print(out2)
print(A2) print(A2)
# torch.testing.assert_allclose(A, A2)
#torch.testing.assert_allclose(A, A2)
def test_overflow(): def test_overflow():
formatB = F.get_special_format_str() formatB = F.get_special_format_str()
print(formatB) print(formatB)
for i in range(2): for i in range(2):
a = torch.arange(5, 15).cuda().to(torch.int8).view(-1,1 ) a = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
b = torch.arange(5, 15).cuda().to(torch.int8).view(-1,1 ) b = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
Ca, Sa = F.nvidia_transform(a, 'col32') Ca, Sa = F.nvidia_transform(a, "col32")
Cb, Sb = F.nvidia_transform(b, formatB) Cb, Sb = F.nvidia_transform(b, formatB)
c = F.igemmlt(Ca, Cb, Sa, Sb, dtype=torch.int8) c = F.igemmlt(Ca, Cb, Sa, Sb, dtype=torch.int8)
...@@ -1198,46 +1315,51 @@ def test_overflow(): ...@@ -1198,46 +1315,51 @@ def test_overflow():
n = 2 n = 2
dim1 = torch.randint(1,4*1024, size=(n,)).tolist() dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1,4*1024, size=(n,)).tolist() dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
#dim1 = [4] # dim1 = [4]
#dim2 = [5] # dim2 = [5]
values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]
values = list(product(dim1,dim2))
names = ['dim1_{0}_dim2_{1}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names) @pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_coo_double_quant(dim1, dim2): def test_coo_double_quant(dim1, dim2):
threshold = 3.00 threshold = 3.00
for i in range(k): for i in range(k):
A = torch.randn(dim1, dim2, device='cuda').half() A = torch.randn(dim1, dim2, device="cuda").half()
idx = (torch.abs(A) >= threshold) idx = torch.abs(A) >= threshold
CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold) CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold)
if coo_tensor is not None: if coo_tensor is not None:
A1 = A*idx A1 = A * idx
A2 = torch.zeros_like(A) A2 = torch.zeros_like(A)
A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values
torch.testing.assert_allclose(A1, A2) torch.testing.assert_allclose(A1, A2)
A1 = A*(idx==0) A1 = A * (idx == 0)
A2 = (CA.float()*statsA.unsqueeze(1)/127).half() A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
torch.testing.assert_allclose(A*(idx==0), A2, rtol=0.05, atol=1.5e-2) torch.testing.assert_allclose(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2)
n = 2 n = 2
dim1 = torch.randint(1,1*1024, size=(n,)).tolist() dim1 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1,1*1024, size=(n,)).tolist() dim2 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
#dim1 = [7] # dim1 = [7]
#dim2 = [11] # dim2 = [11]
transposed_B = [False, True] transposed_B = [False, True]
values = list(product(dim1,dim2, transposed_B)) values = list(product(dim1, dim2, transposed_B))
names = ['dim1_{0}_dim2_{1}_transposed_B_{2}'.format(*vals) for vals in values] names = ["dim1_{0}_dim2_{1}_transposed_B_{2}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names) @pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names)
def test_spmm_coo(dim1, dim2, transposed_B): def test_spmm_coo(dim1, dim2, transposed_B):
threshold = 1.5 threshold = 1.5
dim3 = torch.randint(32, 128, size=(1,)).item() dim3 = torch.randint(32, 128, size=(1,)).item()
#dim3 = 17 # dim3 = 17
for i in range(k): for i in range(k):
A = torch.randn(dim1, dim2).cuda().half() A = torch.randn(dim1, dim2).cuda().half()
if transposed_B: if transposed_B:
...@@ -1249,8 +1371,10 @@ def test_spmm_coo(dim1, dim2, transposed_B): ...@@ -1249,8 +1371,10 @@ def test_spmm_coo(dim1, dim2, transposed_B):
nnz = (idx == 1).sum().item() nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx) rows, cols = torch.where(idx)
values = A[idx] values = A[idx]
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) cooA = F.COOSparseTensor(
A2 = A*idx A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
)
A2 = A * idx
if transposed_B: if transposed_B:
out2 = F.spmm_coo(cooA, B.t()) out2 = F.spmm_coo(cooA, B.t())
...@@ -1262,18 +1386,17 @@ def test_spmm_coo(dim1, dim2, transposed_B): ...@@ -1262,18 +1386,17 @@ def test_spmm_coo(dim1, dim2, transposed_B):
assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30) assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30)
def test_spmm_bench(): def test_spmm_bench():
batch = 2 batch = 2
model = 1024*1 model = 1024 * 1
hidden = model*4 hidden = model * 4
seq = 1024 seq = 1024
dim1 = batch*seq dim1 = batch * seq
dim2 = model dim2 = model
dim3 = hidden dim3 = hidden
threshold = 4 threshold = 4
A = torch.randn(dim1, dim2, device='cuda').half() A = torch.randn(dim1, dim2, device="cuda").half()
B = torch.randn(dim2, dim3, device='cuda').half() B = torch.randn(dim2, dim3, device="cuda").half()
for i in range(10): for i in range(10):
C1 = bnb.matmul(A, B) C1 = bnb.matmul(A, B)
...@@ -1282,14 +1405,16 @@ def test_spmm_bench(): ...@@ -1282,14 +1405,16 @@ def test_spmm_bench():
for i in range(k): for i in range(k):
C1 = bnb.matmul(A, B) C1 = bnb.matmul(A, B)
torch.cuda.synchronize() torch.cuda.synchronize()
t8 = time.time()-t0 t8 = time.time() - t0
idx = torch.abs(A) >= threshold idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item() nnz = (idx == 1).sum().item()
print(nnz/idx.numel()) print(nnz / idx.numel())
rows, cols = torch.where(idx) rows, cols = torch.where(idx)
values = A[idx] values = A[idx]
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) cooA = F.COOSparseTensor(
A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
)
for i in range(10): for i in range(10):
out2 = F.spmm_coo(cooA, B) out2 = F.spmm_coo(cooA, B)
...@@ -1299,20 +1424,22 @@ def test_spmm_bench(): ...@@ -1299,20 +1424,22 @@ def test_spmm_bench():
for i in range(k): for i in range(k):
out2 = F.spmm_coo(cooA, B) out2 = F.spmm_coo(cooA, B)
torch.cuda.synchronize() torch.cuda.synchronize()
tsp = time.time()-t0 tsp = time.time() - t0
print(tsp, t8) print(tsp, t8)
print(tsp/t8) print(tsp / t8)
n = 2 n = 2
dim1 = torch.randint(256,1*1024, size=(n,)).tolist() dim1 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
dim2 = torch.randint(256,1*1024, size=(n,)).tolist() dim2 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
values = list(product(dim1,dim2)) values = list(product(dim1, dim2))
names = ['dim1_{0}_dim2_{1}'.format(*vals) for vals in values] names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names) @pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_integrated_sparse_decomp(dim1, dim2): def test_integrated_sparse_decomp(dim1, dim2):
threshold = 3.0 threshold = 3.0
formatB = 'col_turing' formatB = "col_turing"
for i in range(k): for i in range(k):
A = torch.randn(dim1, dim2).cuda().half() A = torch.randn(dim1, dim2).cuda().half()
w1 = torch.randn(dim1, dim2).cuda().half() w1 = torch.randn(dim1, dim2).cuda().half()
...@@ -1322,13 +1449,13 @@ def test_integrated_sparse_decomp(dim1, dim2): ...@@ -1322,13 +1449,13 @@ def test_integrated_sparse_decomp(dim1, dim2):
CTw1, Sw1 = F.transform(Cw1, formatB) CTw1, Sw1 = F.transform(Cw1, formatB)
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
C32A, SA = F.transform(CA, 'col32') C32A, SA = F.transform(CA, "col32")
out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1) out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold) CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold)
C32A, SA = F.transform(CA, 'col32') C32A, SA = F.transform(CA, "col32")
out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1) out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
out3 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) out3 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
...@@ -1338,8 +1465,8 @@ def test_integrated_sparse_decomp(dim1, dim2): ...@@ -1338,8 +1465,8 @@ def test_integrated_sparse_decomp(dim1, dim2):
out4 = F.spmm_coo(coo_tensor, w1.t()) out4 = F.spmm_coo(coo_tensor, w1.t())
out5 = out3 + out4 out5 = out3 + out4
err1 = torch.abs(out1-out2).mean().item() err1 = torch.abs(out1 - out2).mean().item()
err2 = torch.abs(out1-out5).mean().item() err2 = torch.abs(out1 - out5).mean().item()
assert err2 < err1 assert err2 < err1
...@@ -1350,91 +1477,95 @@ def test_matmuls(): ...@@ -1350,91 +1477,95 @@ def test_matmuls():
c2 = bnb.matmul(a, b) c2 = bnb.matmul(a, b)
c3 = bnb.matmul(a, b) c3 = bnb.matmul(a, b)
err1 = torch.abs(c1-c2).mean().item() err1 = torch.abs(c1 - c2).mean().item()
err2 = torch.abs(c1-c3).mean().item() err2 = torch.abs(c1 - c3).mean().item()
assert err1 < 0.2 assert err1 < 0.2
assert err2 < 0.2 assert err2 < 0.2
n = 2 n = 2
#dim1 = torch.randint(1,1*1024, size=(n,)).tolist() # dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
#dim2 = torch.randint(1,4*1024, size=(n,)).tolist() # dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1 = [1*2048] dim1 = [1 * 2048]
dim2 = [12288] dim2 = [12288]
#dim1 = [32] # dim1 = [32]
#dim2 = [32] # dim2 = [32]
#dtype = [torch.float16, torch.int8] # dtype = [torch.float16, torch.int8]
dtype = [torch.float16] dtype = [torch.float16]
out_function = ['zeros', 'ones'] out_function = ["zeros", "ones"]
values = list(product(dim1,dim2, dtype, out_function)) values = list(product(dim1, dim2, dtype, out_function))
names = ['dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}'.format(*vals) for vals in values] names = ["dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names) @pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names)
def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func): def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
out_func = getattr(torch, out_func) out_func = getattr(torch, out_func)
threshold = 3.3 threshold = 3.3
#threshold = 2.8 # threshold = 2.8
#threshold = 0.0 # threshold = 0.0
A = torch.randn(dim1, dim2, device='cuda').half() A = torch.randn(dim1, dim2, device="cuda").half()
if dtype == torch.float16: if dtype == torch.float16:
B = torch.randn(dim2, dim2*4, device='cuda').half() B = torch.randn(dim2, dim2 * 4, device="cuda").half()
torch.nn.init.xavier_uniform_(B) torch.nn.init.xavier_uniform_(B)
else: else:
B = torch.randn(dim2, dim2*4, device='cuda').half() B = torch.randn(dim2, dim2 * 4, device="cuda").half()
torch.nn.init.xavier_uniform_(B) torch.nn.init.xavier_uniform_(B)
B, SB = F.vectorwise_quant(B, quant_type='linear') B, SB = F.vectorwise_quant(B, quant_type="linear")
#B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8) # B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)
print('') print("")
idx = torch.abs(A) >= threshold idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item() nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx) rows, cols = torch.where(idx)
values = A[idx] values = A[idx]
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) cooA = F.COOSparseTensor(
A2 = A*idx A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
)
A2 = A * idx
out1 = torch.matmul(A2.half(), B.half()) out1 = torch.matmul(A2.half(), B.half())
out = out_func(out1.shape, dtype=torch.float16, device=out1.device) out = out_func(out1.shape, dtype=torch.float16, device=out1.device)
out1 += out.clone() out1 += out.clone()
out2 = F.spmm_coo_very_sparse(cooA, B, out=out) out2 = F.spmm_coo_very_sparse(cooA, B, out=out)
#print(B) # print(B)
#print(out1) # print(out1)
#print(out2) # print(out2)
p = 200/(2048*12288*4) p = 200 / (2048 * 12288 * 4)
n = out1.numel() n = out1.numel()
count = math.ceil(p*n) count = math.ceil(p * n)
std = out1.std() std = out1.std()
out1 /= std out1 /= std
out2 /= std out2 /= std
assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count) assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count)
#assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count) # assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
idx_col = torch.randint(0, A2.shape[-1], size=(15,)) idx_col = torch.randint(0, A2.shape[-1], size=(15,))
#torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001) # torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001)
#Bt = torch.randn(dim2*4, dim2, device='cuda').half() # Bt = torch.randn(dim2*4, dim2, device='cuda').half()
#torch.cuda.synchronize() # torch.cuda.synchronize()
#t0 = time.time() # t0 = time.time()
#print(A2.shape, B.shape) # print(A2.shape, B.shape)
#for i in range(100): # for i in range(100):
# #out3 = F.spmm_coo(cooA, Bt.t()) # #out3 = F.spmm_coo(cooA, Bt.t())
# #out2 = F.spmm_coo(cooA, B) # #out2 = F.spmm_coo(cooA, B)
# #out2 = F.spmm_coo_very_sparse(cooA, B) # #out2 = F.spmm_coo_very_sparse(cooA, B)
# #out1 = torch.matmul(A, Bt.t()) # #out1 = torch.matmul(A, Bt.t())
#torch.cuda.synchronize() # torch.cuda.synchronize()
#print(time.time() - t0) # print(time.time() - t0)
def test_layout(): def test_layout():
a1 = torch.rand(16, 64, device='cuda', dtype=torch.float16) a1 = torch.rand(16, 64, device="cuda", dtype=torch.float16)
a1 = torch.arange(16* 64, device='cuda').reshape(16, 64).byte() a1 = torch.arange(16 * 64, device="cuda").reshape(16, 64).byte()
a2, s2 = F.transform(a1, 'col_turing') a2, s2 = F.transform(a1, "col_turing")
print(a2.shape) print(a2.shape)
print(a1.flatten()[8*64:8*64+32]) print(a1.flatten()[8 * 64 : 8 * 64 + 32])
for i in range(4): for i in range(4):
print(a2.flatten()[i*8*32:i*8*32+32], 0) print(a2.flatten()[i * 8 * 32 : i * 8 * 32 + 32], 0)
def test_coo2csr(): def test_coo2csr():
...@@ -1444,14 +1575,16 @@ def test_coo2csr(): ...@@ -1444,14 +1575,16 @@ def test_coo2csr():
nnz = (idx == 1).sum().item() nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx) rows, cols = torch.where(idx)
values = A[idx] values = A[idx]
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) cooA = F.COOSparseTensor(
A2 = A*idx A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
)
A2 = A * idx
csrA = F.coo2csr(cooA) csrA = F.coo2csr(cooA)
counts = csrA.rowptr[1:] - csrA.rowptr[:-1] counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
assert counts.numel() == A.shape[0] assert counts.numel() == A.shape[0]
torch.testing.assert_allclose(counts, (A2!=0).sum(1)) torch.testing.assert_allclose(counts, (A2 != 0).sum(1))
idx = (A2!=0) idx = A2 != 0
torch.testing.assert_allclose(A2[idx], csrA.values) torch.testing.assert_allclose(A2[idx], csrA.values)
...@@ -1462,41 +1595,43 @@ def test_coo2csc(): ...@@ -1462,41 +1595,43 @@ def test_coo2csc():
nnz = (idx == 1).sum().item() nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx) rows, cols = torch.where(idx)
values = A[idx] values = A[idx]
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) cooA = F.COOSparseTensor(
A2 = A*idx A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
)
A2 = A * idx
cscA = F.coo2csc(cooA) cscA = F.coo2csc(cooA)
counts = cscA.colptr[1:] - cscA.colptr[:-1] counts = cscA.colptr[1:] - cscA.colptr[:-1]
assert counts.numel() == A.shape[1] assert counts.numel() == A.shape[1]
torch.testing.assert_allclose(counts, (A2!=0).sum(0)) torch.testing.assert_allclose(counts, (A2 != 0).sum(0))
# torch uses row-major -> use transpose to transfer to col-major # torch uses row-major -> use transpose to transfer to col-major
idx = (A2.t()!=0) idx = A2.t() != 0
torch.testing.assert_allclose(A2.t()[idx], cscA.values) torch.testing.assert_allclose(A2.t()[idx], cscA.values)
n = 2 n = 2
#dim1 = torch.randint(1,1*1024, size=(n,)).tolist() # dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
#dim2 = torch.randint(1,4*1024, size=(n,)).tolist() # dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1 = [1*2048] dim1 = [1 * 2048]
#dim2 = [12288] # dim2 = [12288]
dim2 = [2048] dim2 = [2048]
#dim1 = [2] # dim1 = [2]
#dim2 = [2] # dim2 = [2]
dtype = [torch.int8] dtype = [torch.int8]
values = list(product(dim1,dim2, dtype)) values = list(product(dim1, dim2, dtype))
names = ['dim1_{0}_dim2_{1}_dtype_{2}'.format(*vals) for vals in values] names = ["dim1_{0}_dim2_{1}_dtype_{2}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names) @pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names)
def test_spmm_coo_dequant(dim1, dim2, dtype): def test_spmm_coo_dequant(dim1, dim2, dtype):
threshold = 6.0 threshold = 6.0
#threshold = 2.8 # threshold = 2.8
#threshold = 0.0 # threshold = 0.0
A = torch.randn(dim1, dim2, device='cuda').half() A = torch.randn(dim1, dim2, device="cuda").half()
B = torch.empty(dim2, dim2*4, device='cuda', dtype=torch.float16) B = torch.empty(dim2, dim2 * 4, device="cuda", dtype=torch.float16)
torch.nn.init.xavier_uniform_(B) torch.nn.init.xavier_uniform_(B)
Bt = B.t().contiguous() Bt = B.t().contiguous()
CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B) CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)
rowidx = torch.randint(0, A.shape[-1], size=(15,)) rowidx = torch.randint(0, A.shape[-1], size=(15,))
...@@ -1507,12 +1642,14 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): ...@@ -1507,12 +1642,14 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
nnz = (idx == 1).sum().item() nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx) rows, cols = torch.where(idx)
values = A[idx] values = A[idx]
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) cooA = F.COOSparseTensor(
A2 = A*idx A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
)
A2 = A * idx
out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
out1 = torch.matmul(A2, B.half()) out1 = torch.matmul(A2, B.half())
out3 = F.spmm_coo_very_sparse(cooA, CBt.half()) out3 = F.spmm_coo_very_sparse(cooA, CBt.half())
out3 = out3*statsBt.half()/127 out3 = out3 * statsBt.half() / 127
values, counts = torch.unique(cooA.rowidx, return_counts=True) values, counts = torch.unique(cooA.rowidx, return_counts=True)
offset = counts.cumsum(0).int() offset = counts.cumsum(0).int()
...@@ -1521,56 +1658,54 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): ...@@ -1521,56 +1658,54 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
torch.testing.assert_allclose(out2, out3, rtol=0.05, atol=0.001) torch.testing.assert_allclose(out2, out3, rtol=0.05, atol=0.001)
p = 200/(2048*12288*4) p = 200 / (2048 * 12288 * 4)
n = out1.numel() n = out1.numel()
count = math.ceil(p*n) count = math.ceil(p * n)
assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count) assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count)
# torch.cuda.synchronize()
# t0 = time.time()
#torch.cuda.synchronize() # for i in range(100):
#t0 = time.time()
#for i in range(100):
# out2 = F.spmm_coo_very_sparse(cooA, B) # out2 = F.spmm_coo_very_sparse(cooA, B)
#torch.cuda.synchronize() # torch.cuda.synchronize()
#print('fp16', time.time() - t0) # print('fp16', time.time() - t0)
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(100): for i in range(100):
out2 = F.spmm_coo(cooA, B) out2 = F.spmm_coo(cooA, B)
torch.cuda.synchronize() torch.cuda.synchronize()
print('cusparse fp16', time.time() - t0) print("cusparse fp16", time.time() - t0)
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(100): for i in range(100):
out2 = F.spmm_coo_very_sparse(cooA, CBt) out2 = F.spmm_coo_very_sparse(cooA, CBt)
torch.cuda.synchronize() torch.cuda.synchronize()
print('int8', time.time() - t0) print("int8", time.time() - t0)
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(100): for i in range(100):
out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
torch.cuda.synchronize() torch.cuda.synchronize()
print('int8+dequant', time.time() - t0) print("int8+dequant", time.time() - t0)
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(100): for i in range(100):
out2 = torch.matmul(A, B) out2 = torch.matmul(A, B)
torch.cuda.synchronize() torch.cuda.synchronize()
print('matmul', time.time() - t0) print("matmul", time.time() - t0)
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(100): for i in range(100):
out1 = bnb.matmul(A, Bt) out1 = bnb.matmul(A, Bt)
out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
out = out1+out2 out = out1 + out2
torch.cuda.synchronize() torch.cuda.synchronize()
print('sparse+ matmul', time.time() - t0) print("sparse+ matmul", time.time() - t0)
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
...@@ -1578,33 +1713,36 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): ...@@ -1578,33 +1713,36 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
out1 = bnb.matmul(A, Bt) out1 = bnb.matmul(A, Bt)
torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1) torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1)
torch.cuda.synchronize() torch.cuda.synchronize()
print('partial matmul', time.time() - t0) print("partial matmul", time.time() - t0)
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(100): for i in range(100):
out1 = bnb.matmul(A, Bt) out1 = bnb.matmul(A, Bt)
torch.cuda.synchronize() torch.cuda.synchronize()
print('partial matmul', time.time() - t0) print("partial matmul", time.time() - t0)
batch_size = 1 batch_size = 1
seqdim = 2048 seqdim = 2048
values = [] values = []
values.append((batch_size, seqdim, 768, 4*768)) values.append((batch_size, seqdim, 768, 4 * 768))
#values.append((batch_size, seqdim, 1024, 4*1024)) # values.append((batch_size, seqdim, 1024, 4*1024))
#values.append((batch_size, seqdim, 1536, 4*1536)) # values.append((batch_size, seqdim, 1536, 4*1536))
#values.append((batch_size, seqdim, 2048, 4*2048)) # values.append((batch_size, seqdim, 2048, 4*2048))
#values.append((batch_size, seqdim, 2560, 4*2560)) # values.append((batch_size, seqdim, 2560, 4*2560))
#values.append((batch_size, seqdim, 4096, 4*4096)) # values.append((batch_size, seqdim, 4096, 4*4096))
#values.append((batch_size, seqdim, 5140, 4*5140)) # values.append((batch_size, seqdim, 5140, 4*5140))
#values.append((batch_size, seqdim, 12288, 4*12288)) # values.append((batch_size, seqdim, 12288, 4*12288))
names = ['batch_{0}_seq_{1}_model_{2}_hidden_{3}'.format(*vals) for vals in values] names = ["batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values]
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) @pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_matmul(batch, seq, model, hidden): def test_bench_matmul(batch, seq, model, hidden):
formatB = F.get_special_format_str() formatB = F.get_special_format_str()
A = torch.randn(batch, seq, model, device='cuda').half() A = torch.randn(batch, seq, model, device="cuda").half()
B = torch.empty(hidden, model, dtype=torch.float16, device='cuda') B = torch.empty(hidden, model, dtype=torch.float16, device="cuda")
torch.nn.init.xavier_uniform_(B) torch.nn.init.xavier_uniform_(B)
linear8bit = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() linear8bit = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
...@@ -1613,31 +1751,37 @@ def test_bench_matmul(batch, seq, model, hidden): ...@@ -1613,31 +1751,37 @@ def test_bench_matmul(batch, seq, model, hidden):
outliers = torch.randint(0, model, size=(5,)).cuda() outliers = torch.randint(0, model, size=(5,)).cuda()
A[:, :, outliers] = 8.0 A[:, :, outliers] = 8.0
linearMixedBit = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half() linearMixedBit = (
bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
)
linearMixedBit.eval() linearMixedBit.eval()
# warmup # warmup
for i in range(100): for i in range(100):
torch.matmul(A, B.t()) torch.matmul(A, B.t())
torch.cuda.synchronize() torch.cuda.synchronize()
print('') print("")
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(100): for i in range(100):
torch.matmul(A, B.t()) torch.matmul(A, B.t())
torch.cuda.synchronize() torch.cuda.synchronize()
print(f'pytorch: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') print(
f"pytorch: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(100): for i in range(100):
bnb.matmul(A, B) bnb.matmul(A, B)
torch.cuda.synchronize() torch.cuda.synchronize()
print(f'bnb lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') print(
f"bnb lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
C32A, SA = F.transform(CA, 'col32') C32A, SA = F.transform(CA, "col32")
CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B) CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
CxB, SB = F.transform(CB, to_order=formatB) CxB, SB = F.transform(CB, to_order=formatB)
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -1645,7 +1789,9 @@ def test_bench_matmul(batch, seq, model, hidden): ...@@ -1645,7 +1789,9 @@ def test_bench_matmul(batch, seq, model, hidden):
for i in range(100): for i in range(100):
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
torch.cuda.synchronize() torch.cuda.synchronize()
print(f'igemmlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') print(
f"igemmlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
BA, statsB = F.vectorwise_quant(B, dim=1) BA, statsB = F.vectorwise_quant(B, dim=1)
CxB, SB = F.nvidia_transform(CB, to_order=formatB) CxB, SB = F.nvidia_transform(CB, to_order=formatB)
...@@ -1654,26 +1800,30 @@ def test_bench_matmul(batch, seq, model, hidden): ...@@ -1654,26 +1800,30 @@ def test_bench_matmul(batch, seq, model, hidden):
for i in range(100): for i in range(100):
A2 = A.view(-1, A.shape[-1]).contiguous() A2 = A.view(-1, A.shape[-1]).contiguous()
CA, statsA = F.vectorwise_quant(A2, dim=1) CA, statsA = F.vectorwise_quant(A2, dim=1)
C32A, SA = F.nvidia_transform(CA, 'col32') C32A, SA = F.nvidia_transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
Cout, Sout = F.nvidia_transform(out32, 'row', state=Sout32) Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
F.vectorwise_mm_dequant(Cout, statsA, statsB.t()) F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
torch.cuda.synchronize() torch.cuda.synchronize()
print(f'vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') print(
f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
BA, statsB = F.vectorwise_quant(B, dim=1, quant_type='linear') BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
CxB, SB = F.nvidia_transform(CB, to_order=formatB) CxB, SB = F.nvidia_transform(CB, to_order=formatB)
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(100): for i in range(100):
A2 = A.view(-1, A.shape[-1]).contiguous() A2 = A.view(-1, A.shape[-1]).contiguous()
CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type='linear') CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
C32A, SA = F.nvidia_transform(CA, 'col32') C32A, SA = F.nvidia_transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
Cout, Sout = F.nvidia_transform(out32, 'row', state=Sout32) Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
out = Cout*statsB*statsA*(1.0/(127*127)) out = Cout * statsB * statsA * (1.0 / (127 * 127))
torch.cuda.synchronize() torch.cuda.synchronize()
print(f'linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') print(
f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
linear8bit(A) linear8bit(A)
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -1681,8 +1831,9 @@ def test_bench_matmul(batch, seq, model, hidden): ...@@ -1681,8 +1831,9 @@ def test_bench_matmul(batch, seq, model, hidden):
for i in range(100): for i in range(100):
linear8bit(A) linear8bit(A)
torch.cuda.synchronize() torch.cuda.synchronize()
print(f'bnb linear8bitlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') print(
f"bnb linear8bitlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
linearMixedBit(A) linearMixedBit(A)
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -1690,65 +1841,66 @@ def test_bench_matmul(batch, seq, model, hidden): ...@@ -1690,65 +1841,66 @@ def test_bench_matmul(batch, seq, model, hidden):
for i in range(100): for i in range(100):
linearMixedBit(A) linearMixedBit(A)
torch.cuda.synchronize() torch.cuda.synchronize()
print(f'bnb linear8bitlt with threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') print(
f"bnb linear8bitlt with threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
def test_zeropoint(): def test_zeropoint():
def min_max(x): def min_max(x):
maxA = torch.amax(x, dim=1, keepdim=True) maxA = torch.amax(x, dim=1, keepdim=True)
minA = torch.amin(x, dim=1, keepdim=True) minA = torch.amin(x, dim=1, keepdim=True)
midpoint = (maxA-minA)/2.0 midpoint = (maxA - minA) / 2.0
dyna = 252/(maxA-minA) dyna = 252 / (maxA - minA)
#dyna *= 0.98 # dyna *= 0.98
x = dyna*x x = dyna * x
x = x - torch.round((dyna*(minA+midpoint))) x = x - torch.round((dyna * (minA + midpoint)))
return x.to(torch.int8), minA, midpoint, dyna return x.to(torch.int8), minA, midpoint, dyna
batch = 2 batch = 2
seq = 2 seq = 2
model = 4 model = 4
hidden = 2*model hidden = 2 * model
#batch = 4 # batch = 4
#seq = 2048 # seq = 2048
#model = 1024 # model = 1024
#hidden = 8*model # hidden = 8*model
A = torch.randn(batch*seq, model, device='cuda').half()-0.4 A = torch.randn(batch * seq, model, device="cuda").half() - 0.4
B = torch.nn.Parameter(torch.randn(model, hidden, device='cuda').half()) B = torch.nn.Parameter(torch.randn(model, hidden, device="cuda").half())
#A[0] = 0 # A[0] = 0
#B[:, 0] = 0 # B[:, 0] = 0
#A = A*(A>0) # A = A*(A>0)
#A[0, 0] = 0 # A[0, 0] = 0
#A[0, 0] = 6.0 # A[0, 0] = 6.0
Ac, minA, midpoint, dyna = min_max(A) Ac, minA, midpoint, dyna = min_max(A)
#print(Ac[0, 0], 'zero') # print(Ac[0, 0], 'zero')
#print(Ac, Ac.min(), Ac.max()) # print(Ac, Ac.min(), Ac.max())
Bc, maxB = F.vectorwise_quant(B, quant_type='linear') Bc, maxB = F.vectorwise_quant(B, quant_type="linear")
out = F.igemm(Ac, Bc) out = F.igemm(Ac, Bc)
out2 = torch.matmul(A,B) out2 = torch.matmul(A, B)
offset = B.sum(0)*torch.round(dyna*(minA+midpoint))/dyna offset = B.sum(0) * torch.round(dyna * (minA + midpoint)) / dyna
out = out.float() out = out.float()
#print(out.shape, maxB.shape, scale.shape, offset.shape) # print(out.shape, maxB.shape, scale.shape, offset.shape)
norm1 = maxB/127 norm1 = maxB / 127
C4 = (out/dyna)*norm1+offset C4 = (out / dyna) * norm1 + offset
B1 = torch.nn.Parameter(B.clone()) B1 = torch.nn.Parameter(B.clone())
B2 = torch.nn.Parameter(B.clone()) B2 = torch.nn.Parameter(B.clone())
B3 = torch.nn.Parameter(B.clone()) B3 = torch.nn.Parameter(B.clone())
B4 = torch.nn.Parameter(B.clone()) B4 = torch.nn.Parameter(B.clone())
C1 = torch.matmul(A, B1) C1 = torch.matmul(A, B1)
C2 = bnb.matmul_cublas(A, B2, None, 'linear') C2 = bnb.matmul_cublas(A, B2, None, "linear")
C3 = bnb.matmul_cublas(A, B3, None, 'zeropoint') C3 = bnb.matmul_cublas(A, B3, None, "zeropoint")
C4 = bnb.matmul_cublas(A, B4, None, 'vector-zeropoint') C4 = bnb.matmul_cublas(A, B4, None, "vector-zeropoint")
err1 = torch.abs(C1-C2).mean().item() err1 = torch.abs(C1 - C2).mean().item()
err2 = torch.abs(C1-C3).mean().item() err2 = torch.abs(C1 - C3).mean().item()
err3 = torch.abs(C1-C4).mean().item() err3 = torch.abs(C1 - C4).mean().item()
print(err1, err2, err3) print(err1, err2, err3)
#assert err1 > err2 # assert err1 > err2
loss1 = C1.mean() loss1 = C1.mean()
loss2 = C2.mean() loss2 = C2.mean()
...@@ -1765,40 +1917,38 @@ def test_zeropoint(): ...@@ -1765,40 +1917,38 @@ def test_zeropoint():
print(B2.grad) print(B2.grad)
print(B3.grad) print(B3.grad)
print(B4.grad) print(B4.grad)
err1 = torch.abs(B1.grad-B2.grad).mean().item() err1 = torch.abs(B1.grad - B2.grad).mean().item()
err2 = torch.abs(B1.grad-B3.grad).mean().item() err2 = torch.abs(B1.grad - B3.grad).mean().item()
err3 = torch.abs(B1.grad-B4.grad).mean().item() err3 = torch.abs(B1.grad - B4.grad).mean().item()
print(err1, err2, err3) print(err1, err2, err3)
def test_zp(): def test_zp():
def quant_zp(x): def quant_zp(x):
dtype = x.dtype dtype = x.dtype
x = x.float() x = x.float()
dyna = x.max() - x.min() dyna = x.max() - x.min()
if dyna == 0: dyna = 1 if dyna == 0:
qx = 254./dyna dyna = 1
qx = 254.0 / dyna
minx = x.min() minx = x.min()
#zpx = torch.round(minx* qx) # zpx = torch.round(minx* qx)
#zpx = 127 - torch.round(x.max()* qx) # zpx = 127 - torch.round(x.max()* qx)
zpx = torch.round(x.min()* qx) - 127 zpx = torch.round(x.min() * qx) - 127
x = (qx*x) + zpx x = (qx * x) + zpx
return x, qx, zpx return x, qx, zpx
batch = 2 batch = 2
seq = 512 seq = 512
model = 1024 model = 1024
hidden = 4*model hidden = 4 * model
A = torch.randn(batch*seq, model, device='cuda').half()*0.1 A = torch.randn(batch * seq, model, device="cuda").half() * 0.1
B = torch.randn(model, hidden, device='cuda').half()*0.1 B = torch.randn(model, hidden, device="cuda").half() * 0.1
C0 = torch.matmul(A, B) C0 = torch.matmul(A, B)
# A, SA = F.vectorwise_quant(A, quant_type='linear')
#A, SA = F.vectorwise_quant(A, quant_type='linear') # B, SB = F.vectorwise_quant(B, quant_type='linear')
#B, SB = F.vectorwise_quant(B, quant_type='linear')
A = A.float() A = A.float()
B = B.float() B = B.float()
...@@ -1806,69 +1956,68 @@ def test_zp(): ...@@ -1806,69 +1956,68 @@ def test_zp():
C3 = bnb.matmul(A.half(), B.t().contiguous().half()) C3 = bnb.matmul(A.half(), B.t().contiguous().half())
zp = 1 zp = 1
#C2 = torch.matmul(A-zp, B) # C2 = torch.matmul(A-zp, B)
#C2 += B.sum(0).view(1, -1)*zp # C2 += B.sum(0).view(1, -1)*zp
C2 = torch.matmul(A, B-zp) C2 = torch.matmul(A, B - zp)
C2 -= A.sum(1).view(-1, 1)*zp C2 -= A.sum(1).view(-1, 1) * zp
ca, cqa, cza = quant_zp(A) ca, cqa, cza = quant_zp(A)
print(ca.min(), ca.max()) print(ca.min(), ca.max())
print((ca-cza).min(), (ca-cza).max()) print((ca - cza).min(), (ca - cza).max())
zp = 1 zp = 1
scale = 2.0 scale = 2.0
C5 = torch.matmul((A*scale)-zp, B) C5 = torch.matmul((A * scale) - zp, B)
C5 += B.sum(0)*zp C5 += B.sum(0) * zp
C5 /= scale C5 /= scale
CA, qa, zpa = quant_zp(A) CA, qa, zpa = quant_zp(A)
C4 = torch.matmul(CA, B) C4 = torch.matmul(CA, B)
C4 -= B.sum(0)*zpa C4 -= B.sum(0) * zpa
C4 /= qa C4 /= qa
zpb = 1 zpb = 1
zpa = 1 zpa = 1
qa = 2 qa = 2
qb = 2 qb = 2
C6 = torch.matmul((A*qa)+zpa, (B*qb)+zpb) C6 = torch.matmul((A * qa) + zpa, (B * qb) + zpb)
C6 -= (qb*B.sum(0).view(1, -1)*zpa) + (qa*A.sum(1).view(-1, 1)*zpb) C6 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
C6 -= zpa*zpb*A.shape[1] C6 -= zpa * zpb * A.shape[1]
C6 /= qa*qb C6 /= qa * qb
CA, qa, zpa = quant_zp(A) CA, qa, zpa = quant_zp(A)
CB, qb, zpb = quant_zp(B) CB, qb, zpb = quant_zp(B)
C7 = torch.matmul(CA, CB) C7 = torch.matmul(CA, CB)
C7 -= (qb*B.sum(0).view(1, -1)*zpa) + (qa*A.sum(1).view(-1, 1)*zpb) C7 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
C7 -= zpa*zpb*A.shape[1] C7 -= zpa * zpb * A.shape[1]
C7 /= qa*qb C7 /= qa * qb
print('') print("")
#print(C0.flatten()[:10]) # print(C0.flatten()[:10])
print(C1.flatten()[:10]) print(C1.flatten()[:10])
print(C2.flatten()[:10]) print(C2.flatten()[:10])
print(C3.flatten()[:10]) print(C3.flatten()[:10])
print(C5.flatten()[:10]) print(C5.flatten()[:10])
print(C6.flatten()[:10]) print(C6.flatten()[:10])
print(C7.flatten()[:10]) print(C7.flatten()[:10])
err1 = torch.abs(C1-C2).mean().item() err1 = torch.abs(C1 - C2).mean().item()
err2 = torch.abs(C1-C3).mean().item() err2 = torch.abs(C1 - C3).mean().item()
err3 = torch.abs(C1-C4).mean().item() err3 = torch.abs(C1 - C4).mean().item()
err4 = torch.abs(C1-C5).mean().item() err4 = torch.abs(C1 - C5).mean().item()
err5 = torch.abs(C1-C6).mean().item() err5 = torch.abs(C1 - C6).mean().item()
err6 = torch.abs(C1-C7).mean().item() err6 = torch.abs(C1 - C7).mean().item()
print(err1, err2, err3, err4, err5, err6) print(err1, err2, err3, err4, err5, err6)
def test_extract_outliers(): def test_extract_outliers():
for i in range(k): for i in range(k):
shapeA = (4096, 4096*4) shapeA = (4096, 4096 * 4)
idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda() idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda()
#idx = torch.Tensor([0]).int().cuda() # idx = torch.Tensor([0]).int().cuda()
A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8) A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
outliers1 = A[:, idx.long()] outliers1 = A[:, idx.long()]
CA, SA = F.transform(A, 'col_turing') CA, SA = F.transform(A, "col_turing")
outliers2 = F.extract_outliers(CA, SA, idx) outliers2 = F.extract_outliers(CA, SA, idx)
...@@ -1877,7 +2026,7 @@ def test_extract_outliers(): ...@@ -1877,7 +2026,7 @@ def test_extract_outliers():
torch.testing.assert_allclose(outliers1, outliers2) torch.testing.assert_allclose(outliers1, outliers2)
CA, SA = F.transform(A, 'col_ampere') CA, SA = F.transform(A, "col_ampere")
outliers2 = F.extract_outliers(CA, SA, idx) outliers2 = F.extract_outliers(CA, SA, idx)
......
from itertools import product
import pytest import pytest
import torch import torch
from itertools import product
from torch import nn from torch import nn
import bitsandbytes as bnb import bitsandbytes as bnb
class MockArgs(object): class MockArgs(object):
def __init__(self, initial_data): def __init__(self, initial_data):
for key in initial_data: for key in initial_data:
setattr(self, key, initial_data[key]) setattr(self, key, initial_data[key])
class MLP8bit(torch.nn.Module): class MLP8bit(torch.nn.Module):
def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0): def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0):
super(MLP8bit, self).__init__() super(MLP8bit, self).__init__()
self.fc1 = bnb.nn.Linear8bitLt(dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold) self.fc1 = bnb.nn.Linear8bitLt(
self.fc2 = bnb.nn.Linear8bitLt(dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold) dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold
)
self.fc2 = bnb.nn.Linear8bitLt(
dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold
)
def forward(self, x): def forward(self, x):
x = self.fc1(x) x = self.fc1(x)
...@@ -25,108 +31,114 @@ class MLP8bit(torch.nn.Module): ...@@ -25,108 +31,114 @@ class MLP8bit(torch.nn.Module):
def get_args(): def get_args():
args = MockArgs([]) args = MockArgs([])
args.quant_type = 'vector' args.quant_type = "vector"
args.use_8bit_training = 'full' args.use_8bit_training = "full"
args.clip_freq = 9999 args.clip_freq = 9999
return args return args
def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10): def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
idx = torch.isclose(a, b, rtol, atol) idx = torch.isclose(a, b, rtol, atol)
sumval = (idx==0).sum().item() sumval = (idx == 0).sum().item()
if sumval > count: if sumval > count:
print(f'Too many values not close: assert {sumval} < {count}') print(f"Too many values not close: assert {sumval} < {count}")
torch.testing.assert_allclose(a, b, rtol, atol) torch.testing.assert_allclose(a, b, rtol, atol)
class LinearFunction(torch.autograd.Function):
class LinearFunction(torch.autograd.Function):
@staticmethod @staticmethod
def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0): def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0):
round_func = LinearFunction.round_stoachastic if stochastic else torch.round round_func = LinearFunction.round_stoachastic if stochastic else torch.round
norm = math.sqrt(math.pi)/math.sqrt(2.0) norm = math.sqrt(math.pi) / math.sqrt(2.0)
#std = torch.abs(x).mean()*norm # std = torch.abs(x).mean()*norm
std = torch.std(x) std = torch.std(x)
max1 = std*trim_value max1 = std * trim_value
x = x/max1*127 x = x / max1 * 127
x = round_func(x) x = round_func(x)
x[x > 127] = 127 x[x > 127] = 127
x[x < -127] = -127 x[x < -127] = -127
x = x/127*max1 x = x / 127 * max1
return x return x
def quant(x, quant_type, dim=1): def quant(x, quant_type, dim=1):
if quant_type == 'linear': if quant_type == "linear":
max1 = torch.abs(x).max().float() max1 = torch.abs(x).max().float()
xq = torch.round(x/max1*127).to(torch.int8) xq = torch.round(x / max1 * 127).to(torch.int8)
return xq, max1 return xq, max1
elif quant_type == 'vector': elif quant_type == "vector":
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
xq = torch.round(x/max1*127).to(torch.int8) xq = torch.round(x / max1 * 127).to(torch.int8)
return xq, max1 return xq, max1
elif quant_type == 'min-max': elif quant_type == "min-max":
maxA = torch.amax(x, dim=dim, keepdim=True).float() maxA = torch.amax(x, dim=dim, keepdim=True).float()
minA = torch.amin(x, dim=dim, keepdim=True).float() minA = torch.amin(x, dim=dim, keepdim=True).float()
scale = (maxA-minA)/2.0 scale = (maxA - minA) / 2.0
xq = torch.round(127*(x-minA-scale)/scale).to(torch.int8) xq = torch.round(127 * (x - minA - scale) / scale).to(torch.int8)
return xq, (minA.float(), scale.float()) return xq, (minA.float(), scale.float())
else: return None else:
return None
def dequant(xq, S1, S2, dtype, quant_type): def dequant(xq, S1, S2, dtype, quant_type):
if quant_type == 'linear': if quant_type == "linear":
norm = S1*S2/(127*127) norm = S1 * S2 / (127 * 127)
# double cast needed to prevent overflows # double cast needed to prevent overflows
return (xq.float()*norm).to(dtype) return (xq.float() * norm).to(dtype)
elif quant_type == 'vector': elif quant_type == "vector":
x = xq.float() x = xq.float()
if len(xq.shape) == 2 and len(S1.shape) == 3: S1 = S1.squeeze(0) if len(xq.shape) == 2 and len(S1.shape) == 3:
if len(xq.shape) == 2 and len(S2.shape) == 3: S2 = S2.squeeze(0) S1 = S1.squeeze(0)
#print(x.shape, S1.shape, S2.shape) if len(xq.shape) == 2 and len(S2.shape) == 3:
S2 = S2.squeeze(0)
# print(x.shape, S1.shape, S2.shape)
if len(S1.shape) == 2: if len(S1.shape) == 2:
x *= S1.t()/127 x *= S1.t() / 127
else: else:
x *= S1/127 x *= S1 / 127
x *= S2/127 x *= S2 / 127
return x.to(dtype) return x.to(dtype)
else: return None else:
return None
def dequant_min_max(xq, A, B, SA, SB, dtype): def dequant_min_max(xq, A, B, SA, SB, dtype):
offset = B.float().t().sum(0)*(SA[0]+SA[1]) offset = B.float().t().sum(0) * (SA[0] + SA[1])
x = xq.float() x = xq.float()
if len(xq.shape) == 2 and len(SB.shape) == 3: SB = SB.squeeze(0) if len(xq.shape) == 2 and len(SB.shape) == 3:
if len(xq.shape) == 2 and len(SA.shape) == 3: SA = SA.squeeze(0) SB = SB.squeeze(0)
if len(xq.shape) == 2 and len(SA.shape) == 3:
SA = SA.squeeze(0)
if len(SB.shape) == 2: if len(SB.shape) == 2:
x *= SB.t()/127 x *= SB.t() / 127
else: else:
x *= SB/127 x *= SB / 127
x *= SA[1]/127 x *= SA[1] / 127
x +=offset x += offset
return x.to(dtype) return x.to(dtype)
def get_8bit_linear(x, stochastic=False): def get_8bit_linear(x, stochastic=False):
round_func = LinearFunction.round_stoachastic if stochastic else torch.round round_func = LinearFunction.round_stoachastic if stochastic else torch.round
max1 = torch.abs(x).max() max1 = torch.abs(x).max()
x = x/max1*127 x = x / max1 * 127
x = round_func(x)/127*max1 x = round_func(x) / 127 * max1
#x = torch.round(x)/128*max1 # x = torch.round(x)/128*max1
return x return x
@staticmethod @staticmethod
def get_8bit_vector_wise(x, dim, stochastic=False): def get_8bit_vector_wise(x, dim, stochastic=False):
round_func = LinearFunction.round_stoachastic if stochastic else torch.round round_func = LinearFunction.round_stoachastic if stochastic else torch.round
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
max1[max1==0] = 1.0 max1[max1 == 0] = 1.0
x = (x*127)/max1 x = (x * 127) / max1
x = round_func(x)/127*max1 x = round_func(x) / 127 * max1
return x return x
@staticmethod @staticmethod
def round_stoachastic(x): def round_stoachastic(x):
sign = torch.sign(x) sign = torch.sign(x)
absx = torch.abs(x) absx = torch.abs(x)
decimal = absx-torch.floor(absx) decimal = absx - torch.floor(absx)
rdm = torch.rand_like(decimal) rdm = torch.rand_like(decimal)
return sign*(torch.floor(absx)+(rdm < decimal).to(x.dtype)) return sign * (torch.floor(absx) + (rdm < decimal).to(x.dtype))
@staticmethod @staticmethod
def fake_8bit_storage(w, exponent_bits): def fake_8bit_storage(w, exponent_bits):
...@@ -140,10 +152,10 @@ class LinearFunction(torch.autograd.Function): ...@@ -140,10 +152,10 @@ class LinearFunction(torch.autograd.Function):
@staticmethod @staticmethod
def fake_8bit_storage_quantile(w, args): def fake_8bit_storage_quantile(w, args):
code = bnb.functional.estimate_quantiles(w.data, offset=args.offset) code = bnb.functional.estimate_quantiles(w.data, offset=args.offset)
#C = bnb.functional.quantize_no_absmax(code, w) # C = bnb.functional.quantize_no_absmax(code, w)
#out = bnb.functional.dequantize_no_absmax(code, C, out=w.data) # out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
#print(out) # print(out)
#out = out.half() # out = out.half()
code /= torch.max(torch.abs(code)) code /= torch.max(torch.abs(code))
absmax, C = bnb.functional.quantize_blockwise(w.data, code=code) absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
out = bnb.functional.dequantize_blockwise(absmax, C, code) out = bnb.functional.dequantize_blockwise(absmax, C, code)
...@@ -162,7 +174,7 @@ class LinearFunction(torch.autograd.Function): ...@@ -162,7 +174,7 @@ class LinearFunction(torch.autograd.Function):
@staticmethod @staticmethod
def fake_8bit_storage_with_max(w, topk=8): def fake_8bit_storage_with_max(w, topk=8):
blocked_w = einops.rearrange(w.flatten(), '(h b) -> h b', b=256) blocked_w = einops.rearrange(w.flatten(), "(h b) -> h b", b=256)
max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True) max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True)
idx = idx[:, :topk] idx = idx[:, :topk]
max_val = max_val[:, :topk] max_val = max_val[:, :topk]
...@@ -191,22 +203,21 @@ class LinearFunction(torch.autograd.Function): ...@@ -191,22 +203,21 @@ class LinearFunction(torch.autograd.Function):
w.copy_(unblocked_w) w.copy_(unblocked_w)
return unblocked_w return unblocked_w
@staticmethod @staticmethod
def forward(ctx, x, weight, bias=None, args=None): def forward(ctx, x, weight, bias=None, args=None):
if args.use_8bit_training != 'off': if args.use_8bit_training != "off":
weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1) weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1)
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2) x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2)
outputq = bnb.functional.igemm(x8, weight8.t()) outputq = bnb.functional.igemm(x8, weight8.t())
output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type) output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type)
#if torch.rand(1) < 0.01: # if torch.rand(1) < 0.01:
#output32 = torch.matmul(x, weight.t()) # output32 = torch.matmul(x, weight.t())
#err = torch.abs(output-output32).float() # err = torch.abs(output-output32).float()
#relerr = err/(torch.abs(output32).float()+1e-8) # relerr = err/(torch.abs(output32).float()+1e-8)
#print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy) # print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
else: else:
#output = torch.matmul(x, weight.t()) # output = torch.matmul(x, weight.t())
output = torch.einsum('bsi,oi->bso', x, weight) output = torch.einsum("bsi,oi->bso", x, weight)
ctx.save_for_backward(x, weight, bias) ctx.save_for_backward(x, weight, bias)
ctx.args = args ctx.args = args
...@@ -221,37 +232,49 @@ class LinearFunction(torch.autograd.Function): ...@@ -221,37 +232,49 @@ class LinearFunction(torch.autograd.Function):
args = ctx.args args = ctx.args
stochastic = False stochastic = False
grad_input = grad_weight = grad_bias = None grad_input = grad_weight = grad_bias = None
if bias is not None and ctx.needs_input_grad[2]: grad_bias = grad_output.sum(0) if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
# weight and x are already 8bit # weight and x are already 8bit
# -> transform grad_output to 8-bit # -> transform grad_output to 8-bit
if args.use_8bit_training == 'forward+wgrad': if args.use_8bit_training == "forward+wgrad":
grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1]) grad_output8, S1 = LinearFunction.quant(
grad_output, args.quant_type, dim=[0, 1]
)
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1]) x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
grad_weight8 = bnb.functional.igemm(grad_output8, x8) grad_weight8 = bnb.functional.igemm(grad_output8, x8)
grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type) grad_weight = LinearFunction.dequant(
grad_weight8, S1, S2, grad_output.dtype, args.quant_type
)
#grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x) # grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
grad_input = grad_output.matmul(weight) grad_input = grad_output.matmul(weight)
elif args.use_8bit_training == 'full': elif args.use_8bit_training == "full":
grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1]) grad_output8, S1 = LinearFunction.quant(
grad_output, args.quant_type, dim=[0, 1]
)
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1]) x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
grad_weight8 = torch.zeros_like(weight, dtype=torch.int32) grad_weight8 = torch.zeros_like(weight, dtype=torch.int32)
bnb.functional.igemm(grad_output8, x8, out=grad_weight8) bnb.functional.igemm(grad_output8, x8, out=grad_weight8)
grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type) grad_weight = LinearFunction.dequant(
grad_weight8, S1, S2, grad_output.dtype, args.quant_type
)
grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2) grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2)
weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0) weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0)
grad_input8 = bnb.functional.igemm(grad_output8, weight8) grad_input8 = bnb.functional.igemm(grad_output8, weight8)
grad_input = LinearFunction.dequant(grad_input8, S1, S3, grad_output.dtype, args.quant_type) grad_input = LinearFunction.dequant(
grad_input8, S1, S3, grad_output.dtype, args.quant_type
)
else: else:
grad_input = grad_output.matmul(weight) grad_input = grad_output.matmul(weight)
grad_weight = torch.einsum('bsi,bso->oi', x, grad_output) grad_weight = torch.einsum("bsi,bso->oi", x, grad_output)
return grad_input, grad_weight, grad_bias, None return grad_input, grad_weight, grad_bias, None
class Linear8bit(nn.Module): class Linear8bit(nn.Module):
def __init__(self, input_features, output_features, bias=True, args=None): def __init__(self, input_features, output_features, bias=True, args=None):
super(Linear8bit, self).__init__() super(Linear8bit, self).__init__()
...@@ -263,7 +286,7 @@ class Linear8bit(nn.Module): ...@@ -263,7 +286,7 @@ class Linear8bit(nn.Module):
if bias: if bias:
self.bias = nn.Parameter(torch.empty(output_features)) self.bias = nn.Parameter(torch.empty(output_features))
else: else:
self.register_parameter('bias', None) self.register_parameter("bias", None)
torch.nn.init.xavier_uniform_(self.weight) torch.nn.init.xavier_uniform_(self.weight)
if self.bias is not None: if self.bias is not None:
...@@ -275,12 +298,11 @@ class Linear8bit(nn.Module): ...@@ -275,12 +298,11 @@ class Linear8bit(nn.Module):
return LinearFunction.apply(x, self.weight, self.bias, self.args) return LinearFunction.apply(x, self.weight, self.bias, self.args)
def test_linear8bit(): def test_linear8bit():
l0 = torch.nn.Linear(32, 64).cuda().half() l0 = torch.nn.Linear(32, 64).cuda().half()
l1 = bnb.nn.Linear8bit(32,64, args=get_args()).cuda().half() l1 = bnb.nn.Linear8bit(32, 64, args=get_args()).cuda().half()
l2 = Linear8bit(32, 64, args=get_args()).cuda().half() l2 = Linear8bit(32, 64, args=get_args()).cuda().half()
l3 = bnb.nn.Linear8bitLt(32,64).cuda().half() l3 = bnb.nn.Linear8bitLt(32, 64).cuda().half()
l0.weight.data = l2.weight.data.clone() l0.weight.data = l2.weight.data.clone()
l0.bias.data = l2.bias.data.clone() l0.bias.data = l2.bias.data.clone()
...@@ -292,8 +314,8 @@ def test_linear8bit(): ...@@ -292,8 +314,8 @@ def test_linear8bit():
l3.bias.data = l2.bias.data.clone() l3.bias.data = l2.bias.data.clone()
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device='cuda').half() b1 = torch.randn(16, 8, 32, device="cuda").half()
t = torch.randn(16, 8, 64, device='cuda').half() t = torch.randn(16, 8, 64, device="cuda").half()
b2 = b1.clone() b2 = b1.clone()
b3 = b1.clone() b3 = b1.clone()
b0 = b1.clone() b0 = b1.clone()
...@@ -318,16 +340,20 @@ def test_linear8bit(): ...@@ -318,16 +340,20 @@ def test_linear8bit():
assert_all_approx_close(l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2) assert_all_approx_close(l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2)
assert_all_approx_close(l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2) assert_all_approx_close(l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2)
assert_all_approx_close(l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2) assert_all_approx_close(
assert_all_approx_close(l3.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2) l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2
)
assert_all_approx_close(
l3.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2
)
err1 = torch.abs(l0.weight.grad-l1.weight.grad).mean().item() err1 = torch.abs(l0.weight.grad - l1.weight.grad).mean().item()
err2 = torch.abs(l0.weight.grad-l2.weight.grad).mean().item() err2 = torch.abs(l0.weight.grad - l2.weight.grad).mean().item()
err3 = torch.abs(l0.weight.grad-l3.weight.grad).mean().item() err3 = torch.abs(l0.weight.grad - l3.weight.grad).mean().item()
assert err1*0.8 < err2 assert err1 * 0.8 < err2
assert err2*0.8 < err3 assert err2 * 0.8 < err3
assert err3*0.8 < err1 assert err3 * 0.8 < err1
l0.weight.grad = None l0.weight.grad = None
l1.weight.grad = None l1.weight.grad = None
...@@ -341,23 +367,28 @@ def test_linear8bit(): ...@@ -341,23 +367,28 @@ def test_linear8bit():
threshold = [0.0, 3.0] threshold = [0.0, 3.0]
values = threshold values = threshold
names = ['threshold_{0}'.format(vals) for vals in values] names = ["threshold_{0}".format(vals) for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names) @pytest.mark.parametrize("threshold", values, ids=names)
def test_linear8bitlt_inference(threshold): def test_linear8bitlt_inference(threshold):
l1 = bnb.nn.Linear8bitLt(32,64, threshold=threshold).cuda().half() l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold).cuda().half()
assert l1.weight.device.type == 'cuda' assert l1.weight.device.type == "cuda"
assert l1.weight.dtype == torch.float16 assert l1.weight.dtype == torch.float16
l1.eval() l1.eval()
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device='cuda').half() b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = l1(b1) o1 = l1(b1)
if i == 1: if i == 1:
assert l1.state.CxB is not None assert l1.state.CxB is not None
def test_linear8bitlt_accumulated_gradient(): def test_linear8bitlt_accumulated_gradient():
l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32,32).cuda().half() for i in range(2)]) l1 = torch.nn.Sequential(
l2 = torch.nn.Sequential(*[torch.nn.Linear(32,32).cuda().half() for i in range(2)]) *[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)]
)
l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)])
l2[0].weight = torch.nn.Parameter(l1[0].weight.clone()) l2[0].weight = torch.nn.Parameter(l1[0].weight.clone())
l2[0].bias = torch.nn.Parameter(l1[0].bias.clone()) l2[0].bias = torch.nn.Parameter(l1[0].bias.clone())
l2[1].weight = torch.nn.Parameter(l1[1].weight.clone()) l2[1].weight = torch.nn.Parameter(l1[1].weight.clone())
...@@ -367,9 +398,8 @@ def test_linear8bitlt_accumulated_gradient(): ...@@ -367,9 +398,8 @@ def test_linear8bitlt_accumulated_gradient():
acc_steps = 10 acc_steps = 10
for i in range(10): for i in range(10):
b1 = torch.randn(16, 8, 32, device='cuda').half() b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = l1(b1) o1 = l1(b1)
o2 = l2(b1) o2 = l2(b1)
loss1 = o1.mean() loss1 = o1.mean()
...@@ -385,8 +415,12 @@ def test_linear8bitlt_accumulated_gradient(): ...@@ -385,8 +415,12 @@ def test_linear8bitlt_accumulated_gradient():
opt1.zero_grad(True) opt1.zero_grad(True)
opt2.step() opt2.step()
opt2.zero_grad(True) opt2.zero_grad(True)
assert_all_approx_close(l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2) assert_all_approx_close(
assert_all_approx_close(l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2) l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2
)
assert_all_approx_close(
l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2
)
# we do this copy because otherwise we have small divergences over time that add up # we do this copy because otherwise we have small divergences over time that add up
l1[0].weight.data.copy_(l2[0].weight.data) l1[0].weight.data.copy_(l2[0].weight.data)
l1[1].weight.data.copy_(l2[1].weight.data) l1[1].weight.data.copy_(l2[1].weight.data)
...@@ -397,15 +431,21 @@ def test_linear8bitlt_accumulated_gradient(): ...@@ -397,15 +431,21 @@ def test_linear8bitlt_accumulated_gradient():
threshold = [0.0, 2.0] threshold = [0.0, 2.0]
values = threshold values = threshold
names = ['threshold_{0}'.format(vals) for vals in values] names = ["threshold_{0}".format(vals) for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names) @pytest.mark.parametrize("threshold", values, ids=names)
def test_linear8bitlt_no_fp16_weights(threshold): def test_linear8bitlt_no_fp16_weights(threshold):
l1 = bnb.nn.Linear8bitLt(32,64, threshold=threshold, has_fp16_weights=False).cuda().half() l1 = (
bnb.nn.Linear8bitLt(32, 64, threshold=threshold, has_fp16_weights=False)
.cuda()
.half()
)
assert l1.weight.dtype == torch.int8 assert l1.weight.dtype == torch.int8
l1.eval() l1.eval()
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device='cuda').half() b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = l1(b1) o1 = l1(b1)
assert o1.dtype == torch.float16 assert o1.dtype == torch.float16
...@@ -414,57 +454,70 @@ def test_linear8bitlt_no_fp16_weights(threshold): ...@@ -414,57 +454,70 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert mlp.fc2.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device='cuda').half() b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = mlp(b1) o1 = mlp(b1)
assert o1.dtype == torch.float16 assert o1.dtype == torch.float16
if threshold > 0: assert mlp.fc1.state.idx is not None if threshold > 0:
if threshold > 0: assert mlp.fc2.state.idx is not None assert mlp.fc1.state.idx is not None
if threshold > 0:
assert mlp.fc2.state.idx is not None
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half() mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half()
assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device='cuda').half() b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = mlp(b1) o1 = mlp(b1)
assert o1.dtype == torch.float16 assert o1.dtype == torch.float16
if threshold > 0: assert mlp.fc1.state.idx is not None if threshold > 0:
if threshold > 0: assert mlp.fc2.state.idx is not None assert mlp.fc1.state.idx is not None
if threshold > 0:
assert mlp.fc2.state.idx is not None
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda() mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda()
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device='cuda').half() b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = mlp(b1) o1 = mlp(b1)
assert o1.dtype == torch.float16 assert o1.dtype == torch.float16
if threshold > 0: assert mlp.fc1.state.idx is not None if threshold > 0:
if threshold > 0: assert mlp.fc2.state.idx is not None assert mlp.fc1.state.idx is not None
if threshold > 0:
assert mlp.fc2.state.idx is not None
assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to("cuda")
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to('cuda')
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device='cuda').half() b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = mlp(b1) o1 = mlp(b1)
assert o1.dtype == torch.float16 assert o1.dtype == torch.float16
if threshold > 0: assert mlp.fc1.state.idx is not None if threshold > 0:
if threshold > 0: assert mlp.fc2.state.idx is not None assert mlp.fc1.state.idx is not None
if threshold > 0:
assert mlp.fc2.state.idx is not None
assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8
assert mlp.fc1.weight.device.type == 'cuda' assert mlp.fc1.weight.device.type == "cuda"
assert mlp.fc2.weight.device.type == 'cuda' assert mlp.fc2.weight.device.type == "cuda"
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).to(torch.float16).to('cuda') mlp = (
MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
.to(torch.float16)
.to("cuda")
)
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device='cuda').half() b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = mlp(b1) o1 = mlp(b1)
assert o1.dtype == torch.float16 assert o1.dtype == torch.float16
if threshold > 0: assert mlp.fc1.state.idx is not None if threshold > 0:
if threshold > 0: assert mlp.fc2.state.idx is not None assert mlp.fc1.state.idx is not None
if threshold > 0:
assert mlp.fc2.state.idx is not None
assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8
assert mlp.fc1.weight.device.type == 'cuda' assert mlp.fc1.weight.device.type == "cuda"
assert mlp.fc2.weight.device.type == 'cuda' assert mlp.fc2.weight.device.type == "cuda"
import ctypes
import os import os
import time
import shutil import shutil
import time
import uuid import uuid
from itertools import product
from os.path import join
import pytest import pytest
import ctypes
import torch import torch
import bitsandbytes as bnb import bitsandbytes as bnb
import bitsandbytes.functional as F import bitsandbytes.functional as F
from os.path import join # import apex
from itertools import product
#import apex
k = 20 k = 20
def get_temp_dir(): def get_temp_dir():
path = '/tmp/autoswap/{0}'.format(str(uuid.uuid4())) path = "/tmp/autoswap/{0}".format(str(uuid.uuid4()))
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
return path return path
def rm_path(path): def rm_path(path):
shutil.rmtree(path) shutil.rmtree(path)
str2optimizers = {} str2optimizers = {}
str2optimizers['adam_pytorch'] = (None, torch.optim.Adam, bnb.optim.Adam) str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
#str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam) # str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
#str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam) # str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
str2optimizers['momentum_pytorch'] = (None, lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), bnb.optim.Adam) str2optimizers["momentum_pytorch"] = (
#str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam) None,
#str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam) lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
bnb.optim.Adam,
str2optimizers['adam'] = (torch.optim.Adam, bnb.optim.Adam) )
#str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam) # str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam)
str2optimizers['momentum'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False)) # str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam)
str2optimizers['lars'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9))
#str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB) str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
str2optimizers['rmsprop'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False)) # str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers['adam8bit'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False)) str2optimizers["momentum"] = (
str2optimizers['momentum8bit'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False)) lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
str2optimizers['rmsprop8bit'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False)) lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
#str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit) )
str2optimizers['lars8bit'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9)) str2optimizers["lars"] = (
lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9),
str2optimizers['adam8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True)) lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9),
str2optimizers['momentum8bit_blockwise'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True)) )
str2optimizers['rmsprop8bit_blockwise'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True)) # str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB)
str2optimizers["rmsprop"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["adam8bit"] = (
torch.optim.Adam,
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False),
)
str2optimizers["momentum8bit"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["rmsprop8bit"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False),
)
# str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit)
str2optimizers["lars8bit"] = (
lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9),
)
str2optimizers["adam8bit_blockwise"] = (
torch.optim.Adam,
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True),
)
str2optimizers["momentum8bit_blockwise"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
)
str2optimizers["rmsprop8bit_blockwise"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True),
)
str2statenames = {} str2statenames = {}
str2statenames['adam'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')] str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames['momentum'] = [('momentum_buffer', 'state1')] str2statenames["momentum"] = [("momentum_buffer", "state1")]
str2statenames['lars'] = [('momentum_buffer', 'state1')] str2statenames["lars"] = [("momentum_buffer", "state1")]
str2statenames['lamb'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')] str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames['rmsprop'] = [('square_avg', 'state1')] str2statenames["rmsprop"] = [("square_avg", "state1")]
str2statenames['adam8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')] str2statenames["adam8bit"] = [
str2statenames['lamb8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')] ("exp_avg", "state1", "qmap1", "max1"),
str2statenames['adam8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')] ("exp_avg_sq", "state2", "qmap2", "max2"),
str2statenames['momentum8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')] ]
str2statenames['momentum8bit_blockwise'] = [('momentum_buffer', 'state1', 'qmap1', 'absmax1')] str2statenames["lamb8bit"] = [
str2statenames['lars8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')] ("exp_avg", "state1", "qmap1", "max1"),
str2statenames['rmsprop8bit'] = [('square_avg', 'state1', 'qmap1', 'max1')] ("exp_avg_sq", "state2", "qmap2", "max2"),
str2statenames['rmsprop8bit_blockwise'] = [('square_avg', 'state1', 'qmap1', 'absmax1')] ]
str2statenames["adam8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1"),
("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
str2statenames["momentum8bit_blockwise"] = [
("momentum_buffer", "state1", "qmap1", "absmax1")
]
str2statenames["lars8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")]
dim1 = [1024] dim1 = [1024]
dim2 = [32, 1024, 4097, 1] dim2 = [32, 1024, 4097, 1]
gtype = [torch.float32, torch.float16] gtype = [torch.float32, torch.float16]
optimizer_names = ['adam', 'momentum', 'rmsprop', 'lars', 'lamb'] optimizer_names = ["adam", "momentum", "rmsprop", "lars", "lamb"]
values = list(product(dim1,dim2, gtype, optimizer_names)) values = list(product(dim1, dim2, gtype, optimizer_names))
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values] names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer32bit(dim1, dim2, gtype, optim_name): def test_optimizer32bit(dim1, dim2, gtype, optim_name):
if dim1 == 1 and dim2 == 1: return if dim1 == 1 and dim2 == 1:
p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 return
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
p2 = p1.clone() p2 = p1.clone()
p1 = p1.float() p1 = p1.float()
torch_optimizer = str2optimizers[optim_name][0]([p1]) torch_optimizer = str2optimizers[optim_name][0]([p1])
bnb_optimizer = str2optimizers[optim_name][1]([p2]) bnb_optimizer = str2optimizers[optim_name][1]([p2])
...@@ -84,9 +135,8 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): ...@@ -84,9 +135,8 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
else: else:
atol, rtol = 1e-4, 1e-3 atol, rtol = 1e-4, 1e-3
for i in range(k): for i in range(k):
g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01 g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
p1.grad = g.clone().float() p1.grad = g.clone().float()
p2.grad = g.clone() p2.grad = g.clone()
...@@ -94,21 +144,31 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): ...@@ -94,21 +144,31 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
torch_optimizer.step() torch_optimizer.step()
for name1, name2 in str2statenames[optim_name]: for name1, name2 in str2statenames[optim_name]:
torch.testing.assert_allclose(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2], atol=atol, rtol=rtol) torch.testing.assert_allclose(
torch_optimizer.state[p1][name1],
bnb_optimizer.state[p2][name2],
atol=atol,
rtol=rtol,
)
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol) torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)
if i % (k//5) == 0 and i > 0: if i % (k // 5) == 0 and i > 0:
path = get_temp_dir() path = get_temp_dir()
torch.save(bnb_optimizer.state_dict(),join(path, 'opt.pt')) torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
del bnb_optimizer del bnb_optimizer
bnb_optimizer = None bnb_optimizer = None
bnb_optimizer = str2optimizers[optim_name][1]([p2]) bnb_optimizer = str2optimizers[optim_name][1]([p2])
bnb_optimizer.load_state_dict(torch.load(join(path, 'opt.pt'))) bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
rm_path(path) rm_path(path)
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol) torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)
for name1, name2 in str2statenames[optim_name]: for name1, name2 in str2statenames[optim_name]:
torch.testing.assert_allclose(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2], atol=atol, rtol=rtol) torch.testing.assert_allclose(
torch_optimizer.state[p1][name1],
bnb_optimizer.state[p2][name2],
atol=atol,
rtol=rtol,
)
if gtype == torch.float16: if gtype == torch.float16:
# the adam buffers should also be close because they are 32-bit # the adam buffers should also be close because they are 32-bit
...@@ -118,20 +178,24 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): ...@@ -118,20 +178,24 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
p1.data = p1.data.half().float() p1.data = p1.data.half().float()
p2.copy_(p1.data) p2.copy_(p1.data)
torch.testing.assert_allclose(p1.half(), p2) torch.testing.assert_allclose(p1.half(), p2)
if optim_name in ['lars', 'lamb']: if optim_name in ["lars", "lamb"]:
assert bnb_optimizer.state[p2]['unorm_vec'] > 0.0 assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0
dim1 = [1024] dim1 = [1024]
dim2 = [32, 1024, 4097] dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16] gtype = [torch.float32, torch.float16]
values = list(product(dim1,dim2, gtype)) values = list(product(dim1, dim2, gtype))
names = ['dim1_{0}_dim2_{1}_gtype_{2}'.format(*vals) for vals in values] names = ["dim1_{0}_dim2_{1}_gtype_{2}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names) @pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names)
def test_global_config(dim1, dim2, gtype): def test_global_config(dim1, dim2, gtype):
if dim1 == 1 and dim2 == 1: return if dim1 == 1 and dim2 == 1:
p1 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1 return
p2 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1 p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
p3 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1 p2 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
p3 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
mask = torch.rand_like(p2) < 0.1 mask = torch.rand_like(p2) < 0.1
beta1 = 0.9 beta1 = 0.9
beta2 = 0.999 beta2 = 0.999
...@@ -139,7 +203,7 @@ def test_global_config(dim1, dim2, gtype): ...@@ -139,7 +203,7 @@ def test_global_config(dim1, dim2, gtype):
eps = 1e-8 eps = 1e-8
bnb.optim.GlobalOptimManager.get_instance().initialize() bnb.optim.GlobalOptimManager.get_instance().initialize()
bnb.optim.GlobalOptimManager.get_instance().override_config(p3, 'optim_bits', 8) bnb.optim.GlobalOptimManager.get_instance().override_config(p3, "optim_bits", 8)
bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3]) bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3])
p1 = p1.cuda() p1 = p1.cuda()
...@@ -154,30 +218,41 @@ def test_global_config(dim1, dim2, gtype): ...@@ -154,30 +218,41 @@ def test_global_config(dim1, dim2, gtype):
atol, rtol = 1e-4, 1e-3 atol, rtol = 1e-4, 1e-3
for i in range(50): for i in range(50):
g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001 g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
g2 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001 g2 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
g3 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001 g3 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
p1.grad = g1 p1.grad = g1
p2.grad = g2 p2.grad = g2
p3.grad = g3 p3.grad = g3
adam2.step() adam2.step()
assert adam2.state[p3]['state1'].dtype == torch.uint8 assert adam2.state[p3]["state1"].dtype == torch.uint8
assert adam2.state[p3]['state2'].dtype == torch.uint8 assert adam2.state[p3]["state2"].dtype == torch.uint8
dim1 = [1024] dim1 = [1024]
dim2 = [32, 1024, 4097] dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16] gtype = [torch.float32, torch.float16]
optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise'] optimizer_names = [
values = list(product(dim1,dim2, gtype, optimizer_names)) "adam8bit",
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values] "momentum8bit",
"rmsprop8bit",
"adam8bit_blockwise",
"lamb8bit",
"lars8bit",
"momentum8bit_blockwise",
"rmsprop8bit_blockwise",
]
values = list(product(dim1, dim2, gtype, optimizer_names))
names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer8bit(dim1, dim2, gtype, optim_name): def test_optimizer8bit(dim1, dim2, gtype, optim_name):
if dim1 == 1 and dim2 == 1: return if dim1 == 1 and dim2 == 1:
p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 return
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
p2 = p1.clone() p2 = p1.clone()
p1 = p1.float() p1 = p1.float()
blocksize = 2048 blocksize = 2048
...@@ -197,7 +272,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ...@@ -197,7 +272,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
relerrors = [] relerrors = []
for i in range(50): for i in range(50):
g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01 g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
p1.grad = g.clone().float() p1.grad = g.clone().float()
p2.grad = g.clone() p2.grad = g.clone()
...@@ -208,17 +283,31 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ...@@ -208,17 +283,31 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
dequant_states = [] dequant_states = []
for name1, name2, qmap, max_val in str2statenames[optim_name]: for name1, name2, qmap, max_val in str2statenames[optim_name]:
#print(bnb_optimizer.state[p2][max_val], name1) # print(bnb_optimizer.state[p2][max_val], name1)
if 'blockwise' in optim_name: if "blockwise" in optim_name:
s1 = F.dequantize_blockwise(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2], blocksize=blocksize) s1 = F.dequantize_blockwise(
code=bnb_optimizer.state[p2][qmap],
absmax=bnb_optimizer.state[p2][max_val],
A=bnb_optimizer.state[p2][name2],
blocksize=blocksize,
)
else: else:
s1 = F.dequantize(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2]) s1 = F.dequantize(
num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol)==0 code=bnb_optimizer.state[p2][qmap],
absmax=bnb_optimizer.state[p2][max_val],
A=bnb_optimizer.state[p2][name2],
)
num_not_close = (
torch.isclose(
torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol
)
== 0
)
assert num_not_close.sum().item() < 20 assert num_not_close.sum().item() < 20
dequant_states.append(s1.clone()) dequant_states.append(s1.clone())
err = torch.abs(p1-p2) err = torch.abs(p1 - p2)
relerr = err/torch.abs(p1) relerr = err / torch.abs(p1)
assert err.mean() < 0.0001 assert err.mean() < 0.0001
assert relerr.mean() < 0.001 assert relerr.mean() < 0.001
...@@ -226,28 +315,44 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ...@@ -226,28 +315,44 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
relerrors.append(relerr.mean().item()) relerrors.append(relerr.mean().item())
if i % 10 == 0 and i > 0: if i % 10 == 0 and i > 0:
for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states): for (name1, name2, qmap, max_val), s in zip(
str2statenames[optim_name], dequant_states
):
s1cpy = s.clone() s1cpy = s.clone()
raws1cpy = bnb_optimizer.state[p2][name2].clone() raws1cpy = bnb_optimizer.state[p2][name2].clone()
qmap1 = bnb_optimizer.state[p2][qmap].clone() qmap1 = bnb_optimizer.state[p2][qmap].clone()
path = get_temp_dir() path = get_temp_dir()
torch.save(bnb_optimizer.state_dict(),join(path, 'opt.pt')) torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
del bnb_optimizer del bnb_optimizer
bnb_optimizer = None bnb_optimizer = None
bnb_optimizer = str2optimizers[optim_name][1]([p2]) bnb_optimizer = str2optimizers[optim_name][1]([p2])
bnb_optimizer.load_state_dict(torch.load(join(path, 'opt.pt'))) bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
rm_path(path) rm_path(path)
torch.testing.assert_allclose(raws1cpy, bnb_optimizer.state[p2][name2]) torch.testing.assert_allclose(raws1cpy, bnb_optimizer.state[p2][name2])
torch.testing.assert_allclose(qmap1, bnb_optimizer.state[p2][qmap]) torch.testing.assert_allclose(qmap1, bnb_optimizer.state[p2][qmap])
if 'blockwise' in optim_name: if "blockwise" in optim_name:
s1 = F.dequantize_blockwise(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2], blocksize=blocksize) s1 = F.dequantize_blockwise(
code=bnb_optimizer.state[p2][qmap],
absmax=bnb_optimizer.state[p2][max_val],
A=bnb_optimizer.state[p2][name2],
blocksize=blocksize,
)
else: else:
s1 = F.dequantize(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2]) s1 = F.dequantize(
code=bnb_optimizer.state[p2][qmap],
absmax=bnb_optimizer.state[p2][max_val],
A=bnb_optimizer.state[p2][name2],
)
torch.testing.assert_allclose(s1cpy, s1) torch.testing.assert_allclose(s1cpy, s1)
num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol)==0 num_not_close = (
torch.isclose(
torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol
)
== 0
)
assert num_not_close.sum().item() < 20 assert num_not_close.sum().item() < 20
torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol) torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol)
...@@ -256,24 +361,28 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ...@@ -256,24 +361,28 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
p1.data = p1.data.to(gtype).float() p1.data = p1.data.to(gtype).float()
p2.copy_(p1.data) p2.copy_(p1.data)
torch.testing.assert_allclose(p1.to(gtype), p2) torch.testing.assert_allclose(p1.to(gtype), p2)
for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states): for (name1, name2, qmap, max_val), s in zip(
str2statenames[optim_name], dequant_states
):
torch_optimizer.state[p1][name1].copy_(s.data) torch_optimizer.state[p1][name1].copy_(s.data)
#print(sum(errors)/len(errors)) # print(sum(errors)/len(errors))
#print(sum(relerrors)/len(relerrors)) # print(sum(relerrors)/len(relerrors))
dim1 = [1024] dim1 = [1024]
dim2 = [32, 1024, 4097] dim2 = [32, 1024, 4097]
gtype = [torch.float32] gtype = [torch.float32]
optim_bits = [32, 8] optim_bits = [32, 8]
values = list(product(dim1,dim2, gtype, optim_bits)) values = list(product(dim1, dim2, gtype, optim_bits))
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}'.format(*vals) for vals in values] names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names) @pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names)
def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
if dim1 == 1 and dim2 == 1: return if dim1 == 1 and dim2 == 1:
p1 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1 return
p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
beta1 = 0.9 beta1 = 0.9
beta2 = 0.999 beta2 = 0.999
lr = 0.001 lr = 0.001
...@@ -281,19 +390,23 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): ...@@ -281,19 +390,23 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
p1 = p1.cuda() p1 = p1.cuda()
p2 = p1.clone() p2 = p1.clone()
adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits) adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits)
adam2 = bnb.optim.Adam([p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5) adam2 = bnb.optim.Adam(
[p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5
)
gnorm_vec = torch.zeros(100).cuda() gnorm_vec = torch.zeros(100).cuda()
step = 0 step = 0
for i in range(50): for i in range(50):
step += 1 step += 1
g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + (0.01*i) g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (0.01 * i)
g2 = g1.clone() g2 = g1.clone()
p2.grad = g2 p2.grad = g2
current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(g1, gnorm_vec, step, 5) current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(
g1 = (g1.float()*gnorm_scale).to(gtype) g1, gnorm_vec, step, 5
)
g1 = (g1.float() * gnorm_scale).to(gtype)
p1.grad = g1 p1.grad = g1
adam1.step() adam1.step()
...@@ -302,47 +415,69 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): ...@@ -302,47 +415,69 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
# gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state # gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state
if optim_bits == 32: if optim_bits == 32:
torch.testing.assert_allclose(p1, p2) torch.testing.assert_allclose(p1, p2)
torch.testing.assert_allclose(adam1.state[p1]['state1'], adam2.state[p2]['state1'], atol=5e-5, rtol=1e-4) torch.testing.assert_allclose(
torch.testing.assert_allclose(adam1.state[p1]['state2'], adam2.state[p2]['state2'], atol=5e-5, rtol=1e-4) adam1.state[p1]["state1"],
adam2.state[p2]["state1"],
atol=5e-5,
rtol=1e-4,
)
torch.testing.assert_allclose(
adam1.state[p1]["state2"],
adam2.state[p2]["state2"],
atol=5e-5,
rtol=1e-4,
)
elif optim_bits == 8: elif optim_bits == 8:
torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3) torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3)
torch.testing.assert_allclose(adam1.state[p1]['state1'], adam2.state[p2]['state1'], atol=2, rtol=1e-3) torch.testing.assert_allclose(
torch.testing.assert_allclose(adam1.state[p1]['state2'], adam2.state[p2]['state2'], atol=2, rtol=1e-3) adam1.state[p1]["state1"], adam2.state[p2]["state1"], atol=2, rtol=1e-3
adam1.state[p1]['state1'].copy_(adam2.state[p2]['state1']) )
adam1.state[p1]['state2'].copy_(adam2.state[p2]['state2']) torch.testing.assert_allclose(
adam1.state[p1]["state2"], adam2.state[p2]["state2"], atol=2, rtol=1e-3
)
adam1.state[p1]["state1"].copy_(adam2.state[p2]["state1"])
adam1.state[p1]["state2"].copy_(adam2.state[p2]["state2"])
if i % 10 == 0 and i > 0: if i % 10 == 0 and i > 0:
path = get_temp_dir() path = get_temp_dir()
torch.save(adam2.state_dict(),join(path, 'opt.pt')) torch.save(adam2.state_dict(), join(path, "opt.pt"))
del adam2 del adam2
adam2 = None adam2 = None
adam2 = bnb.optim.Adam([p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5) adam2 = bnb.optim.Adam(
adam2.load_state_dict(torch.load(join(path, 'opt.pt'))) [p2],
lr,
(beta1, beta2),
eps,
optim_bits=optim_bits,
percentile_clipping=5,
)
adam2.load_state_dict(torch.load(join(path, "opt.pt")))
dim1 = [4096] dim1 = [4096]
dim2 = [4096] dim2 = [4096]
gtype = [torch.float32, torch.float16] gtype = [torch.float32, torch.float16]
#optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit'] # optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit']
#optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch'] # optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch']
#optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch'] # optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
#optimizer_names = ['lamb_apex', 'lamb8bit'] # optimizer_names = ['lamb_apex', 'lamb8bit']
#optimizer_names = ['lars_apex', 'lars8bit'] # optimizer_names = ['lars_apex', 'lars8bit']
optimizer_names = ['adam8bit_blockwise'] optimizer_names = ["adam8bit_blockwise"]
values = list(product(dim1,dim2, gtype, optimizer_names)) values = list(product(dim1, dim2, gtype, optimizer_names))
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values] names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
if dim1 == 1 and dim2 == 1: return if dim1 == 1 and dim2 == 1:
p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 return
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
bnb_optimizer = str2optimizers[optim_name][1]([p1]) bnb_optimizer = str2optimizers[optim_name][1]([p1])
g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01 g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
p1.grad = g p1.grad = g
for i in range(k): for i in range(k):
if i == k//5: if i == k // 5:
# 100 iterations for burn-in # 100 iterations for burn-in
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
...@@ -350,10 +485,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): ...@@ -350,10 +485,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
bnb_optimizer.step() bnb_optimizer.step()
torch.cuda.synchronize() torch.cuda.synchronize()
s = time.time()-t0 s = time.time() - t0
print('') print("")
params = (k-k//5)*dim1*dim2 params = (k - k // 5) * dim1 * dim2
print(optim_name, gtype, s/params) print(optim_name, gtype, s / params)
#assert s < 3.9 # assert s < 3.9
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment