Commit 5a4263f4 authored by Ruff's avatar Ruff Committed by Aarni Koskela
Browse files

Reformat with ruff-format

parent 02e30ca6
...@@ -3,14 +3,14 @@ import torch ...@@ -3,14 +3,14 @@ import torch
from bitsandbytes.triton.triton_utils import is_triton_available from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available(): if not is_triton_available():
def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): return None
else:
def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias):
return None
else:
import triton import triton
import triton.language as tl import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# This is a matmul kernel based on triton.ops.matmul # This is a matmul kernel based on triton.ops.matmul
# It is modified to support rowwise quantized input and global quantized weight # It is modified to support rowwise quantized input and global quantized weight
# It's purpose is fused matmul then dequantize # It's purpose is fused matmul then dequantize
...@@ -27,57 +27,82 @@ else: ...@@ -27,57 +27,82 @@ else:
for block_n in [32, 64, 128, 256]: for block_n in [32, 64, 128, 256]:
num_warps = 2 if block_n <= 64 else 4 num_warps = 2 if block_n <= 64 else 4
configs.append( configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, triton.Config(
num_stages=num_stages, num_warps=num_warps)) {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": 1},
num_stages=num_stages,
num_warps=num_warps,
),
)
# split_k # split_k
for split_k in [2, 4, 8, 16]: for split_k in [2, 4, 8, 16]:
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, configs.append(
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) triton.Config(
{"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": split_k},
num_stages=num_stages,
num_warps=num_warps,
pre_hook=init_to_zero("C"),
),
)
return configs return configs
@triton.autotune( @triton.autotune(
configs=[ configs=[
# basic configs for compute-bound matmuls # basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2),
# good for int8 # good for int8
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2),
*get_configs_io_bound(), *get_configs_io_bound(),
], ],
key=['M', 'N', 'K'], key=["M", "N", "K"],
prune_configs_by={ prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10},
'early_config_prune': early_config_prune, )
'perf_model': estimate_matmul_time, @triton.heuristics(
'top_k': 10 {
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
}, },
) )
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.jit @triton.jit
def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr, def _int8_matmul_mixed_dequantize(
stride_am, stride_ak, A,
stride_bk, stride_bn, B,
stride_cm, stride_cn, C,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, bias,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, state_x_ptr,
ACC_TYPE: tl.constexpr state_w_ptr,
M,
N,
K,
divfactor: tl.constexpr,
has_bias: tl.constexpr,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
SPLIT_K: tl.constexpr,
EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
): ):
# matrix multiplication # matrix multiplication
pid = tl.program_id(0) pid = tl.program_id(0)
...@@ -115,13 +140,13 @@ else: ...@@ -115,13 +140,13 @@ else:
b = tl.load(B) b = tl.load(B)
else: else:
k_remaining = K - k * (BLOCK_K * SPLIT_K) k_remaining = K - k * (BLOCK_K * SPLIT_K)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0)
acc += tl.dot(a, b) acc += tl.dot(a, b)
A += BLOCK_K * SPLIT_K * stride_ak A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk B += BLOCK_K * SPLIT_K * stride_bk
acc = (w_factor * (x_factor * (acc * divfactor))) acc = w_factor * (x_factor * (acc * divfactor))
acc = acc.to(C.dtype.element_ty) acc = acc.to(C.dtype.element_ty)
# conditionally add bias # conditionally add bias
...@@ -137,10 +162,9 @@ else: ...@@ -137,10 +162,9 @@ else:
else: else:
tl.atomic_add(C, acc, mask=mask) tl.atomic_add(C, acc, mask=mask)
def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias):
device = a.device device = a.device
divfactor = 1. / (127. * 127.) divfactor = 1.0 / (127.0 * 127.0)
has_bias = 0 if bias is None else 1 has_bias = 0 if bias is None else 1
# handle non-contiguous inputs if necessary # handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1: if a.stride(0) > 1 and a.stride(1) > 1:
...@@ -154,12 +178,28 @@ else: ...@@ -154,12 +178,28 @@ else:
# allocates output # allocates output
c = torch.empty((M, N), device=device, dtype=torch.float16) c = torch.empty((M, N), device=device, dtype=torch.float16)
# accumulator types # accumulator types
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 ACC_TYPE = tl.float32 # if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch int8_matmul_mixed_dequantize kernel # launch int8_matmul_mixed_dequantize kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), META["SPLIT_K"])
_int8_matmul_mixed_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, _int8_matmul_mixed_dequantize[grid](
a.stride(0), a.stride(1), a,
b.stride(0), b.stride(1), b,
c.stride(0), c.stride(1), c,
GROUP_M=8, ACC_TYPE=ACC_TYPE) bias,
state_x,
state_w,
M,
N,
K,
divfactor,
has_bias,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
GROUP_M=8,
ACC_TYPE=ACC_TYPE,
)
return c return c
...@@ -3,7 +3,9 @@ import torch ...@@ -3,7 +3,9 @@ import torch
from bitsandbytes.triton.triton_utils import is_triton_available from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available(): if not is_triton_available():
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): return None
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):
return None
else: else:
import triton import triton
import triton.language as tl import triton.language as tl
...@@ -17,7 +19,6 @@ else: ...@@ -17,7 +19,6 @@ else:
def init_to_zero(name): def init_to_zero(name):
return lambda nargs: nargs[name].zero_() return lambda nargs: nargs[name].zero_()
def get_configs_io_bound(): def get_configs_io_bound():
configs = [] configs = []
for num_stages in [2, 3, 4, 5, 6]: for num_stages in [2, 3, 4, 5, 6]:
...@@ -26,57 +27,82 @@ else: ...@@ -26,57 +27,82 @@ else:
for block_n in [32, 64, 128, 256]: for block_n in [32, 64, 128, 256]:
num_warps = 2 if block_n <= 64 else 4 num_warps = 2 if block_n <= 64 else 4
configs.append( configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, triton.Config(
num_stages=num_stages, num_warps=num_warps)) {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": 1},
num_stages=num_stages,
num_warps=num_warps,
),
)
# split_k # split_k
for split_k in [2, 4, 8, 16]: for split_k in [2, 4, 8, 16]:
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, configs.append(
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) triton.Config(
{"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": split_k},
num_stages=num_stages,
num_warps=num_warps,
pre_hook=init_to_zero("C"),
),
)
return configs return configs
@triton.autotune( @triton.autotune(
configs=[ configs=[
# basic configs for compute-bound matmuls # basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2),
# good for int8 # good for int8
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2),
*get_configs_io_bound(), *get_configs_io_bound(),
], ],
key=['M', 'N', 'K'], key=["M", "N", "K"],
prune_configs_by={ prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10},
'early_config_prune': early_config_prune, )
'perf_model': estimate_matmul_time, @triton.heuristics(
'top_k': 10 {
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
}, },
) )
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.jit @triton.jit
def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr, def _int8_matmul_rowwise_dequantize(
stride_am, stride_ak, A,
stride_bk, stride_bn, B,
stride_cm, stride_cn, C,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, bias,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, state_x_ptr,
ACC_TYPE: tl.constexpr state_w_ptr,
M,
N,
K,
divfactor,
has_bias: tl.constexpr,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
SPLIT_K: tl.constexpr,
EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
): ):
# matrix multiplication # matrix multiplication
pid = tl.program_id(0) pid = tl.program_id(0)
...@@ -114,13 +140,13 @@ else: ...@@ -114,13 +140,13 @@ else:
b = tl.load(B) b = tl.load(B)
else: else:
k_remaining = K - k * (BLOCK_K * SPLIT_K) k_remaining = K - k * (BLOCK_K * SPLIT_K)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0)
acc += tl.dot(a, b) acc += tl.dot(a, b)
A += BLOCK_K * SPLIT_K * stride_ak A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk B += BLOCK_K * SPLIT_K * stride_bk
acc = (w_factor * (x_factor * (acc * divfactor))) acc = w_factor * (x_factor * (acc * divfactor))
acc = acc.to(C.dtype.element_ty) acc = acc.to(C.dtype.element_ty)
if has_bias: if has_bias:
...@@ -135,9 +161,8 @@ else: ...@@ -135,9 +161,8 @@ else:
else: else:
tl.atomic_add(C, acc, mask=mask) tl.atomic_add(C, acc, mask=mask)
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):
divfactor = 1. / (127. * 127.) divfactor = 1.0 / (127.0 * 127.0)
has_bias = 0 if bias is None else 1 has_bias = 0 if bias is None else 1
...@@ -154,12 +179,28 @@ else: ...@@ -154,12 +179,28 @@ else:
# allocates output # allocates output
c = torch.empty((M, N), device=device, dtype=torch.float16) c = torch.empty((M, N), device=device, dtype=torch.float16)
# accumulator types # accumulator types
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 ACC_TYPE = tl.float32 # if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch int8_matmul_rowwise_dequantize kernel # launch int8_matmul_rowwise_dequantize kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), META["SPLIT_K"])
_int8_matmul_rowwise_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, _int8_matmul_rowwise_dequantize[grid](
a.stride(0), a.stride(1), a,
b.stride(0), b.stride(1), b,
c.stride(0), c.stride(1), c,
GROUP_M=8, ACC_TYPE=ACC_TYPE) bias,
state_x,
state_w,
M,
N,
K,
divfactor,
has_bias,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
GROUP_M=8,
ACC_TYPE=ACC_TYPE,
)
return c return c
...@@ -5,9 +5,10 @@ import torch ...@@ -5,9 +5,10 @@ import torch
from bitsandbytes.triton.triton_utils import is_triton_available from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available(): if not is_triton_available():
def quantize_columnwise_and_transpose(x: torch.Tensor): return None
else:
def quantize_columnwise_and_transpose(x: torch.Tensor):
return None
else:
import triton import triton
import triton.language as tl import triton.language as tl
...@@ -31,7 +32,7 @@ else: ...@@ -31,7 +32,7 @@ else:
triton.Config({}, num_warps=4), triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8), triton.Config({}, num_warps=8),
], ],
key=['n_elements'] key=["n_elements"],
) )
@triton.jit @triton.jit
def _quantize_columnwise_and_transpose( def _quantize_columnwise_and_transpose(
...@@ -39,7 +40,8 @@ else: ...@@ -39,7 +40,8 @@ else:
output_ptr, output_ptr,
output_maxs, output_maxs,
n_elements, n_elements,
M : tl.constexpr, N : tl.constexpr, M: tl.constexpr,
N: tl.constexpr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
P2: tl.constexpr, P2: tl.constexpr,
): ):
...@@ -52,7 +54,7 @@ else: ...@@ -52,7 +54,7 @@ else:
x = tl.load(x_ptr + offsets, mask=p2_arange_mask) x = tl.load(x_ptr + offsets, mask=p2_arange_mask)
abs_x = tl.abs(x) abs_x = tl.abs(x)
max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0) max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0)
output = tl.libdevice.llrint(127. * (x / max_val)) output = tl.libdevice.llrint(127.0 * (x / max_val))
new_start = pid * M new_start = pid * M
new_offsets = new_start + p2_arange new_offsets = new_start + p2_arange
...@@ -68,6 +70,6 @@ else: ...@@ -68,6 +70,6 @@ else:
assert x.is_cuda and output.is_cuda assert x.is_cuda and output.is_cuda
n_elements = output.numel() n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
_quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2) _quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2)
return output, output_maxs return output, output_maxs
import torch import torch
from bitsandbytes.triton.triton_utils import is_triton_available from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available(): if not is_triton_available():
def quantize_global_transpose(input): return None
def quantize_global(x: torch.Tensor): return None
else:
def quantize_global_transpose(input):
return None
def quantize_global(x: torch.Tensor):
return None
else:
import triton import triton
import triton.language as tl import triton.language as tl
# global quantize # global quantize
@triton.autotune( @triton.autotune(
configs=[ configs=[
triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4), triton.Config({"BLOCK_SIZE": 1024}, num_warps=4),
triton.Config({'BLOCK_SIZE': 2048,}, num_stages=1), triton.Config({"BLOCK_SIZE": 2048}, num_stages=1),
], ],
key=['n_elements'] key=["n_elements"],
) )
@triton.jit @triton.jit
def _quantize_global( def _quantize_global(
...@@ -34,35 +35,43 @@ else: ...@@ -34,35 +35,43 @@ else:
mask = offsets < n_elements mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask) x = tl.load(x_ptr + offsets, mask=mask)
absmax_inv = tl.load(absmax_inv_ptr) absmax_inv = tl.load(absmax_inv_ptr)
output = tl.libdevice.llrint(127. * (x * absmax_inv)) output = tl.libdevice.llrint(127.0 * (x * absmax_inv))
tl.store(output_ptr + offsets, output, mask=mask) tl.store(output_ptr + offsets, output, mask=mask)
def quantize_global(x: torch.Tensor): def quantize_global(x: torch.Tensor):
absmax = x.abs().max().unsqueeze(0) absmax = x.abs().max().unsqueeze(0)
absmax_inv = 1./ absmax absmax_inv = 1.0 / absmax
output = torch.empty(*x.shape, device='cuda', dtype=torch.int8) output = torch.empty(*x.shape, device="cuda", dtype=torch.int8)
assert x.is_cuda and output.is_cuda assert x.is_cuda and output.is_cuda
n_elements = output.numel() n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
_quantize_global[grid](x, absmax_inv, output, n_elements) _quantize_global[grid](x, absmax_inv, output, n_elements)
return output, absmax return output, absmax
# global quantize and transpose # global quantize and transpose
@triton.autotune( @triton.autotune(
configs=[ configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "GROUP_M": 8}, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "GROUP_M": 8}, num_warps=4),
# ... # ...
], ],
key=['M', 'N'] key=["M", "N"],
) )
@triton.jit @triton.jit
def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, stride_bn, stride_bm, M, N, def _quantize_global_transpose(
BLOCK_M : tl.constexpr, A,
BLOCK_N : tl.constexpr, absmax_inv_ptr,
GROUP_M : tl.constexpr): B,
stride_am,
stride_an,
stride_bn,
stride_bm,
M,
N,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
GROUP_M: tl.constexpr,
):
pid = tl.program_id(0) pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N grid_n = (N + BLOCK_N - 1) // BLOCK_N
...@@ -86,20 +95,30 @@ else: ...@@ -86,20 +95,30 @@ else:
B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn) B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn)
mask = (rm < M)[:, None] & (rn < N)[None, :] mask = (rm < M)[:, None] & (rn < N)[None, :]
output = tl.libdevice.llrint(127. * (a * absmax_inv)) output = tl.libdevice.llrint(127.0 * (a * absmax_inv))
tl.store(B, output, mask=mask) tl.store(B, output, mask=mask)
def quantize_global_transpose(input): def quantize_global_transpose(input):
absmax = input.abs().max().unsqueeze(0) absmax = input.abs().max().unsqueeze(0)
absmax_inv = 1./ absmax absmax_inv = 1.0 / absmax
M, N = input.shape M, N = input.shape
out = torch.empty(N, M, device='cuda', dtype=torch.int8) out = torch.empty(N, M, device="cuda", dtype=torch.int8)
assert out.size(0) == N and out.size(1) == M assert out.size(0) == N and out.size(1) == M
assert input.stride(0) == 1 or input.stride(1) == 1 assert input.stride(0) == 1 or input.stride(1) == 1
assert out.stride(0) == 1 or out.stride(1) == 1 assert out.stride(0) == 1 or out.stride(1) == 1
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),) grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
_quantize_global_transpose[grid](input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N) _quantize_global_transpose[grid](
input,
absmax_inv,
out,
input.stride(0),
input.stride(1),
out.stride(0),
out.stride(1),
M,
N,
)
return out, absmax return out, absmax
...@@ -5,9 +5,10 @@ import torch ...@@ -5,9 +5,10 @@ import torch
from bitsandbytes.triton.triton_utils import is_triton_available from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available(): if not is_triton_available():
def quantize_rowwise(x: torch.Tensor): return None
else:
def quantize_rowwise(x: torch.Tensor):
return None
else:
import triton import triton
import triton.language as tl import triton.language as tl
...@@ -29,7 +30,7 @@ else: ...@@ -29,7 +30,7 @@ else:
triton.Config({}, num_warps=4), triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8), triton.Config({}, num_warps=8),
], ],
key=['n_elements'] key=["n_elements"],
) )
@triton.jit @triton.jit
def _quantize_rowwise( def _quantize_rowwise(
...@@ -49,7 +50,7 @@ else: ...@@ -49,7 +50,7 @@ else:
abs_x = tl.abs(x) abs_x = tl.abs(x)
max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0) max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0)
output = tl.libdevice.llrint(127. * (x / max_val)) output = tl.libdevice.llrint(127.0 * (x / max_val))
tl.store(output_ptr + offsets, output, mask=row_mask) tl.store(output_ptr + offsets, output, mask=row_mask)
tl.store(output_maxs + pid, max_val) tl.store(output_maxs + pid, max_val)
......
This diff is collapsed.
...@@ -2,14 +2,14 @@ import torch ...@@ -2,14 +2,14 @@ import torch
import bitsandbytes as bnb import bitsandbytes as bnb
p = torch.nn.Parameter(torch.rand(10,10).cuda()) p = torch.nn.Parameter(torch.rand(10, 10).cuda())
a = torch.rand(10,10).cuda() a = torch.rand(10, 10).cuda()
p1 = p.data.sum().item() p1 = p.data.sum().item()
adam = bnb.optim.Adam([p]) adam = bnb.optim.Adam([p])
out = a*p out = a * p
loss = out.sum() loss = out.sum()
loss.backward() loss.backward()
adam.step() adam.step()
...@@ -17,5 +17,5 @@ adam.step() ...@@ -17,5 +17,5 @@ adam.step()
p2 = p.data.sum().item() p2 = p.data.sum().item()
assert p1 != p2 assert p1 != p2
print('SUCCESS!') print("SUCCESS!")
print('Installation was successful!') print("Installation was successful!")
...@@ -2,23 +2,18 @@ import torch ...@@ -2,23 +2,18 @@ import torch
from transformers import LlamaForCausalLM, LlamaTokenizer from transformers import LlamaForCausalLM, LlamaTokenizer
MAX_NEW_TOKENS = 128 MAX_NEW_TOKENS = 128
model_name = 'meta-llama/Llama-2-7b-hf' model_name = "meta-llama/Llama-2-7b-hf"
text = 'Hamburg is in which country?\n' text = "Hamburg is in which country?\n"
tokenizer = LlamaTokenizer.from_pretrained(model_name) tokenizer = LlamaTokenizer.from_pretrained(model_name)
input_ids = tokenizer(text, return_tensors="pt").input_ids input_ids = tokenizer(text, return_tensors="pt").input_ids
max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB' max_memory = f"{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB"
n_gpus = torch.cuda.device_count() n_gpus = torch.cuda.device_count()
max_memory = {i: max_memory for i in range(n_gpus)} max_memory = {i: max_memory for i in range(n_gpus)}
model = LlamaForCausalLM.from_pretrained( model = LlamaForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True, max_memory=max_memory)
model_name,
device_map='auto',
load_in_8bit=True,
max_memory=max_memory
)
generated_ids = model.generate(input_ids, max_length=MAX_NEW_TOKENS) generated_ids = model.generate(input_ids, max_length=MAX_NEW_TOKENS)
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True)) print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -35,7 +35,4 @@ def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog): ...@@ -35,7 +35,4 @@ def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog):
def test_get_cuda_bnb_library_path_nocublaslt(monkeypatch, cuda111_noblas_spec): def test_get_cuda_bnb_library_path_nocublaslt(monkeypatch, cuda111_noblas_spec):
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
assert ( assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda111_nocublaslt"
get_cuda_bnb_library_path(cuda111_noblas_spec).stem
== "libbitsandbytes_cuda111_nocublaslt"
)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment