Commit 5a4263f4 authored by Ruff's avatar Ruff Committed by Aarni Koskela
Browse files

Reformat with ruff-format

parent 02e30ca6
......@@ -3,14 +3,14 @@ import torch
from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available():
def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): return None
else:
def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias):
return None
else:
import triton
import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# This is a matmul kernel based on triton.ops.matmul
# It is modified to support rowwise quantized input and global quantized weight
# It's purpose is fused matmul then dequantize
......@@ -27,57 +27,82 @@ else:
for block_n in [32, 64, 128, 256]:
num_warps = 2 if block_n <= 64 else 4
configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
num_stages=num_stages, num_warps=num_warps))
triton.Config(
{"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": 1},
num_stages=num_stages,
num_warps=num_warps,
),
)
# split_k
for split_k in [2, 4, 8, 16]:
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
configs.append(
triton.Config(
{"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": split_k},
num_stages=num_stages,
num_warps=num_warps,
pre_hook=init_to_zero("C"),
),
)
return configs
@triton.autotune(
configs=[
# basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2),
# good for int8
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2),
*get_configs_io_bound(),
],
key=['M', 'N', 'K'],
prune_configs_by={
'early_config_prune': early_config_prune,
'perf_model': estimate_matmul_time,
'top_k': 10
key=["M", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10},
)
@triton.heuristics(
{
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
},
)
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.jit
def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr
def _int8_matmul_mixed_dequantize(
A,
B,
C,
bias,
state_x_ptr,
state_w_ptr,
M,
N,
K,
divfactor: tl.constexpr,
has_bias: tl.constexpr,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
SPLIT_K: tl.constexpr,
EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
):
# matrix multiplication
pid = tl.program_id(0)
......@@ -115,13 +140,13 @@ else:
b = tl.load(B)
else:
k_remaining = K - k * (BLOCK_K * SPLIT_K)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0)
acc += tl.dot(a, b)
A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk
acc = (w_factor * (x_factor * (acc * divfactor)))
acc = w_factor * (x_factor * (acc * divfactor))
acc = acc.to(C.dtype.element_ty)
# conditionally add bias
......@@ -137,10 +162,9 @@ else:
else:
tl.atomic_add(C, acc, mask=mask)
def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias):
device = a.device
divfactor = 1. / (127. * 127.)
divfactor = 1.0 / (127.0 * 127.0)
has_bias = 0 if bias is None else 1
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1:
......@@ -154,12 +178,28 @@ else:
# allocates output
c = torch.empty((M, N), device=device, dtype=torch.float16)
# accumulator types
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
ACC_TYPE = tl.float32 # if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch int8_matmul_mixed_dequantize kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_int8_matmul_mixed_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
GROUP_M=8, ACC_TYPE=ACC_TYPE)
grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), META["SPLIT_K"])
_int8_matmul_mixed_dequantize[grid](
a,
b,
c,
bias,
state_x,
state_w,
M,
N,
K,
divfactor,
has_bias,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
GROUP_M=8,
ACC_TYPE=ACC_TYPE,
)
return c
......@@ -3,7 +3,9 @@ import torch
from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available():
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): return None
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):
return None
else:
import triton
import triton.language as tl
......@@ -17,7 +19,6 @@ else:
def init_to_zero(name):
return lambda nargs: nargs[name].zero_()
def get_configs_io_bound():
configs = []
for num_stages in [2, 3, 4, 5, 6]:
......@@ -26,57 +27,82 @@ else:
for block_n in [32, 64, 128, 256]:
num_warps = 2 if block_n <= 64 else 4
configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
num_stages=num_stages, num_warps=num_warps))
triton.Config(
{"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": 1},
num_stages=num_stages,
num_warps=num_warps,
),
)
# split_k
for split_k in [2, 4, 8, 16]:
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
configs.append(
triton.Config(
{"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": split_k},
num_stages=num_stages,
num_warps=num_warps,
pre_hook=init_to_zero("C"),
),
)
return configs
@triton.autotune(
configs=[
# basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2),
# good for int8
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2),
*get_configs_io_bound(),
],
key=['M', 'N', 'K'],
prune_configs_by={
'early_config_prune': early_config_prune,
'perf_model': estimate_matmul_time,
'top_k': 10
key=["M", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10},
)
@triton.heuristics(
{
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
},
)
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.jit
def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr
def _int8_matmul_rowwise_dequantize(
A,
B,
C,
bias,
state_x_ptr,
state_w_ptr,
M,
N,
K,
divfactor,
has_bias: tl.constexpr,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
SPLIT_K: tl.constexpr,
EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
):
# matrix multiplication
pid = tl.program_id(0)
......@@ -114,13 +140,13 @@ else:
b = tl.load(B)
else:
k_remaining = K - k * (BLOCK_K * SPLIT_K)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0)
acc += tl.dot(a, b)
A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk
acc = (w_factor * (x_factor * (acc * divfactor)))
acc = w_factor * (x_factor * (acc * divfactor))
acc = acc.to(C.dtype.element_ty)
if has_bias:
......@@ -135,9 +161,8 @@ else:
else:
tl.atomic_add(C, acc, mask=mask)
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):
divfactor = 1. / (127. * 127.)
divfactor = 1.0 / (127.0 * 127.0)
has_bias = 0 if bias is None else 1
......@@ -154,12 +179,28 @@ else:
# allocates output
c = torch.empty((M, N), device=device, dtype=torch.float16)
# accumulator types
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
ACC_TYPE = tl.float32 # if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch int8_matmul_rowwise_dequantize kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_int8_matmul_rowwise_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
GROUP_M=8, ACC_TYPE=ACC_TYPE)
grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), META["SPLIT_K"])
_int8_matmul_rowwise_dequantize[grid](
a,
b,
c,
bias,
state_x,
state_w,
M,
N,
K,
divfactor,
has_bias,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
GROUP_M=8,
ACC_TYPE=ACC_TYPE,
)
return c
......@@ -5,9 +5,10 @@ import torch
from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available():
def quantize_columnwise_and_transpose(x: torch.Tensor): return None
else:
def quantize_columnwise_and_transpose(x: torch.Tensor):
return None
else:
import triton
import triton.language as tl
......@@ -31,7 +32,7 @@ else:
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=['n_elements']
key=["n_elements"],
)
@triton.jit
def _quantize_columnwise_and_transpose(
......@@ -39,7 +40,8 @@ else:
output_ptr,
output_maxs,
n_elements,
M : tl.constexpr, N : tl.constexpr,
M: tl.constexpr,
N: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
P2: tl.constexpr,
):
......@@ -52,7 +54,7 @@ else:
x = tl.load(x_ptr + offsets, mask=p2_arange_mask)
abs_x = tl.abs(x)
max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0)
output = tl.libdevice.llrint(127. * (x / max_val))
output = tl.libdevice.llrint(127.0 * (x / max_val))
new_start = pid * M
new_offsets = new_start + p2_arange
......@@ -68,6 +70,6 @@ else:
assert x.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
_quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2)
return output, output_maxs
import torch
from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available():
def quantize_global_transpose(input): return None
def quantize_global(x: torch.Tensor): return None
else:
def quantize_global_transpose(input):
return None
def quantize_global(x: torch.Tensor):
return None
else:
import triton
import triton.language as tl
# global quantize
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4),
triton.Config({'BLOCK_SIZE': 2048,}, num_stages=1),
triton.Config({"BLOCK_SIZE": 1024}, num_warps=4),
triton.Config({"BLOCK_SIZE": 2048}, num_stages=1),
],
key=['n_elements']
key=["n_elements"],
)
@triton.jit
def _quantize_global(
......@@ -34,35 +35,43 @@ else:
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
absmax_inv = tl.load(absmax_inv_ptr)
output = tl.libdevice.llrint(127. * (x * absmax_inv))
output = tl.libdevice.llrint(127.0 * (x * absmax_inv))
tl.store(output_ptr + offsets, output, mask=mask)
def quantize_global(x: torch.Tensor):
absmax = x.abs().max().unsqueeze(0)
absmax_inv = 1./ absmax
output = torch.empty(*x.shape, device='cuda', dtype=torch.int8)
absmax_inv = 1.0 / absmax
output = torch.empty(*x.shape, device="cuda", dtype=torch.int8)
assert x.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
_quantize_global[grid](x, absmax_inv, output, n_elements)
return output, absmax
# global quantize and transpose
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "GROUP_M": 8}, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "GROUP_M": 8}, num_warps=4),
# ...
],
key=['M', 'N']
key=["M", "N"],
)
@triton.jit
def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, stride_bn, stride_bm, M, N,
BLOCK_M : tl.constexpr,
BLOCK_N : tl.constexpr,
GROUP_M : tl.constexpr):
def _quantize_global_transpose(
A,
absmax_inv_ptr,
B,
stride_am,
stride_an,
stride_bn,
stride_bm,
M,
N,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
GROUP_M: tl.constexpr,
):
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
......@@ -86,20 +95,30 @@ else:
B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
output = tl.libdevice.llrint(127. * (a * absmax_inv))
output = tl.libdevice.llrint(127.0 * (a * absmax_inv))
tl.store(B, output, mask=mask)
def quantize_global_transpose(input):
absmax = input.abs().max().unsqueeze(0)
absmax_inv = 1./ absmax
absmax_inv = 1.0 / absmax
M, N = input.shape
out = torch.empty(N, M, device='cuda', dtype=torch.int8)
out = torch.empty(N, M, device="cuda", dtype=torch.int8)
assert out.size(0) == N and out.size(1) == M
assert input.stride(0) == 1 or input.stride(1) == 1
assert out.stride(0) == 1 or out.stride(1) == 1
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
_quantize_global_transpose[grid](input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N)
grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
_quantize_global_transpose[grid](
input,
absmax_inv,
out,
input.stride(0),
input.stride(1),
out.stride(0),
out.stride(1),
M,
N,
)
return out, absmax
......@@ -5,9 +5,10 @@ import torch
from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available():
def quantize_rowwise(x: torch.Tensor): return None
else:
def quantize_rowwise(x: torch.Tensor):
return None
else:
import triton
import triton.language as tl
......@@ -29,7 +30,7 @@ else:
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=['n_elements']
key=["n_elements"],
)
@triton.jit
def _quantize_rowwise(
......@@ -49,7 +50,7 @@ else:
abs_x = tl.abs(x)
max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0)
output = tl.libdevice.llrint(127. * (x / max_val))
output = tl.libdevice.llrint(127.0 * (x / max_val))
tl.store(output_ptr + offsets, output, mask=row_mask)
tl.store(output_maxs + pid, max_val)
......
This diff is collapsed.
......@@ -2,14 +2,14 @@ import torch
import bitsandbytes as bnb
p = torch.nn.Parameter(torch.rand(10,10).cuda())
a = torch.rand(10,10).cuda()
p = torch.nn.Parameter(torch.rand(10, 10).cuda())
a = torch.rand(10, 10).cuda()
p1 = p.data.sum().item()
adam = bnb.optim.Adam([p])
out = a*p
out = a * p
loss = out.sum()
loss.backward()
adam.step()
......@@ -17,5 +17,5 @@ adam.step()
p2 = p.data.sum().item()
assert p1 != p2
print('SUCCESS!')
print('Installation was successful!')
print("SUCCESS!")
print("Installation was successful!")
......@@ -2,23 +2,18 @@ import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
MAX_NEW_TOKENS = 128
model_name = 'meta-llama/Llama-2-7b-hf'
model_name = "meta-llama/Llama-2-7b-hf"
text = 'Hamburg is in which country?\n'
text = "Hamburg is in which country?\n"
tokenizer = LlamaTokenizer.from_pretrained(model_name)
input_ids = tokenizer(text, return_tensors="pt").input_ids
max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
max_memory = f"{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB"
n_gpus = torch.cuda.device_count()
max_memory = {i: max_memory for i in range(n_gpus)}
model = LlamaForCausalLM.from_pretrained(
model_name,
device_map='auto',
load_in_8bit=True,
max_memory=max_memory
)
model = LlamaForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True, max_memory=max_memory)
generated_ids = model.generate(input_ids, max_length=MAX_NEW_TOKENS)
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -35,7 +35,4 @@ def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog):
def test_get_cuda_bnb_library_path_nocublaslt(monkeypatch, cuda111_noblas_spec):
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
assert (
get_cuda_bnb_library_path(cuda111_noblas_spec).stem
== "libbitsandbytes_cuda111_nocublaslt"
)
assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda111_nocublaslt"
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment