Unverified Commit 06029dd6 authored by Titus's avatar Titus Committed by GitHub
Browse files

Merge pull request #1081 from akx/ruff-format

Reformat Python code with Ruff
parents fd723b78 5a4263f4
......@@ -5,9 +5,10 @@ import torch
from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available():
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): return None
else:
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
return None
else:
import triton
import triton.language as tl
......@@ -15,21 +16,21 @@ else:
# TODO: autotune this better.
@triton.autotune(
configs=[
triton.Config({}, num_stages=1, num_warps=8),
triton.Config({}, num_stages=2, num_warps=8),
triton.Config({}, num_stages=4, num_warps=8),
triton.Config({}, num_stages=8, num_warps=8),
triton.Config({}, num_stages=1),
triton.Config({}, num_stages=2),
triton.Config({}, num_stages=4),
triton.Config({}, num_stages=8),
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=['n_elements']
configs=[
triton.Config({}, num_stages=1, num_warps=8),
triton.Config({}, num_stages=2, num_warps=8),
triton.Config({}, num_stages=4, num_warps=8),
triton.Config({}, num_stages=8, num_warps=8),
triton.Config({}, num_stages=1),
triton.Config({}, num_stages=2),
triton.Config({}, num_stages=4),
triton.Config({}, num_stages=8),
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=["n_elements"],
)
@triton.jit
def _dequantize_rowwise(
......@@ -51,7 +52,6 @@ else:
output = max_val * x * inv_127
tl.store(output_ptr + offsets, output, mask=row_mask)
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
output = torch.empty(*x.shape, device=x.device, dtype=torch.float16)
......@@ -60,5 +60,5 @@ else:
assert x.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (x.shape[0],)
_dequantize_rowwise[grid](x, state_x, output, 1./127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
_dequantize_rowwise[grid](x, state_x, output, 1.0 / 127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
return output
......@@ -3,14 +3,14 @@ import torch
from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available():
def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): return None
else:
def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias):
return None
else:
import triton
import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# This is a matmul kernel based on triton.ops.matmul
# It is modified to support rowwise quantized input and global quantized weight
# It's purpose is fused matmul then dequantize
......@@ -27,58 +27,83 @@ else:
for block_n in [32, 64, 128, 256]:
num_warps = 2 if block_n <= 64 else 4
configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
num_stages=num_stages, num_warps=num_warps))
triton.Config(
{"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": 1},
num_stages=num_stages,
num_warps=num_warps,
),
)
# split_k
for split_k in [2, 4, 8, 16]:
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
configs.append(
triton.Config(
{"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": split_k},
num_stages=num_stages,
num_warps=num_warps,
pre_hook=init_to_zero("C"),
),
)
return configs
@triton.autotune(
configs=[
# basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2),
# good for int8
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2),
*get_configs_io_bound(),
],
key=['M', 'N', 'K'],
prune_configs_by={
'early_config_prune': early_config_prune,
'perf_model': estimate_matmul_time,
'top_k': 10
key=["M", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10},
)
@triton.heuristics(
{
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
},
)
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.jit
def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr
):
def _int8_matmul_mixed_dequantize(
A,
B,
C,
bias,
state_x_ptr,
state_w_ptr,
M,
N,
K,
divfactor: tl.constexpr,
has_bias: tl.constexpr,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
SPLIT_K: tl.constexpr,
EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
):
# matrix multiplication
pid = tl.program_id(0)
pid_z = tl.program_id(1)
......@@ -115,13 +140,13 @@ else:
b = tl.load(B)
else:
k_remaining = K - k * (BLOCK_K * SPLIT_K)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0)
acc += tl.dot(a, b)
A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk
acc = (w_factor * (x_factor * (acc * divfactor)))
acc = w_factor * (x_factor * (acc * divfactor))
acc = acc.to(C.dtype.element_ty)
# conditionally add bias
......@@ -137,10 +162,9 @@ else:
else:
tl.atomic_add(C, acc, mask=mask)
def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias):
device = a.device
divfactor = 1. / (127. * 127.)
divfactor = 1.0 / (127.0 * 127.0)
has_bias = 0 if bias is None else 1
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1:
......@@ -154,12 +178,28 @@ else:
# allocates output
c = torch.empty((M, N), device=device, dtype=torch.float16)
# accumulator types
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
ACC_TYPE = tl.float32 # if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch int8_matmul_mixed_dequantize kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_int8_matmul_mixed_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
GROUP_M=8, ACC_TYPE=ACC_TYPE)
grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), META["SPLIT_K"])
_int8_matmul_mixed_dequantize[grid](
a,
b,
c,
bias,
state_x,
state_w,
M,
N,
K,
divfactor,
has_bias,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
GROUP_M=8,
ACC_TYPE=ACC_TYPE,
)
return c
......@@ -3,7 +3,9 @@ import torch
from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available():
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): return None
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):
return None
else:
import triton
import triton.language as tl
......@@ -17,7 +19,6 @@ else:
def init_to_zero(name):
return lambda nargs: nargs[name].zero_()
def get_configs_io_bound():
configs = []
for num_stages in [2, 3, 4, 5, 6]:
......@@ -26,58 +27,83 @@ else:
for block_n in [32, 64, 128, 256]:
num_warps = 2 if block_n <= 64 else 4
configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
num_stages=num_stages, num_warps=num_warps))
triton.Config(
{"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": 1},
num_stages=num_stages,
num_warps=num_warps,
),
)
# split_k
for split_k in [2, 4, 8, 16]:
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
configs.append(
triton.Config(
{"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": split_k},
num_stages=num_stages,
num_warps=num_warps,
pre_hook=init_to_zero("C"),
),
)
return configs
@triton.autotune(
configs=[
# basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2),
# good for int8
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2),
*get_configs_io_bound(),
],
key=['M', 'N', 'K'],
prune_configs_by={
'early_config_prune': early_config_prune,
'perf_model': estimate_matmul_time,
'top_k': 10
key=["M", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10},
)
@triton.heuristics(
{
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
},
)
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.jit
def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr
):
def _int8_matmul_rowwise_dequantize(
A,
B,
C,
bias,
state_x_ptr,
state_w_ptr,
M,
N,
K,
divfactor,
has_bias: tl.constexpr,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
SPLIT_K: tl.constexpr,
EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
):
# matrix multiplication
pid = tl.program_id(0)
pid_z = tl.program_id(1)
......@@ -114,13 +140,13 @@ else:
b = tl.load(B)
else:
k_remaining = K - k * (BLOCK_K * SPLIT_K)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0)
acc += tl.dot(a, b)
A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk
acc = (w_factor * (x_factor * (acc * divfactor)))
acc = w_factor * (x_factor * (acc * divfactor))
acc = acc.to(C.dtype.element_ty)
if has_bias:
......@@ -135,9 +161,8 @@ else:
else:
tl.atomic_add(C, acc, mask=mask)
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):
divfactor = 1. / (127. * 127.)
divfactor = 1.0 / (127.0 * 127.0)
has_bias = 0 if bias is None else 1
......@@ -154,12 +179,28 @@ else:
# allocates output
c = torch.empty((M, N), device=device, dtype=torch.float16)
# accumulator types
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
ACC_TYPE = tl.float32 # if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch int8_matmul_rowwise_dequantize kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_int8_matmul_rowwise_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
GROUP_M=8, ACC_TYPE=ACC_TYPE)
grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), META["SPLIT_K"])
_int8_matmul_rowwise_dequantize[grid](
a,
b,
c,
bias,
state_x,
state_w,
M,
N,
K,
divfactor,
has_bias,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
GROUP_M=8,
ACC_TYPE=ACC_TYPE,
)
return c
......@@ -5,9 +5,10 @@ import torch
from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available():
def quantize_columnwise_and_transpose(x: torch.Tensor): return None
else:
def quantize_columnwise_and_transpose(x: torch.Tensor):
return None
else:
import triton
import triton.language as tl
......@@ -15,23 +16,23 @@ else:
# TODO: autotune this better.
@triton.autotune(
configs=[
triton.Config({}, num_stages=1),
triton.Config({}, num_stages=2),
triton.Config({}, num_stages=4),
triton.Config({}, num_stages=8),
triton.Config({}, num_stages=16),
triton.Config({}, num_stages=1, num_warps=8),
triton.Config({}, num_stages=2, num_warps=8),
triton.Config({}, num_stages=4, num_warps=8),
triton.Config({}, num_stages=8, num_warps=8),
triton.Config({}, num_stages=16, num_warps=8),
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=['n_elements']
configs=[
triton.Config({}, num_stages=1),
triton.Config({}, num_stages=2),
triton.Config({}, num_stages=4),
triton.Config({}, num_stages=8),
triton.Config({}, num_stages=16),
triton.Config({}, num_stages=1, num_warps=8),
triton.Config({}, num_stages=2, num_warps=8),
triton.Config({}, num_stages=4, num_warps=8),
triton.Config({}, num_stages=8, num_warps=8),
triton.Config({}, num_stages=16, num_warps=8),
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=["n_elements"],
)
@triton.jit
def _quantize_columnwise_and_transpose(
......@@ -39,7 +40,8 @@ else:
output_ptr,
output_maxs,
n_elements,
M : tl.constexpr, N : tl.constexpr,
M: tl.constexpr,
N: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
P2: tl.constexpr,
):
......@@ -47,12 +49,12 @@ else:
block_start = pid
p2_arange = tl.arange(0, P2)
p2_arange_mask = p2_arange < M
arange = p2_arange * N
arange = p2_arange * N
offsets = block_start + arange
x = tl.load(x_ptr + offsets, mask=p2_arange_mask)
abs_x = tl.abs(x)
max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0)
output = tl.libdevice.llrint(127. * (x / max_val))
output = tl.libdevice.llrint(127.0 * (x / max_val))
new_start = pid * M
new_offsets = new_start + p2_arange
......@@ -68,6 +70,6 @@ else:
assert x.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
_quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2)
return output, output_maxs
import torch
from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available():
def quantize_global_transpose(input): return None
def quantize_global(x: torch.Tensor): return None
else:
def quantize_global_transpose(input):
return None
def quantize_global(x: torch.Tensor):
return None
else:
import triton
import triton.language as tl
# global quantize
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4),
triton.Config({'BLOCK_SIZE': 2048,}, num_stages=1),
],
key=['n_elements']
configs=[
triton.Config({"BLOCK_SIZE": 1024}, num_warps=4),
triton.Config({"BLOCK_SIZE": 2048}, num_stages=1),
],
key=["n_elements"],
)
@triton.jit
def _quantize_global(
......@@ -34,35 +35,43 @@ else:
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
absmax_inv = tl.load(absmax_inv_ptr)
output = tl.libdevice.llrint(127. * (x * absmax_inv))
output = tl.libdevice.llrint(127.0 * (x * absmax_inv))
tl.store(output_ptr + offsets, output, mask=mask)
def quantize_global(x: torch.Tensor):
absmax = x.abs().max().unsqueeze(0)
absmax_inv = 1./ absmax
output = torch.empty(*x.shape, device='cuda', dtype=torch.int8)
absmax_inv = 1.0 / absmax
output = torch.empty(*x.shape, device="cuda", dtype=torch.int8)
assert x.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
_quantize_global[grid](x, absmax_inv, output, n_elements)
return output, absmax
# global quantize and transpose
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),
# ...
],
key=['M', 'N']
configs=[
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "GROUP_M": 8}, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "GROUP_M": 8}, num_warps=4),
# ...
],
key=["M", "N"],
)
@triton.jit
def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, stride_bn, stride_bm, M, N,
BLOCK_M : tl.constexpr,
BLOCK_N : tl.constexpr,
GROUP_M : tl.constexpr):
def _quantize_global_transpose(
A,
absmax_inv_ptr,
B,
stride_am,
stride_an,
stride_bn,
stride_bm,
M,
N,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
GROUP_M: tl.constexpr,
):
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
......@@ -86,20 +95,30 @@ else:
B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
output = tl.libdevice.llrint(127. * (a * absmax_inv))
output = tl.libdevice.llrint(127.0 * (a * absmax_inv))
tl.store(B, output, mask=mask)
def quantize_global_transpose(input):
absmax = input.abs().max().unsqueeze(0)
absmax_inv = 1./ absmax
absmax_inv = 1.0 / absmax
M, N = input.shape
out = torch.empty(N, M, device='cuda', dtype=torch.int8)
out = torch.empty(N, M, device="cuda", dtype=torch.int8)
assert out.size(0) == N and out.size(1) == M
assert input.stride(0) == 1 or input.stride(1) == 1
assert out.stride(0) == 1 or out.stride(1) == 1
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
_quantize_global_transpose[grid](input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N)
grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
_quantize_global_transpose[grid](
input,
absmax_inv,
out,
input.stride(0),
input.stride(1),
out.stride(0),
out.stride(1),
M,
N,
)
return out, absmax
......@@ -5,9 +5,10 @@ import torch
from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available():
def quantize_rowwise(x: torch.Tensor): return None
else:
def quantize_rowwise(x: torch.Tensor):
return None
else:
import triton
import triton.language as tl
......@@ -15,21 +16,21 @@ else:
# TODO: autotune this better.
@triton.autotune(
configs=[
triton.Config({}, num_stages=1, num_warps=8),
triton.Config({}, num_stages=2, num_warps=8),
triton.Config({}, num_stages=4, num_warps=8),
triton.Config({}, num_stages=8, num_warps=8),
triton.Config({}, num_stages=1),
triton.Config({}, num_stages=2),
triton.Config({}, num_stages=4),
triton.Config({}, num_stages=8),
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=['n_elements']
configs=[
triton.Config({}, num_stages=1, num_warps=8),
triton.Config({}, num_stages=2, num_warps=8),
triton.Config({}, num_stages=4, num_warps=8),
triton.Config({}, num_stages=8, num_warps=8),
triton.Config({}, num_stages=1),
triton.Config({}, num_stages=2),
triton.Config({}, num_stages=4),
triton.Config({}, num_stages=8),
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=["n_elements"],
)
@triton.jit
def _quantize_rowwise(
......@@ -49,7 +50,7 @@ else:
abs_x = tl.abs(x)
max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0)
output = tl.libdevice.llrint(127. * (x / max_val))
output = tl.libdevice.llrint(127.0 * (x / max_val))
tl.store(output_ptr + offsets, output, mask=row_mask)
tl.store(output_maxs + pid, max_val)
......
......@@ -30,7 +30,7 @@ def outlier_hook(module, input):
# (1) zscore test of std of hidden dimension
outlier_idx = find_outlier_dims(merged, reduction_dim=1, zscore=3)
# (2) magnitude > 6 test
dims = (torch.abs(input[0])> 6).sum(dim=list(range(len(input[0].shape)-1)))
dims = (torch.abs(input[0]) > 6).sum(dim=list(range(len(input[0].shape) - 1)))
outlier_idx2 = torch.where(dims > 0)[0]
outlier_idx = torch.cat([outlier_idx, outlier_idx2]).unique()
tracer.hvalue2outlier_idx[hvalue] = outlier_idx
......@@ -59,14 +59,14 @@ class OutlierTracer:
self.hooks.append(m.register_forward_pre_hook(outlier_hook))
def is_initialized(self):
return getattr(self, 'initialized', False)
return getattr(self, "initialized", False)
def get_hvalue(self, weight):
return weight.data.storage().data_ptr()
def get_outliers(self, weight):
if not self.is_initialized():
print('Outlier tracer is not initialized...')
print("Outlier tracer is not initialized...")
return None
hvalue = self.get_hvalue(weight)
if hvalue in self.hvalue2outlier_idx:
......@@ -80,6 +80,7 @@ class OutlierTracer:
cls._instance = cls.__new__(cls)
return cls._instance
def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False):
if rdm:
return torch.randint(0, weight.shape[1], size=(topk,), device=weight.device).long()
......@@ -87,13 +88,13 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False)
m = weight.mean(reduction_dim)
mm = m.mean()
mstd = m.std()
zm = (m-mm)/mstd
zm = (m - mm) / mstd
std = weight.std(reduction_dim)
stdm = std.mean()
stdstd = std.std()
zstd = (std-stdm)/stdstd
zstd = (std - stdm) / stdstd
if topk is not None:
val, idx = torch.topk(std.abs(), k=topk, dim=0)
......@@ -105,10 +106,7 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False)
def execute_and_return(command_string: str) -> Tuple[str, str]:
def _decode(subprocess_err_out_tuple):
return tuple(
to_decode.decode("UTF-8").strip()
for to_decode in subprocess_err_out_tuple
)
return tuple(to_decode.decode("UTF-8").strip() for to_decode in subprocess_err_out_tuple)
def execute_and_return_decoded_std_streams(command_string):
return _decode(
......@@ -116,14 +114,13 @@ def execute_and_return(command_string: str) -> Tuple[str, str]:
shlex.split(command_string),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
).communicate()
).communicate(),
)
std_out, std_err = execute_and_return_decoded_std_streams(command_string)
return std_out, std_err
def replace_linear(
model,
linear_replacement,
......@@ -163,8 +160,9 @@ def replace_linear(
model._modules[name].bias = old_module.bias
if post_processing_function is not None:
func = getattr(module, post_processing_function, None)
if func is not None: func(module)
func = getattr(module, post_processing_function, None)
if func is not None:
func(module)
return model
......@@ -179,7 +177,7 @@ def pack_dict_to_tensor(source_dict):
A torch tensor containing the packed data.
"""
json_str = json.dumps(source_dict)
json_bytes = json_str.encode('utf-8')
json_bytes = json_str.encode("utf-8")
tensor_data = torch.tensor(list(json_bytes), dtype=torch.uint8)
return tensor_data
......@@ -196,7 +194,7 @@ def unpack_tensor_to_dict(tensor_data):
A Python dictionary containing the unpacked data.
"""
json_bytes = bytes(tensor_data.cpu().numpy())
json_str = json_bytes.decode('utf-8')
json_str = json_bytes.decode("utf-8")
unpacked_dict = json.loads(json_str)
return unpacked_dict
......@@ -2,14 +2,14 @@ import torch
import bitsandbytes as bnb
p = torch.nn.Parameter(torch.rand(10,10).cuda())
a = torch.rand(10,10).cuda()
p = torch.nn.Parameter(torch.rand(10, 10).cuda())
a = torch.rand(10, 10).cuda()
p1 = p.data.sum().item()
adam = bnb.optim.Adam([p])
out = a*p
out = a * p
loss = out.sum()
loss.backward()
adam.step()
......@@ -17,5 +17,5 @@ adam.step()
p2 = p.data.sum().item()
assert p1 != p2
print('SUCCESS!')
print('Installation was successful!')
print("SUCCESS!")
print("Installation was successful!")
......@@ -2,23 +2,18 @@ import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
MAX_NEW_TOKENS = 128
model_name = 'meta-llama/Llama-2-7b-hf'
model_name = "meta-llama/Llama-2-7b-hf"
text = 'Hamburg is in which country?\n'
text = "Hamburg is in which country?\n"
tokenizer = LlamaTokenizer.from_pretrained(model_name)
input_ids = tokenizer(text, return_tensors="pt").input_ids
max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
max_memory = f"{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB"
n_gpus = torch.cuda.device_count()
max_memory = {i: max_memory for i in range(n_gpus)}
model = LlamaForCausalLM.from_pretrained(
model_name,
device_map='auto',
load_in_8bit=True,
max_memory=max_memory
)
model = LlamaForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True, max_memory=max_memory)
generated_ids = model.generate(input_ids, max_length=MAX_NEW_TOKENS)
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
......@@ -19,6 +19,7 @@ cuda_versions = {
"123": "https://developer.download.nvidia.com/compute/cuda/12.3.2/local_installers/cuda_12.3.2_545.23.08_linux.run",
}
def install_cuda(version, base_path, download_path):
formatted_version = f"{version[:-1]}.{version[-1]}"
folder = f"cuda-{formatted_version}"
......@@ -29,7 +30,7 @@ def install_cuda(version, base_path, download_path):
subprocess.run(["rm", "-rf", install_path], check=True)
url = cuda_versions[version]
filename = url.split('/')[-1]
filename = url.split("/")[-1]
filepath = os.path.join(download_path, filename)
if not os.path.exists(filepath):
......@@ -44,9 +45,14 @@ def install_cuda(version, base_path, download_path):
# Install CUDA
print(f"Installing CUDA version {version}...")
install_command = [
"bash", filepath,
"--no-drm", "--no-man-page", "--override",
"--toolkitpath=" + install_path, "--toolkit", "--silent"
"bash",
filepath,
"--no-drm",
"--no-man-page",
"--override",
"--toolkitpath=" + install_path,
"--toolkit",
"--silent",
]
print(f"Running command: {' '.join(install_command)}")
......@@ -62,6 +68,7 @@ def install_cuda(version, base_path, download_path):
print(f"CUDA version {version} installed at {install_path}")
def main():
user_base_path = os.path.expanduser("~/cuda")
system_base_path = "/usr/local/cuda"
......@@ -93,5 +100,6 @@ def main():
print(f"Invalid CUDA version: {version}. Available versions are: {', '.join(cuda_versions.keys())}")
sys.exit(1)
if __name__ == "__main__":
main()
......@@ -8,6 +8,10 @@ src = [
"tests",
"benchmarking"
]
target-version = "py38"
line-length = 119
[tool.ruff.lint]
select = [
"B", # bugbear: security warnings
"E", # pycodestyle
......@@ -17,7 +21,6 @@ select = [
"UP", # alert you when better syntax is available in your python version
"RUF", # the ruff developer's own rules
]
target-version = "py38"
ignore = [
"B007", # Loop control variable not used within the loop body (TODO: enable)
"B028", # Warning without stacklevel (TODO: enable)
......@@ -30,7 +33,7 @@ ignore = [
]
ignore-init-module-imports = true # allow to expose in __init__.py via imports
[tool.ruff.extend-per-file-ignores]
[tool.ruff.lint.extend-per-file-ignores]
"**/__init__.py" = ["F401"] # allow unused imports in __init__.py
"{benchmarking,tests}/**/*.py" = [
"B007",
......@@ -42,7 +45,7 @@ ignore-init-module-imports = true # allow to expose in __init__.py via imports
"UP030",
]
[tool.ruff.isort]
[tool.ruff.lint.isort]
combine-as-imports = true
detect-same-package = true
force-sort-within-sections = true
......
......@@ -15,6 +15,7 @@
Script to close stale issue. Taken in part from the AllenNLP repository.
https://github.com/allenai/allennlp.
"""
from datetime import datetime as dt, timezone
import os
......@@ -50,7 +51,7 @@ def main():
issue.create_comment(
"This issue has been automatically marked as stale because it has not had "
"recent activity. If you think this still needs to be addressed "
"please comment on this thread.\n\n"
"please comment on this thread.\n\n",
)
......
......@@ -20,7 +20,11 @@ TRANSPOSE_VALS = [(False, True), (False, False)]
@pytest.mark.parametrize("dim2", get_test_dims(32, 96, n=1), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
@pytest.mark.parametrize("funcs", [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)], ids=["func=bmm", "func=matmul"])
@pytest.mark.parametrize(
"funcs",
[(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)],
ids=["func=bmm", "func=matmul"],
)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("req_grad", BOOLEAN_TUPLES, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
......@@ -30,16 +34,13 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
dim3 = dim3 - (dim3 % 16)
dim4 = dim4 - (dim4 % 16)
for i in range(25):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0])
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
target = torch.randn(
size=(dim2, dim4), device="cuda", requires_grad=req_grad[1]
)
target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1])
torch.nn.init.xavier_uniform_(B)
if not transpose[0] and not transpose[1]:
......@@ -71,9 +72,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
A.grad = None
B.grad = None
loss_torch = torch.nn.functional.mse_loss(
out_torch, target
).mean()
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
......@@ -81,18 +80,14 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
B.grad = None
if req_grad[0]:
torch.testing.assert_close(
gradA1, gradA2, atol=0.015, rtol=0.1
)
torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02
torch.testing.assert_close(
gradB1, gradB2, atol=0.18, rtol=0.3
)
torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3)
# batched matrix multiply
if funcs[0] in [torch.bmm, torch.matmul]:
......@@ -119,9 +114,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx == 0).sum().item() < n * 0.01
torch.testing.assert_close(
out_bnb, out_torch, atol=0.027, rtol=0.2
)
torch.testing.assert_close(out_bnb, out_torch, atol=0.027, rtol=0.2)
if any(req_grad):
out_bnb.data.copy_(out_torch)
......@@ -133,9 +126,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
A.grad = None
B.grad = None
loss_torch = torch.nn.functional.mse_loss(
out_torch, target
).mean()
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
......@@ -143,9 +134,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
B.grad = None
if req_grad[0]:
torch.testing.assert_close(
gradA1, gradA2, atol=0.015, rtol=0.1
)
torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
......@@ -192,9 +181,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
A.grad = None
B.grad = None
loss_torch = torch.nn.functional.mse_loss(
out_torch, target
).mean()
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
......@@ -202,9 +189,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
B.grad = None
if req_grad[0]:
torch.testing.assert_close(
gradA1, gradA2, atol=0.015, rtol=0.1
)
torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
......@@ -218,25 +203,17 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
@pytest.mark.parametrize("decomp", [0.0, 6.0], ids=id_formatter("decomp"))
@pytest.mark.parametrize("funcs", [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)], ids=["func=matmul", "func=switchback_bnb"])
@pytest.mark.parametrize(
"funcs",
[(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)],
ids=["func=matmul", "func=switchback_bnb"],
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
def test_matmullt(
dim1,
dim2,
dim3,
dim4,
funcs,
dtype,
req_grad,
transpose,
decomp,
has_fp16_weights,
has_bias
):
def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias):
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda")
......@@ -245,18 +222,13 @@ def test_matmullt(
req_grad[2] = False
for i in range(3):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
A = torch.randn(
size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype
)
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
if decomp == 6.0:
with torch.no_grad():
A[:, outlier_dim] = 6.0
B = torch.randn(
size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype
)
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
target = torch.randn(
size=(dim2, dim4),
device="cuda",
......@@ -266,7 +238,7 @@ def test_matmullt(
bias = None
bias2 = None
if has_bias:
bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2])
bias = torch.randn(dim4, device="cuda", dtype=dtype, requires_grad=req_grad[2])
bias2 = bias.clone()
torch.nn.init.xavier_uniform_(B)
B2 = B.clone()
......@@ -311,9 +283,7 @@ def test_matmullt(
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(
out_bnb, target
).mean()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
......@@ -323,9 +293,7 @@ def test_matmullt(
gradBias1 = bias.grad
bias.grad = None
loss_torch = torch.nn.functional.mse_loss(
out_torch, target
).mean()
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
......@@ -336,9 +304,7 @@ def test_matmullt(
bias.grad = None
if req_grad[0]:
torch.testing.assert_close(
gradA1, gradA2, atol=0.015, rtol=0.1
)
torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[1]:
n = gradB1.numel()
if dim2 > 0:
......@@ -352,9 +318,7 @@ def test_matmullt(
assert (idx == 0).sum().item() <= n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() <= n * 0.02
torch.testing.assert_close(
gradB1, gradB2, atol=0.18, rtol=0.3
)
torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3)
if req_grad[2]:
torch.testing.assert_close(gradBias1, gradBias2)
......@@ -370,8 +334,20 @@ def test_matmullt(
@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'], ids=id_formatter("quant_type"))
def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type):
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"], ids=id_formatter("quant_type"))
def test_matmul_4bit(
dim1,
dim2,
dim3,
dim4,
funcs,
dtype,
req_grad,
transpose,
has_bias,
compress_statistics,
quant_type,
):
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
if has_bias == False:
......@@ -387,11 +363,15 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
bias = None
bias2 = None
if has_bias:
bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2])
bias = torch.randn(dim4, device="cuda", dtype=dtype, requires_grad=req_grad[2])
bias2 = bias.clone()
torch.nn.init.xavier_uniform_(B)
B2, quant_state = bnb.functional.quantize_4bit(B, compress_statistics=compress_statistics, quant_type=quant_type)
B2, quant_state = bnb.functional.quantize_4bit(
B,
compress_statistics=compress_statistics,
quant_type=quant_type,
)
if not transpose[0] and transpose[1]:
out_torch = funcs[0](A, B.t())
......@@ -410,7 +390,7 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
if n > 0:
assert err < 0.115
#assert err < 0.20
# assert err < 0.20
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
......@@ -424,7 +404,7 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
gradBias1 = bias.grad
bias.grad = None
loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean()
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
......@@ -435,7 +415,7 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
bias.grad = None
if req_grad[0]:
torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1)
torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[2]:
torch.testing.assert_close(gradBias1, gradBias2)
......@@ -448,8 +428,12 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("funcs", [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)], ids=["matmul_fp8_mixed", 'matmul_fp8_global'])
def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
@pytest.mark.parametrize(
"funcs",
[(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)],
ids=["matmul_fp8_mixed", "matmul_fp8_global"],
)
def test_matmul_fp8(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
req_grad = list(req_grad)
......@@ -480,7 +464,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
err = torch.abs(out_bnb - out_torch).float().mean().item()
if n > 0:
assert err < 0.115
#assert err < 0.20
# assert err < 0.20
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
......@@ -491,7 +475,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
A.grad = None
B.grad = None
loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean()
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
......@@ -499,7 +483,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B.grad = None
if req_grad[0]:
torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1)
torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[1]:
n = gradB1.numel()
......@@ -514,8 +498,6 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
assert (idx == 0).sum().item() <= n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() <= n * 0.02
grad_err = (gradB1-gradB2).abs().mean()
grad_err = (gradB1 - gradB2).abs().mean()
assert grad_err.item() < 0.003
torch.testing.assert_close(
gradB1, gradB2, atol=0.18, rtol=0.3
)
torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3)
......@@ -35,7 +35,4 @@ def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog):
def test_get_cuda_bnb_library_path_nocublaslt(monkeypatch, cuda111_noblas_spec):
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
assert (
get_cuda_bnb_library_path(cuda111_noblas_spec).stem
== "libbitsandbytes_cuda111_nocublaslt"
)
assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda111_nocublaslt"
......@@ -19,9 +19,7 @@ from tests.helpers import (
id_formatter,
)
torch.set_printoptions(
precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
)
torch.set_printoptions(precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000)
k = 20
......@@ -98,9 +96,7 @@ def teardown():
pass
@pytest.mark.parametrize(
"dtype", [torch.float32, torch.float16], ids=["float", "half"]
)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["float", "half"])
def test_estimate_quantiles(dtype):
A = torch.rand(1024, 1024, device="cuda")
A = A.to(dtype)
......@@ -136,7 +132,6 @@ def test_quantile_quantization():
assert diff < 0.001
def test_dynamic_quantization():
diffs = []
reldiffs = []
......@@ -149,8 +144,8 @@ def test_dynamic_quantization():
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
assert diff.mean().item() < 0.0135
print(sum(diffs)/len(diffs))
print(sum(reldiffs)/len(reldiffs))
print(sum(diffs) / len(diffs))
print(sum(reldiffs) / len(reldiffs))
for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda")
......@@ -161,13 +156,12 @@ def test_dynamic_quantization():
assert diff < 0.004
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested"))
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
@pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed"))
def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
#print('')
# print('')
diffs = []
reldiffs = []
for i in range(100):
......@@ -178,10 +172,10 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
reldiff = diff / torch.abs(A1.float() + 1e-8)
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
abserr = sum(diffs)/len(diffs)
relerr = sum(reldiffs)/len(reldiffs)
#print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs))
#print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs))
abserr = sum(diffs) / len(diffs)
relerr = sum(reldiffs) / len(reldiffs)
# print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs))
# print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs))
assert abserr < 0.011
assert relerr < 0.018
assert A2.dtype == dtype
......@@ -196,9 +190,9 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
reldiff = diff / torch.abs(A1.float() + 1e-8)
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
#torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
abserr = sum(diffs)/len(diffs)
relerr = sum(reldiffs)/len(reldiffs)
# torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
abserr = sum(diffs) / len(diffs)
relerr = sum(reldiffs) / len(reldiffs)
if signed:
assert abserr < 0.0035
assert relerr < 0.015
......@@ -206,14 +200,11 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
assert abserr < 0.00175
assert relerr < 0.012
assert A2.dtype == dtype
#print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
#print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
# print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
# print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
@pytest.mark.parametrize(
"gtype", [torch.float32, torch.float16], ids=["float", "half"]
)
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=["float", "half"])
def test_percentile_clipping(gtype):
gnorm_vec1 = torch.zeros(100, device="cuda")
gnorm_vec2 = torch.zeros(100, device="cuda")
......@@ -223,9 +214,7 @@ def test_percentile_clipping(gtype):
for i in range(k):
step += 1
g = torch.randn(n, n, dtype=gtype, device="cuda")
gnorm1, clip2, gnorm_scale = F.percentile_clipping(
g, gnorm_vec2, step, percentile=percentile
)
gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile)
assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1
gnorm2 = torch.norm(g.float())
......@@ -309,7 +298,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
dim2 = dim2 - (dim2 % 32)
errors = []
relerrors = []
#print("")
# print("")
for i in range(5):
if batched:
A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda")
......@@ -321,9 +310,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda")
maxA, Ac = quant_methods[0](A, 1)
maxB, Bc = quant_methods[1](B, 0)
torch.testing.assert_close(
quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05
)
torch.testing.assert_close(quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05)
if batched:
out2 = torch.bmm(A, B)
C = torch.bmm(Ac.float(), Bc.float())
......@@ -338,8 +325,8 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
relerr = err / torch.abs(out2)
errors.append(err.mean().item())
relerrors.append(relerr.mean().item())
#print(mean(errors))
#print(mean(relerrors))
# print(mean(errors))
# print(mean(relerrors))
def test_stable_embedding():
......@@ -356,16 +343,8 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
batch_dim = batch_dim - (batch_dim % 16)
seq_dim = seq_dim - (seq_dim % 16)
for i in range(k):
shapeA = (
(batch_dim, hidden_dim)
if not transpose[0]
else (hidden_dim, batch_dim)
)
shapeB = (
(32 * random.randint(1, 4), hidden_dim)
if transpose[1]
else (hidden_dim, 32 * random.randint(1, 4))
)
shapeA = (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim)
shapeB = (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4))
A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
if not transpose[0] and not transpose[1]:
......@@ -385,11 +364,7 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
for i in range(k):
shapeA = (batch_dim, seq_dim, hidden_dim)
shapeB = (
(32 * random.randint(1, 4), hidden_dim)
if transpose[1]
else (hidden_dim, 32 * random.randint(1, 4))
)
shapeB = (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4))
A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
if not transpose[0] and not transpose[1]:
......@@ -410,16 +385,10 @@ def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
hidden_dim = hidden_dim - (hidden_dim % 32)
batch_dim = batch_dim - (batch_dim % 2)
for i in range(25):
A = torch.randint(
-128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
).to(torch.int8)
B = torch.randint(
-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda"
).to(torch.int8)
A = torch.randint(-128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda").to(torch.int8)
B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda").to(torch.int8)
out2 = torch.einsum("bsi, bso->io", A.float(), B.float())
iout = torch.empty(
A.shape[2], B.shape[2], dtype=torch.int32, device=A.device
)
iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device)
out = F.igemm(A, B, out=iout)
torch.testing.assert_close(out.float(), out2)
......@@ -444,9 +413,7 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
errs2 = []
relerrs2 = []
for i in range(k):
A = torch.normal(
0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
)
A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda")
if transpose:
B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda")
else:
......@@ -523,9 +490,7 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float())
out = F.igemm(A.permute([0, 2, 1]), B)
elif transpose[0] and transpose[1]:
out2 = torch.bmm(
A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()
)
out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float())
out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
torch.testing.assert_close(out.float(), out2.float())
......@@ -541,7 +506,7 @@ def test_vector_quant(dim1, dim2, dim3):
qA, SA = F.vectorwise_quant(A, dim=0)
A1 = F.vectorwise_dequant(qA, SA)
n = A1.numel()
assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n*0.002))
assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n * 0.002))
@pytest.mark.parametrize("dim1", get_test_dims(2, 256, n=2), ids=id_formatter("dim1"))
......@@ -565,9 +530,7 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
if dims == 2:
A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype)
elif dims == 3:
A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(
dtype
)
A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(dtype)
out, S = F.nvidia_transform(A, to_order=orderOut)
......@@ -579,17 +542,11 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
if dims == 2:
n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
elif dims == 3:
n = (
A.shape[0]
* A.shape[1]
* (A.shape[2] + (32 - (A.shape[2] % 32)))
)
n = A.shape[0] * A.shape[1] * (A.shape[2] + (32 - (A.shape[2] % 32)))
assert out.numel() == n
elif orderOut == "col_turing":
# 32 col 8 row tiles
n = (A.shape[0] + (8 - A.shape[0] % 8)) * (
A.shape[1] + (32 - (A.shape[1] % 32))
)
n = (A.shape[0] + (8 - A.shape[0] % 8)) * (A.shape[1] + (32 - (A.shape[1] % 32)))
assert out.numel() == n
total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0)
for row in range(A.shape[0]):
......@@ -598,9 +555,7 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
j = col
coltile = (col // 32) + (1 if col % 32 != 0 else 0)
rowtile = (
(row // 8) + (1 if row % 8 != 0 else 0)
) * total_coltile
rowtile = ((row // 8) + (1 if row % 8 != 0 else 0)) * total_coltile
offset = 32 * 8 * (rowtile + coltile)
col2 = col % 32
row2 = (row % 8) * 32
......@@ -611,9 +566,7 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
# torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
if orderOut == "col32":
out2, S = F.nvidia_transform(
out, from_order=orderOut, to_order="row", state=S
)
out2, S = F.nvidia_transform(out, from_order=orderOut, to_order="row", state=S)
torch.testing.assert_close(A, out2)
......@@ -626,16 +579,10 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
for i in range(k):
if dims == 2:
A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(
torch.int8
)
A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(torch.int8)
elif dims == 3:
A = torch.randint(
-128, 127, size=(dim1, dim2, dim3), device="cuda"
).to(torch.int8)
B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(
torch.int8
)
A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(torch.int8)
B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8)
C1 = torch.matmul(A.float(), B.t().float())
A2, SA = F.transform(A, "col32")
......@@ -645,9 +592,7 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
torch.testing.assert_close(C1, C3.float())
# transpose
B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(
torch.int8
)
B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(torch.int8)
C1 = torch.matmul(A.float(), B.float())
B2t, SBt = F.transform(B, "col_turing", transpose=True)
......@@ -667,9 +612,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
if dims == 2:
A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half()
elif dims == 3:
A = torch.normal(
0, 0.5, size=(dim1, dim2, dim3), device="cuda"
).half()
A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device="cuda").half()
B = torch.randn((dim4, dim3), device="cuda").half()
torch.nn.init.xavier_uniform_(B)
C1 = torch.matmul(A, B.t())
......@@ -700,6 +643,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
# C3, S = F.transform(C2, 'row', state=SC)
# torch.testing.assert_close(C1, C3.float())
@pytest.mark.parametrize(
("batch", "seq", "model", "hidden"),
[
......@@ -729,7 +673,6 @@ def test_bench_8bit_training(batch, seq, model, hidden):
torch.cuda.synchronize()
t0 = time.time()
for i in range(k):
out1 = torch.matmul(A, w1.t()) # fc1
# out2 = torch.matmul(out1, w2.t())# fc2
......@@ -866,13 +809,15 @@ def test_bench_8bit_training(batch, seq, model, hidden):
def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
inner = torch.randint(1, 128, size=(1,)).item()
bias = None
if has_bias: bias = torch.randn(dim4, device='cuda', dtype=torch.float16)
if has_bias:
bias = torch.randn(dim4, device="cuda", dtype=torch.float16)
formatB = F.get_special_format_str()
for i in range(1):
A = torch.randn(dim1, inner, device="cuda")
B = torch.randn(dim4, inner, device="cuda")
C1 = torch.matmul(A.half(), B.t().half())
if has_bias: C1 += bias
if has_bias:
C1 += bias
A1, maxA = F.vectorwise_quant(A, dim=1)
B1, maxB = F.vectorwise_quant(B, dim=1)
......@@ -883,7 +828,8 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
C3, S = F.nvidia_transform(C2, "row", state=SC)
C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())
if has_bias: C4 += bias
if has_bias:
C4 += bias
# TODO: is something wrong here? If so, the problem goes deeper
# n = C1.numel()
......@@ -917,9 +863,7 @@ def test_colrow_absmax(dim1, dim2, dims):
else:
assert False
row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
A, threshold=threshold
)
row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold)
A_blocked = einops.rearrange(
torch.abs(A),
......@@ -939,9 +883,7 @@ def test_colrow_absmax(dim1, dim2, dims):
torch.testing.assert_close(row_stats1_trunc, row_stats2)
torch.testing.assert_close(nnz_block_ptr1.int(), nnz_block_ptr2)
row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
A, threshold=0.0
)
row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0)
torch.testing.assert_close(col_stats1, col_stats2)
torch.testing.assert_close(row_stats1, row_stats2)
......@@ -963,24 +905,16 @@ def test_double_quant(dim1, dim2):
torch.testing.assert_close(CAt, out_col1, atol=1, rtol=0)
n = CAt.numel()
num_not_close_rows = (
(torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
)
num_not_close_cols = (
(torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
)
num_not_close_rows = (torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
# allow for 1:500 error due to rounding differences
min_error = 1 / 500
if num_not_close_cols > (min_error * n):
print(
f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}"
)
print(f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}")
assert False
if num_not_close_rows > (min_error * n):
print(
f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}"
)
print(f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}")
assert False
torch.testing.assert_close(Srow.flatten().float(), statsA)
......@@ -991,13 +925,12 @@ def test_double_quant(dim1, dim2):
("dim1", "dim4", "inner"),
(
pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}")
for (dim1, dim4, inner)
in zip(
for (dim1, dim4, inner) in zip(
get_test_dims(1, 4 * 1024, n=4),
get_test_dims(1, 4 * 1024, n=4),
get_test_dims(1, 4 * 1024, n=4),
)
)
),
)
def test_integrated_igemmlt(dim1, dim4, inner):
for i in range(k):
......@@ -1037,13 +970,12 @@ def test_integrated_igemmlt(dim1, dim4, inner):
("dim1", "dim4", "inner"),
(
pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}")
for (dim1, dim4, inner)
in zip(
for (dim1, dim4, inner) in zip(
get_test_dims(1, 4 * 1024, n=6),
get_test_dims(1, 4 * 1024, n=6),
get_test_dims(1, 4 * 1024, n=6),
)
)
),
)
@pytest.mark.skip("Row scale has some bugs for ampere")
def test_igemmlt_row_scale(dim1, dim4, inner):
......@@ -1067,9 +999,7 @@ def test_igemmlt_row_scale(dim1, dim4, inner):
c = 10.0 * inner * scale
row_scale = torch.ones_like(maxA) / c
outC32, SC = F.igemmlt(
A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
)
outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale)
C3, S = F.nvidia_transform(outC32, "row", state=SC)
maxval = torch.abs(C3).max()
if maxval == 127:
......@@ -1150,9 +1080,7 @@ def test_row_scale_bench(dim1, dim4, inner):
torch.cuda.synchronize()
t0 = time.time()
for i in range(k):
outC32, SC = F.igemmlt(
A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
)
outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale)
torch.cuda.synchronize()
print("row-wise", time.time() - t0)
......@@ -1177,13 +1105,9 @@ def test_row_scale_bench(dim1, dim4, inner):
def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
for i in range(k):
if dims == 2:
A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(
dtype
)
A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(dtype)
elif dims == 3:
A = torch.randint(
10, 99, size=(dim1, dim2, dim3), device="cuda"
).to(dtype)
A = torch.randint(10, 99, size=(dim1, dim2, dim3), device="cuda").to(dtype)
A.view(-1)[-1] = -1
if transpose:
......@@ -1224,23 +1148,17 @@ def test_coo_double_quant(dim1, dim2):
idx = torch.abs(A) >= threshold
CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
A, threshold=threshold
)
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold)
if coo_tensor is not None:
A1 = A * idx
A2 = torch.zeros_like(A)
A2[
coo_tensor.rowidx.long(), coo_tensor.colidx.long()
] = coo_tensor.values
A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values
torch.testing.assert_close(A1, A2)
A1 = A * (idx == 0)
A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
torch.testing.assert_close(
A * (idx == 0), A2, rtol=0.05, atol=1.5e-2
)
torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2)
@pytest.mark.parametrize("dim1", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim1"))
......@@ -1261,9 +1179,7 @@ def test_spmm_coo(dim1, dim2, transposed_B):
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(
A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
)
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
A2 = A * idx
if transposed_B:
......@@ -1303,9 +1219,7 @@ def test_spmm_bench():
print(nnz / idx.numel())
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(
A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
)
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
for i in range(10):
out2 = F.spmm_coo(cooA, B)
......@@ -1339,9 +1253,7 @@ def test_integrated_sparse_decomp(dim1, dim2):
out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
A, threshold=threshold
)
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold)
C32A, SA = F.transform(CA, "col32")
out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
......@@ -1396,9 +1308,7 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(
A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
)
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
A2 = A * idx
out1 = torch.matmul(A2.half(), B.half())
out = out_func(out1.shape, dtype=torch.float16, device=out1.device)
......@@ -1413,9 +1323,7 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
std = out1.std()
out1 /= std
out2 /= std
assert_all_approx_close(
out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count
)
assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count)
# assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
idx_col = torch.randint(0, A2.shape[-1], size=(15,))
......@@ -1443,9 +1351,7 @@ def test_coo2csr():
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(
A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
)
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
A2 = A * idx
csrA = F.coo2csr(cooA)
counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
......@@ -1463,9 +1369,7 @@ def test_coo2csc():
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(
A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
)
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
A2 = A * idx
cscA = F.coo2csc(cooA)
counts = cscA.colptr[1:] - cscA.colptr[:-1]
......@@ -1499,9 +1403,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(
A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
)
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
A2 = A * idx
out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
out1 = torch.matmul(A2, B.half())
......@@ -1582,7 +1484,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
@pytest.mark.parametrize(
("batch", "seq", "model", "hidden"),
[pytest.param(1, 1, 6656, 4*6656, id="batch=1, seq=1, model=6656, hidden=26k")],
[pytest.param(1, 1, 6656, 4 * 6656, id="batch=1, seq=1, model=6656, hidden=26k")],
)
@pytest.mark.benchmark
def test_bench_matmul(batch, seq, model, hidden):
......@@ -1605,8 +1507,8 @@ def test_bench_matmul(batch, seq, model, hidden):
outliers = torch.randint(0, model, size=(5,)).cuda()
A[:, :, outliers] = 8.0
linearMixedBit = (bnb.nn.Linear8bitLt(model, hidden, False, False, threshold=6.0).cuda().half())
#linearMixedBit.eval()
linearMixedBit = bnb.nn.Linear8bitLt(model, hidden, False, False, threshold=6.0).cuda().half()
# linearMixedBit.eval()
linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
......@@ -1623,121 +1525,123 @@ def test_bench_matmul(batch, seq, model, hidden):
for i in range(iters):
torch.matmul(A, B.t())
torch.cuda.synchronize()
print( f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
print(
f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s",
)
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# bnb.matmul_4bit(A, B_fp4.t(), quant_state=state)
#torch.cuda.synchronize()
#print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
# torch.cuda.synchronize()
# print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c)
#torch.cuda.synchronize()
#print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
# torch.cuda.synchronize()
# print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
torch.cuda.synchronize()
print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
print(f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
bnb.matmul_4bit(A, B_nf4_c.t(), quant_state=state_nf4_c)
torch.cuda.synchronize()
print( f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
print(f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# bnb.matmul(A, B)
#torch.cuda.synchronize()
#print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
# torch.cuda.synchronize()
# print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# bnb.matmul(A, B, threshold=6.0)
#torch.cuda.synchronize()
#print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
#C32A, SA = F.transform(CA, "col32")
#CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
#CxB, SB = F.transform(CB, to_order=formatB)
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
# torch.cuda.synchronize()
# print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
# CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
# C32A, SA = F.transform(CA, "col32")
# CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
# CxB, SB = F.transform(CB, to_order=formatB)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
#torch.cuda.synchronize()
#print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#BA, statsB = F.vectorwise_quant(B, dim=1)
#CxB, SB = F.nvidia_transform(CB, to_order=formatB)
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
# torch.cuda.synchronize()
# print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
# BA, statsB = F.vectorwise_quant(B, dim=1)
# CxB, SB = F.nvidia_transform(CB, to_order=formatB)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# A2 = A.view(-1, A.shape[-1]).contiguous()
# CA, statsA = F.vectorwise_quant(A2, dim=1)
# C32A, SA = F.nvidia_transform(CA, "col32")
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
# Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
# F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
#torch.cuda.synchronize()
#print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
#CxB, SB = F.nvidia_transform(CB, to_order=formatB)
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
# torch.cuda.synchronize()
# print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
# BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
# CxB, SB = F.nvidia_transform(CB, to_order=formatB)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# A2 = A.view(-1, A.shape[-1]).contiguous()
# CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
# C32A, SA = F.nvidia_transform(CA, "col32")
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
# Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
# out = Cout * statsB * statsA * (1.0 / (127 * 127))
#torch.cuda.synchronize()
#print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
# torch.cuda.synchronize()
# print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#linear8bit(A)
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
# linear8bit(A)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# linear8bit(A)
#torch.cuda.synchronize()
#print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
# torch.cuda.synchronize()
# print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#linearMixedBit(A)
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
# linearMixedBit(A)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# linearMixedBit(A)
#torch.cuda.synchronize()
#print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
# torch.cuda.synchronize()
# print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#linear8bit_train(A)
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
# linear8bit_train(A)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# linear8bit_train(A)
#torch.cuda.synchronize()
#print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
# torch.cuda.synchronize()
# print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#linear8bit_train_thresh(A)
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
# linear8bit_train_thresh(A)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# linear8bit_train(A)
#torch.cuda.synchronize()
#print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
# torch.cuda.synchronize()
# print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
def test_zeropoint():
def quant_zp(x):
......@@ -1778,8 +1682,8 @@ def test_zeropoint():
C2 -= A.sum(1).view(-1, 1) * zp
ca, cqa, cza = quant_zp(A)
#print(ca.min(), ca.max())
#print((ca - cza).min(), (ca - cza).max())
# print(ca.min(), ca.max())
# print((ca - cza).min(), (ca - cza).max())
zp = 1
scale = 2.0
......@@ -1808,14 +1712,14 @@ def test_zeropoint():
C7 -= zpa * zpb * A.shape[1]
C7 /= qa * qb
#print("")
# print("")
# print(C0.flatten()[:10])
#print(C1.flatten()[:10])
#print(C2.flatten()[:10])
#print(C3.flatten()[:10])
#print(C5.flatten()[:10])
#print(C6.flatten()[:10])
#print(C7.flatten()[:10])
# print(C1.flatten()[:10])
# print(C2.flatten()[:10])
# print(C3.flatten()[:10])
# print(C5.flatten()[:10])
# print(C6.flatten()[:10])
# print(C7.flatten()[:10])
err1 = torch.abs(C1 - C2).mean().item()
err2 = torch.abs(C1 - C3).mean().item()
err3 = torch.abs(C1 - C4).mean().item()
......@@ -1852,16 +1756,15 @@ def test_extract_outliers():
torch.testing.assert_close(outliers1, outliers2)
def test_blockwise_cpu_large():
diffs = []
reldiffs = []
batch = 128
seq = 128
for hidden in [128]:#, 14336]:
for hidden in [128]: # , 14336]:
for blocksize in [4096, 16384]:
for i in range(2):
A1 = torch.randn(batch, seq, hidden, device='cpu')
A1 = torch.randn(batch, seq, hidden, device="cpu")
t0 = time.time()
C, S = F.quantize_blockwise(A1, blocksize=blocksize)
A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
......@@ -1875,10 +1778,9 @@ def test_blockwise_cpu_large():
# print(sum(reldiffs)/len(reldiffs))
def test_fp8_quant():
for e_bits in range(1, 7):
p_bits = 7-e_bits
p_bits = 7 - e_bits
code = F.create_fp8_map(True, e_bits, p_bits).cuda()
abserr = []
......@@ -1888,12 +1790,12 @@ def test_fp8_quant():
C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff/torch.abs(A1+1e-8)
reldiff = diff / torch.abs(A1 + 1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
#assert diff < 0.0075
#print(sum(abserr)/len(abserr))
#print(sum(relerr)/len(relerr))
# assert diff < 0.0075
# print(sum(abserr)/len(abserr))
# print(sum(relerr)/len(relerr))
abserr = []
relerr = []
......@@ -1902,12 +1804,12 @@ def test_fp8_quant():
C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff/torch.abs(A1+1e-8)
reldiff = diff / torch.abs(A1 + 1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
#assert diff < 0.0075
#print(sum(abserr)/len(abserr))
#print(sum(relerr)/len(relerr))
# assert diff < 0.0075
# print(sum(abserr)/len(abserr))
# print(sum(relerr)/len(relerr))
abserr = []
relerr = []
......@@ -1916,50 +1818,48 @@ def test_fp8_quant():
C, SC = F.quantize_blockwise(A1)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff/torch.abs(A1+1e-8)
reldiff = diff / torch.abs(A1 + 1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
#assert diff < 0.0075
#print(3, sum(abserr)/len(abserr))
#print(3, sum(relerr)/len(relerr))
# assert diff < 0.0075
# print(3, sum(abserr)/len(abserr))
# print(3, sum(relerr)/len(relerr))
def test_few_bit_quant():
#print('')
# print('')
for bits in range(2, 9):
#print('='*30, bits, '='*30)
for method in ['linear', 'fp8', 'dynamic', 'quantile']:
# print('='*30, bits, '='*30)
for method in ["linear", "fp8", "dynamic", "quantile"]:
abserrs = []
relerrs = []
code = None
if method == 'linear':
if method == "linear":
code = F.create_linear_map(True, total_bits=bits).cuda()
elif method == 'fp8':
ebits = math.ceil(bits/2)
pbits = bits-ebits-1
elif method == "fp8":
ebits = math.ceil(bits / 2)
pbits = bits - ebits - 1
code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
elif method == 'dynamic':
code = F.create_dynamic_map(True, bits-0, bits).cuda()
elif method == 'quantile':
values = torch.randn(2048, 2048, device='cuda')
elif method == "dynamic":
code = F.create_dynamic_map(True, bits - 0, bits).cuda()
elif method == "quantile":
values = torch.randn(2048, 2048, device="cuda")
code = F.create_quantile_map(values, bits).cuda()
# for some data types we have no zero
# for some data types we have one zero
# for some data types we have two zeros
assert torch.unique(code).numel() in [2**bits, 2**bits-1], f'bits: {bits}, method: {method}'
#print(method, (code==0).sum())
assert torch.unique(code).numel() in [2**bits, 2**bits - 1], f"bits: {bits}, method: {method}"
# print(method, (code==0).sum())
assert code.numel() == 256
for i in range(10):
values = torch.randn(1, 32, device='cuda')
values = torch.randn(1, 32, device="cuda")
values /= values.abs().max()
#values[values.abs() < 1e-6] += 1e-5
# values[values.abs() < 1e-6] += 1e-5
q1 = []
v1 = []
for v in values[0]:
idx = torch.abs(v-code).argmin()
idx = torch.abs(v - code).argmin()
q1.append(idx.item())
v1.append(code[idx].item())
......@@ -1970,62 +1870,61 @@ def test_few_bit_quant():
v2 = F.dequantize_blockwise(q2, S2)
idx = torch.isclose(q1.int(), q2.int())
err2 = torch.abs(v2-values)
err2 = torch.abs(v2 - values)
abserrs.append(err2.mean().item())
relerrs.append((err2/(1e-10+values).abs()).mean().item())
relerrs.append((err2 / (1e-10 + values).abs()).mean().item())
if idx.sum():
# some weird cases
err1 = torch.abs(v1-values).mean()
#assert err2.mean() <= err1
err1 = torch.abs(v1 - values).mean()
# assert err2.mean() <= err1
else:
torch.testing.assert_close(q1, q2)
#print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
#assert False
# print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
# assert False
def test_kbit_quantile_estimation():
for i in range(100):
data = torch.randn(1024, 1024, device='cuda')
data = torch.randn(1024, 1024, device="cuda")
for bits in range(2, 9):
p = np.linspace(1.3e-4, 1-1.3e-4, 2**bits)
p = np.linspace(1.3e-4, 1 - 1.3e-4, 2**bits)
val1 = torch.Tensor(norm.ppf(p)).cuda()
val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
err = torch.abs(val1-val2).mean()
err = torch.abs(val1 - val2).mean()
assert err < 0.038
for i in range(100):
data = torch.randn(1024, 1024, device='cuda')
data = torch.randn(1024, 1024, device="cuda")
for bits in range(2, 4):
total_values = 2**bits-1
p = np.linspace(0, 1, 2*total_values+1)
idx = np.arange(1, 2*total_values+1, 2)
total_values = 2**bits - 1
p = np.linspace(0, 1, 2 * total_values + 1)
idx = np.arange(1, 2 * total_values + 1, 2)
p = p[idx]
offset = 1/(2*total_values)
p = np.linspace(offset, 1-offset, total_values)
offset = 1 / (2 * total_values)
p = np.linspace(offset, 1 - offset, total_values)
val1 = torch.Tensor(norm.ppf(p)).cuda()
val2 = F.estimate_quantiles(data, num_quantiles=2**bits-1)
err = torch.abs(val1-val2).mean()
val2 = F.estimate_quantiles(data, num_quantiles=2**bits - 1)
err = torch.abs(val1 - val2).mean()
assert err < 0.035
@pytest.mark.benchmark
def test_bench_dequantization():
a = torch.rand(1024, 1024, device='cuda').half()
code =F.create_fp8_map(True, 3, 0, 4).cuda()
a = torch.rand(1024, 1024, device="cuda").half()
code = F.create_fp8_map(True, 3, 0, 4).cuda()
qa, SA = F.quantize_blockwise(a, code=code)
print(qa.max())
max_theoretical_mu = 1024*1024*2/1024**3/672*1000*1000
#print(max_theoretical_mu)
max_theoretical_mu = 1024 * 1024 * 2 / 1024**3 / 672 * 1000 * 1000
# print(max_theoretical_mu)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
qa, SA = F.quantize_blockwise(a)
torch.cuda.synchronize()
#print((time.time()-t0)/1e6)
# print((time.time()-t0)/1e6)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
......@@ -2037,26 +1936,28 @@ def test_fp4_quant(dtype):
result = 0
bias = 3
sign, e1, e2, p1 = bits
idx = sign*8 + e1*4 + e2*2 + p1*1
idx = sign * 8 + e1 * 4 + e2 * 2 + p1 * 1
sign = -1.0 if sign else 1.0
exp = e1*2 + e2*1
exp = e1 * 2 + e2 * 1
if exp == 0:
# sub-normal
if p1 == 0: result = 0
else: result = sign*0.0625
if p1 == 0:
result = 0
else:
result = sign * 0.0625
else:
# normal
exp = 2**(-exp + bias + 1)
exp = 2 ** (-exp + bias + 1)
frac = 1.5 if p1 else 1.0
result = sign*exp*frac
result = sign * exp * frac
code[idx] = result
A1 = torch.randn(1024, 1024, device='cuda', dtype=dtype)
A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
qa, SA = F.quantize_fp4(A1, blocksize=64)
A2 = F.dequantize_fp4(qa, SA)
err = (A1 - A2).abs().float()
relerr = (err/(A1.abs().float()+1e-8)).mean()
relerr = (err / (A1.abs().float() + 1e-8)).mean()
idx = err > 1.0
err = err.mean()
......@@ -2065,31 +1966,29 @@ def test_fp4_quant(dtype):
assert relerr.item() < 0.28
@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
def test_4bit_compressed_stats(quant_type):
for blocksize in [128, 64]:
errs1 = []
errs2 = []
for i in range(10):
A1 = torch.randn(1024, 1024, device='cuda').half()
A1 = torch.randn(1024, 1024, device="cuda").half()
q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
q3, SA3= F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type)
q3, SA3 = F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type)
A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type)
A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type)
err = (A1 - A2).abs().float()
relerr = (err/(A1.abs().float()+1e-15)).mean()
relerr = (err / (A1.abs().float() + 1e-15)).mean()
err = err.mean()
errs1.append(err.item())
assert err.item() < 0.11
assert relerr.item() < 0.28
err = (A1 - A3).abs().float()
relerr = (err/(A1.abs().float()+1e-15)).mean()
relerr = (err / (A1.abs().float() + 1e-15)).mean()
err = err.mean()
errs2.append(err.item())
......@@ -2097,70 +1996,71 @@ def test_4bit_compressed_stats(quant_type):
assert err.item() < 0.11
assert relerr.item() < 0.28
#print(sum(errs1)/len(errs1), blocksize, quant_type)
#print(sum(errs2)/len(errs2), blocksize, quant_type)
# print(sum(errs1)/len(errs1), blocksize, quant_type)
# print(sum(errs2)/len(errs2), blocksize, quant_type)
#@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
@pytest.mark.parametrize("quant_type", ['nf4'])
# @pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
@pytest.mark.parametrize("quant_type", ["nf4"])
@pytest.mark.benchmark
def test_bench_4bit_dequant(quant_type):
blocksize = 256
a = torch.rand(1024*12*4, 1024*12, device='cuda').half()
a = torch.rand(1024 * 12 * 4, 1024 * 12, device="cuda").half()
qa, SA = F.quantize_4bit(a, blocksize=blocksize, quant_type=quant_type)
input_size = a.numel()/2
output_size = a.numel()*2
num_bytes = input_size+output_size
GB = num_bytes/1e9
max_theoretical_s = GB/768
#print(max_theoretical_s*1e6)
b = torch.randn(128, 1024*12, device='cuda').half()
input_size = a.numel() / 2
output_size = a.numel() * 2
num_bytes = input_size + output_size
GB = num_bytes / 1e9
max_theoretical_s = GB / 768
# print(max_theoretical_s*1e6)
b = torch.randn(128, 1024 * 12, device="cuda").half()
iters = 100
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
#b.copy_(a)
# b.copy_(a)
torch.cuda.synchronize()
#print((time.time()-t0)/iters*1e6)
# print((time.time()-t0)/iters*1e6)
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# torch.matmul(b, a.t())
#torch.cuda.synchronize()
#print((time.time()-t0)/iters*1e6)
# torch.cuda.synchronize()
# print((time.time()-t0)/iters*1e6)
def test_normal_map_tree():
code = F.create_normal_map()
values =code[:8].tolist() + code[-8:].tolist()
values = code[:8].tolist() + code[-8:].tolist()
num_pivots = 1
#print(values)
while num_pivots <16:
idx = list(range(16//num_pivots//2, 16, 16//num_pivots))
#print(idx)
# print(values)
while num_pivots < 16:
idx = list(range(16 // num_pivots // 2, 16, 16 // num_pivots))
# print(idx)
num_pivots *= 2
pivots = []
for i in idx:
pivots.append((values[i-1]+values[i])/2)
#print(pivots)
pivots.append((values[i - 1] + values[i]) / 2)
# print(pivots)
@pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}")
@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'])
@pytest.mark.parametrize("kind", ['fc1', 'fc2', 'attn', 'attn_packed'])
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"])
@pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("quant_storage", [torch.uint8, torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize(
"quant_storage",
[torch.uint8, torch.float16, torch.bfloat16, torch.float32],
ids=describe_dtype,
)
def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
for dim in [128, 256, 512, 1024]:
#for dim in [4*1024]:
#for dim in [1*16]:
# for dim in [4*1024]:
# for dim in [1*16]:
errs1 = []
errs2 = []
errs3 = []
......@@ -2171,38 +2071,42 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
max_errs2 = []
max_errs3 = []
for i in range(100):
if kind == 'fc1':
A = torch.randn(1, dim, dtype=dtype, device='cuda')
B = torch.randn(dim*4, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
elif kind == 'fc2':
A = torch.randn(1, 4*dim, dtype=dtype, device='cuda')
B = torch.randn(dim, 4*dim, dtype=dtype, device='cuda')/math.sqrt(dim)
elif kind == 'attn':
A = torch.randn(1, dim, dtype=dtype, device='cuda')
B = torch.randn(dim, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
elif kind == 'attn_packed':
A = torch.randn(1, dim, dtype=dtype, device='cuda')
B = torch.randn(dim*3, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant, quant_storage=quant_storage)
if kind == "fc1":
A = torch.randn(1, dim, dtype=dtype, device="cuda")
B = torch.randn(dim * 4, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
elif kind == "fc2":
A = torch.randn(1, 4 * dim, dtype=dtype, device="cuda")
B = torch.randn(dim, 4 * dim, dtype=dtype, device="cuda") / math.sqrt(dim)
elif kind == "attn":
A = torch.randn(1, dim, dtype=dtype, device="cuda")
B = torch.randn(dim, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
elif kind == "attn_packed":
A = torch.randn(1, dim, dtype=dtype, device="cuda")
B = torch.randn(dim * 3, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
qB, state = F.quantize_4bit(
B,
quant_type=storage_type,
compress_statistics=double_quant,
quant_storage=quant_storage,
)
C3 = torch.matmul(A, B.t())
C2 = F.gemv_4bit(A, qB.t(), state=state)
A.requires_grad = True
C1 = bnb.matmul_4bit(A, qB.t(), state)
err1 = (C1-C2).abs().float()
err2 = (C3-C2).abs().float()
err3 = (C3-C1).abs().float()
err1 = (C1 - C2).abs().float()
err2 = (C3 - C2).abs().float()
err3 = (C3 - C1).abs().float()
mag1 = torch.abs(C1).float()+1e-5
mag2 = torch.abs(C3).float()+1e-5
mag3 = torch.abs(C3).float()+1e-5
mag1 = torch.abs(C1).float() + 1e-5
mag2 = torch.abs(C3).float() + 1e-5
mag3 = torch.abs(C3).float() + 1e-5
relerr1 = err1/mag1
relerr2 = err2/mag2
relerr3 = err3/mag3
relerr1 = err1 / mag1
relerr2 = err2 / mag2
relerr3 = err3 / mag3
max_err1 = err1.max()
max_err2 = err2.max()
......@@ -2220,34 +2124,34 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
max_errs2.append(max_err2.item())
max_errs3.append(max_err3.item())
c = int(C1.numel()*0.0014*(dim/256))+1
c = int(C1.numel() * 0.0014 * (dim / 256)) + 1
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False)
err1 = sum(errs1)/len(errs1)/math.sqrt(dim)
err2 = sum(errs2)/len(errs2)/math.sqrt(dim)
err3 = sum(errs3)/len(errs3)/math.sqrt(dim)
relerr1 = sum(relerrs1)/len(relerrs1)/math.sqrt(dim)
relerr2 = sum(relerrs2)/len(relerrs2)/math.sqrt(dim)
relerr3 = sum(relerrs3)/len(relerrs3)/math.sqrt(dim)
maxerr1 = sum(max_errs1)/len(max_errs1)/math.sqrt(dim)
maxerr2 = sum(max_errs2)/len(max_errs2)/math.sqrt(dim)
maxerr3 = sum(max_errs3)/len(max_errs3)/math.sqrt(dim)
absratio = err2/err3
relratio = relerr2/relerr3
maxratio = relerr2/relerr3
err1 = sum(errs1) / len(errs1) / math.sqrt(dim)
err2 = sum(errs2) / len(errs2) / math.sqrt(dim)
err3 = sum(errs3) / len(errs3) / math.sqrt(dim)
relerr1 = sum(relerrs1) / len(relerrs1) / math.sqrt(dim)
relerr2 = sum(relerrs2) / len(relerrs2) / math.sqrt(dim)
relerr3 = sum(relerrs3) / len(relerrs3) / math.sqrt(dim)
maxerr1 = sum(max_errs1) / len(max_errs1) / math.sqrt(dim)
maxerr2 = sum(max_errs2) / len(max_errs2) / math.sqrt(dim)
maxerr3 = sum(max_errs3) / len(max_errs3) / math.sqrt(dim)
absratio = err2 / err3
relratio = relerr2 / relerr3
maxratio = relerr2 / relerr3
# for debugging if the tests fails
#
#print('='*80)
#print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
#print(C1.flatten()[-20:])
#print(C2.flatten()[-20:])
#print(f'inference vs training abs: {err1}')
#print(f'inference vs training rel: {relerr1}')
#print(f'inference vs training max: {maxerr1}')
#print(f'inference vs training vs torch err ratio abs: {absratio}')
#print(f'inference vs training vs torch err ratio rel: {relratio}')
#print(f'inference vs training vs torch err ratio max: {maxratio}')
# print('='*80)
# print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
# print(C1.flatten()[-20:])
# print(C2.flatten()[-20:])
# print(f'inference vs training abs: {err1}')
# print(f'inference vs training rel: {relerr1}')
# print(f'inference vs training max: {maxerr1}')
# print(f'inference vs training vs torch err ratio abs: {absratio}')
# print(f'inference vs training vs torch err ratio rel: {relratio}')
# print(f'inference vs training vs torch err ratio max: {maxratio}')
if dtype == torch.float16:
if dim <= 512:
assert err1 < 7e-5
......@@ -2283,56 +2187,59 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
assert relratio < 1.04 and relratio > 0.96
assert maxratio < 1.02 and maxratio > 0.98
@pytest.mark.skip("Row scale has some bugs for ampere")
def test_managed():
n = 32*10
n = 32 * 10
A = F.get_paged(n, n, dtype=torch.float32)
B = F.get_paged(n, n, dtype=torch.uint8)
B2 = F.get_paged(n, n, dtype=torch.float32)
assert A.is_paged
assert B.is_paged
assert A.page_deviceid==0
assert B.page_deviceid==0
assert A.page_deviceid == 0
assert B.page_deviceid == 0
F.fill(A, 17.0)
F.fill(B, 17)
F.fill(B2, 2)
assert (A==17).sum().item() == n*n
assert (B==17).sum().item() == n*n
C = A*B.float()
assert (C==289).sum().item() == n*n
assert (A == 17).sum().item() == n * n
assert (B == 17).sum().item() == n * n
C = A * B.float()
assert (C == 289).sum().item() == n * n
F._mul(A, B2)
F._mul(A, B2)
F._mul(A, B2)
assert (A==17*(2**3)).sum().item() == n*n
# F.prefetch_tensor(A)
# F.prefetch_tensor(B)
assert (A == 17 * (2**3)).sum().item() == n * n
# F.prefetch_tensor(A)
# F.prefetch_tensor(B)
# F.fill(B2, 17.0)
# F._mul(A, B2)
# F.fill(B2, 17.0)
# F._mul(A, B2)
# F.prefetch_tensor(A, to_cpu=True)
# F.prefetch_tensor(B, to_cpu=True)
# F.prefetch_tensor(B2, to_cpu=True)
# torch.cuda.synchronize()
# F.prefetch_tensor(A, to_cpu=True)
# F.prefetch_tensor(B, to_cpu=True)
# F.prefetch_tensor(B2, to_cpu=True)
# torch.cuda.synchronize()
# assert (A==17).sum().item() == n*n
# assert (A==17).sum().item() == n*n
# torch.testing.assert_close(A, torch.ones(A.shape)*289)
# torch.testing.assert_close(A, torch.ones(A.shape)*289)
@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4'])
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("double_quant", [False], ids=['DQ_True'])
@pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"])
def test_gemv_eye_4bit(storage_type, dtype, double_quant):
dims = 10
torch.random.manual_seed(np.random.randint(0, 412424242))
dims = get_test_dims(0, 8192, n=dims)
dims = [dim + (64-(dim % 64)) for dim in dims]
#for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
dims = [dim + (64 - (dim % 64)) for dim in dims]
# for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
for dim in dims:
A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device='cuda')
B = torch.eye(dim, dtype=dtype, device='cuda')
A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device="cuda")
B = torch.eye(dim, dtype=dtype, device="cuda")
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
C3 = torch.matmul(A, B.t())
......@@ -2343,5 +2250,5 @@ def test_gemv_eye_4bit(storage_type, dtype, double_quant):
torch.testing.assert_close(A, C3)
torch.testing.assert_close(A, C1)
torch.testing.assert_close(A, C2)
#torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001)
#torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080)
# torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001)
# torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080)
......@@ -10,56 +10,61 @@ transformers = pytest.importorskip("transformers")
def get_4bit_config():
return transformers.BitsAndBytesConfig(
load_in_4bit=True,
load_in_8bit=False,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4',
)
return transformers.BitsAndBytesConfig(
load_in_4bit=True,
load_in_8bit=False,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
def get_model_and_tokenizer(config):
model_name_or_path, quant_type = config
bnb_config = get_4bit_config()
if quant_type == '16bit':
if quant_type == "16bit":
bnb_config.load_in_4bit = False
else:
bnb_config.bnb_4bit_quant_type= quant_type
model = transformers.AutoModelForCausalLM.from_pretrained(model_name_or_path,
bnb_config.bnb_4bit_quant_type = quant_type
model = transformers.AutoModelForCausalLM.from_pretrained(
model_name_or_path,
quantization_config=bnb_config,
max_memory={0:'48GB'},
device_map='auto',
torch_dtype=torch.bfloat16
).eval()
max_memory={0: "48GB"},
device_map="auto",
torch_dtype=torch.bfloat16,
).eval()
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path)
return model, tokenizer
def get_prompt_for_generation_eval(text, add_roles=True):
description = (
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
)
if add_roles:
prompt = f'{description} ### Human: {text} ### Assistant:'
prompt = f"{description} ### Human: {text} ### Assistant:"
else:
prompt = f'{description} {text}'
prompt = f"{description} {text}"
return prompt
def generate(model, tokenizer, text, generation_config, prompt_func=get_prompt_for_generation_eval):
text = prompt_func(text)
inputs = tokenizer(text, return_tensors="pt").to('cuda:0')
outputs = model.generate(inputs=inputs['input_ids'], generation_config=generation_config)
inputs = tokenizer(text, return_tensors="pt").to("cuda:0")
outputs = model.generate(inputs=inputs["input_ids"], generation_config=generation_config)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
models = ['huggyllama/llama-7b', 'bigscience/bloom-1b7']
dtypes = ['nf4', 'fp4']
@pytest.fixture(scope='session', params=product(models, dtypes))
models = ["huggyllama/llama-7b", "bigscience/bloom-1b7"]
dtypes = ["nf4", "fp4"]
@pytest.fixture(scope="session", params=product(models, dtypes))
def model_and_tokenizer(request):
model, tokenizer = get_model_and_tokenizer(request.param)
yield request.param, model, tokenizer
......@@ -81,20 +86,19 @@ def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ, dtype):
)
generation_config.max_new_tokens = 20
#text = 'Please write down the first 50 digits of pi.'
#text = get_prompt_for_generation_eval(text)
#text += ' Sure, here the first 50 digits of pi: 3.14159'
# text = 'Please write down the first 50 digits of pi.'
# text = get_prompt_for_generation_eval(text)
# text += ' Sure, here the first 50 digits of pi: 3.14159'
n_cases = 6
text = '3.14159'
if hasattr(model.config, 'quantization_config'):
text = "3.14159"
if hasattr(model.config, "quantization_config"):
model.config.quantization_config.bnb_4bit_compute_dtype = dtype
model.config.quantization_config.bnb_4bit_use_double_quant = DQ
if not inference_kernel:
text = [text]*n_cases
inputs = tokenizer(text, return_tensors="pt").to('cuda:0')
x = inputs['input_ids']
text = [text] * n_cases
inputs = tokenizer(text, return_tensors="pt").to("cuda:0")
x = inputs["input_ids"]
outputs = []
if inference_kernel:
for i in range(n_cases):
......@@ -105,15 +109,14 @@ def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ, dtype):
outputs = model.generate(x, generation_config=generation_config)
outputs = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
assert len(outputs) == n_cases
failure_count = 0
for i in range(n_cases):
if not outputs[i][:len(str(math.pi))] == str(math.pi):
if not outputs[i][: len(str(math.pi))] == str(math.pi):
failure_count += 1
failure_max = (2 if fixture_config[0] == 'huggyllama/llama-7b' else 4)
failure_max = 2 if fixture_config[0] == "huggyllama/llama-7b" else 4
if failure_count > failure_max:
print(math.pi)
for out in outputs:
print(out)
raise ValueError(f'Failure count: {failure_count}/{n_cases}')
raise ValueError(f"Failure count: {failure_count}/{n_cases}")
......@@ -28,9 +28,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
device = "cuda"
layer_shape = (300, 400)
linear = torch.nn.Linear(
*layer_shape, dtype=original_dtype, device="cpu"
) # original layer
linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu") # original layer
# Quantizing original layer
linear_q = bnb.nn.Linear4bit(
......@@ -42,9 +40,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
quant_type=quant_type,
device="meta",
)
new_weight = bnb.nn.Params4bit(
data=linear.weight, quant_type=quant_type, requires_grad=False
)
new_weight = bnb.nn.Params4bit(data=linear.weight, quant_type=quant_type, requires_grad=False)
linear_q.weight = new_weight
if bias:
linear_q.bias = torch.nn.Parameter(linear.bias)
......@@ -172,7 +168,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
target_compression = (
0.143 if original_dtype == torch.float32 else 0.29
) # these numbers get lower as weight shape increases
ratio_error_msg = f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}"
ratio_error_msg = (
f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}"
)
assert size_ratio < target_compression, ratio_error_msg
......
......@@ -19,6 +19,7 @@ from tests.helpers import (
# contributed by Alex Borzunov, see:
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
@pytest.mark.skipif(
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5),
reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs",
......@@ -50,7 +51,9 @@ def test_linear_no_igemmlt():
linear_custom.state.force_no_igemmlt = True
linear_custom.weight = bnb.nn.Int8Params(
linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False
linear.weight.data.clone(),
requires_grad=False,
has_fp16_weights=False,
).to(linear.weight.dtype)
linear_custom.bias = linear.bias
linear_custom = linear_custom.cuda()
......@@ -77,7 +80,14 @@ def test_linear_no_igemmlt():
@pytest.mark.parametrize("force_no_igemmlt", TRUE_FALSE, ids=id_formatter("force_no_igemmlt"))
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward"))
@pytest.mark.parametrize("load_before_cuda", TRUE_FALSE, ids=id_formatter("load_before_cuda"))
def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt, save_before_forward, load_before_cuda):
def test_linear_serialization(
has_fp16_weights,
serialize_before_forward,
deserialize_before_cuda,
force_no_igemmlt,
save_before_forward,
load_before_cuda,
):
linear = torch.nn.Linear(32, 96)
x = torch.randn(3, 32, dtype=torch.half)
......@@ -92,7 +102,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
linear_custom.state.force_no_igemmlt = True
linear_custom.weight = bnb.nn.Int8Params(
linear.weight.data.clone(), requires_grad=has_fp16_weights, has_fp16_weights=has_fp16_weights
linear.weight.data.clone(),
requires_grad=has_fp16_weights,
has_fp16_weights=has_fp16_weights,
)
linear_custom.bias = linear.bias
linear_custom = linear_custom.cuda()
......
......@@ -19,12 +19,18 @@ class MLP8bit(torch.nn.Module):
def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0):
super().__init__()
self.fc1 = bnb.nn.Linear8bitLt(
dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
threshold=threshold
dim1,
dim2,
has_fp16_weights=has_fp16_weights,
memory_efficient_backward=memory_efficient_backward,
threshold=threshold,
)
self.fc2 = bnb.nn.Linear8bitLt(
dim2, dim1, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
threshold=threshold
dim2,
dim1,
has_fp16_weights=has_fp16_weights,
memory_efficient_backward=memory_efficient_backward,
threshold=threshold,
)
def forward(self, x):
......@@ -52,9 +58,7 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
class LinearFunction(torch.autograd.Function):
@staticmethod
def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0):
round_func = (
LinearFunction.round_stoachastic if stochastic else torch.round
)
round_func = LinearFunction.round_stoachastic if stochastic else torch.round
norm = math.sqrt(math.pi) / math.sqrt(2.0)
# std = torch.abs(x).mean()*norm
std = torch.std(x)
......@@ -122,9 +126,7 @@ class LinearFunction(torch.autograd.Function):
return x.to(dtype)
def get_8bit_linear(x, stochastic=False):
round_func = (
LinearFunction.round_stoachastic if stochastic else torch.round
)
round_func = LinearFunction.round_stoachastic if stochastic else torch.round
max1 = torch.abs(x).max()
x = x / max1 * 127
x = round_func(x) / 127 * max1
......@@ -133,9 +135,7 @@ class LinearFunction(torch.autograd.Function):
@staticmethod
def get_8bit_vector_wise(x, dim, stochastic=False):
round_func = (
LinearFunction.round_stoachastic if stochastic else torch.round
)
round_func = LinearFunction.round_stoachastic if stochastic else torch.round
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
max1[max1 == 0] = 1.0
x = (x * 127) / max1
......@@ -219,9 +219,7 @@ class LinearFunction(torch.autograd.Function):
weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1)
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2)
outputq = bnb.functional.igemm(x8, weight8.t())
output = LinearFunction.dequant(
outputq, S1, S2, x.dtype, args.quant_type
)
output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type)
# if torch.rand(1) < 0.01:
# output32 = torch.matmul(x, weight.t())
# err = torch.abs(output-output32).float()
......@@ -250,37 +248,25 @@ class LinearFunction(torch.autograd.Function):
# weight and x are already 8bit
# -> transform grad_output to 8-bit
if args.use_8bit_training == "forward+wgrad":
grad_output8, S1 = LinearFunction.quant(
grad_output, args.quant_type, dim=[0, 1]
)
grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
grad_weight8 = bnb.functional.igemm(grad_output8, x8)
grad_weight = LinearFunction.dequant(
grad_weight8, S1, S2, grad_output.dtype, args.quant_type
)
grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)
# grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
grad_input = grad_output.matmul(weight)
elif args.use_8bit_training == "full":
grad_output8, S1 = LinearFunction.quant(
grad_output, args.quant_type, dim=[0, 1]
)
grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
grad_weight8 = torch.zeros_like(weight, dtype=torch.int32)
bnb.functional.igemm(grad_output8, x8, out=grad_weight8)
grad_weight = LinearFunction.dequant(
grad_weight8, S1, S2, grad_output.dtype, args.quant_type
)
grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)
grad_output8, S1 = LinearFunction.quant(
grad_output, args.quant_type, dim=2
)
grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2)
weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0)
grad_input8 = bnb.functional.igemm(grad_output8, weight8)
grad_input = LinearFunction.dequant(
grad_input8, S1, S3, grad_output.dtype, args.quant_type
)
grad_input = LinearFunction.dequant(grad_input8, S1, S3, grad_output.dtype, args.quant_type)
else:
grad_input = grad_output.matmul(weight)
......@@ -356,12 +342,8 @@ def test_linear8bitlt_accumulated_gradient():
opt1.zero_grad(True)
opt2.step()
opt2.zero_grad(True)
assert_all_approx_close(
l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2
)
assert_all_approx_close(
l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2
)
assert_all_approx_close(l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2)
assert_all_approx_close(l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2)
# we do this copy because otherwise we have small divergences over time that add up
l1[0].weight.data.copy_(l2[0].weight.data)
l1[1].weight.data.copy_(l2[1].weight.data)
......@@ -375,7 +357,17 @@ def test_linear8bitlt_accumulated_gradient():
@pytest.mark.parametrize("threshold", [0.0, 2.0])
@pytest.mark.parametrize("memory_efficient_backward", [False])
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
l1 = (bnb.nn.Linear8bitLt( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).cuda().half())
l1 = (
bnb.nn.Linear8bitLt(
32,
64,
threshold=threshold,
has_fp16_weights=False,
memory_efficient_backward=memory_efficient_backward,
)
.cuda()
.half()
)
assert l1.weight.dtype == torch.int8
l1.eval()
......@@ -397,11 +389,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
if threshold > 0:
assert mlp.fc2.state.idx is not None
mlp = (
MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
.cuda()
.half()
)
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half()
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
......@@ -414,11 +402,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
if threshold > 0:
assert mlp.fc2.state.idx is not None
mlp = (
MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
.half()
.cuda()
)
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda()
for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half()
......@@ -431,7 +415,17 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
mlp = ( MLP8bit( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).half().to("cuda"))
mlp = (
MLP8bit(
32,
64,
threshold=threshold,
has_fp16_weights=False,
memory_efficient_backward=memory_efficient_backward,
)
.half()
.to("cuda")
)
for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half()
......@@ -447,8 +441,12 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
assert mlp.fc2.weight.device.type == "cuda"
mlp = MLP8bit(
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
)
32,
64,
threshold=threshold,
has_fp16_weights=False,
memory_efficient_backward=memory_efficient_backward,
)
w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda() # grab weights before quantization,
mlp = mlp.cuda().half() # and this line triggers quantization
......@@ -489,7 +487,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
lambda n_in, n_out, bias=True: bnb.nn.Linear8bitLt(n_in, n_out, bias=bias, has_fp16_weights=False),
bnb.nn.LinearFP4,
],
ids=['Int8Lt', 'FP4'],
ids=["Int8Lt", "FP4"],
)
def test_linear_kbit_fp32_bias(module):
# casts model to fp16 -> int8 automatically
......@@ -544,7 +542,7 @@ def test_kbit_backprop(module):
kbit[1].bias.detach().copy_(ref[1].bias)
ref = ref.half().cuda()
kbit = kbit.half().cuda()
kbit = kbit.half().to('cuda')
kbit = kbit.half().to("cuda")
errs1 = []
errs2 = []
......@@ -562,10 +560,10 @@ def test_kbit_backprop(module):
bgrad1 = ref[0].bias.grad
bgrad2 = kbit[0].bias.grad
err1 = (out1-out2).abs().float()
err2 = (grad1-grad2).abs().float()
relerr1 = (err1/(out1.abs().float()+1e-9))
relerr2 = (err2/(grad1.abs().float()+1e-9))
err1 = (out1 - out2).abs().float()
err2 = (grad1 - grad2).abs().float()
relerr1 = err1 / (out1.abs().float() + 1e-9)
relerr2 = err2 / (grad1.abs().float() + 1e-9)
errs1.append(err1.mean().item())
errs2.append(err2.mean().item())
relerrs1.append(relerr1.mean().item())
......@@ -582,20 +580,20 @@ def test_kbit_backprop(module):
assert kbit[0].weight.grad is None or kbit[0].weight.grad.sum().item() == 0
assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0
#print('out', sum(errs1)/len(errs1))
#print('grad', sum(errs2)/len(errs2))
#print('rel out', sum(relerrs1)/len(relerrs1))
#print('rel grad', sum(relerrs2)/len(relerrs2))
# print('out', sum(errs1)/len(errs1))
# print('grad', sum(errs2)/len(errs2))
# print('rel out', sum(relerrs1)/len(relerrs1))
# print('rel grad', sum(relerrs2)/len(relerrs2))
def test_fp8linear():
def test_fp8linear():
b = 10
h = 1024
inp = torch.randn(b, h).cuda()
fp32 = torch.nn.Linear(h, h*2).cuda()
fp8 = bnb.research.nn.LinearFP8Mixed(h, h*2).cuda()
fp32b = torch.nn.Linear(h*2, h).cuda()
fp8b = bnb.research.nn.LinearFP8Mixed(h*2, h).cuda()
fp32 = torch.nn.Linear(h, h * 2).cuda()
fp8 = bnb.research.nn.LinearFP8Mixed(h, h * 2).cuda()
fp32b = torch.nn.Linear(h * 2, h).cuda()
fp8b = bnb.research.nn.LinearFP8Mixed(h * 2, h).cuda()
fp8.weight.data.copy_(fp32.weight.data)
fp8.bias.data.copy_(fp32.bias.data)
......@@ -605,34 +603,34 @@ def test_fp8linear():
a = fp32b(torch.nn.functional.gelu(fp32(inp)))
b = fp8b(torch.nn.functional.gelu(fp8(inp)))
err = (a-b).abs().mean()
err = (a - b).abs().mean()
a.mean().backward()
b.mean().backward()
graderr = (fp8.weight.grad-fp32.weight.grad).abs().mean()
bgraderr = (fp8.bias.grad-fp32.bias.grad).abs().mean()
graderr = (fp8.weight.grad - fp32.weight.grad).abs().mean()
bgraderr = (fp8.bias.grad - fp32.bias.grad).abs().mean()
assert err < 0.05
assert graderr < 0.00002
assert bgraderr < 0.00002
def test_4bit_warnings():
dim1 = 64
with pytest.warns(UserWarning, match=r'inference or training'):
with pytest.warns(UserWarning, match=r"inference or training"):
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
net = net.cuda()
inp = torch.rand(10, dim1).cuda().half()
net(inp)
with pytest.warns(UserWarning, match=r'inference.'):
with pytest.warns(UserWarning, match=r"inference."):
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
net = net.cuda()
inp = torch.rand(1, dim1).cuda().half()
net(inp)
with pytest.warns(UserWarning) as record:
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
net = net.cuda()
inp = torch.rand(10, dim1).cuda().half()
......
......@@ -16,6 +16,7 @@ from tests.helpers import describe_dtype, id_formatter
k = 20
def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0):
idx = torch.isclose(a, b, rtol=rtol, atol=atol)
error_count = (idx == 0).sum().item()
......@@ -33,6 +34,7 @@ def get_temp_dir():
def rm_path(path):
shutil.rmtree(path)
str2optimizers = {}
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion)
......@@ -66,8 +68,14 @@ str2optimizers["rmsprop8bit"] = (
)
str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
str2optimizers["paged_adamw8bit_blockwise"] = (torch.optim.AdamW, lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True))
str2optimizers["paged_adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True))
str2optimizers["paged_adamw8bit_blockwise"] = (
torch.optim.AdamW,
lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True),
)
str2optimizers["paged_adam8bit_blockwise"] = (
torch.optim.Adam,
lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True),
)
str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True))
str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True))
str2optimizers["momentum8bit_blockwise"] = (
......@@ -90,9 +98,18 @@ str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["rmsprop"] = [("square_avg", "state1")]
str2statenames["adam8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
str2statenames["lamb8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
str2statenames["adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
str2statenames["paged_adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
str2statenames["paged_adamw8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
str2statenames["adam8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1"),
("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
str2statenames["paged_adam8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1"),
("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
str2statenames["paged_adamw8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1"),
("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
str2statenames["lion8bit"] = [("exp_avg", "state1", "qmap1", "max1")]
str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")]
......@@ -101,7 +118,7 @@ str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "ab
str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
optimizer_names_32bit = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion', 'paged_lion']
optimizer_names_32bit = ["adam", "momentum", "rmsprop", "paged_adamw", "paged_adam", "lion", "paged_lion"]
@pytest.mark.parametrize("optim_name", optimizer_names_32bit, ids=id_formatter("opt"))
......@@ -109,7 +126,7 @@ optimizer_names_32bit = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_ad
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2"))
def test_optimizer32bit(dim1, dim2, gtype, optim_name):
if gtype == torch.bfloat16 and optim_name in ['momentum', 'rmsprop']:
if gtype == torch.bfloat16 and optim_name in ["momentum", "rmsprop"]:
pytest.skip()
if dim1 == 1 and dim2 == 1:
return
......@@ -161,9 +178,13 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
for name1, name2 in str2statenames[optim_name]:
# since Lion can have pretty noisy updates where things lie at the boundary
# allow up to 10 errors for Lion
assert_most_approx_close(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2],
atol=atol, rtol=rtol,
max_error_count=10)
assert_most_approx_close(
torch_optimizer.state[p1][name1],
bnb_optimizer.state[p2][name2],
atol=atol,
rtol=rtol,
max_error_count=10,
)
if gtype != torch.float32:
# the adam buffers should also be close because they are 32-bit
......@@ -193,13 +214,9 @@ def test_global_config(dim1, dim2, gtype):
eps = 1e-8
bnb.optim.GlobalOptimManager.get_instance().initialize()
bnb.optim.GlobalOptimManager.get_instance().override_config(
p3, "optim_bits", 8
)
bnb.optim.GlobalOptimManager.get_instance().override_config(p3, "optim_bits", 8)
bnb.optim.GlobalOptimManager.get_instance().register_parameters(
[p1, p2, p3]
)
bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3])
p1 = p1.cuda()
p2 = p2.cuda()
p3 = p3.cuda()
......@@ -242,7 +259,8 @@ optimizer_names_8bit = [
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
def test_optimizer8bit(dim1, dim2, gtype, optim_name):
if gtype == torch.bfloat16 and optim_name not in ['adam8bit_blockwise', 'lion8bit_blockwise']: pytest.skip()
if gtype == torch.bfloat16 and optim_name not in ["adam8bit_blockwise", "lion8bit_blockwise"]:
pytest.skip()
if dim1 == 1 and dim2 == 1:
return
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
......@@ -294,17 +312,12 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
absmax=bnb_optimizer.state[p2][max_val],
A=bnb_optimizer.state[p2][name2],
)
num_not_close = (
torch.isclose(
torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol
)
== 0
)
#assert num_not_close.sum().item() < 20
num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0
# assert num_not_close.sum().item() < 20
dequant_states.append(s1.clone())
err = torch.abs(p1 - p2)
relerr = err / (torch.abs(p1)+1e-9)
relerr = err / (torch.abs(p1) + 1e-9)
if g.dtype == torch.bfloat16:
assert err.mean() < 0.00015
assert relerr.mean() < 0.0016
......@@ -316,9 +329,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
relerrors.append(relerr.mean().item())
if i % 10 == 0 and i > 0:
for (name1, name2, qmap, max_val), s in zip(
str2statenames[optim_name], dequant_states
):
for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
s1cpy = s.clone()
raws1cpy = bnb_optimizer.state[p2][name2].clone()
qmap1 = bnb_optimizer.state[p2][qmap].clone()
......@@ -348,7 +359,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
)
torch.testing.assert_close(s1cpy, s1)
num_not_close = (torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0)
num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0
assert num_not_close.sum().item() < 20
# since Lion can have pretty noisy updates where things lie at the boundary
# allow up to 5 errors for Lion
......@@ -395,15 +406,11 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
for i in range(50):
step += 1
g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (
0.01 * i
)
g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (0.01 * i)
g2 = g1.clone()
p2.grad = g2
current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(
g1, gnorm_vec, step, 5
)
current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(g1, gnorm_vec, step, 5)
g1 = (g1.float() * gnorm_scale).to(gtype)
p1.grad = g1
......@@ -497,8 +504,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
@pytest.mark.parametrize("dim1", [2 * 1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("gtype", [torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("optim_name", ['paged_adamw'], ids=id_formatter("optim_name"))
@pytest.mark.parametrize("mode", ['bnb'], ids=id_formatter("mode"))
@pytest.mark.parametrize("optim_name", ["paged_adamw"], ids=id_formatter("optim_name"))
@pytest.mark.parametrize("mode", ["bnb"], ids=id_formatter("mode"))
@pytest.mark.benchmark
def test_stream_optimizer_bench(dim1, gtype, optim_name, mode):
layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)]))
......@@ -506,24 +513,24 @@ def test_stream_optimizer_bench(dim1, gtype, optim_name, mode):
layers1 = layers1.cuda()
large_tensor = None
if mode == 'torch':
if mode == "torch":
optim = str2optimizers[optim_name][0](layers1.parameters())
else:
optim = str2optimizers[optim_name][1](layers1.parameters())
# 12 GB
large_tensor = torch.empty((int(4.5e9),), device='cuda')
large_tensor = torch.empty((int(4.5e9),), device="cuda")
torch.cuda.synchronize()
time.sleep(5)
num_batches = 5
batches = torch.randn(num_batches, 128, dim1, device='cuda').to(gtype)
lbls = torch.randint(0, 10, size=(num_batches,128)).cuda()
batches = torch.randn(num_batches, 128, dim1, device="cuda").to(gtype)
lbls = torch.randint(0, 10, size=(num_batches, 128)).cuda()
for i in range(num_batches):
print(i)
b = batches[i]
if i ==2:
if i == 2:
torch.cuda.synchronize()
t0 = time.time()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment