Unverified Commit 06029dd6 authored by Titus's avatar Titus Committed by GitHub
Browse files

Merge pull request #1081 from akx/ruff-format

Reformat Python code with Ruff
parents fd723b78 5a4263f4
...@@ -5,9 +5,10 @@ import torch ...@@ -5,9 +5,10 @@ import torch
from bitsandbytes.triton.triton_utils import is_triton_available from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available(): if not is_triton_available():
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): return None
else:
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
return None
else:
import triton import triton
import triton.language as tl import triton.language as tl
...@@ -15,21 +16,21 @@ else: ...@@ -15,21 +16,21 @@ else:
# TODO: autotune this better. # TODO: autotune this better.
@triton.autotune( @triton.autotune(
configs=[ configs=[
triton.Config({}, num_stages=1, num_warps=8), triton.Config({}, num_stages=1, num_warps=8),
triton.Config({}, num_stages=2, num_warps=8), triton.Config({}, num_stages=2, num_warps=8),
triton.Config({}, num_stages=4, num_warps=8), triton.Config({}, num_stages=4, num_warps=8),
triton.Config({}, num_stages=8, num_warps=8), triton.Config({}, num_stages=8, num_warps=8),
triton.Config({}, num_stages=1), triton.Config({}, num_stages=1),
triton.Config({}, num_stages=2), triton.Config({}, num_stages=2),
triton.Config({}, num_stages=4), triton.Config({}, num_stages=4),
triton.Config({}, num_stages=8), triton.Config({}, num_stages=8),
triton.Config({}, num_warps=1), triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2), triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4), triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8), triton.Config({}, num_warps=8),
], ],
key=['n_elements'] key=["n_elements"],
) )
@triton.jit @triton.jit
def _dequantize_rowwise( def _dequantize_rowwise(
...@@ -51,7 +52,6 @@ else: ...@@ -51,7 +52,6 @@ else:
output = max_val * x * inv_127 output = max_val * x * inv_127
tl.store(output_ptr + offsets, output, mask=row_mask) tl.store(output_ptr + offsets, output, mask=row_mask)
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
output = torch.empty(*x.shape, device=x.device, dtype=torch.float16) output = torch.empty(*x.shape, device=x.device, dtype=torch.float16)
...@@ -60,5 +60,5 @@ else: ...@@ -60,5 +60,5 @@ else:
assert x.is_cuda and output.is_cuda assert x.is_cuda and output.is_cuda
n_elements = output.numel() n_elements = output.numel()
grid = lambda meta: (x.shape[0],) grid = lambda meta: (x.shape[0],)
_dequantize_rowwise[grid](x, state_x, output, 1./127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) _dequantize_rowwise[grid](x, state_x, output, 1.0 / 127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
return output return output
...@@ -3,14 +3,14 @@ import torch ...@@ -3,14 +3,14 @@ import torch
from bitsandbytes.triton.triton_utils import is_triton_available from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available(): if not is_triton_available():
def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): return None
else:
def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias):
return None
else:
import triton import triton
import triton.language as tl import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# This is a matmul kernel based on triton.ops.matmul # This is a matmul kernel based on triton.ops.matmul
# It is modified to support rowwise quantized input and global quantized weight # It is modified to support rowwise quantized input and global quantized weight
# It's purpose is fused matmul then dequantize # It's purpose is fused matmul then dequantize
...@@ -27,58 +27,83 @@ else: ...@@ -27,58 +27,83 @@ else:
for block_n in [32, 64, 128, 256]: for block_n in [32, 64, 128, 256]:
num_warps = 2 if block_n <= 64 else 4 num_warps = 2 if block_n <= 64 else 4
configs.append( configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, triton.Config(
num_stages=num_stages, num_warps=num_warps)) {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": 1},
num_stages=num_stages,
num_warps=num_warps,
),
)
# split_k # split_k
for split_k in [2, 4, 8, 16]: for split_k in [2, 4, 8, 16]:
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, configs.append(
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) triton.Config(
{"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": split_k},
num_stages=num_stages,
num_warps=num_warps,
pre_hook=init_to_zero("C"),
),
)
return configs return configs
@triton.autotune( @triton.autotune(
configs=[ configs=[
# basic configs for compute-bound matmuls # basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2),
# good for int8 # good for int8
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2),
*get_configs_io_bound(), *get_configs_io_bound(),
], ],
key=['M', 'N', 'K'], key=["M", "N", "K"],
prune_configs_by={ prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10},
'early_config_prune': early_config_prune, )
'perf_model': estimate_matmul_time, @triton.heuristics(
'top_k': 10 {
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
}, },
) )
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.jit @triton.jit
def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr, def _int8_matmul_mixed_dequantize(
stride_am, stride_ak, A,
stride_bk, stride_bn, B,
stride_cm, stride_cn, C,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, bias,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, state_x_ptr,
ACC_TYPE: tl.constexpr state_w_ptr,
): M,
N,
K,
divfactor: tl.constexpr,
has_bias: tl.constexpr,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
SPLIT_K: tl.constexpr,
EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
):
# matrix multiplication # matrix multiplication
pid = tl.program_id(0) pid = tl.program_id(0)
pid_z = tl.program_id(1) pid_z = tl.program_id(1)
...@@ -115,13 +140,13 @@ else: ...@@ -115,13 +140,13 @@ else:
b = tl.load(B) b = tl.load(B)
else: else:
k_remaining = K - k * (BLOCK_K * SPLIT_K) k_remaining = K - k * (BLOCK_K * SPLIT_K)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0)
acc += tl.dot(a, b) acc += tl.dot(a, b)
A += BLOCK_K * SPLIT_K * stride_ak A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk B += BLOCK_K * SPLIT_K * stride_bk
acc = (w_factor * (x_factor * (acc * divfactor))) acc = w_factor * (x_factor * (acc * divfactor))
acc = acc.to(C.dtype.element_ty) acc = acc.to(C.dtype.element_ty)
# conditionally add bias # conditionally add bias
...@@ -137,10 +162,9 @@ else: ...@@ -137,10 +162,9 @@ else:
else: else:
tl.atomic_add(C, acc, mask=mask) tl.atomic_add(C, acc, mask=mask)
def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias):
device = a.device device = a.device
divfactor = 1. / (127. * 127.) divfactor = 1.0 / (127.0 * 127.0)
has_bias = 0 if bias is None else 1 has_bias = 0 if bias is None else 1
# handle non-contiguous inputs if necessary # handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1: if a.stride(0) > 1 and a.stride(1) > 1:
...@@ -154,12 +178,28 @@ else: ...@@ -154,12 +178,28 @@ else:
# allocates output # allocates output
c = torch.empty((M, N), device=device, dtype=torch.float16) c = torch.empty((M, N), device=device, dtype=torch.float16)
# accumulator types # accumulator types
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 ACC_TYPE = tl.float32 # if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch int8_matmul_mixed_dequantize kernel # launch int8_matmul_mixed_dequantize kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), META["SPLIT_K"])
_int8_matmul_mixed_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, _int8_matmul_mixed_dequantize[grid](
a.stride(0), a.stride(1), a,
b.stride(0), b.stride(1), b,
c.stride(0), c.stride(1), c,
GROUP_M=8, ACC_TYPE=ACC_TYPE) bias,
state_x,
state_w,
M,
N,
K,
divfactor,
has_bias,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
GROUP_M=8,
ACC_TYPE=ACC_TYPE,
)
return c return c
...@@ -3,7 +3,9 @@ import torch ...@@ -3,7 +3,9 @@ import torch
from bitsandbytes.triton.triton_utils import is_triton_available from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available(): if not is_triton_available():
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): return None
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):
return None
else: else:
import triton import triton
import triton.language as tl import triton.language as tl
...@@ -17,7 +19,6 @@ else: ...@@ -17,7 +19,6 @@ else:
def init_to_zero(name): def init_to_zero(name):
return lambda nargs: nargs[name].zero_() return lambda nargs: nargs[name].zero_()
def get_configs_io_bound(): def get_configs_io_bound():
configs = [] configs = []
for num_stages in [2, 3, 4, 5, 6]: for num_stages in [2, 3, 4, 5, 6]:
...@@ -26,58 +27,83 @@ else: ...@@ -26,58 +27,83 @@ else:
for block_n in [32, 64, 128, 256]: for block_n in [32, 64, 128, 256]:
num_warps = 2 if block_n <= 64 else 4 num_warps = 2 if block_n <= 64 else 4
configs.append( configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, triton.Config(
num_stages=num_stages, num_warps=num_warps)) {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": 1},
num_stages=num_stages,
num_warps=num_warps,
),
)
# split_k # split_k
for split_k in [2, 4, 8, 16]: for split_k in [2, 4, 8, 16]:
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, configs.append(
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) triton.Config(
{"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": split_k},
num_stages=num_stages,
num_warps=num_warps,
pre_hook=init_to_zero("C"),
),
)
return configs return configs
@triton.autotune( @triton.autotune(
configs=[ configs=[
# basic configs for compute-bound matmuls # basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2),
# good for int8 # good for int8
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2),
*get_configs_io_bound(), *get_configs_io_bound(),
], ],
key=['M', 'N', 'K'], key=["M", "N", "K"],
prune_configs_by={ prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10},
'early_config_prune': early_config_prune, )
'perf_model': estimate_matmul_time, @triton.heuristics(
'top_k': 10 {
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
}, },
) )
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.jit @triton.jit
def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr, def _int8_matmul_rowwise_dequantize(
stride_am, stride_ak, A,
stride_bk, stride_bn, B,
stride_cm, stride_cn, C,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, bias,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, state_x_ptr,
ACC_TYPE: tl.constexpr state_w_ptr,
): M,
N,
K,
divfactor,
has_bias: tl.constexpr,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
SPLIT_K: tl.constexpr,
EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
):
# matrix multiplication # matrix multiplication
pid = tl.program_id(0) pid = tl.program_id(0)
pid_z = tl.program_id(1) pid_z = tl.program_id(1)
...@@ -114,13 +140,13 @@ else: ...@@ -114,13 +140,13 @@ else:
b = tl.load(B) b = tl.load(B)
else: else:
k_remaining = K - k * (BLOCK_K * SPLIT_K) k_remaining = K - k * (BLOCK_K * SPLIT_K)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0)
acc += tl.dot(a, b) acc += tl.dot(a, b)
A += BLOCK_K * SPLIT_K * stride_ak A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk B += BLOCK_K * SPLIT_K * stride_bk
acc = (w_factor * (x_factor * (acc * divfactor))) acc = w_factor * (x_factor * (acc * divfactor))
acc = acc.to(C.dtype.element_ty) acc = acc.to(C.dtype.element_ty)
if has_bias: if has_bias:
...@@ -135,9 +161,8 @@ else: ...@@ -135,9 +161,8 @@ else:
else: else:
tl.atomic_add(C, acc, mask=mask) tl.atomic_add(C, acc, mask=mask)
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):
divfactor = 1. / (127. * 127.) divfactor = 1.0 / (127.0 * 127.0)
has_bias = 0 if bias is None else 1 has_bias = 0 if bias is None else 1
...@@ -154,12 +179,28 @@ else: ...@@ -154,12 +179,28 @@ else:
# allocates output # allocates output
c = torch.empty((M, N), device=device, dtype=torch.float16) c = torch.empty((M, N), device=device, dtype=torch.float16)
# accumulator types # accumulator types
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 ACC_TYPE = tl.float32 # if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch int8_matmul_rowwise_dequantize kernel # launch int8_matmul_rowwise_dequantize kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), META["SPLIT_K"])
_int8_matmul_rowwise_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, _int8_matmul_rowwise_dequantize[grid](
a.stride(0), a.stride(1), a,
b.stride(0), b.stride(1), b,
c.stride(0), c.stride(1), c,
GROUP_M=8, ACC_TYPE=ACC_TYPE) bias,
state_x,
state_w,
M,
N,
K,
divfactor,
has_bias,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
GROUP_M=8,
ACC_TYPE=ACC_TYPE,
)
return c return c
...@@ -5,9 +5,10 @@ import torch ...@@ -5,9 +5,10 @@ import torch
from bitsandbytes.triton.triton_utils import is_triton_available from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available(): if not is_triton_available():
def quantize_columnwise_and_transpose(x: torch.Tensor): return None
else:
def quantize_columnwise_and_transpose(x: torch.Tensor):
return None
else:
import triton import triton
import triton.language as tl import triton.language as tl
...@@ -15,23 +16,23 @@ else: ...@@ -15,23 +16,23 @@ else:
# TODO: autotune this better. # TODO: autotune this better.
@triton.autotune( @triton.autotune(
configs=[ configs=[
triton.Config({}, num_stages=1), triton.Config({}, num_stages=1),
triton.Config({}, num_stages=2), triton.Config({}, num_stages=2),
triton.Config({}, num_stages=4), triton.Config({}, num_stages=4),
triton.Config({}, num_stages=8), triton.Config({}, num_stages=8),
triton.Config({}, num_stages=16), triton.Config({}, num_stages=16),
triton.Config({}, num_stages=1, num_warps=8), triton.Config({}, num_stages=1, num_warps=8),
triton.Config({}, num_stages=2, num_warps=8), triton.Config({}, num_stages=2, num_warps=8),
triton.Config({}, num_stages=4, num_warps=8), triton.Config({}, num_stages=4, num_warps=8),
triton.Config({}, num_stages=8, num_warps=8), triton.Config({}, num_stages=8, num_warps=8),
triton.Config({}, num_stages=16, num_warps=8), triton.Config({}, num_stages=16, num_warps=8),
triton.Config({}, num_warps=1), triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2), triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4), triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8), triton.Config({}, num_warps=8),
], ],
key=['n_elements'] key=["n_elements"],
) )
@triton.jit @triton.jit
def _quantize_columnwise_and_transpose( def _quantize_columnwise_and_transpose(
...@@ -39,7 +40,8 @@ else: ...@@ -39,7 +40,8 @@ else:
output_ptr, output_ptr,
output_maxs, output_maxs,
n_elements, n_elements,
M : tl.constexpr, N : tl.constexpr, M: tl.constexpr,
N: tl.constexpr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
P2: tl.constexpr, P2: tl.constexpr,
): ):
...@@ -47,12 +49,12 @@ else: ...@@ -47,12 +49,12 @@ else:
block_start = pid block_start = pid
p2_arange = tl.arange(0, P2) p2_arange = tl.arange(0, P2)
p2_arange_mask = p2_arange < M p2_arange_mask = p2_arange < M
arange = p2_arange * N arange = p2_arange * N
offsets = block_start + arange offsets = block_start + arange
x = tl.load(x_ptr + offsets, mask=p2_arange_mask) x = tl.load(x_ptr + offsets, mask=p2_arange_mask)
abs_x = tl.abs(x) abs_x = tl.abs(x)
max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0) max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0)
output = tl.libdevice.llrint(127. * (x / max_val)) output = tl.libdevice.llrint(127.0 * (x / max_val))
new_start = pid * M new_start = pid * M
new_offsets = new_start + p2_arange new_offsets = new_start + p2_arange
...@@ -68,6 +70,6 @@ else: ...@@ -68,6 +70,6 @@ else:
assert x.is_cuda and output.is_cuda assert x.is_cuda and output.is_cuda
n_elements = output.numel() n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
_quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2) _quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2)
return output, output_maxs return output, output_maxs
import torch import torch
from bitsandbytes.triton.triton_utils import is_triton_available from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available(): if not is_triton_available():
def quantize_global_transpose(input): return None
def quantize_global(x: torch.Tensor): return None
else:
def quantize_global_transpose(input):
return None
def quantize_global(x: torch.Tensor):
return None
else:
import triton import triton
import triton.language as tl import triton.language as tl
# global quantize # global quantize
@triton.autotune( @triton.autotune(
configs=[ configs=[
triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4), triton.Config({"BLOCK_SIZE": 1024}, num_warps=4),
triton.Config({'BLOCK_SIZE': 2048,}, num_stages=1), triton.Config({"BLOCK_SIZE": 2048}, num_stages=1),
],
], key=["n_elements"],
key=['n_elements']
) )
@triton.jit @triton.jit
def _quantize_global( def _quantize_global(
...@@ -34,35 +35,43 @@ else: ...@@ -34,35 +35,43 @@ else:
mask = offsets < n_elements mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask) x = tl.load(x_ptr + offsets, mask=mask)
absmax_inv = tl.load(absmax_inv_ptr) absmax_inv = tl.load(absmax_inv_ptr)
output = tl.libdevice.llrint(127. * (x * absmax_inv)) output = tl.libdevice.llrint(127.0 * (x * absmax_inv))
tl.store(output_ptr + offsets, output, mask=mask) tl.store(output_ptr + offsets, output, mask=mask)
def quantize_global(x: torch.Tensor): def quantize_global(x: torch.Tensor):
absmax = x.abs().max().unsqueeze(0) absmax = x.abs().max().unsqueeze(0)
absmax_inv = 1./ absmax absmax_inv = 1.0 / absmax
output = torch.empty(*x.shape, device='cuda', dtype=torch.int8) output = torch.empty(*x.shape, device="cuda", dtype=torch.int8)
assert x.is_cuda and output.is_cuda assert x.is_cuda and output.is_cuda
n_elements = output.numel() n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
_quantize_global[grid](x, absmax_inv, output, n_elements) _quantize_global[grid](x, absmax_inv, output, n_elements)
return output, absmax return output, absmax
# global quantize and transpose # global quantize and transpose
@triton.autotune( @triton.autotune(
configs=[ configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "GROUP_M": 8}, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "GROUP_M": 8}, num_warps=4),
# ...
# ... ],
], key=["M", "N"],
key=['M', 'N']
) )
@triton.jit @triton.jit
def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, stride_bn, stride_bm, M, N, def _quantize_global_transpose(
BLOCK_M : tl.constexpr, A,
BLOCK_N : tl.constexpr, absmax_inv_ptr,
GROUP_M : tl.constexpr): B,
stride_am,
stride_an,
stride_bn,
stride_bm,
M,
N,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
GROUP_M: tl.constexpr,
):
pid = tl.program_id(0) pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N grid_n = (N + BLOCK_N - 1) // BLOCK_N
...@@ -86,20 +95,30 @@ else: ...@@ -86,20 +95,30 @@ else:
B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn) B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn)
mask = (rm < M)[:, None] & (rn < N)[None, :] mask = (rm < M)[:, None] & (rn < N)[None, :]
output = tl.libdevice.llrint(127. * (a * absmax_inv)) output = tl.libdevice.llrint(127.0 * (a * absmax_inv))
tl.store(B, output, mask=mask) tl.store(B, output, mask=mask)
def quantize_global_transpose(input): def quantize_global_transpose(input):
absmax = input.abs().max().unsqueeze(0) absmax = input.abs().max().unsqueeze(0)
absmax_inv = 1./ absmax absmax_inv = 1.0 / absmax
M, N = input.shape M, N = input.shape
out = torch.empty(N, M, device='cuda', dtype=torch.int8) out = torch.empty(N, M, device="cuda", dtype=torch.int8)
assert out.size(0) == N and out.size(1) == M assert out.size(0) == N and out.size(1) == M
assert input.stride(0) == 1 or input.stride(1) == 1 assert input.stride(0) == 1 or input.stride(1) == 1
assert out.stride(0) == 1 or out.stride(1) == 1 assert out.stride(0) == 1 or out.stride(1) == 1
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),) grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
_quantize_global_transpose[grid](input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N) _quantize_global_transpose[grid](
input,
absmax_inv,
out,
input.stride(0),
input.stride(1),
out.stride(0),
out.stride(1),
M,
N,
)
return out, absmax return out, absmax
...@@ -5,9 +5,10 @@ import torch ...@@ -5,9 +5,10 @@ import torch
from bitsandbytes.triton.triton_utils import is_triton_available from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available(): if not is_triton_available():
def quantize_rowwise(x: torch.Tensor): return None
else:
def quantize_rowwise(x: torch.Tensor):
return None
else:
import triton import triton
import triton.language as tl import triton.language as tl
...@@ -15,21 +16,21 @@ else: ...@@ -15,21 +16,21 @@ else:
# TODO: autotune this better. # TODO: autotune this better.
@triton.autotune( @triton.autotune(
configs=[ configs=[
triton.Config({}, num_stages=1, num_warps=8), triton.Config({}, num_stages=1, num_warps=8),
triton.Config({}, num_stages=2, num_warps=8), triton.Config({}, num_stages=2, num_warps=8),
triton.Config({}, num_stages=4, num_warps=8), triton.Config({}, num_stages=4, num_warps=8),
triton.Config({}, num_stages=8, num_warps=8), triton.Config({}, num_stages=8, num_warps=8),
triton.Config({}, num_stages=1), triton.Config({}, num_stages=1),
triton.Config({}, num_stages=2), triton.Config({}, num_stages=2),
triton.Config({}, num_stages=4), triton.Config({}, num_stages=4),
triton.Config({}, num_stages=8), triton.Config({}, num_stages=8),
triton.Config({}, num_warps=1), triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2), triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4), triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8), triton.Config({}, num_warps=8),
], ],
key=['n_elements'] key=["n_elements"],
) )
@triton.jit @triton.jit
def _quantize_rowwise( def _quantize_rowwise(
...@@ -49,7 +50,7 @@ else: ...@@ -49,7 +50,7 @@ else:
abs_x = tl.abs(x) abs_x = tl.abs(x)
max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0) max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0)
output = tl.libdevice.llrint(127. * (x / max_val)) output = tl.libdevice.llrint(127.0 * (x / max_val))
tl.store(output_ptr + offsets, output, mask=row_mask) tl.store(output_ptr + offsets, output, mask=row_mask)
tl.store(output_maxs + pid, max_val) tl.store(output_maxs + pid, max_val)
......
...@@ -30,7 +30,7 @@ def outlier_hook(module, input): ...@@ -30,7 +30,7 @@ def outlier_hook(module, input):
# (1) zscore test of std of hidden dimension # (1) zscore test of std of hidden dimension
outlier_idx = find_outlier_dims(merged, reduction_dim=1, zscore=3) outlier_idx = find_outlier_dims(merged, reduction_dim=1, zscore=3)
# (2) magnitude > 6 test # (2) magnitude > 6 test
dims = (torch.abs(input[0])> 6).sum(dim=list(range(len(input[0].shape)-1))) dims = (torch.abs(input[0]) > 6).sum(dim=list(range(len(input[0].shape) - 1)))
outlier_idx2 = torch.where(dims > 0)[0] outlier_idx2 = torch.where(dims > 0)[0]
outlier_idx = torch.cat([outlier_idx, outlier_idx2]).unique() outlier_idx = torch.cat([outlier_idx, outlier_idx2]).unique()
tracer.hvalue2outlier_idx[hvalue] = outlier_idx tracer.hvalue2outlier_idx[hvalue] = outlier_idx
...@@ -59,14 +59,14 @@ class OutlierTracer: ...@@ -59,14 +59,14 @@ class OutlierTracer:
self.hooks.append(m.register_forward_pre_hook(outlier_hook)) self.hooks.append(m.register_forward_pre_hook(outlier_hook))
def is_initialized(self): def is_initialized(self):
return getattr(self, 'initialized', False) return getattr(self, "initialized", False)
def get_hvalue(self, weight): def get_hvalue(self, weight):
return weight.data.storage().data_ptr() return weight.data.storage().data_ptr()
def get_outliers(self, weight): def get_outliers(self, weight):
if not self.is_initialized(): if not self.is_initialized():
print('Outlier tracer is not initialized...') print("Outlier tracer is not initialized...")
return None return None
hvalue = self.get_hvalue(weight) hvalue = self.get_hvalue(weight)
if hvalue in self.hvalue2outlier_idx: if hvalue in self.hvalue2outlier_idx:
...@@ -80,6 +80,7 @@ class OutlierTracer: ...@@ -80,6 +80,7 @@ class OutlierTracer:
cls._instance = cls.__new__(cls) cls._instance = cls.__new__(cls)
return cls._instance return cls._instance
def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False): def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False):
if rdm: if rdm:
return torch.randint(0, weight.shape[1], size=(topk,), device=weight.device).long() return torch.randint(0, weight.shape[1], size=(topk,), device=weight.device).long()
...@@ -87,13 +88,13 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False) ...@@ -87,13 +88,13 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False)
m = weight.mean(reduction_dim) m = weight.mean(reduction_dim)
mm = m.mean() mm = m.mean()
mstd = m.std() mstd = m.std()
zm = (m-mm)/mstd zm = (m - mm) / mstd
std = weight.std(reduction_dim) std = weight.std(reduction_dim)
stdm = std.mean() stdm = std.mean()
stdstd = std.std() stdstd = std.std()
zstd = (std-stdm)/stdstd zstd = (std - stdm) / stdstd
if topk is not None: if topk is not None:
val, idx = torch.topk(std.abs(), k=topk, dim=0) val, idx = torch.topk(std.abs(), k=topk, dim=0)
...@@ -105,10 +106,7 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False) ...@@ -105,10 +106,7 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False)
def execute_and_return(command_string: str) -> Tuple[str, str]: def execute_and_return(command_string: str) -> Tuple[str, str]:
def _decode(subprocess_err_out_tuple): def _decode(subprocess_err_out_tuple):
return tuple( return tuple(to_decode.decode("UTF-8").strip() for to_decode in subprocess_err_out_tuple)
to_decode.decode("UTF-8").strip()
for to_decode in subprocess_err_out_tuple
)
def execute_and_return_decoded_std_streams(command_string): def execute_and_return_decoded_std_streams(command_string):
return _decode( return _decode(
...@@ -116,14 +114,13 @@ def execute_and_return(command_string: str) -> Tuple[str, str]: ...@@ -116,14 +114,13 @@ def execute_and_return(command_string: str) -> Tuple[str, str]:
shlex.split(command_string), shlex.split(command_string),
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stderr=subprocess.PIPE,
).communicate() ).communicate(),
) )
std_out, std_err = execute_and_return_decoded_std_streams(command_string) std_out, std_err = execute_and_return_decoded_std_streams(command_string)
return std_out, std_err return std_out, std_err
def replace_linear( def replace_linear(
model, model,
linear_replacement, linear_replacement,
...@@ -163,8 +160,9 @@ def replace_linear( ...@@ -163,8 +160,9 @@ def replace_linear(
model._modules[name].bias = old_module.bias model._modules[name].bias = old_module.bias
if post_processing_function is not None: if post_processing_function is not None:
func = getattr(module, post_processing_function, None) func = getattr(module, post_processing_function, None)
if func is not None: func(module) if func is not None:
func(module)
return model return model
...@@ -179,7 +177,7 @@ def pack_dict_to_tensor(source_dict): ...@@ -179,7 +177,7 @@ def pack_dict_to_tensor(source_dict):
A torch tensor containing the packed data. A torch tensor containing the packed data.
""" """
json_str = json.dumps(source_dict) json_str = json.dumps(source_dict)
json_bytes = json_str.encode('utf-8') json_bytes = json_str.encode("utf-8")
tensor_data = torch.tensor(list(json_bytes), dtype=torch.uint8) tensor_data = torch.tensor(list(json_bytes), dtype=torch.uint8)
return tensor_data return tensor_data
...@@ -196,7 +194,7 @@ def unpack_tensor_to_dict(tensor_data): ...@@ -196,7 +194,7 @@ def unpack_tensor_to_dict(tensor_data):
A Python dictionary containing the unpacked data. A Python dictionary containing the unpacked data.
""" """
json_bytes = bytes(tensor_data.cpu().numpy()) json_bytes = bytes(tensor_data.cpu().numpy())
json_str = json_bytes.decode('utf-8') json_str = json_bytes.decode("utf-8")
unpacked_dict = json.loads(json_str) unpacked_dict = json.loads(json_str)
return unpacked_dict return unpacked_dict
...@@ -2,14 +2,14 @@ import torch ...@@ -2,14 +2,14 @@ import torch
import bitsandbytes as bnb import bitsandbytes as bnb
p = torch.nn.Parameter(torch.rand(10,10).cuda()) p = torch.nn.Parameter(torch.rand(10, 10).cuda())
a = torch.rand(10,10).cuda() a = torch.rand(10, 10).cuda()
p1 = p.data.sum().item() p1 = p.data.sum().item()
adam = bnb.optim.Adam([p]) adam = bnb.optim.Adam([p])
out = a*p out = a * p
loss = out.sum() loss = out.sum()
loss.backward() loss.backward()
adam.step() adam.step()
...@@ -17,5 +17,5 @@ adam.step() ...@@ -17,5 +17,5 @@ adam.step()
p2 = p.data.sum().item() p2 = p.data.sum().item()
assert p1 != p2 assert p1 != p2
print('SUCCESS!') print("SUCCESS!")
print('Installation was successful!') print("Installation was successful!")
...@@ -2,23 +2,18 @@ import torch ...@@ -2,23 +2,18 @@ import torch
from transformers import LlamaForCausalLM, LlamaTokenizer from transformers import LlamaForCausalLM, LlamaTokenizer
MAX_NEW_TOKENS = 128 MAX_NEW_TOKENS = 128
model_name = 'meta-llama/Llama-2-7b-hf' model_name = "meta-llama/Llama-2-7b-hf"
text = 'Hamburg is in which country?\n' text = "Hamburg is in which country?\n"
tokenizer = LlamaTokenizer.from_pretrained(model_name) tokenizer = LlamaTokenizer.from_pretrained(model_name)
input_ids = tokenizer(text, return_tensors="pt").input_ids input_ids = tokenizer(text, return_tensors="pt").input_ids
max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB' max_memory = f"{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB"
n_gpus = torch.cuda.device_count() n_gpus = torch.cuda.device_count()
max_memory = {i: max_memory for i in range(n_gpus)} max_memory = {i: max_memory for i in range(n_gpus)}
model = LlamaForCausalLM.from_pretrained( model = LlamaForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True, max_memory=max_memory)
model_name,
device_map='auto',
load_in_8bit=True,
max_memory=max_memory
)
generated_ids = model.generate(input_ids, max_length=MAX_NEW_TOKENS) generated_ids = model.generate(input_ids, max_length=MAX_NEW_TOKENS)
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True)) print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
...@@ -19,6 +19,7 @@ cuda_versions = { ...@@ -19,6 +19,7 @@ cuda_versions = {
"123": "https://developer.download.nvidia.com/compute/cuda/12.3.2/local_installers/cuda_12.3.2_545.23.08_linux.run", "123": "https://developer.download.nvidia.com/compute/cuda/12.3.2/local_installers/cuda_12.3.2_545.23.08_linux.run",
} }
def install_cuda(version, base_path, download_path): def install_cuda(version, base_path, download_path):
formatted_version = f"{version[:-1]}.{version[-1]}" formatted_version = f"{version[:-1]}.{version[-1]}"
folder = f"cuda-{formatted_version}" folder = f"cuda-{formatted_version}"
...@@ -29,7 +30,7 @@ def install_cuda(version, base_path, download_path): ...@@ -29,7 +30,7 @@ def install_cuda(version, base_path, download_path):
subprocess.run(["rm", "-rf", install_path], check=True) subprocess.run(["rm", "-rf", install_path], check=True)
url = cuda_versions[version] url = cuda_versions[version]
filename = url.split('/')[-1] filename = url.split("/")[-1]
filepath = os.path.join(download_path, filename) filepath = os.path.join(download_path, filename)
if not os.path.exists(filepath): if not os.path.exists(filepath):
...@@ -44,9 +45,14 @@ def install_cuda(version, base_path, download_path): ...@@ -44,9 +45,14 @@ def install_cuda(version, base_path, download_path):
# Install CUDA # Install CUDA
print(f"Installing CUDA version {version}...") print(f"Installing CUDA version {version}...")
install_command = [ install_command = [
"bash", filepath, "bash",
"--no-drm", "--no-man-page", "--override", filepath,
"--toolkitpath=" + install_path, "--toolkit", "--silent" "--no-drm",
"--no-man-page",
"--override",
"--toolkitpath=" + install_path,
"--toolkit",
"--silent",
] ]
print(f"Running command: {' '.join(install_command)}") print(f"Running command: {' '.join(install_command)}")
...@@ -62,6 +68,7 @@ def install_cuda(version, base_path, download_path): ...@@ -62,6 +68,7 @@ def install_cuda(version, base_path, download_path):
print(f"CUDA version {version} installed at {install_path}") print(f"CUDA version {version} installed at {install_path}")
def main(): def main():
user_base_path = os.path.expanduser("~/cuda") user_base_path = os.path.expanduser("~/cuda")
system_base_path = "/usr/local/cuda" system_base_path = "/usr/local/cuda"
...@@ -93,5 +100,6 @@ def main(): ...@@ -93,5 +100,6 @@ def main():
print(f"Invalid CUDA version: {version}. Available versions are: {', '.join(cuda_versions.keys())}") print(f"Invalid CUDA version: {version}. Available versions are: {', '.join(cuda_versions.keys())}")
sys.exit(1) sys.exit(1)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -8,6 +8,10 @@ src = [ ...@@ -8,6 +8,10 @@ src = [
"tests", "tests",
"benchmarking" "benchmarking"
] ]
target-version = "py38"
line-length = 119
[tool.ruff.lint]
select = [ select = [
"B", # bugbear: security warnings "B", # bugbear: security warnings
"E", # pycodestyle "E", # pycodestyle
...@@ -17,7 +21,6 @@ select = [ ...@@ -17,7 +21,6 @@ select = [
"UP", # alert you when better syntax is available in your python version "UP", # alert you when better syntax is available in your python version
"RUF", # the ruff developer's own rules "RUF", # the ruff developer's own rules
] ]
target-version = "py38"
ignore = [ ignore = [
"B007", # Loop control variable not used within the loop body (TODO: enable) "B007", # Loop control variable not used within the loop body (TODO: enable)
"B028", # Warning without stacklevel (TODO: enable) "B028", # Warning without stacklevel (TODO: enable)
...@@ -30,7 +33,7 @@ ignore = [ ...@@ -30,7 +33,7 @@ ignore = [
] ]
ignore-init-module-imports = true # allow to expose in __init__.py via imports ignore-init-module-imports = true # allow to expose in __init__.py via imports
[tool.ruff.extend-per-file-ignores] [tool.ruff.lint.extend-per-file-ignores]
"**/__init__.py" = ["F401"] # allow unused imports in __init__.py "**/__init__.py" = ["F401"] # allow unused imports in __init__.py
"{benchmarking,tests}/**/*.py" = [ "{benchmarking,tests}/**/*.py" = [
"B007", "B007",
...@@ -42,7 +45,7 @@ ignore-init-module-imports = true # allow to expose in __init__.py via imports ...@@ -42,7 +45,7 @@ ignore-init-module-imports = true # allow to expose in __init__.py via imports
"UP030", "UP030",
] ]
[tool.ruff.isort] [tool.ruff.lint.isort]
combine-as-imports = true combine-as-imports = true
detect-same-package = true detect-same-package = true
force-sort-within-sections = true force-sort-within-sections = true
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
Script to close stale issue. Taken in part from the AllenNLP repository. Script to close stale issue. Taken in part from the AllenNLP repository.
https://github.com/allenai/allennlp. https://github.com/allenai/allennlp.
""" """
from datetime import datetime as dt, timezone from datetime import datetime as dt, timezone
import os import os
...@@ -50,7 +51,7 @@ def main(): ...@@ -50,7 +51,7 @@ def main():
issue.create_comment( issue.create_comment(
"This issue has been automatically marked as stale because it has not had " "This issue has been automatically marked as stale because it has not had "
"recent activity. If you think this still needs to be addressed " "recent activity. If you think this still needs to be addressed "
"please comment on this thread.\n\n" "please comment on this thread.\n\n",
) )
......
...@@ -20,7 +20,11 @@ TRANSPOSE_VALS = [(False, True), (False, False)] ...@@ -20,7 +20,11 @@ TRANSPOSE_VALS = [(False, True), (False, False)]
@pytest.mark.parametrize("dim2", get_test_dims(32, 96, n=1), ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", get_test_dims(32, 96, n=1), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) @pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4")) @pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
@pytest.mark.parametrize("funcs", [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)], ids=["func=bmm", "func=matmul"]) @pytest.mark.parametrize(
"funcs",
[(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)],
ids=["func=bmm", "func=matmul"],
)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("req_grad", BOOLEAN_TUPLES, ids=id_formatter("req_grad")) @pytest.mark.parametrize("req_grad", BOOLEAN_TUPLES, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose")) @pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
...@@ -30,16 +34,13 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool ...@@ -30,16 +34,13 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
dim3 = dim3 - (dim3 % 16) dim3 = dim3 - (dim3 % 16)
dim4 = dim4 - (dim4 % 16) dim4 = dim4 - (dim4 % 16)
for i in range(25): for i in range(25):
# normal multiply # normal multiply
if funcs[0] in [torch.mm, torch.matmul]: if funcs[0] in [torch.mm, torch.matmul]:
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0]) A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0])
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1]) B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
target = torch.randn( target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1])
size=(dim2, dim4), device="cuda", requires_grad=req_grad[1]
)
torch.nn.init.xavier_uniform_(B) torch.nn.init.xavier_uniform_(B)
if not transpose[0] and not transpose[1]: if not transpose[0] and not transpose[1]:
...@@ -71,9 +72,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool ...@@ -71,9 +72,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
A.grad = None A.grad = None
B.grad = None B.grad = None
loss_torch = torch.nn.functional.mse_loss( loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
out_torch, target
).mean()
loss_torch.backward() loss_torch.backward()
gradA2 = A.grad gradA2 = A.grad
gradB2 = B.grad gradB2 = B.grad
...@@ -81,18 +80,14 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool ...@@ -81,18 +80,14 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
B.grad = None B.grad = None
if req_grad[0]: if req_grad[0]:
torch.testing.assert_close( torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
gradA1, gradA2, atol=0.015, rtol=0.1
)
if req_grad[1]: if req_grad[1]:
n = gradB1.numel() n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.1 assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02 assert (idx == 0).sum().item() < n * 0.02
torch.testing.assert_close( torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3)
gradB1, gradB2, atol=0.18, rtol=0.3
)
# batched matrix multiply # batched matrix multiply
if funcs[0] in [torch.bmm, torch.matmul]: if funcs[0] in [torch.bmm, torch.matmul]:
...@@ -119,9 +114,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool ...@@ -119,9 +114,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
n = out_bnb.numel() n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx == 0).sum().item() < n * 0.01 assert (idx == 0).sum().item() < n * 0.01
torch.testing.assert_close( torch.testing.assert_close(out_bnb, out_torch, atol=0.027, rtol=0.2)
out_bnb, out_torch, atol=0.027, rtol=0.2
)
if any(req_grad): if any(req_grad):
out_bnb.data.copy_(out_torch) out_bnb.data.copy_(out_torch)
...@@ -133,9 +126,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool ...@@ -133,9 +126,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
A.grad = None A.grad = None
B.grad = None B.grad = None
loss_torch = torch.nn.functional.mse_loss( loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
out_torch, target
).mean()
loss_torch.backward() loss_torch.backward()
gradA2 = A.grad gradA2 = A.grad
gradB2 = B.grad gradB2 = B.grad
...@@ -143,9 +134,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool ...@@ -143,9 +134,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
B.grad = None B.grad = None
if req_grad[0]: if req_grad[0]:
torch.testing.assert_close( torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
gradA1, gradA2, atol=0.015, rtol=0.1
)
if req_grad[1]: if req_grad[1]:
n = gradB1.numel() n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
...@@ -192,9 +181,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool ...@@ -192,9 +181,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
A.grad = None A.grad = None
B.grad = None B.grad = None
loss_torch = torch.nn.functional.mse_loss( loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
out_torch, target
).mean()
loss_torch.backward() loss_torch.backward()
gradA2 = A.grad gradA2 = A.grad
gradB2 = B.grad gradB2 = B.grad
...@@ -202,9 +189,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool ...@@ -202,9 +189,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
B.grad = None B.grad = None
if req_grad[0]: if req_grad[0]:
torch.testing.assert_close( torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
gradA1, gradA2, atol=0.015, rtol=0.1
)
if req_grad[1]: if req_grad[1]:
n = gradB1.numel() n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
...@@ -218,25 +203,17 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool ...@@ -218,25 +203,17 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) @pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4")) @pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
@pytest.mark.parametrize("decomp", [0.0, 6.0], ids=id_formatter("decomp")) @pytest.mark.parametrize("decomp", [0.0, 6.0], ids=id_formatter("decomp"))
@pytest.mark.parametrize("funcs", [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)], ids=["func=matmul", "func=switchback_bnb"]) @pytest.mark.parametrize(
"funcs",
[(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)],
ids=["func=matmul", "func=switchback_bnb"],
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad")) @pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose")) @pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights")) @pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) @pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
def test_matmullt( def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias):
dim1,
dim2,
dim3,
dim4,
funcs,
dtype,
req_grad,
transpose,
decomp,
has_fp16_weights,
has_bias
):
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda") outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda")
...@@ -245,18 +222,13 @@ def test_matmullt( ...@@ -245,18 +222,13 @@ def test_matmullt(
req_grad[2] = False req_grad[2] = False
for i in range(3): for i in range(3):
# normal multiply # normal multiply
if funcs[0] in [torch.mm, torch.matmul]: if funcs[0] in [torch.mm, torch.matmul]:
A = torch.randn( A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype
)
if decomp == 6.0: if decomp == 6.0:
with torch.no_grad(): with torch.no_grad():
A[:, outlier_dim] = 6.0 A[:, outlier_dim] = 6.0
B = torch.randn( B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype
)
target = torch.randn( target = torch.randn(
size=(dim2, dim4), size=(dim2, dim4),
device="cuda", device="cuda",
...@@ -266,7 +238,7 @@ def test_matmullt( ...@@ -266,7 +238,7 @@ def test_matmullt(
bias = None bias = None
bias2 = None bias2 = None
if has_bias: if has_bias:
bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2]) bias = torch.randn(dim4, device="cuda", dtype=dtype, requires_grad=req_grad[2])
bias2 = bias.clone() bias2 = bias.clone()
torch.nn.init.xavier_uniform_(B) torch.nn.init.xavier_uniform_(B)
B2 = B.clone() B2 = B.clone()
...@@ -311,9 +283,7 @@ def test_matmullt( ...@@ -311,9 +283,7 @@ def test_matmullt(
if any(req_grad): if any(req_grad):
out_bnb.data.copy_(out_torch) out_bnb.data.copy_(out_torch)
torch.cuda.synchronize() torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss( loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
out_bnb, target
).mean()
loss_bnb.backward() loss_bnb.backward()
gradA1 = A.grad gradA1 = A.grad
gradB1 = B.grad gradB1 = B.grad
...@@ -323,9 +293,7 @@ def test_matmullt( ...@@ -323,9 +293,7 @@ def test_matmullt(
gradBias1 = bias.grad gradBias1 = bias.grad
bias.grad = None bias.grad = None
loss_torch = torch.nn.functional.mse_loss( loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
out_torch, target
).mean()
loss_torch.backward() loss_torch.backward()
gradA2 = A.grad gradA2 = A.grad
gradB2 = B.grad gradB2 = B.grad
...@@ -336,9 +304,7 @@ def test_matmullt( ...@@ -336,9 +304,7 @@ def test_matmullt(
bias.grad = None bias.grad = None
if req_grad[0]: if req_grad[0]:
torch.testing.assert_close( torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
gradA1, gradA2, atol=0.015, rtol=0.1
)
if req_grad[1]: if req_grad[1]:
n = gradB1.numel() n = gradB1.numel()
if dim2 > 0: if dim2 > 0:
...@@ -352,9 +318,7 @@ def test_matmullt( ...@@ -352,9 +318,7 @@ def test_matmullt(
assert (idx == 0).sum().item() <= n * 0.1 assert (idx == 0).sum().item() <= n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() <= n * 0.02 assert (idx == 0).sum().item() <= n * 0.02
torch.testing.assert_close( torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3)
gradB1, gradB2, atol=0.18, rtol=0.3
)
if req_grad[2]: if req_grad[2]:
torch.testing.assert_close(gradBias1, gradBias2) torch.testing.assert_close(gradBias1, gradBias2)
...@@ -370,8 +334,20 @@ def test_matmullt( ...@@ -370,8 +334,20 @@ def test_matmullt(
@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) @pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'], ids=id_formatter("quant_type")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"], ids=id_formatter("quant_type"))
def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type): def test_matmul_4bit(
dim1,
dim2,
dim3,
dim4,
funcs,
dtype,
req_grad,
transpose,
has_bias,
compress_statistics,
quant_type,
):
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
if has_bias == False: if has_bias == False:
...@@ -387,11 +363,15 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, ...@@ -387,11 +363,15 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
bias = None bias = None
bias2 = None bias2 = None
if has_bias: if has_bias:
bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2]) bias = torch.randn(dim4, device="cuda", dtype=dtype, requires_grad=req_grad[2])
bias2 = bias.clone() bias2 = bias.clone()
torch.nn.init.xavier_uniform_(B) torch.nn.init.xavier_uniform_(B)
B2, quant_state = bnb.functional.quantize_4bit(B, compress_statistics=compress_statistics, quant_type=quant_type) B2, quant_state = bnb.functional.quantize_4bit(
B,
compress_statistics=compress_statistics,
quant_type=quant_type,
)
if not transpose[0] and transpose[1]: if not transpose[0] and transpose[1]:
out_torch = funcs[0](A, B.t()) out_torch = funcs[0](A, B.t())
...@@ -410,7 +390,7 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, ...@@ -410,7 +390,7 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
if n > 0: if n > 0:
assert err < 0.115 assert err < 0.115
#assert err < 0.20 # assert err < 0.20
if any(req_grad): if any(req_grad):
out_bnb.data.copy_(out_torch) out_bnb.data.copy_(out_torch)
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -424,7 +404,7 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, ...@@ -424,7 +404,7 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
gradBias1 = bias.grad gradBias1 = bias.grad
bias.grad = None bias.grad = None
loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean() loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
loss_torch.backward() loss_torch.backward()
gradA2 = A.grad gradA2 = A.grad
gradB2 = B.grad gradB2 = B.grad
...@@ -435,7 +415,7 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, ...@@ -435,7 +415,7 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
bias.grad = None bias.grad = None
if req_grad[0]: if req_grad[0]:
torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1) torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[2]: if req_grad[2]:
torch.testing.assert_close(gradBias1, gradBias2) torch.testing.assert_close(gradBias1, gradBias2)
...@@ -448,8 +428,12 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, ...@@ -448,8 +428,12 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad")) @pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose")) @pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("funcs", [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)], ids=["matmul_fp8_mixed", 'matmul_fp8_global']) @pytest.mark.parametrize(
def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): "funcs",
[(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)],
ids=["matmul_fp8_mixed", "matmul_fp8_global"],
)
def test_matmul_fp8(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
req_grad = list(req_grad) req_grad = list(req_grad)
...@@ -480,7 +464,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): ...@@ -480,7 +464,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
err = torch.abs(out_bnb - out_torch).float().mean().item() err = torch.abs(out_bnb - out_torch).float().mean().item()
if n > 0: if n > 0:
assert err < 0.115 assert err < 0.115
#assert err < 0.20 # assert err < 0.20
if any(req_grad): if any(req_grad):
out_bnb.data.copy_(out_torch) out_bnb.data.copy_(out_torch)
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -491,7 +475,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): ...@@ -491,7 +475,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
A.grad = None A.grad = None
B.grad = None B.grad = None
loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean() loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
loss_torch.backward() loss_torch.backward()
gradA2 = A.grad gradA2 = A.grad
gradB2 = B.grad gradB2 = B.grad
...@@ -499,7 +483,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): ...@@ -499,7 +483,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B.grad = None B.grad = None
if req_grad[0]: if req_grad[0]:
torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1) torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[1]: if req_grad[1]:
n = gradB1.numel() n = gradB1.numel()
...@@ -514,8 +498,6 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): ...@@ -514,8 +498,6 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
assert (idx == 0).sum().item() <= n * 0.1 assert (idx == 0).sum().item() <= n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() <= n * 0.02 assert (idx == 0).sum().item() <= n * 0.02
grad_err = (gradB1-gradB2).abs().mean() grad_err = (gradB1 - gradB2).abs().mean()
assert grad_err.item() < 0.003 assert grad_err.item() < 0.003
torch.testing.assert_close( torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3)
gradB1, gradB2, atol=0.18, rtol=0.3
)
...@@ -35,7 +35,4 @@ def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog): ...@@ -35,7 +35,4 @@ def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog):
def test_get_cuda_bnb_library_path_nocublaslt(monkeypatch, cuda111_noblas_spec): def test_get_cuda_bnb_library_path_nocublaslt(monkeypatch, cuda111_noblas_spec):
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
assert ( assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda111_nocublaslt"
get_cuda_bnb_library_path(cuda111_noblas_spec).stem
== "libbitsandbytes_cuda111_nocublaslt"
)
...@@ -19,9 +19,7 @@ from tests.helpers import ( ...@@ -19,9 +19,7 @@ from tests.helpers import (
id_formatter, id_formatter,
) )
torch.set_printoptions( torch.set_printoptions(precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000)
precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
)
k = 20 k = 20
...@@ -98,9 +96,7 @@ def teardown(): ...@@ -98,9 +96,7 @@ def teardown():
pass pass
@pytest.mark.parametrize( @pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["float", "half"])
"dtype", [torch.float32, torch.float16], ids=["float", "half"]
)
def test_estimate_quantiles(dtype): def test_estimate_quantiles(dtype):
A = torch.rand(1024, 1024, device="cuda") A = torch.rand(1024, 1024, device="cuda")
A = A.to(dtype) A = A.to(dtype)
...@@ -136,7 +132,6 @@ def test_quantile_quantization(): ...@@ -136,7 +132,6 @@ def test_quantile_quantization():
assert diff < 0.001 assert diff < 0.001
def test_dynamic_quantization(): def test_dynamic_quantization():
diffs = [] diffs = []
reldiffs = [] reldiffs = []
...@@ -149,8 +144,8 @@ def test_dynamic_quantization(): ...@@ -149,8 +144,8 @@ def test_dynamic_quantization():
diffs.append(diff.mean().item()) diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item()) reldiffs.append(reldiff.mean().item())
assert diff.mean().item() < 0.0135 assert diff.mean().item() < 0.0135
print(sum(diffs)/len(diffs)) print(sum(diffs) / len(diffs))
print(sum(reldiffs)/len(reldiffs)) print(sum(reldiffs) / len(reldiffs))
for i in range(100): for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda") A1 = torch.rand(1024, 1024, device="cuda")
...@@ -161,13 +156,12 @@ def test_dynamic_quantization(): ...@@ -161,13 +156,12 @@ def test_dynamic_quantization():
assert diff < 0.004 assert diff < 0.004
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested")) @pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested"))
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64]) @pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
@pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed")) @pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed"))
def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed): def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
#print('') # print('')
diffs = [] diffs = []
reldiffs = [] reldiffs = []
for i in range(100): for i in range(100):
...@@ -178,10 +172,10 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed): ...@@ -178,10 +172,10 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
reldiff = diff / torch.abs(A1.float() + 1e-8) reldiff = diff / torch.abs(A1.float() + 1e-8)
diffs.append(diff.mean().item()) diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item()) reldiffs.append(reldiff.mean().item())
abserr = sum(diffs)/len(diffs) abserr = sum(diffs) / len(diffs)
relerr = sum(reldiffs)/len(reldiffs) relerr = sum(reldiffs) / len(reldiffs)
#print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs)) # print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs))
#print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs)) # print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs))
assert abserr < 0.011 assert abserr < 0.011
assert relerr < 0.018 assert relerr < 0.018
assert A2.dtype == dtype assert A2.dtype == dtype
...@@ -196,9 +190,9 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed): ...@@ -196,9 +190,9 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
reldiff = diff / torch.abs(A1.float() + 1e-8) reldiff = diff / torch.abs(A1.float() + 1e-8)
diffs.append(diff.mean().item()) diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item()) reldiffs.append(reldiff.mean().item())
#torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0) # torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
abserr = sum(diffs)/len(diffs) abserr = sum(diffs) / len(diffs)
relerr = sum(reldiffs)/len(reldiffs) relerr = sum(reldiffs) / len(reldiffs)
if signed: if signed:
assert abserr < 0.0035 assert abserr < 0.0035
assert relerr < 0.015 assert relerr < 0.015
...@@ -206,14 +200,11 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed): ...@@ -206,14 +200,11 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
assert abserr < 0.00175 assert abserr < 0.00175
assert relerr < 0.012 assert relerr < 0.012
assert A2.dtype == dtype assert A2.dtype == dtype
#print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs)) # print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
#print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs)) # print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
@pytest.mark.parametrize( @pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=["float", "half"])
"gtype", [torch.float32, torch.float16], ids=["float", "half"]
)
def test_percentile_clipping(gtype): def test_percentile_clipping(gtype):
gnorm_vec1 = torch.zeros(100, device="cuda") gnorm_vec1 = torch.zeros(100, device="cuda")
gnorm_vec2 = torch.zeros(100, device="cuda") gnorm_vec2 = torch.zeros(100, device="cuda")
...@@ -223,9 +214,7 @@ def test_percentile_clipping(gtype): ...@@ -223,9 +214,7 @@ def test_percentile_clipping(gtype):
for i in range(k): for i in range(k):
step += 1 step += 1
g = torch.randn(n, n, dtype=gtype, device="cuda") g = torch.randn(n, n, dtype=gtype, device="cuda")
gnorm1, clip2, gnorm_scale = F.percentile_clipping( gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile)
g, gnorm_vec2, step, percentile=percentile
)
assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1 assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1
gnorm2 = torch.norm(g.float()) gnorm2 = torch.norm(g.float())
...@@ -309,7 +298,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched): ...@@ -309,7 +298,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
dim2 = dim2 - (dim2 % 32) dim2 = dim2 - (dim2 % 32)
errors = [] errors = []
relerrors = [] relerrors = []
#print("") # print("")
for i in range(5): for i in range(5):
if batched: if batched:
A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda") A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda")
...@@ -321,9 +310,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched): ...@@ -321,9 +310,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda") B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda")
maxA, Ac = quant_methods[0](A, 1) maxA, Ac = quant_methods[0](A, 1)
maxB, Bc = quant_methods[1](B, 0) maxB, Bc = quant_methods[1](B, 0)
torch.testing.assert_close( torch.testing.assert_close(quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05)
quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05
)
if batched: if batched:
out2 = torch.bmm(A, B) out2 = torch.bmm(A, B)
C = torch.bmm(Ac.float(), Bc.float()) C = torch.bmm(Ac.float(), Bc.float())
...@@ -338,8 +325,8 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched): ...@@ -338,8 +325,8 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
relerr = err / torch.abs(out2) relerr = err / torch.abs(out2)
errors.append(err.mean().item()) errors.append(err.mean().item())
relerrors.append(relerr.mean().item()) relerrors.append(relerr.mean().item())
#print(mean(errors)) # print(mean(errors))
#print(mean(relerrors)) # print(mean(relerrors))
def test_stable_embedding(): def test_stable_embedding():
...@@ -356,16 +343,8 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): ...@@ -356,16 +343,8 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
batch_dim = batch_dim - (batch_dim % 16) batch_dim = batch_dim - (batch_dim % 16)
seq_dim = seq_dim - (seq_dim % 16) seq_dim = seq_dim - (seq_dim % 16)
for i in range(k): for i in range(k):
shapeA = ( shapeA = (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim)
(batch_dim, hidden_dim) shapeB = (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4))
if not transpose[0]
else (hidden_dim, batch_dim)
)
shapeB = (
(32 * random.randint(1, 4), hidden_dim)
if transpose[1]
else (hidden_dim, 32 * random.randint(1, 4))
)
A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8) B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
if not transpose[0] and not transpose[1]: if not transpose[0] and not transpose[1]:
...@@ -385,11 +364,7 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): ...@@ -385,11 +364,7 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
for i in range(k): for i in range(k):
shapeA = (batch_dim, seq_dim, hidden_dim) shapeA = (batch_dim, seq_dim, hidden_dim)
shapeB = ( shapeB = (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4))
(32 * random.randint(1, 4), hidden_dim)
if transpose[1]
else (hidden_dim, 32 * random.randint(1, 4))
)
A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8) B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
if not transpose[0] and not transpose[1]: if not transpose[0] and not transpose[1]:
...@@ -410,16 +385,10 @@ def test_dim3_igemm(seq_dim, hidden_dim, batch_dim): ...@@ -410,16 +385,10 @@ def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
hidden_dim = hidden_dim - (hidden_dim % 32) hidden_dim = hidden_dim - (hidden_dim % 32)
batch_dim = batch_dim - (batch_dim % 2) batch_dim = batch_dim - (batch_dim % 2)
for i in range(25): for i in range(25):
A = torch.randint( A = torch.randint(-128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda").to(torch.int8)
-128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda" B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda").to(torch.int8)
).to(torch.int8)
B = torch.randint(
-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda"
).to(torch.int8)
out2 = torch.einsum("bsi, bso->io", A.float(), B.float()) out2 = torch.einsum("bsi, bso->io", A.float(), B.float())
iout = torch.empty( iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device)
A.shape[2], B.shape[2], dtype=torch.int32, device=A.device
)
out = F.igemm(A, B, out=iout) out = F.igemm(A, B, out=iout)
torch.testing.assert_close(out.float(), out2) torch.testing.assert_close(out.float(), out2)
...@@ -444,9 +413,7 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose): ...@@ -444,9 +413,7 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
errs2 = [] errs2 = []
relerrs2 = [] relerrs2 = []
for i in range(k): for i in range(k):
A = torch.normal( A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda")
0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
)
if transpose: if transpose:
B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda") B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda")
else: else:
...@@ -523,9 +490,7 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose): ...@@ -523,9 +490,7 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float()) out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float())
out = F.igemm(A.permute([0, 2, 1]), B) out = F.igemm(A.permute([0, 2, 1]), B)
elif transpose[0] and transpose[1]: elif transpose[0] and transpose[1]:
out2 = torch.bmm( out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float())
A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()
)
out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1])) out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
torch.testing.assert_close(out.float(), out2.float()) torch.testing.assert_close(out.float(), out2.float())
...@@ -541,7 +506,7 @@ def test_vector_quant(dim1, dim2, dim3): ...@@ -541,7 +506,7 @@ def test_vector_quant(dim1, dim2, dim3):
qA, SA = F.vectorwise_quant(A, dim=0) qA, SA = F.vectorwise_quant(A, dim=0)
A1 = F.vectorwise_dequant(qA, SA) A1 = F.vectorwise_dequant(qA, SA)
n = A1.numel() n = A1.numel()
assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n*0.002)) assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n * 0.002))
@pytest.mark.parametrize("dim1", get_test_dims(2, 256, n=2), ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", get_test_dims(2, 256, n=2), ids=id_formatter("dim1"))
...@@ -565,9 +530,7 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans ...@@ -565,9 +530,7 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
if dims == 2: if dims == 2:
A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype) A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype)
elif dims == 3: elif dims == 3:
A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to( A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(dtype)
dtype
)
out, S = F.nvidia_transform(A, to_order=orderOut) out, S = F.nvidia_transform(A, to_order=orderOut)
...@@ -579,17 +542,11 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans ...@@ -579,17 +542,11 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
if dims == 2: if dims == 2:
n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32))) n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
elif dims == 3: elif dims == 3:
n = ( n = A.shape[0] * A.shape[1] * (A.shape[2] + (32 - (A.shape[2] % 32)))
A.shape[0]
* A.shape[1]
* (A.shape[2] + (32 - (A.shape[2] % 32)))
)
assert out.numel() == n assert out.numel() == n
elif orderOut == "col_turing": elif orderOut == "col_turing":
# 32 col 8 row tiles # 32 col 8 row tiles
n = (A.shape[0] + (8 - A.shape[0] % 8)) * ( n = (A.shape[0] + (8 - A.shape[0] % 8)) * (A.shape[1] + (32 - (A.shape[1] % 32)))
A.shape[1] + (32 - (A.shape[1] % 32))
)
assert out.numel() == n assert out.numel() == n
total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0) total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0)
for row in range(A.shape[0]): for row in range(A.shape[0]):
...@@ -598,9 +555,7 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans ...@@ -598,9 +555,7 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
j = col j = col
coltile = (col // 32) + (1 if col % 32 != 0 else 0) coltile = (col // 32) + (1 if col % 32 != 0 else 0)
rowtile = ( rowtile = ((row // 8) + (1 if row % 8 != 0 else 0)) * total_coltile
(row // 8) + (1 if row % 8 != 0 else 0)
) * total_coltile
offset = 32 * 8 * (rowtile + coltile) offset = 32 * 8 * (rowtile + coltile)
col2 = col % 32 col2 = col % 32
row2 = (row % 8) * 32 row2 = (row % 8) * 32
...@@ -611,9 +566,7 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans ...@@ -611,9 +566,7 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
# torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset]) # torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
if orderOut == "col32": if orderOut == "col32":
out2, S = F.nvidia_transform( out2, S = F.nvidia_transform(out, from_order=orderOut, to_order="row", state=S)
out, from_order=orderOut, to_order="row", state=S
)
torch.testing.assert_close(A, out2) torch.testing.assert_close(A, out2)
...@@ -626,16 +579,10 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans ...@@ -626,16 +579,10 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
for i in range(k): for i in range(k):
if dims == 2: if dims == 2:
A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to( A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(torch.int8)
torch.int8
)
elif dims == 3: elif dims == 3:
A = torch.randint( A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(torch.int8)
-128, 127, size=(dim1, dim2, dim3), device="cuda" B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8)
).to(torch.int8)
B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(
torch.int8
)
C1 = torch.matmul(A.float(), B.t().float()) C1 = torch.matmul(A.float(), B.t().float())
A2, SA = F.transform(A, "col32") A2, SA = F.transform(A, "col32")
...@@ -645,9 +592,7 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): ...@@ -645,9 +592,7 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
torch.testing.assert_close(C1, C3.float()) torch.testing.assert_close(C1, C3.float())
# transpose # transpose
B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to( B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(torch.int8)
torch.int8
)
C1 = torch.matmul(A.float(), B.float()) C1 = torch.matmul(A.float(), B.float())
B2t, SBt = F.transform(B, "col_turing", transpose=True) B2t, SBt = F.transform(B, "col_turing", transpose=True)
...@@ -667,9 +612,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): ...@@ -667,9 +612,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
if dims == 2: if dims == 2:
A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half() A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half()
elif dims == 3: elif dims == 3:
A = torch.normal( A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device="cuda").half()
0, 0.5, size=(dim1, dim2, dim3), device="cuda"
).half()
B = torch.randn((dim4, dim3), device="cuda").half() B = torch.randn((dim4, dim3), device="cuda").half()
torch.nn.init.xavier_uniform_(B) torch.nn.init.xavier_uniform_(B)
C1 = torch.matmul(A, B.t()) C1 = torch.matmul(A, B.t())
...@@ -700,6 +643,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): ...@@ -700,6 +643,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
# C3, S = F.transform(C2, 'row', state=SC) # C3, S = F.transform(C2, 'row', state=SC)
# torch.testing.assert_close(C1, C3.float()) # torch.testing.assert_close(C1, C3.float())
@pytest.mark.parametrize( @pytest.mark.parametrize(
("batch", "seq", "model", "hidden"), ("batch", "seq", "model", "hidden"),
[ [
...@@ -729,7 +673,6 @@ def test_bench_8bit_training(batch, seq, model, hidden): ...@@ -729,7 +673,6 @@ def test_bench_8bit_training(batch, seq, model, hidden):
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(k): for i in range(k):
out1 = torch.matmul(A, w1.t()) # fc1 out1 = torch.matmul(A, w1.t()) # fc1
# out2 = torch.matmul(out1, w2.t())# fc2 # out2 = torch.matmul(out1, w2.t())# fc2
...@@ -866,13 +809,15 @@ def test_bench_8bit_training(batch, seq, model, hidden): ...@@ -866,13 +809,15 @@ def test_bench_8bit_training(batch, seq, model, hidden):
def test_dequant_mm(dim1, dim4, dims, formatB, has_bias): def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
inner = torch.randint(1, 128, size=(1,)).item() inner = torch.randint(1, 128, size=(1,)).item()
bias = None bias = None
if has_bias: bias = torch.randn(dim4, device='cuda', dtype=torch.float16) if has_bias:
bias = torch.randn(dim4, device="cuda", dtype=torch.float16)
formatB = F.get_special_format_str() formatB = F.get_special_format_str()
for i in range(1): for i in range(1):
A = torch.randn(dim1, inner, device="cuda") A = torch.randn(dim1, inner, device="cuda")
B = torch.randn(dim4, inner, device="cuda") B = torch.randn(dim4, inner, device="cuda")
C1 = torch.matmul(A.half(), B.t().half()) C1 = torch.matmul(A.half(), B.t().half())
if has_bias: C1 += bias if has_bias:
C1 += bias
A1, maxA = F.vectorwise_quant(A, dim=1) A1, maxA = F.vectorwise_quant(A, dim=1)
B1, maxB = F.vectorwise_quant(B, dim=1) B1, maxB = F.vectorwise_quant(B, dim=1)
...@@ -883,7 +828,8 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias): ...@@ -883,7 +828,8 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
C3, S = F.nvidia_transform(C2, "row", state=SC) C3, S = F.nvidia_transform(C2, "row", state=SC)
C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t()) C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())
if has_bias: C4 += bias if has_bias:
C4 += bias
# TODO: is something wrong here? If so, the problem goes deeper # TODO: is something wrong here? If so, the problem goes deeper
# n = C1.numel() # n = C1.numel()
...@@ -917,9 +863,7 @@ def test_colrow_absmax(dim1, dim2, dims): ...@@ -917,9 +863,7 @@ def test_colrow_absmax(dim1, dim2, dims):
else: else:
assert False assert False
row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax( row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold)
A, threshold=threshold
)
A_blocked = einops.rearrange( A_blocked = einops.rearrange(
torch.abs(A), torch.abs(A),
...@@ -939,9 +883,7 @@ def test_colrow_absmax(dim1, dim2, dims): ...@@ -939,9 +883,7 @@ def test_colrow_absmax(dim1, dim2, dims):
torch.testing.assert_close(row_stats1_trunc, row_stats2) torch.testing.assert_close(row_stats1_trunc, row_stats2)
torch.testing.assert_close(nnz_block_ptr1.int(), nnz_block_ptr2) torch.testing.assert_close(nnz_block_ptr1.int(), nnz_block_ptr2)
row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax( row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0)
A, threshold=0.0
)
torch.testing.assert_close(col_stats1, col_stats2) torch.testing.assert_close(col_stats1, col_stats2)
torch.testing.assert_close(row_stats1, row_stats2) torch.testing.assert_close(row_stats1, row_stats2)
...@@ -963,24 +905,16 @@ def test_double_quant(dim1, dim2): ...@@ -963,24 +905,16 @@ def test_double_quant(dim1, dim2):
torch.testing.assert_close(CAt, out_col1, atol=1, rtol=0) torch.testing.assert_close(CAt, out_col1, atol=1, rtol=0)
n = CAt.numel() n = CAt.numel()
num_not_close_rows = ( num_not_close_rows = (torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
(torch.isclose(CA, out_row1, atol=1) == 0).sum().item() num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
)
num_not_close_cols = (
(torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
)
# allow for 1:500 error due to rounding differences # allow for 1:500 error due to rounding differences
min_error = 1 / 500 min_error = 1 / 500
if num_not_close_cols > (min_error * n): if num_not_close_cols > (min_error * n):
print( print(f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}")
f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}"
)
assert False assert False
if num_not_close_rows > (min_error * n): if num_not_close_rows > (min_error * n):
print( print(f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}")
f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}"
)
assert False assert False
torch.testing.assert_close(Srow.flatten().float(), statsA) torch.testing.assert_close(Srow.flatten().float(), statsA)
...@@ -991,13 +925,12 @@ def test_double_quant(dim1, dim2): ...@@ -991,13 +925,12 @@ def test_double_quant(dim1, dim2):
("dim1", "dim4", "inner"), ("dim1", "dim4", "inner"),
( (
pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}") pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}")
for (dim1, dim4, inner) for (dim1, dim4, inner) in zip(
in zip(
get_test_dims(1, 4 * 1024, n=4), get_test_dims(1, 4 * 1024, n=4),
get_test_dims(1, 4 * 1024, n=4), get_test_dims(1, 4 * 1024, n=4),
get_test_dims(1, 4 * 1024, n=4), get_test_dims(1, 4 * 1024, n=4),
) )
) ),
) )
def test_integrated_igemmlt(dim1, dim4, inner): def test_integrated_igemmlt(dim1, dim4, inner):
for i in range(k): for i in range(k):
...@@ -1037,13 +970,12 @@ def test_integrated_igemmlt(dim1, dim4, inner): ...@@ -1037,13 +970,12 @@ def test_integrated_igemmlt(dim1, dim4, inner):
("dim1", "dim4", "inner"), ("dim1", "dim4", "inner"),
( (
pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}") pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}")
for (dim1, dim4, inner) for (dim1, dim4, inner) in zip(
in zip(
get_test_dims(1, 4 * 1024, n=6), get_test_dims(1, 4 * 1024, n=6),
get_test_dims(1, 4 * 1024, n=6), get_test_dims(1, 4 * 1024, n=6),
get_test_dims(1, 4 * 1024, n=6), get_test_dims(1, 4 * 1024, n=6),
) )
) ),
) )
@pytest.mark.skip("Row scale has some bugs for ampere") @pytest.mark.skip("Row scale has some bugs for ampere")
def test_igemmlt_row_scale(dim1, dim4, inner): def test_igemmlt_row_scale(dim1, dim4, inner):
...@@ -1067,9 +999,7 @@ def test_igemmlt_row_scale(dim1, dim4, inner): ...@@ -1067,9 +999,7 @@ def test_igemmlt_row_scale(dim1, dim4, inner):
c = 10.0 * inner * scale c = 10.0 * inner * scale
row_scale = torch.ones_like(maxA) / c row_scale = torch.ones_like(maxA) / c
outC32, SC = F.igemmlt( outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale)
A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
)
C3, S = F.nvidia_transform(outC32, "row", state=SC) C3, S = F.nvidia_transform(outC32, "row", state=SC)
maxval = torch.abs(C3).max() maxval = torch.abs(C3).max()
if maxval == 127: if maxval == 127:
...@@ -1150,9 +1080,7 @@ def test_row_scale_bench(dim1, dim4, inner): ...@@ -1150,9 +1080,7 @@ def test_row_scale_bench(dim1, dim4, inner):
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(k): for i in range(k):
outC32, SC = F.igemmlt( outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale)
A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
)
torch.cuda.synchronize() torch.cuda.synchronize()
print("row-wise", time.time() - t0) print("row-wise", time.time() - t0)
...@@ -1177,13 +1105,9 @@ def test_row_scale_bench(dim1, dim4, inner): ...@@ -1177,13 +1105,9 @@ def test_row_scale_bench(dim1, dim4, inner):
def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
for i in range(k): for i in range(k):
if dims == 2: if dims == 2:
A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to( A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(dtype)
dtype
)
elif dims == 3: elif dims == 3:
A = torch.randint( A = torch.randint(10, 99, size=(dim1, dim2, dim3), device="cuda").to(dtype)
10, 99, size=(dim1, dim2, dim3), device="cuda"
).to(dtype)
A.view(-1)[-1] = -1 A.view(-1)[-1] = -1
if transpose: if transpose:
...@@ -1224,23 +1148,17 @@ def test_coo_double_quant(dim1, dim2): ...@@ -1224,23 +1148,17 @@ def test_coo_double_quant(dim1, dim2):
idx = torch.abs(A) >= threshold idx = torch.abs(A) >= threshold
CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant( CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold)
A, threshold=threshold
)
if coo_tensor is not None: if coo_tensor is not None:
A1 = A * idx A1 = A * idx
A2 = torch.zeros_like(A) A2 = torch.zeros_like(A)
A2[ A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values
coo_tensor.rowidx.long(), coo_tensor.colidx.long()
] = coo_tensor.values
torch.testing.assert_close(A1, A2) torch.testing.assert_close(A1, A2)
A1 = A * (idx == 0) A1 = A * (idx == 0)
A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
torch.testing.assert_close( torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2)
A * (idx == 0), A2, rtol=0.05, atol=1.5e-2
)
@pytest.mark.parametrize("dim1", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim1"))
...@@ -1261,9 +1179,7 @@ def test_spmm_coo(dim1, dim2, transposed_B): ...@@ -1261,9 +1179,7 @@ def test_spmm_coo(dim1, dim2, transposed_B):
nnz = (idx == 1).sum().item() nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx) rows, cols = torch.where(idx)
values = A[idx] values = A[idx]
cooA = F.COOSparseTensor( cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
)
A2 = A * idx A2 = A * idx
if transposed_B: if transposed_B:
...@@ -1303,9 +1219,7 @@ def test_spmm_bench(): ...@@ -1303,9 +1219,7 @@ def test_spmm_bench():
print(nnz / idx.numel()) print(nnz / idx.numel())
rows, cols = torch.where(idx) rows, cols = torch.where(idx)
values = A[idx] values = A[idx]
cooA = F.COOSparseTensor( cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
)
for i in range(10): for i in range(10):
out2 = F.spmm_coo(cooA, B) out2 = F.spmm_coo(cooA, B)
...@@ -1339,9 +1253,7 @@ def test_integrated_sparse_decomp(dim1, dim2): ...@@ -1339,9 +1253,7 @@ def test_integrated_sparse_decomp(dim1, dim2):
out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1) out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant( CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold)
A, threshold=threshold
)
C32A, SA = F.transform(CA, "col32") C32A, SA = F.transform(CA, "col32")
out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1) out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
...@@ -1396,9 +1308,7 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func): ...@@ -1396,9 +1308,7 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
nnz = (idx == 1).sum().item() nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx) rows, cols = torch.where(idx)
values = A[idx] values = A[idx]
cooA = F.COOSparseTensor( cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
)
A2 = A * idx A2 = A * idx
out1 = torch.matmul(A2.half(), B.half()) out1 = torch.matmul(A2.half(), B.half())
out = out_func(out1.shape, dtype=torch.float16, device=out1.device) out = out_func(out1.shape, dtype=torch.float16, device=out1.device)
...@@ -1413,9 +1323,7 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func): ...@@ -1413,9 +1323,7 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
std = out1.std() std = out1.std()
out1 /= std out1 /= std
out2 /= std out2 /= std
assert_all_approx_close( assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count)
out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count
)
# assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count) # assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
idx_col = torch.randint(0, A2.shape[-1], size=(15,)) idx_col = torch.randint(0, A2.shape[-1], size=(15,))
...@@ -1443,9 +1351,7 @@ def test_coo2csr(): ...@@ -1443,9 +1351,7 @@ def test_coo2csr():
nnz = (idx == 1).sum().item() nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx) rows, cols = torch.where(idx)
values = A[idx] values = A[idx]
cooA = F.COOSparseTensor( cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
)
A2 = A * idx A2 = A * idx
csrA = F.coo2csr(cooA) csrA = F.coo2csr(cooA)
counts = csrA.rowptr[1:] - csrA.rowptr[:-1] counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
...@@ -1463,9 +1369,7 @@ def test_coo2csc(): ...@@ -1463,9 +1369,7 @@ def test_coo2csc():
nnz = (idx == 1).sum().item() nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx) rows, cols = torch.where(idx)
values = A[idx] values = A[idx]
cooA = F.COOSparseTensor( cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
)
A2 = A * idx A2 = A * idx
cscA = F.coo2csc(cooA) cscA = F.coo2csc(cooA)
counts = cscA.colptr[1:] - cscA.colptr[:-1] counts = cscA.colptr[1:] - cscA.colptr[:-1]
...@@ -1499,9 +1403,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): ...@@ -1499,9 +1403,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
nnz = (idx == 1).sum().item() nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx) rows, cols = torch.where(idx)
values = A[idx] values = A[idx]
cooA = F.COOSparseTensor( cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
)
A2 = A * idx A2 = A * idx
out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
out1 = torch.matmul(A2, B.half()) out1 = torch.matmul(A2, B.half())
...@@ -1582,7 +1484,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): ...@@ -1582,7 +1484,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
@pytest.mark.parametrize( @pytest.mark.parametrize(
("batch", "seq", "model", "hidden"), ("batch", "seq", "model", "hidden"),
[pytest.param(1, 1, 6656, 4*6656, id="batch=1, seq=1, model=6656, hidden=26k")], [pytest.param(1, 1, 6656, 4 * 6656, id="batch=1, seq=1, model=6656, hidden=26k")],
) )
@pytest.mark.benchmark @pytest.mark.benchmark
def test_bench_matmul(batch, seq, model, hidden): def test_bench_matmul(batch, seq, model, hidden):
...@@ -1605,8 +1507,8 @@ def test_bench_matmul(batch, seq, model, hidden): ...@@ -1605,8 +1507,8 @@ def test_bench_matmul(batch, seq, model, hidden):
outliers = torch.randint(0, model, size=(5,)).cuda() outliers = torch.randint(0, model, size=(5,)).cuda()
A[:, :, outliers] = 8.0 A[:, :, outliers] = 8.0
linearMixedBit = (bnb.nn.Linear8bitLt(model, hidden, False, False, threshold=6.0).cuda().half()) linearMixedBit = bnb.nn.Linear8bitLt(model, hidden, False, False, threshold=6.0).cuda().half()
#linearMixedBit.eval() # linearMixedBit.eval()
linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half() linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
...@@ -1623,121 +1525,123 @@ def test_bench_matmul(batch, seq, model, hidden): ...@@ -1623,121 +1525,123 @@ def test_bench_matmul(batch, seq, model, hidden):
for i in range(iters): for i in range(iters):
torch.matmul(A, B.t()) torch.matmul(A, B.t())
torch.cuda.synchronize() torch.cuda.synchronize()
print( f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) print(
f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s",
)
#torch.cuda.synchronize() # torch.cuda.synchronize()
#t0 = time.time() # t0 = time.time()
#for i in range(iters): # for i in range(iters):
# bnb.matmul_4bit(A, B_fp4.t(), quant_state=state) # bnb.matmul_4bit(A, B_fp4.t(), quant_state=state)
#torch.cuda.synchronize() # torch.cuda.synchronize()
#print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) # print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
#torch.cuda.synchronize() # torch.cuda.synchronize()
#t0 = time.time() # t0 = time.time()
#for i in range(iters): # for i in range(iters):
# bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c) # bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c)
#torch.cuda.synchronize() # torch.cuda.synchronize()
#print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) # print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(iters): for i in range(iters):
bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4) bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
torch.cuda.synchronize() torch.cuda.synchronize()
print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) print(f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(iters): for i in range(iters):
bnb.matmul_4bit(A, B_nf4_c.t(), quant_state=state_nf4_c) bnb.matmul_4bit(A, B_nf4_c.t(), quant_state=state_nf4_c)
torch.cuda.synchronize() torch.cuda.synchronize()
print( f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) print(f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
# torch.cuda.synchronize()
#torch.cuda.synchronize() # t0 = time.time()
#t0 = time.time() # for i in range(iters):
#for i in range(iters):
# bnb.matmul(A, B) # bnb.matmul(A, B)
#torch.cuda.synchronize() # torch.cuda.synchronize()
#print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") # print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#torch.cuda.synchronize() # torch.cuda.synchronize()
#t0 = time.time() # t0 = time.time()
#for i in range(iters): # for i in range(iters):
# bnb.matmul(A, B, threshold=6.0) # bnb.matmul(A, B, threshold=6.0)
#torch.cuda.synchronize() # torch.cuda.synchronize()
#print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") # print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) # CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
#C32A, SA = F.transform(CA, "col32") # C32A, SA = F.transform(CA, "col32")
#CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B) # CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
#CxB, SB = F.transform(CB, to_order=formatB) # CxB, SB = F.transform(CB, to_order=formatB)
#torch.cuda.synchronize() # torch.cuda.synchronize()
#t0 = time.time() # t0 = time.time()
#for i in range(iters): # for i in range(iters):
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
#torch.cuda.synchronize() # torch.cuda.synchronize()
#print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") # print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#BA, statsB = F.vectorwise_quant(B, dim=1) # BA, statsB = F.vectorwise_quant(B, dim=1)
#CxB, SB = F.nvidia_transform(CB, to_order=formatB) # CxB, SB = F.nvidia_transform(CB, to_order=formatB)
#torch.cuda.synchronize() # torch.cuda.synchronize()
#t0 = time.time() # t0 = time.time()
#for i in range(iters): # for i in range(iters):
# A2 = A.view(-1, A.shape[-1]).contiguous() # A2 = A.view(-1, A.shape[-1]).contiguous()
# CA, statsA = F.vectorwise_quant(A2, dim=1) # CA, statsA = F.vectorwise_quant(A2, dim=1)
# C32A, SA = F.nvidia_transform(CA, "col32") # C32A, SA = F.nvidia_transform(CA, "col32")
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
# Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) # Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
# F.vectorwise_mm_dequant(Cout, statsA, statsB.t()) # F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
#torch.cuda.synchronize() # torch.cuda.synchronize()
#print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") # print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear") # BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
#CxB, SB = F.nvidia_transform(CB, to_order=formatB) # CxB, SB = F.nvidia_transform(CB, to_order=formatB)
#torch.cuda.synchronize() # torch.cuda.synchronize()
#t0 = time.time() # t0 = time.time()
#for i in range(iters): # for i in range(iters):
# A2 = A.view(-1, A.shape[-1]).contiguous() # A2 = A.view(-1, A.shape[-1]).contiguous()
# CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear") # CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
# C32A, SA = F.nvidia_transform(CA, "col32") # C32A, SA = F.nvidia_transform(CA, "col32")
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
# Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) # Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
# out = Cout * statsB * statsA * (1.0 / (127 * 127)) # out = Cout * statsB * statsA * (1.0 / (127 * 127))
#torch.cuda.synchronize() # torch.cuda.synchronize()
#print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") # print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#linear8bit(A) # linear8bit(A)
#torch.cuda.synchronize() # torch.cuda.synchronize()
#t0 = time.time() # t0 = time.time()
#for i in range(iters): # for i in range(iters):
# linear8bit(A) # linear8bit(A)
#torch.cuda.synchronize() # torch.cuda.synchronize()
#print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") # print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#linearMixedBit(A) # linearMixedBit(A)
#torch.cuda.synchronize() # torch.cuda.synchronize()
#t0 = time.time() # t0 = time.time()
#for i in range(iters): # for i in range(iters):
# linearMixedBit(A) # linearMixedBit(A)
#torch.cuda.synchronize() # torch.cuda.synchronize()
#print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") # print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#linear8bit_train(A) # linear8bit_train(A)
#torch.cuda.synchronize() # torch.cuda.synchronize()
#t0 = time.time() # t0 = time.time()
#for i in range(iters): # for i in range(iters):
# linear8bit_train(A) # linear8bit_train(A)
#torch.cuda.synchronize() # torch.cuda.synchronize()
#print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") # print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#linear8bit_train_thresh(A) # linear8bit_train_thresh(A)
#torch.cuda.synchronize() # torch.cuda.synchronize()
#t0 = time.time() # t0 = time.time()
#for i in range(iters): # for i in range(iters):
# linear8bit_train(A) # linear8bit_train(A)
#torch.cuda.synchronize() # torch.cuda.synchronize()
#print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") # print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
def test_zeropoint(): def test_zeropoint():
def quant_zp(x): def quant_zp(x):
...@@ -1778,8 +1682,8 @@ def test_zeropoint(): ...@@ -1778,8 +1682,8 @@ def test_zeropoint():
C2 -= A.sum(1).view(-1, 1) * zp C2 -= A.sum(1).view(-1, 1) * zp
ca, cqa, cza = quant_zp(A) ca, cqa, cza = quant_zp(A)
#print(ca.min(), ca.max()) # print(ca.min(), ca.max())
#print((ca - cza).min(), (ca - cza).max()) # print((ca - cza).min(), (ca - cza).max())
zp = 1 zp = 1
scale = 2.0 scale = 2.0
...@@ -1808,14 +1712,14 @@ def test_zeropoint(): ...@@ -1808,14 +1712,14 @@ def test_zeropoint():
C7 -= zpa * zpb * A.shape[1] C7 -= zpa * zpb * A.shape[1]
C7 /= qa * qb C7 /= qa * qb
#print("") # print("")
# print(C0.flatten()[:10]) # print(C0.flatten()[:10])
#print(C1.flatten()[:10]) # print(C1.flatten()[:10])
#print(C2.flatten()[:10]) # print(C2.flatten()[:10])
#print(C3.flatten()[:10]) # print(C3.flatten()[:10])
#print(C5.flatten()[:10]) # print(C5.flatten()[:10])
#print(C6.flatten()[:10]) # print(C6.flatten()[:10])
#print(C7.flatten()[:10]) # print(C7.flatten()[:10])
err1 = torch.abs(C1 - C2).mean().item() err1 = torch.abs(C1 - C2).mean().item()
err2 = torch.abs(C1 - C3).mean().item() err2 = torch.abs(C1 - C3).mean().item()
err3 = torch.abs(C1 - C4).mean().item() err3 = torch.abs(C1 - C4).mean().item()
...@@ -1852,16 +1756,15 @@ def test_extract_outliers(): ...@@ -1852,16 +1756,15 @@ def test_extract_outliers():
torch.testing.assert_close(outliers1, outliers2) torch.testing.assert_close(outliers1, outliers2)
def test_blockwise_cpu_large(): def test_blockwise_cpu_large():
diffs = [] diffs = []
reldiffs = [] reldiffs = []
batch = 128 batch = 128
seq = 128 seq = 128
for hidden in [128]:#, 14336]: for hidden in [128]: # , 14336]:
for blocksize in [4096, 16384]: for blocksize in [4096, 16384]:
for i in range(2): for i in range(2):
A1 = torch.randn(batch, seq, hidden, device='cpu') A1 = torch.randn(batch, seq, hidden, device="cpu")
t0 = time.time() t0 = time.time()
C, S = F.quantize_blockwise(A1, blocksize=blocksize) C, S = F.quantize_blockwise(A1, blocksize=blocksize)
A2 = F.dequantize_blockwise(C, S, blocksize=blocksize) A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
...@@ -1875,10 +1778,9 @@ def test_blockwise_cpu_large(): ...@@ -1875,10 +1778,9 @@ def test_blockwise_cpu_large():
# print(sum(reldiffs)/len(reldiffs)) # print(sum(reldiffs)/len(reldiffs))
def test_fp8_quant(): def test_fp8_quant():
for e_bits in range(1, 7): for e_bits in range(1, 7):
p_bits = 7-e_bits p_bits = 7 - e_bits
code = F.create_fp8_map(True, e_bits, p_bits).cuda() code = F.create_fp8_map(True, e_bits, p_bits).cuda()
abserr = [] abserr = []
...@@ -1888,12 +1790,12 @@ def test_fp8_quant(): ...@@ -1888,12 +1790,12 @@ def test_fp8_quant():
C, SC = F.quantize_blockwise(A1, code=code) C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC) A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2) diff = torch.abs(A1 - A2)
reldiff = diff/torch.abs(A1+1e-8) reldiff = diff / torch.abs(A1 + 1e-8)
abserr.append(diff.mean().item()) abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item()) relerr.append(reldiff.mean().item())
#assert diff < 0.0075 # assert diff < 0.0075
#print(sum(abserr)/len(abserr)) # print(sum(abserr)/len(abserr))
#print(sum(relerr)/len(relerr)) # print(sum(relerr)/len(relerr))
abserr = [] abserr = []
relerr = [] relerr = []
...@@ -1902,12 +1804,12 @@ def test_fp8_quant(): ...@@ -1902,12 +1804,12 @@ def test_fp8_quant():
C, SC = F.quantize_blockwise(A1, code=code) C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC) A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2) diff = torch.abs(A1 - A2)
reldiff = diff/torch.abs(A1+1e-8) reldiff = diff / torch.abs(A1 + 1e-8)
abserr.append(diff.mean().item()) abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item()) relerr.append(reldiff.mean().item())
#assert diff < 0.0075 # assert diff < 0.0075
#print(sum(abserr)/len(abserr)) # print(sum(abserr)/len(abserr))
#print(sum(relerr)/len(relerr)) # print(sum(relerr)/len(relerr))
abserr = [] abserr = []
relerr = [] relerr = []
...@@ -1916,50 +1818,48 @@ def test_fp8_quant(): ...@@ -1916,50 +1818,48 @@ def test_fp8_quant():
C, SC = F.quantize_blockwise(A1) C, SC = F.quantize_blockwise(A1)
A2 = F.dequantize_blockwise(C, SC) A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2) diff = torch.abs(A1 - A2)
reldiff = diff/torch.abs(A1+1e-8) reldiff = diff / torch.abs(A1 + 1e-8)
abserr.append(diff.mean().item()) abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item()) relerr.append(reldiff.mean().item())
#assert diff < 0.0075 # assert diff < 0.0075
#print(3, sum(abserr)/len(abserr)) # print(3, sum(abserr)/len(abserr))
#print(3, sum(relerr)/len(relerr)) # print(3, sum(relerr)/len(relerr))
def test_few_bit_quant(): def test_few_bit_quant():
# print('')
#print('')
for bits in range(2, 9): for bits in range(2, 9):
#print('='*30, bits, '='*30) # print('='*30, bits, '='*30)
for method in ['linear', 'fp8', 'dynamic', 'quantile']: for method in ["linear", "fp8", "dynamic", "quantile"]:
abserrs = [] abserrs = []
relerrs = [] relerrs = []
code = None code = None
if method == 'linear': if method == "linear":
code = F.create_linear_map(True, total_bits=bits).cuda() code = F.create_linear_map(True, total_bits=bits).cuda()
elif method == 'fp8': elif method == "fp8":
ebits = math.ceil(bits/2) ebits = math.ceil(bits / 2)
pbits = bits-ebits-1 pbits = bits - ebits - 1
code = F.create_fp8_map(True, ebits, pbits, bits).cuda() code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
elif method == 'dynamic': elif method == "dynamic":
code = F.create_dynamic_map(True, bits-0, bits).cuda() code = F.create_dynamic_map(True, bits - 0, bits).cuda()
elif method == 'quantile': elif method == "quantile":
values = torch.randn(2048, 2048, device='cuda') values = torch.randn(2048, 2048, device="cuda")
code = F.create_quantile_map(values, bits).cuda() code = F.create_quantile_map(values, bits).cuda()
# for some data types we have no zero # for some data types we have no zero
# for some data types we have one zero # for some data types we have one zero
# for some data types we have two zeros # for some data types we have two zeros
assert torch.unique(code).numel() in [2**bits, 2**bits-1], f'bits: {bits}, method: {method}' assert torch.unique(code).numel() in [2**bits, 2**bits - 1], f"bits: {bits}, method: {method}"
#print(method, (code==0).sum()) # print(method, (code==0).sum())
assert code.numel() == 256 assert code.numel() == 256
for i in range(10): for i in range(10):
values = torch.randn(1, 32, device="cuda")
values = torch.randn(1, 32, device='cuda')
values /= values.abs().max() values /= values.abs().max()
#values[values.abs() < 1e-6] += 1e-5 # values[values.abs() < 1e-6] += 1e-5
q1 = [] q1 = []
v1 = [] v1 = []
for v in values[0]: for v in values[0]:
idx = torch.abs(v-code).argmin() idx = torch.abs(v - code).argmin()
q1.append(idx.item()) q1.append(idx.item())
v1.append(code[idx].item()) v1.append(code[idx].item())
...@@ -1970,62 +1870,61 @@ def test_few_bit_quant(): ...@@ -1970,62 +1870,61 @@ def test_few_bit_quant():
v2 = F.dequantize_blockwise(q2, S2) v2 = F.dequantize_blockwise(q2, S2)
idx = torch.isclose(q1.int(), q2.int()) idx = torch.isclose(q1.int(), q2.int())
err2 = torch.abs(v2-values) err2 = torch.abs(v2 - values)
abserrs.append(err2.mean().item()) abserrs.append(err2.mean().item())
relerrs.append((err2/(1e-10+values).abs()).mean().item()) relerrs.append((err2 / (1e-10 + values).abs()).mean().item())
if idx.sum(): if idx.sum():
# some weird cases # some weird cases
err1 = torch.abs(v1-values).mean() err1 = torch.abs(v1 - values).mean()
#assert err2.mean() <= err1 # assert err2.mean() <= err1
else: else:
torch.testing.assert_close(q1, q2) torch.testing.assert_close(q1, q2)
#print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs)) # print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
#assert False # assert False
def test_kbit_quantile_estimation(): def test_kbit_quantile_estimation():
for i in range(100): for i in range(100):
data = torch.randn(1024, 1024, device='cuda') data = torch.randn(1024, 1024, device="cuda")
for bits in range(2, 9): for bits in range(2, 9):
p = np.linspace(1.3e-4, 1-1.3e-4, 2**bits) p = np.linspace(1.3e-4, 1 - 1.3e-4, 2**bits)
val1 = torch.Tensor(norm.ppf(p)).cuda() val1 = torch.Tensor(norm.ppf(p)).cuda()
val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits) val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
err = torch.abs(val1-val2).mean() err = torch.abs(val1 - val2).mean()
assert err < 0.038 assert err < 0.038
for i in range(100): for i in range(100):
data = torch.randn(1024, 1024, device='cuda') data = torch.randn(1024, 1024, device="cuda")
for bits in range(2, 4): for bits in range(2, 4):
total_values = 2**bits-1 total_values = 2**bits - 1
p = np.linspace(0, 1, 2*total_values+1) p = np.linspace(0, 1, 2 * total_values + 1)
idx = np.arange(1, 2*total_values+1, 2) idx = np.arange(1, 2 * total_values + 1, 2)
p = p[idx] p = p[idx]
offset = 1/(2*total_values) offset = 1 / (2 * total_values)
p = np.linspace(offset, 1-offset, total_values) p = np.linspace(offset, 1 - offset, total_values)
val1 = torch.Tensor(norm.ppf(p)).cuda() val1 = torch.Tensor(norm.ppf(p)).cuda()
val2 = F.estimate_quantiles(data, num_quantiles=2**bits-1) val2 = F.estimate_quantiles(data, num_quantiles=2**bits - 1)
err = torch.abs(val1-val2).mean() err = torch.abs(val1 - val2).mean()
assert err < 0.035 assert err < 0.035
@pytest.mark.benchmark @pytest.mark.benchmark
def test_bench_dequantization(): def test_bench_dequantization():
a = torch.rand(1024, 1024, device='cuda').half() a = torch.rand(1024, 1024, device="cuda").half()
code =F.create_fp8_map(True, 3, 0, 4).cuda() code = F.create_fp8_map(True, 3, 0, 4).cuda()
qa, SA = F.quantize_blockwise(a, code=code) qa, SA = F.quantize_blockwise(a, code=code)
print(qa.max()) print(qa.max())
max_theoretical_mu = 1024*1024*2/1024**3/672*1000*1000 max_theoretical_mu = 1024 * 1024 * 2 / 1024**3 / 672 * 1000 * 1000
#print(max_theoretical_mu) # print(max_theoretical_mu)
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(100): for i in range(100):
qa, SA = F.quantize_blockwise(a) qa, SA = F.quantize_blockwise(a)
torch.cuda.synchronize() torch.cuda.synchronize()
#print((time.time()-t0)/1e6) # print((time.time()-t0)/1e6)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
...@@ -2037,26 +1936,28 @@ def test_fp4_quant(dtype): ...@@ -2037,26 +1936,28 @@ def test_fp4_quant(dtype):
result = 0 result = 0
bias = 3 bias = 3
sign, e1, e2, p1 = bits sign, e1, e2, p1 = bits
idx = sign*8 + e1*4 + e2*2 + p1*1 idx = sign * 8 + e1 * 4 + e2 * 2 + p1 * 1
sign = -1.0 if sign else 1.0 sign = -1.0 if sign else 1.0
exp = e1*2 + e2*1 exp = e1 * 2 + e2 * 1
if exp == 0: if exp == 0:
# sub-normal # sub-normal
if p1 == 0: result = 0 if p1 == 0:
else: result = sign*0.0625 result = 0
else:
result = sign * 0.0625
else: else:
# normal # normal
exp = 2**(-exp + bias + 1) exp = 2 ** (-exp + bias + 1)
frac = 1.5 if p1 else 1.0 frac = 1.5 if p1 else 1.0
result = sign*exp*frac result = sign * exp * frac
code[idx] = result code[idx] = result
A1 = torch.randn(1024, 1024, device='cuda', dtype=dtype) A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
qa, SA = F.quantize_fp4(A1, blocksize=64) qa, SA = F.quantize_fp4(A1, blocksize=64)
A2 = F.dequantize_fp4(qa, SA) A2 = F.dequantize_fp4(qa, SA)
err = (A1 - A2).abs().float() err = (A1 - A2).abs().float()
relerr = (err/(A1.abs().float()+1e-8)).mean() relerr = (err / (A1.abs().float() + 1e-8)).mean()
idx = err > 1.0 idx = err > 1.0
err = err.mean() err = err.mean()
...@@ -2065,31 +1966,29 @@ def test_fp4_quant(dtype): ...@@ -2065,31 +1966,29 @@ def test_fp4_quant(dtype):
assert relerr.item() < 0.28 assert relerr.item() < 0.28
@pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
def test_4bit_compressed_stats(quant_type): def test_4bit_compressed_stats(quant_type):
for blocksize in [128, 64]: for blocksize in [128, 64]:
errs1 = [] errs1 = []
errs2 = [] errs2 = []
for i in range(10): for i in range(10):
A1 = torch.randn(1024, 1024, device='cuda').half() A1 = torch.randn(1024, 1024, device="cuda").half()
q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type) q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
q3, SA3= F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type) q3, SA3 = F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type)
A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type) A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type)
A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type) A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type)
err = (A1 - A2).abs().float() err = (A1 - A2).abs().float()
relerr = (err/(A1.abs().float()+1e-15)).mean() relerr = (err / (A1.abs().float() + 1e-15)).mean()
err = err.mean() err = err.mean()
errs1.append(err.item()) errs1.append(err.item())
assert err.item() < 0.11 assert err.item() < 0.11
assert relerr.item() < 0.28 assert relerr.item() < 0.28
err = (A1 - A3).abs().float() err = (A1 - A3).abs().float()
relerr = (err/(A1.abs().float()+1e-15)).mean() relerr = (err / (A1.abs().float() + 1e-15)).mean()
err = err.mean() err = err.mean()
errs2.append(err.item()) errs2.append(err.item())
...@@ -2097,70 +1996,71 @@ def test_4bit_compressed_stats(quant_type): ...@@ -2097,70 +1996,71 @@ def test_4bit_compressed_stats(quant_type):
assert err.item() < 0.11 assert err.item() < 0.11
assert relerr.item() < 0.28 assert relerr.item() < 0.28
#print(sum(errs1)/len(errs1), blocksize, quant_type) # print(sum(errs1)/len(errs1), blocksize, quant_type)
#print(sum(errs2)/len(errs2), blocksize, quant_type) # print(sum(errs2)/len(errs2), blocksize, quant_type)
# @pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
#@pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) @pytest.mark.parametrize("quant_type", ["nf4"])
@pytest.mark.parametrize("quant_type", ['nf4'])
@pytest.mark.benchmark @pytest.mark.benchmark
def test_bench_4bit_dequant(quant_type): def test_bench_4bit_dequant(quant_type):
blocksize = 256 blocksize = 256
a = torch.rand(1024*12*4, 1024*12, device='cuda').half() a = torch.rand(1024 * 12 * 4, 1024 * 12, device="cuda").half()
qa, SA = F.quantize_4bit(a, blocksize=blocksize, quant_type=quant_type) qa, SA = F.quantize_4bit(a, blocksize=blocksize, quant_type=quant_type)
input_size = a.numel()/2 input_size = a.numel() / 2
output_size = a.numel()*2 output_size = a.numel() * 2
num_bytes = input_size+output_size num_bytes = input_size + output_size
GB = num_bytes/1e9 GB = num_bytes / 1e9
max_theoretical_s = GB/768 max_theoretical_s = GB / 768
#print(max_theoretical_s*1e6) # print(max_theoretical_s*1e6)
b = torch.randn(128, 1024*12, device='cuda').half() b = torch.randn(128, 1024 * 12, device="cuda").half()
iters = 100 iters = 100
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(iters): for i in range(iters):
F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type) F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
#b.copy_(a) # b.copy_(a)
torch.cuda.synchronize() torch.cuda.synchronize()
#print((time.time()-t0)/iters*1e6) # print((time.time()-t0)/iters*1e6)
#torch.cuda.synchronize() # torch.cuda.synchronize()
#t0 = time.time() # t0 = time.time()
#for i in range(iters): # for i in range(iters):
# torch.matmul(b, a.t()) # torch.matmul(b, a.t())
#torch.cuda.synchronize() # torch.cuda.synchronize()
#print((time.time()-t0)/iters*1e6) # print((time.time()-t0)/iters*1e6)
def test_normal_map_tree(): def test_normal_map_tree():
code = F.create_normal_map() code = F.create_normal_map()
values =code[:8].tolist() + code[-8:].tolist() values = code[:8].tolist() + code[-8:].tolist()
num_pivots = 1 num_pivots = 1
#print(values) # print(values)
while num_pivots <16: while num_pivots < 16:
idx = list(range(16//num_pivots//2, 16, 16//num_pivots)) idx = list(range(16 // num_pivots // 2, 16, 16 // num_pivots))
#print(idx) # print(idx)
num_pivots *= 2 num_pivots *= 2
pivots = [] pivots = []
for i in idx: for i in idx:
pivots.append((values[i-1]+values[i])/2) pivots.append((values[i - 1] + values[i]) / 2)
#print(pivots) # print(pivots)
@pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}") @pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}")
@pytest.mark.parametrize("storage_type", ['nf4', 'fp4']) @pytest.mark.parametrize("storage_type", ["nf4", "fp4"])
@pytest.mark.parametrize("kind", ['fc1', 'fc2', 'attn', 'attn_packed']) @pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("quant_storage", [torch.uint8, torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize(
"quant_storage",
[torch.uint8, torch.float16, torch.bfloat16, torch.float32],
ids=describe_dtype,
)
def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
for dim in [128, 256, 512, 1024]: for dim in [128, 256, 512, 1024]:
#for dim in [4*1024]: # for dim in [4*1024]:
#for dim in [1*16]: # for dim in [1*16]:
errs1 = [] errs1 = []
errs2 = [] errs2 = []
errs3 = [] errs3 = []
...@@ -2171,38 +2071,42 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): ...@@ -2171,38 +2071,42 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
max_errs2 = [] max_errs2 = []
max_errs3 = [] max_errs3 = []
for i in range(100): for i in range(100):
if kind == 'fc1': if kind == "fc1":
A = torch.randn(1, dim, dtype=dtype, device='cuda') A = torch.randn(1, dim, dtype=dtype, device="cuda")
B = torch.randn(dim*4, dim, dtype=dtype, device='cuda')/math.sqrt(dim) B = torch.randn(dim * 4, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
elif kind == 'fc2': elif kind == "fc2":
A = torch.randn(1, 4*dim, dtype=dtype, device='cuda') A = torch.randn(1, 4 * dim, dtype=dtype, device="cuda")
B = torch.randn(dim, 4*dim, dtype=dtype, device='cuda')/math.sqrt(dim) B = torch.randn(dim, 4 * dim, dtype=dtype, device="cuda") / math.sqrt(dim)
elif kind == 'attn': elif kind == "attn":
A = torch.randn(1, dim, dtype=dtype, device='cuda') A = torch.randn(1, dim, dtype=dtype, device="cuda")
B = torch.randn(dim, dim, dtype=dtype, device='cuda')/math.sqrt(dim) B = torch.randn(dim, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
elif kind == 'attn_packed': elif kind == "attn_packed":
A = torch.randn(1, dim, dtype=dtype, device='cuda') A = torch.randn(1, dim, dtype=dtype, device="cuda")
B = torch.randn(dim*3, dim, dtype=dtype, device='cuda')/math.sqrt(dim) B = torch.randn(dim * 3, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant, quant_storage=quant_storage) qB, state = F.quantize_4bit(
B,
quant_type=storage_type,
compress_statistics=double_quant,
quant_storage=quant_storage,
)
C3 = torch.matmul(A, B.t()) C3 = torch.matmul(A, B.t())
C2 = F.gemv_4bit(A, qB.t(), state=state) C2 = F.gemv_4bit(A, qB.t(), state=state)
A.requires_grad = True A.requires_grad = True
C1 = bnb.matmul_4bit(A, qB.t(), state) C1 = bnb.matmul_4bit(A, qB.t(), state)
err1 = (C1-C2).abs().float() err1 = (C1 - C2).abs().float()
err2 = (C3-C2).abs().float() err2 = (C3 - C2).abs().float()
err3 = (C3-C1).abs().float() err3 = (C3 - C1).abs().float()
mag1 = torch.abs(C1).float()+1e-5 mag1 = torch.abs(C1).float() + 1e-5
mag2 = torch.abs(C3).float()+1e-5 mag2 = torch.abs(C3).float() + 1e-5
mag3 = torch.abs(C3).float()+1e-5 mag3 = torch.abs(C3).float() + 1e-5
relerr1 = err1/mag1 relerr1 = err1 / mag1
relerr2 = err2/mag2 relerr2 = err2 / mag2
relerr3 = err3/mag3 relerr3 = err3 / mag3
max_err1 = err1.max() max_err1 = err1.max()
max_err2 = err2.max() max_err2 = err2.max()
...@@ -2220,34 +2124,34 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): ...@@ -2220,34 +2124,34 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
max_errs2.append(max_err2.item()) max_errs2.append(max_err2.item())
max_errs3.append(max_err3.item()) max_errs3.append(max_err3.item())
c = int(C1.numel()*0.0014*(dim/256))+1 c = int(C1.numel() * 0.0014 * (dim / 256)) + 1
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False) c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False)
err1 = sum(errs1)/len(errs1)/math.sqrt(dim) err1 = sum(errs1) / len(errs1) / math.sqrt(dim)
err2 = sum(errs2)/len(errs2)/math.sqrt(dim) err2 = sum(errs2) / len(errs2) / math.sqrt(dim)
err3 = sum(errs3)/len(errs3)/math.sqrt(dim) err3 = sum(errs3) / len(errs3) / math.sqrt(dim)
relerr1 = sum(relerrs1)/len(relerrs1)/math.sqrt(dim) relerr1 = sum(relerrs1) / len(relerrs1) / math.sqrt(dim)
relerr2 = sum(relerrs2)/len(relerrs2)/math.sqrt(dim) relerr2 = sum(relerrs2) / len(relerrs2) / math.sqrt(dim)
relerr3 = sum(relerrs3)/len(relerrs3)/math.sqrt(dim) relerr3 = sum(relerrs3) / len(relerrs3) / math.sqrt(dim)
maxerr1 = sum(max_errs1)/len(max_errs1)/math.sqrt(dim) maxerr1 = sum(max_errs1) / len(max_errs1) / math.sqrt(dim)
maxerr2 = sum(max_errs2)/len(max_errs2)/math.sqrt(dim) maxerr2 = sum(max_errs2) / len(max_errs2) / math.sqrt(dim)
maxerr3 = sum(max_errs3)/len(max_errs3)/math.sqrt(dim) maxerr3 = sum(max_errs3) / len(max_errs3) / math.sqrt(dim)
absratio = err2/err3 absratio = err2 / err3
relratio = relerr2/relerr3 relratio = relerr2 / relerr3
maxratio = relerr2/relerr3 maxratio = relerr2 / relerr3
# for debugging if the tests fails # for debugging if the tests fails
# #
#print('='*80) # print('='*80)
#print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:') # print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
#print(C1.flatten()[-20:]) # print(C1.flatten()[-20:])
#print(C2.flatten()[-20:]) # print(C2.flatten()[-20:])
#print(f'inference vs training abs: {err1}') # print(f'inference vs training abs: {err1}')
#print(f'inference vs training rel: {relerr1}') # print(f'inference vs training rel: {relerr1}')
#print(f'inference vs training max: {maxerr1}') # print(f'inference vs training max: {maxerr1}')
#print(f'inference vs training vs torch err ratio abs: {absratio}') # print(f'inference vs training vs torch err ratio abs: {absratio}')
#print(f'inference vs training vs torch err ratio rel: {relratio}') # print(f'inference vs training vs torch err ratio rel: {relratio}')
#print(f'inference vs training vs torch err ratio max: {maxratio}') # print(f'inference vs training vs torch err ratio max: {maxratio}')
if dtype == torch.float16: if dtype == torch.float16:
if dim <= 512: if dim <= 512:
assert err1 < 7e-5 assert err1 < 7e-5
...@@ -2283,56 +2187,59 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): ...@@ -2283,56 +2187,59 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
assert relratio < 1.04 and relratio > 0.96 assert relratio < 1.04 and relratio > 0.96
assert maxratio < 1.02 and maxratio > 0.98 assert maxratio < 1.02 and maxratio > 0.98
@pytest.mark.skip("Row scale has some bugs for ampere") @pytest.mark.skip("Row scale has some bugs for ampere")
def test_managed(): def test_managed():
n = 32*10 n = 32 * 10
A = F.get_paged(n, n, dtype=torch.float32) A = F.get_paged(n, n, dtype=torch.float32)
B = F.get_paged(n, n, dtype=torch.uint8) B = F.get_paged(n, n, dtype=torch.uint8)
B2 = F.get_paged(n, n, dtype=torch.float32) B2 = F.get_paged(n, n, dtype=torch.float32)
assert A.is_paged assert A.is_paged
assert B.is_paged assert B.is_paged
assert A.page_deviceid==0 assert A.page_deviceid == 0
assert B.page_deviceid==0 assert B.page_deviceid == 0
F.fill(A, 17.0) F.fill(A, 17.0)
F.fill(B, 17) F.fill(B, 17)
F.fill(B2, 2) F.fill(B2, 2)
assert (A==17).sum().item() == n*n assert (A == 17).sum().item() == n * n
assert (B==17).sum().item() == n*n assert (B == 17).sum().item() == n * n
C = A*B.float() C = A * B.float()
assert (C==289).sum().item() == n*n assert (C == 289).sum().item() == n * n
F._mul(A, B2) F._mul(A, B2)
F._mul(A, B2) F._mul(A, B2)
F._mul(A, B2) F._mul(A, B2)
assert (A==17*(2**3)).sum().item() == n*n assert (A == 17 * (2**3)).sum().item() == n * n
# F.prefetch_tensor(A)
# F.prefetch_tensor(B)
# F.prefetch_tensor(A)
# F.prefetch_tensor(B)
# F.fill(B2, 17.0) # F.fill(B2, 17.0)
# F._mul(A, B2) # F._mul(A, B2)
# F.prefetch_tensor(A, to_cpu=True) # F.prefetch_tensor(A, to_cpu=True)
# F.prefetch_tensor(B, to_cpu=True) # F.prefetch_tensor(B, to_cpu=True)
# F.prefetch_tensor(B2, to_cpu=True) # F.prefetch_tensor(B2, to_cpu=True)
# torch.cuda.synchronize() # torch.cuda.synchronize()
# assert (A==17).sum().item() == n*n # assert (A==17).sum().item() == n*n
# torch.testing.assert_close(A, torch.ones(A.shape)*289) # torch.testing.assert_close(A, torch.ones(A.shape)*289)
@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4']) @pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("double_quant", [False], ids=['DQ_True']) @pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"])
def test_gemv_eye_4bit(storage_type, dtype, double_quant): def test_gemv_eye_4bit(storage_type, dtype, double_quant):
dims = 10 dims = 10
torch.random.manual_seed(np.random.randint(0, 412424242)) torch.random.manual_seed(np.random.randint(0, 412424242))
dims = get_test_dims(0, 8192, n=dims) dims = get_test_dims(0, 8192, n=dims)
dims = [dim + (64-(dim % 64)) for dim in dims] dims = [dim + (64 - (dim % 64)) for dim in dims]
#for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]: # for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
for dim in dims: for dim in dims:
A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device='cuda') A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device="cuda")
B = torch.eye(dim, dtype=dtype, device='cuda') B = torch.eye(dim, dtype=dtype, device="cuda")
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant) qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
C3 = torch.matmul(A, B.t()) C3 = torch.matmul(A, B.t())
...@@ -2343,5 +2250,5 @@ def test_gemv_eye_4bit(storage_type, dtype, double_quant): ...@@ -2343,5 +2250,5 @@ def test_gemv_eye_4bit(storage_type, dtype, double_quant):
torch.testing.assert_close(A, C3) torch.testing.assert_close(A, C3)
torch.testing.assert_close(A, C1) torch.testing.assert_close(A, C1)
torch.testing.assert_close(A, C2) torch.testing.assert_close(A, C2)
#torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001) # torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001)
#torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080) # torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080)
...@@ -10,56 +10,61 @@ transformers = pytest.importorskip("transformers") ...@@ -10,56 +10,61 @@ transformers = pytest.importorskip("transformers")
def get_4bit_config(): def get_4bit_config():
return transformers.BitsAndBytesConfig( return transformers.BitsAndBytesConfig(
load_in_4bit=True, load_in_4bit=True,
load_in_8bit=False, load_in_8bit=False,
llm_int8_threshold=6.0, llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False, llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16, bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True, bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4', bnb_4bit_quant_type="nf4",
) )
def get_model_and_tokenizer(config): def get_model_and_tokenizer(config):
model_name_or_path, quant_type = config model_name_or_path, quant_type = config
bnb_config = get_4bit_config() bnb_config = get_4bit_config()
if quant_type == '16bit': if quant_type == "16bit":
bnb_config.load_in_4bit = False bnb_config.load_in_4bit = False
else: else:
bnb_config.bnb_4bit_quant_type= quant_type bnb_config.bnb_4bit_quant_type = quant_type
model = transformers.AutoModelForCausalLM.from_pretrained(model_name_or_path, model = transformers.AutoModelForCausalLM.from_pretrained(
model_name_or_path,
quantization_config=bnb_config, quantization_config=bnb_config,
max_memory={0:'48GB'}, max_memory={0: "48GB"},
device_map='auto', device_map="auto",
torch_dtype=torch.bfloat16 torch_dtype=torch.bfloat16,
).eval() ).eval()
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path) tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path)
return model, tokenizer return model, tokenizer
def get_prompt_for_generation_eval(text, add_roles=True): def get_prompt_for_generation_eval(text, add_roles=True):
description = ( description = (
"A chat between a curious human and an artificial intelligence assistant. " "A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions." "The assistant gives helpful, detailed, and polite answers to the user's questions."
) )
if add_roles: if add_roles:
prompt = f'{description} ### Human: {text} ### Assistant:' prompt = f"{description} ### Human: {text} ### Assistant:"
else: else:
prompt = f'{description} {text}' prompt = f"{description} {text}"
return prompt return prompt
def generate(model, tokenizer, text, generation_config, prompt_func=get_prompt_for_generation_eval): def generate(model, tokenizer, text, generation_config, prompt_func=get_prompt_for_generation_eval):
text = prompt_func(text) text = prompt_func(text)
inputs = tokenizer(text, return_tensors="pt").to('cuda:0') inputs = tokenizer(text, return_tensors="pt").to("cuda:0")
outputs = model.generate(inputs=inputs['input_ids'], generation_config=generation_config) outputs = model.generate(inputs=inputs["input_ids"], generation_config=generation_config)
return tokenizer.decode(outputs[0], skip_special_tokens=True) return tokenizer.decode(outputs[0], skip_special_tokens=True)
models = ['huggyllama/llama-7b', 'bigscience/bloom-1b7']
dtypes = ['nf4', 'fp4']
@pytest.fixture(scope='session', params=product(models, dtypes)) models = ["huggyllama/llama-7b", "bigscience/bloom-1b7"]
dtypes = ["nf4", "fp4"]
@pytest.fixture(scope="session", params=product(models, dtypes))
def model_and_tokenizer(request): def model_and_tokenizer(request):
model, tokenizer = get_model_and_tokenizer(request.param) model, tokenizer = get_model_and_tokenizer(request.param)
yield request.param, model, tokenizer yield request.param, model, tokenizer
...@@ -81,20 +86,19 @@ def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ, dtype): ...@@ -81,20 +86,19 @@ def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ, dtype):
) )
generation_config.max_new_tokens = 20 generation_config.max_new_tokens = 20
# text = 'Please write down the first 50 digits of pi.'
#text = 'Please write down the first 50 digits of pi.' # text = get_prompt_for_generation_eval(text)
#text = get_prompt_for_generation_eval(text) # text += ' Sure, here the first 50 digits of pi: 3.14159'
#text += ' Sure, here the first 50 digits of pi: 3.14159'
n_cases = 6 n_cases = 6
text = '3.14159' text = "3.14159"
if hasattr(model.config, 'quantization_config'): if hasattr(model.config, "quantization_config"):
model.config.quantization_config.bnb_4bit_compute_dtype = dtype model.config.quantization_config.bnb_4bit_compute_dtype = dtype
model.config.quantization_config.bnb_4bit_use_double_quant = DQ model.config.quantization_config.bnb_4bit_use_double_quant = DQ
if not inference_kernel: if not inference_kernel:
text = [text]*n_cases text = [text] * n_cases
inputs = tokenizer(text, return_tensors="pt").to('cuda:0') inputs = tokenizer(text, return_tensors="pt").to("cuda:0")
x = inputs['input_ids'] x = inputs["input_ids"]
outputs = [] outputs = []
if inference_kernel: if inference_kernel:
for i in range(n_cases): for i in range(n_cases):
...@@ -105,15 +109,14 @@ def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ, dtype): ...@@ -105,15 +109,14 @@ def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ, dtype):
outputs = model.generate(x, generation_config=generation_config) outputs = model.generate(x, generation_config=generation_config)
outputs = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs] outputs = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
assert len(outputs) == n_cases assert len(outputs) == n_cases
failure_count = 0 failure_count = 0
for i in range(n_cases): for i in range(n_cases):
if not outputs[i][:len(str(math.pi))] == str(math.pi): if not outputs[i][: len(str(math.pi))] == str(math.pi):
failure_count += 1 failure_count += 1
failure_max = (2 if fixture_config[0] == 'huggyllama/llama-7b' else 4) failure_max = 2 if fixture_config[0] == "huggyllama/llama-7b" else 4
if failure_count > failure_max: if failure_count > failure_max:
print(math.pi) print(math.pi)
for out in outputs: for out in outputs:
print(out) print(out)
raise ValueError(f'Failure count: {failure_count}/{n_cases}') raise ValueError(f"Failure count: {failure_count}/{n_cases}")
...@@ -28,9 +28,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora ...@@ -28,9 +28,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
device = "cuda" device = "cuda"
layer_shape = (300, 400) layer_shape = (300, 400)
linear = torch.nn.Linear( linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu") # original layer
*layer_shape, dtype=original_dtype, device="cpu"
) # original layer
# Quantizing original layer # Quantizing original layer
linear_q = bnb.nn.Linear4bit( linear_q = bnb.nn.Linear4bit(
...@@ -42,9 +40,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora ...@@ -42,9 +40,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
quant_type=quant_type, quant_type=quant_type,
device="meta", device="meta",
) )
new_weight = bnb.nn.Params4bit( new_weight = bnb.nn.Params4bit(data=linear.weight, quant_type=quant_type, requires_grad=False)
data=linear.weight, quant_type=quant_type, requires_grad=False
)
linear_q.weight = new_weight linear_q.weight = new_weight
if bias: if bias:
linear_q.bias = torch.nn.Parameter(linear.bias) linear_q.bias = torch.nn.Parameter(linear.bias)
...@@ -172,7 +168,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora ...@@ -172,7 +168,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
target_compression = ( target_compression = (
0.143 if original_dtype == torch.float32 else 0.29 0.143 if original_dtype == torch.float32 else 0.29
) # these numbers get lower as weight shape increases ) # these numbers get lower as weight shape increases
ratio_error_msg = f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}" ratio_error_msg = (
f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}"
)
assert size_ratio < target_compression, ratio_error_msg assert size_ratio < target_compression, ratio_error_msg
......
...@@ -19,6 +19,7 @@ from tests.helpers import ( ...@@ -19,6 +19,7 @@ from tests.helpers import (
# contributed by Alex Borzunov, see: # contributed by Alex Borzunov, see:
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py # https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
@pytest.mark.skipif( @pytest.mark.skipif(
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5), not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5),
reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs", reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs",
...@@ -50,7 +51,9 @@ def test_linear_no_igemmlt(): ...@@ -50,7 +51,9 @@ def test_linear_no_igemmlt():
linear_custom.state.force_no_igemmlt = True linear_custom.state.force_no_igemmlt = True
linear_custom.weight = bnb.nn.Int8Params( linear_custom.weight = bnb.nn.Int8Params(
linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False linear.weight.data.clone(),
requires_grad=False,
has_fp16_weights=False,
).to(linear.weight.dtype) ).to(linear.weight.dtype)
linear_custom.bias = linear.bias linear_custom.bias = linear.bias
linear_custom = linear_custom.cuda() linear_custom = linear_custom.cuda()
...@@ -77,7 +80,14 @@ def test_linear_no_igemmlt(): ...@@ -77,7 +80,14 @@ def test_linear_no_igemmlt():
@pytest.mark.parametrize("force_no_igemmlt", TRUE_FALSE, ids=id_formatter("force_no_igemmlt")) @pytest.mark.parametrize("force_no_igemmlt", TRUE_FALSE, ids=id_formatter("force_no_igemmlt"))
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward")) @pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward"))
@pytest.mark.parametrize("load_before_cuda", TRUE_FALSE, ids=id_formatter("load_before_cuda")) @pytest.mark.parametrize("load_before_cuda", TRUE_FALSE, ids=id_formatter("load_before_cuda"))
def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt, save_before_forward, load_before_cuda): def test_linear_serialization(
has_fp16_weights,
serialize_before_forward,
deserialize_before_cuda,
force_no_igemmlt,
save_before_forward,
load_before_cuda,
):
linear = torch.nn.Linear(32, 96) linear = torch.nn.Linear(32, 96)
x = torch.randn(3, 32, dtype=torch.half) x = torch.randn(3, 32, dtype=torch.half)
...@@ -92,7 +102,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri ...@@ -92,7 +102,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
linear_custom.state.force_no_igemmlt = True linear_custom.state.force_no_igemmlt = True
linear_custom.weight = bnb.nn.Int8Params( linear_custom.weight = bnb.nn.Int8Params(
linear.weight.data.clone(), requires_grad=has_fp16_weights, has_fp16_weights=has_fp16_weights linear.weight.data.clone(),
requires_grad=has_fp16_weights,
has_fp16_weights=has_fp16_weights,
) )
linear_custom.bias = linear.bias linear_custom.bias = linear.bias
linear_custom = linear_custom.cuda() linear_custom = linear_custom.cuda()
......
...@@ -19,12 +19,18 @@ class MLP8bit(torch.nn.Module): ...@@ -19,12 +19,18 @@ class MLP8bit(torch.nn.Module):
def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0): def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0):
super().__init__() super().__init__()
self.fc1 = bnb.nn.Linear8bitLt( self.fc1 = bnb.nn.Linear8bitLt(
dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward, dim1,
threshold=threshold dim2,
has_fp16_weights=has_fp16_weights,
memory_efficient_backward=memory_efficient_backward,
threshold=threshold,
) )
self.fc2 = bnb.nn.Linear8bitLt( self.fc2 = bnb.nn.Linear8bitLt(
dim2, dim1, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward, dim2,
threshold=threshold dim1,
has_fp16_weights=has_fp16_weights,
memory_efficient_backward=memory_efficient_backward,
threshold=threshold,
) )
def forward(self, x): def forward(self, x):
...@@ -52,9 +58,7 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10): ...@@ -52,9 +58,7 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
class LinearFunction(torch.autograd.Function): class LinearFunction(torch.autograd.Function):
@staticmethod @staticmethod
def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0): def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0):
round_func = ( round_func = LinearFunction.round_stoachastic if stochastic else torch.round
LinearFunction.round_stoachastic if stochastic else torch.round
)
norm = math.sqrt(math.pi) / math.sqrt(2.0) norm = math.sqrt(math.pi) / math.sqrt(2.0)
# std = torch.abs(x).mean()*norm # std = torch.abs(x).mean()*norm
std = torch.std(x) std = torch.std(x)
...@@ -122,9 +126,7 @@ class LinearFunction(torch.autograd.Function): ...@@ -122,9 +126,7 @@ class LinearFunction(torch.autograd.Function):
return x.to(dtype) return x.to(dtype)
def get_8bit_linear(x, stochastic=False): def get_8bit_linear(x, stochastic=False):
round_func = ( round_func = LinearFunction.round_stoachastic if stochastic else torch.round
LinearFunction.round_stoachastic if stochastic else torch.round
)
max1 = torch.abs(x).max() max1 = torch.abs(x).max()
x = x / max1 * 127 x = x / max1 * 127
x = round_func(x) / 127 * max1 x = round_func(x) / 127 * max1
...@@ -133,9 +135,7 @@ class LinearFunction(torch.autograd.Function): ...@@ -133,9 +135,7 @@ class LinearFunction(torch.autograd.Function):
@staticmethod @staticmethod
def get_8bit_vector_wise(x, dim, stochastic=False): def get_8bit_vector_wise(x, dim, stochastic=False):
round_func = ( round_func = LinearFunction.round_stoachastic if stochastic else torch.round
LinearFunction.round_stoachastic if stochastic else torch.round
)
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
max1[max1 == 0] = 1.0 max1[max1 == 0] = 1.0
x = (x * 127) / max1 x = (x * 127) / max1
...@@ -219,9 +219,7 @@ class LinearFunction(torch.autograd.Function): ...@@ -219,9 +219,7 @@ class LinearFunction(torch.autograd.Function):
weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1) weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1)
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2) x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2)
outputq = bnb.functional.igemm(x8, weight8.t()) outputq = bnb.functional.igemm(x8, weight8.t())
output = LinearFunction.dequant( output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type)
outputq, S1, S2, x.dtype, args.quant_type
)
# if torch.rand(1) < 0.01: # if torch.rand(1) < 0.01:
# output32 = torch.matmul(x, weight.t()) # output32 = torch.matmul(x, weight.t())
# err = torch.abs(output-output32).float() # err = torch.abs(output-output32).float()
...@@ -250,37 +248,25 @@ class LinearFunction(torch.autograd.Function): ...@@ -250,37 +248,25 @@ class LinearFunction(torch.autograd.Function):
# weight and x are already 8bit # weight and x are already 8bit
# -> transform grad_output to 8-bit # -> transform grad_output to 8-bit
if args.use_8bit_training == "forward+wgrad": if args.use_8bit_training == "forward+wgrad":
grad_output8, S1 = LinearFunction.quant( grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
grad_output, args.quant_type, dim=[0, 1]
)
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1]) x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
grad_weight8 = bnb.functional.igemm(grad_output8, x8) grad_weight8 = bnb.functional.igemm(grad_output8, x8)
grad_weight = LinearFunction.dequant( grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)
grad_weight8, S1, S2, grad_output.dtype, args.quant_type
)
# grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x) # grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
grad_input = grad_output.matmul(weight) grad_input = grad_output.matmul(weight)
elif args.use_8bit_training == "full": elif args.use_8bit_training == "full":
grad_output8, S1 = LinearFunction.quant( grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
grad_output, args.quant_type, dim=[0, 1]
)
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1]) x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
grad_weight8 = torch.zeros_like(weight, dtype=torch.int32) grad_weight8 = torch.zeros_like(weight, dtype=torch.int32)
bnb.functional.igemm(grad_output8, x8, out=grad_weight8) bnb.functional.igemm(grad_output8, x8, out=grad_weight8)
grad_weight = LinearFunction.dequant( grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)
grad_weight8, S1, S2, grad_output.dtype, args.quant_type
)
grad_output8, S1 = LinearFunction.quant( grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2)
grad_output, args.quant_type, dim=2
)
weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0) weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0)
grad_input8 = bnb.functional.igemm(grad_output8, weight8) grad_input8 = bnb.functional.igemm(grad_output8, weight8)
grad_input = LinearFunction.dequant( grad_input = LinearFunction.dequant(grad_input8, S1, S3, grad_output.dtype, args.quant_type)
grad_input8, S1, S3, grad_output.dtype, args.quant_type
)
else: else:
grad_input = grad_output.matmul(weight) grad_input = grad_output.matmul(weight)
...@@ -356,12 +342,8 @@ def test_linear8bitlt_accumulated_gradient(): ...@@ -356,12 +342,8 @@ def test_linear8bitlt_accumulated_gradient():
opt1.zero_grad(True) opt1.zero_grad(True)
opt2.step() opt2.step()
opt2.zero_grad(True) opt2.zero_grad(True)
assert_all_approx_close( assert_all_approx_close(l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2)
l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2 assert_all_approx_close(l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2)
)
assert_all_approx_close(
l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2
)
# we do this copy because otherwise we have small divergences over time that add up # we do this copy because otherwise we have small divergences over time that add up
l1[0].weight.data.copy_(l2[0].weight.data) l1[0].weight.data.copy_(l2[0].weight.data)
l1[1].weight.data.copy_(l2[1].weight.data) l1[1].weight.data.copy_(l2[1].weight.data)
...@@ -375,7 +357,17 @@ def test_linear8bitlt_accumulated_gradient(): ...@@ -375,7 +357,17 @@ def test_linear8bitlt_accumulated_gradient():
@pytest.mark.parametrize("threshold", [0.0, 2.0]) @pytest.mark.parametrize("threshold", [0.0, 2.0])
@pytest.mark.parametrize("memory_efficient_backward", [False]) @pytest.mark.parametrize("memory_efficient_backward", [False])
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
l1 = (bnb.nn.Linear8bitLt( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).cuda().half()) l1 = (
bnb.nn.Linear8bitLt(
32,
64,
threshold=threshold,
has_fp16_weights=False,
memory_efficient_backward=memory_efficient_backward,
)
.cuda()
.half()
)
assert l1.weight.dtype == torch.int8 assert l1.weight.dtype == torch.int8
l1.eval() l1.eval()
...@@ -397,11 +389,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): ...@@ -397,11 +389,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
if threshold > 0: if threshold > 0:
assert mlp.fc2.state.idx is not None assert mlp.fc2.state.idx is not None
mlp = ( mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half()
MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
.cuda()
.half()
)
assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8
...@@ -414,11 +402,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): ...@@ -414,11 +402,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
if threshold > 0: if threshold > 0:
assert mlp.fc2.state.idx is not None assert mlp.fc2.state.idx is not None
mlp = ( mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda()
MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
.half()
.cuda()
)
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half() b1 = torch.randn(16, 8, 32, device="cuda").half()
...@@ -431,7 +415,17 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): ...@@ -431,7 +415,17 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8
mlp = ( MLP8bit( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).half().to("cuda")) mlp = (
MLP8bit(
32,
64,
threshold=threshold,
has_fp16_weights=False,
memory_efficient_backward=memory_efficient_backward,
)
.half()
.to("cuda")
)
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half() b1 = torch.randn(16, 8, 32, device="cuda").half()
...@@ -447,8 +441,12 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): ...@@ -447,8 +441,12 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
assert mlp.fc2.weight.device.type == "cuda" assert mlp.fc2.weight.device.type == "cuda"
mlp = MLP8bit( mlp = MLP8bit(
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward 32,
) 64,
threshold=threshold,
has_fp16_weights=False,
memory_efficient_backward=memory_efficient_backward,
)
w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda() # grab weights before quantization, w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda() # grab weights before quantization,
mlp = mlp.cuda().half() # and this line triggers quantization mlp = mlp.cuda().half() # and this line triggers quantization
...@@ -489,7 +487,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): ...@@ -489,7 +487,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
lambda n_in, n_out, bias=True: bnb.nn.Linear8bitLt(n_in, n_out, bias=bias, has_fp16_weights=False), lambda n_in, n_out, bias=True: bnb.nn.Linear8bitLt(n_in, n_out, bias=bias, has_fp16_weights=False),
bnb.nn.LinearFP4, bnb.nn.LinearFP4,
], ],
ids=['Int8Lt', 'FP4'], ids=["Int8Lt", "FP4"],
) )
def test_linear_kbit_fp32_bias(module): def test_linear_kbit_fp32_bias(module):
# casts model to fp16 -> int8 automatically # casts model to fp16 -> int8 automatically
...@@ -544,7 +542,7 @@ def test_kbit_backprop(module): ...@@ -544,7 +542,7 @@ def test_kbit_backprop(module):
kbit[1].bias.detach().copy_(ref[1].bias) kbit[1].bias.detach().copy_(ref[1].bias)
ref = ref.half().cuda() ref = ref.half().cuda()
kbit = kbit.half().cuda() kbit = kbit.half().cuda()
kbit = kbit.half().to('cuda') kbit = kbit.half().to("cuda")
errs1 = [] errs1 = []
errs2 = [] errs2 = []
...@@ -562,10 +560,10 @@ def test_kbit_backprop(module): ...@@ -562,10 +560,10 @@ def test_kbit_backprop(module):
bgrad1 = ref[0].bias.grad bgrad1 = ref[0].bias.grad
bgrad2 = kbit[0].bias.grad bgrad2 = kbit[0].bias.grad
err1 = (out1-out2).abs().float() err1 = (out1 - out2).abs().float()
err2 = (grad1-grad2).abs().float() err2 = (grad1 - grad2).abs().float()
relerr1 = (err1/(out1.abs().float()+1e-9)) relerr1 = err1 / (out1.abs().float() + 1e-9)
relerr2 = (err2/(grad1.abs().float()+1e-9)) relerr2 = err2 / (grad1.abs().float() + 1e-9)
errs1.append(err1.mean().item()) errs1.append(err1.mean().item())
errs2.append(err2.mean().item()) errs2.append(err2.mean().item())
relerrs1.append(relerr1.mean().item()) relerrs1.append(relerr1.mean().item())
...@@ -582,20 +580,20 @@ def test_kbit_backprop(module): ...@@ -582,20 +580,20 @@ def test_kbit_backprop(module):
assert kbit[0].weight.grad is None or kbit[0].weight.grad.sum().item() == 0 assert kbit[0].weight.grad is None or kbit[0].weight.grad.sum().item() == 0
assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0 assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0
#print('out', sum(errs1)/len(errs1)) # print('out', sum(errs1)/len(errs1))
#print('grad', sum(errs2)/len(errs2)) # print('grad', sum(errs2)/len(errs2))
#print('rel out', sum(relerrs1)/len(relerrs1)) # print('rel out', sum(relerrs1)/len(relerrs1))
#print('rel grad', sum(relerrs2)/len(relerrs2)) # print('rel grad', sum(relerrs2)/len(relerrs2))
def test_fp8linear():
def test_fp8linear():
b = 10 b = 10
h = 1024 h = 1024
inp = torch.randn(b, h).cuda() inp = torch.randn(b, h).cuda()
fp32 = torch.nn.Linear(h, h*2).cuda() fp32 = torch.nn.Linear(h, h * 2).cuda()
fp8 = bnb.research.nn.LinearFP8Mixed(h, h*2).cuda() fp8 = bnb.research.nn.LinearFP8Mixed(h, h * 2).cuda()
fp32b = torch.nn.Linear(h*2, h).cuda() fp32b = torch.nn.Linear(h * 2, h).cuda()
fp8b = bnb.research.nn.LinearFP8Mixed(h*2, h).cuda() fp8b = bnb.research.nn.LinearFP8Mixed(h * 2, h).cuda()
fp8.weight.data.copy_(fp32.weight.data) fp8.weight.data.copy_(fp32.weight.data)
fp8.bias.data.copy_(fp32.bias.data) fp8.bias.data.copy_(fp32.bias.data)
...@@ -605,34 +603,34 @@ def test_fp8linear(): ...@@ -605,34 +603,34 @@ def test_fp8linear():
a = fp32b(torch.nn.functional.gelu(fp32(inp))) a = fp32b(torch.nn.functional.gelu(fp32(inp)))
b = fp8b(torch.nn.functional.gelu(fp8(inp))) b = fp8b(torch.nn.functional.gelu(fp8(inp)))
err = (a-b).abs().mean() err = (a - b).abs().mean()
a.mean().backward() a.mean().backward()
b.mean().backward() b.mean().backward()
graderr = (fp8.weight.grad-fp32.weight.grad).abs().mean() graderr = (fp8.weight.grad - fp32.weight.grad).abs().mean()
bgraderr = (fp8.bias.grad-fp32.bias.grad).abs().mean() bgraderr = (fp8.bias.grad - fp32.bias.grad).abs().mean()
assert err < 0.05 assert err < 0.05
assert graderr < 0.00002 assert graderr < 0.00002
assert bgraderr < 0.00002 assert bgraderr < 0.00002
def test_4bit_warnings(): def test_4bit_warnings():
dim1 = 64 dim1 = 64
with pytest.warns(UserWarning, match=r'inference or training'): with pytest.warns(UserWarning, match=r"inference or training"):
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)]) net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
net = net.cuda() net = net.cuda()
inp = torch.rand(10, dim1).cuda().half() inp = torch.rand(10, dim1).cuda().half()
net(inp) net(inp)
with pytest.warns(UserWarning, match=r'inference.'): with pytest.warns(UserWarning, match=r"inference."):
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)]) net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
net = net.cuda() net = net.cuda()
inp = torch.rand(1, dim1).cuda().half() inp = torch.rand(1, dim1).cuda().half()
net(inp) net(inp)
with pytest.warns(UserWarning) as record: with pytest.warns(UserWarning) as record:
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)]) net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
net = net.cuda() net = net.cuda()
inp = torch.rand(10, dim1).cuda().half() inp = torch.rand(10, dim1).cuda().half()
......
...@@ -16,6 +16,7 @@ from tests.helpers import describe_dtype, id_formatter ...@@ -16,6 +16,7 @@ from tests.helpers import describe_dtype, id_formatter
k = 20 k = 20
def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0): def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0):
idx = torch.isclose(a, b, rtol=rtol, atol=atol) idx = torch.isclose(a, b, rtol=rtol, atol=atol)
error_count = (idx == 0).sum().item() error_count = (idx == 0).sum().item()
...@@ -33,6 +34,7 @@ def get_temp_dir(): ...@@ -33,6 +34,7 @@ def get_temp_dir():
def rm_path(path): def rm_path(path):
shutil.rmtree(path) shutil.rmtree(path)
str2optimizers = {} str2optimizers = {}
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam) str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion) str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion)
...@@ -66,8 +68,14 @@ str2optimizers["rmsprop8bit"] = ( ...@@ -66,8 +68,14 @@ str2optimizers["rmsprop8bit"] = (
) )
str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True)) str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
str2optimizers["paged_adamw8bit_blockwise"] = (torch.optim.AdamW, lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True)) str2optimizers["paged_adamw8bit_blockwise"] = (
str2optimizers["paged_adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True)) torch.optim.AdamW,
lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True),
)
str2optimizers["paged_adam8bit_blockwise"] = (
torch.optim.Adam,
lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True),
)
str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True)) str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True))
str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True)) str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True))
str2optimizers["momentum8bit_blockwise"] = ( str2optimizers["momentum8bit_blockwise"] = (
...@@ -90,9 +98,18 @@ str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] ...@@ -90,9 +98,18 @@ str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["rmsprop"] = [("square_avg", "state1")] str2statenames["rmsprop"] = [("square_avg", "state1")]
str2statenames["adam8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")] str2statenames["adam8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
str2statenames["lamb8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")] str2statenames["lamb8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
str2statenames["adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")] str2statenames["adam8bit_blockwise"] = [
str2statenames["paged_adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")] ("exp_avg", "state1", "qmap1", "absmax1"),
str2statenames["paged_adamw8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")] ("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
str2statenames["paged_adam8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1"),
("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
str2statenames["paged_adamw8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1"),
("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")] str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
str2statenames["lion8bit"] = [("exp_avg", "state1", "qmap1", "max1")] str2statenames["lion8bit"] = [("exp_avg", "state1", "qmap1", "max1")]
str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")] str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")]
...@@ -101,7 +118,7 @@ str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "ab ...@@ -101,7 +118,7 @@ str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "ab
str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")] str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")] str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
optimizer_names_32bit = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion', 'paged_lion'] optimizer_names_32bit = ["adam", "momentum", "rmsprop", "paged_adamw", "paged_adam", "lion", "paged_lion"]
@pytest.mark.parametrize("optim_name", optimizer_names_32bit, ids=id_formatter("opt")) @pytest.mark.parametrize("optim_name", optimizer_names_32bit, ids=id_formatter("opt"))
...@@ -109,7 +126,7 @@ optimizer_names_32bit = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_ad ...@@ -109,7 +126,7 @@ optimizer_names_32bit = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_ad
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2"))
def test_optimizer32bit(dim1, dim2, gtype, optim_name): def test_optimizer32bit(dim1, dim2, gtype, optim_name):
if gtype == torch.bfloat16 and optim_name in ['momentum', 'rmsprop']: if gtype == torch.bfloat16 and optim_name in ["momentum", "rmsprop"]:
pytest.skip() pytest.skip()
if dim1 == 1 and dim2 == 1: if dim1 == 1 and dim2 == 1:
return return
...@@ -161,9 +178,13 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): ...@@ -161,9 +178,13 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
for name1, name2 in str2statenames[optim_name]: for name1, name2 in str2statenames[optim_name]:
# since Lion can have pretty noisy updates where things lie at the boundary # since Lion can have pretty noisy updates where things lie at the boundary
# allow up to 10 errors for Lion # allow up to 10 errors for Lion
assert_most_approx_close(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2], assert_most_approx_close(
atol=atol, rtol=rtol, torch_optimizer.state[p1][name1],
max_error_count=10) bnb_optimizer.state[p2][name2],
atol=atol,
rtol=rtol,
max_error_count=10,
)
if gtype != torch.float32: if gtype != torch.float32:
# the adam buffers should also be close because they are 32-bit # the adam buffers should also be close because they are 32-bit
...@@ -193,13 +214,9 @@ def test_global_config(dim1, dim2, gtype): ...@@ -193,13 +214,9 @@ def test_global_config(dim1, dim2, gtype):
eps = 1e-8 eps = 1e-8
bnb.optim.GlobalOptimManager.get_instance().initialize() bnb.optim.GlobalOptimManager.get_instance().initialize()
bnb.optim.GlobalOptimManager.get_instance().override_config( bnb.optim.GlobalOptimManager.get_instance().override_config(p3, "optim_bits", 8)
p3, "optim_bits", 8
)
bnb.optim.GlobalOptimManager.get_instance().register_parameters( bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3])
[p1, p2, p3]
)
p1 = p1.cuda() p1 = p1.cuda()
p2 = p2.cuda() p2 = p2.cuda()
p3 = p3.cuda() p3 = p3.cuda()
...@@ -242,7 +259,8 @@ optimizer_names_8bit = [ ...@@ -242,7 +259,8 @@ optimizer_names_8bit = [
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
def test_optimizer8bit(dim1, dim2, gtype, optim_name): def test_optimizer8bit(dim1, dim2, gtype, optim_name):
if gtype == torch.bfloat16 and optim_name not in ['adam8bit_blockwise', 'lion8bit_blockwise']: pytest.skip() if gtype == torch.bfloat16 and optim_name not in ["adam8bit_blockwise", "lion8bit_blockwise"]:
pytest.skip()
if dim1 == 1 and dim2 == 1: if dim1 == 1 and dim2 == 1:
return return
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
...@@ -294,17 +312,12 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ...@@ -294,17 +312,12 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
absmax=bnb_optimizer.state[p2][max_val], absmax=bnb_optimizer.state[p2][max_val],
A=bnb_optimizer.state[p2][name2], A=bnb_optimizer.state[p2][name2],
) )
num_not_close = ( num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0
torch.isclose( # assert num_not_close.sum().item() < 20
torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol
)
== 0
)
#assert num_not_close.sum().item() < 20
dequant_states.append(s1.clone()) dequant_states.append(s1.clone())
err = torch.abs(p1 - p2) err = torch.abs(p1 - p2)
relerr = err / (torch.abs(p1)+1e-9) relerr = err / (torch.abs(p1) + 1e-9)
if g.dtype == torch.bfloat16: if g.dtype == torch.bfloat16:
assert err.mean() < 0.00015 assert err.mean() < 0.00015
assert relerr.mean() < 0.0016 assert relerr.mean() < 0.0016
...@@ -316,9 +329,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ...@@ -316,9 +329,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
relerrors.append(relerr.mean().item()) relerrors.append(relerr.mean().item())
if i % 10 == 0 and i > 0: if i % 10 == 0 and i > 0:
for (name1, name2, qmap, max_val), s in zip( for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
str2statenames[optim_name], dequant_states
):
s1cpy = s.clone() s1cpy = s.clone()
raws1cpy = bnb_optimizer.state[p2][name2].clone() raws1cpy = bnb_optimizer.state[p2][name2].clone()
qmap1 = bnb_optimizer.state[p2][qmap].clone() qmap1 = bnb_optimizer.state[p2][qmap].clone()
...@@ -348,7 +359,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ...@@ -348,7 +359,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
) )
torch.testing.assert_close(s1cpy, s1) torch.testing.assert_close(s1cpy, s1)
num_not_close = (torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0) num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0
assert num_not_close.sum().item() < 20 assert num_not_close.sum().item() < 20
# since Lion can have pretty noisy updates where things lie at the boundary # since Lion can have pretty noisy updates where things lie at the boundary
# allow up to 5 errors for Lion # allow up to 5 errors for Lion
...@@ -395,15 +406,11 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): ...@@ -395,15 +406,11 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
for i in range(50): for i in range(50):
step += 1 step += 1
g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + ( g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (0.01 * i)
0.01 * i
)
g2 = g1.clone() g2 = g1.clone()
p2.grad = g2 p2.grad = g2
current_gnorm, clip_val, gnorm_scale = F.percentile_clipping( current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(g1, gnorm_vec, step, 5)
g1, gnorm_vec, step, 5
)
g1 = (g1.float() * gnorm_scale).to(gtype) g1 = (g1.float() * gnorm_scale).to(gtype)
p1.grad = g1 p1.grad = g1
...@@ -497,8 +504,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): ...@@ -497,8 +504,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
@pytest.mark.parametrize("dim1", [2 * 1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [2 * 1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("gtype", [torch.float16], ids=describe_dtype) @pytest.mark.parametrize("gtype", [torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("optim_name", ['paged_adamw'], ids=id_formatter("optim_name")) @pytest.mark.parametrize("optim_name", ["paged_adamw"], ids=id_formatter("optim_name"))
@pytest.mark.parametrize("mode", ['bnb'], ids=id_formatter("mode")) @pytest.mark.parametrize("mode", ["bnb"], ids=id_formatter("mode"))
@pytest.mark.benchmark @pytest.mark.benchmark
def test_stream_optimizer_bench(dim1, gtype, optim_name, mode): def test_stream_optimizer_bench(dim1, gtype, optim_name, mode):
layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)])) layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)]))
...@@ -506,24 +513,24 @@ def test_stream_optimizer_bench(dim1, gtype, optim_name, mode): ...@@ -506,24 +513,24 @@ def test_stream_optimizer_bench(dim1, gtype, optim_name, mode):
layers1 = layers1.cuda() layers1 = layers1.cuda()
large_tensor = None large_tensor = None
if mode == 'torch': if mode == "torch":
optim = str2optimizers[optim_name][0](layers1.parameters()) optim = str2optimizers[optim_name][0](layers1.parameters())
else: else:
optim = str2optimizers[optim_name][1](layers1.parameters()) optim = str2optimizers[optim_name][1](layers1.parameters())
# 12 GB # 12 GB
large_tensor = torch.empty((int(4.5e9),), device='cuda') large_tensor = torch.empty((int(4.5e9),), device="cuda")
torch.cuda.synchronize() torch.cuda.synchronize()
time.sleep(5) time.sleep(5)
num_batches = 5 num_batches = 5
batches = torch.randn(num_batches, 128, dim1, device='cuda').to(gtype) batches = torch.randn(num_batches, 128, dim1, device="cuda").to(gtype)
lbls = torch.randint(0, 10, size=(num_batches,128)).cuda() lbls = torch.randint(0, 10, size=(num_batches, 128)).cuda()
for i in range(num_batches): for i in range(num_batches):
print(i) print(i)
b = batches[i] b = batches[i]
if i ==2: if i == 2:
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment