"...git@developer.sourcefind.cn:OpenDAS/torch-harmonics.git" did not exist on "c7afb5461e6c4c9298ce5afd3ada1f8436cdc15d"
Commit 5a4263f4 authored by Ruff's avatar Ruff Committed by Aarni Koskela
Browse files

Reformat with ruff-format

parent 02e30ca6
...@@ -3,14 +3,14 @@ import torch ...@@ -3,14 +3,14 @@ import torch
from bitsandbytes.triton.triton_utils import is_triton_available from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available(): if not is_triton_available():
def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): return None
else:
def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias):
return None
else:
import triton import triton
import triton.language as tl import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# This is a matmul kernel based on triton.ops.matmul # This is a matmul kernel based on triton.ops.matmul
# It is modified to support rowwise quantized input and global quantized weight # It is modified to support rowwise quantized input and global quantized weight
# It's purpose is fused matmul then dequantize # It's purpose is fused matmul then dequantize
...@@ -27,58 +27,83 @@ else: ...@@ -27,58 +27,83 @@ else:
for block_n in [32, 64, 128, 256]: for block_n in [32, 64, 128, 256]:
num_warps = 2 if block_n <= 64 else 4 num_warps = 2 if block_n <= 64 else 4
configs.append( configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, triton.Config(
num_stages=num_stages, num_warps=num_warps)) {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": 1},
num_stages=num_stages,
num_warps=num_warps,
),
)
# split_k # split_k
for split_k in [2, 4, 8, 16]: for split_k in [2, 4, 8, 16]:
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, configs.append(
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) triton.Config(
{"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": split_k},
num_stages=num_stages,
num_warps=num_warps,
pre_hook=init_to_zero("C"),
),
)
return configs return configs
@triton.autotune( @triton.autotune(
configs=[ configs=[
# basic configs for compute-bound matmuls # basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2),
# good for int8 # good for int8
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2),
*get_configs_io_bound(), *get_configs_io_bound(),
], ],
key=['M', 'N', 'K'], key=["M", "N", "K"],
prune_configs_by={ prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10},
'early_config_prune': early_config_prune, )
'perf_model': estimate_matmul_time, @triton.heuristics(
'top_k': 10 {
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
}, },
) )
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.jit @triton.jit
def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr, def _int8_matmul_mixed_dequantize(
stride_am, stride_ak, A,
stride_bk, stride_bn, B,
stride_cm, stride_cn, C,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, bias,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, state_x_ptr,
ACC_TYPE: tl.constexpr state_w_ptr,
): M,
N,
K,
divfactor: tl.constexpr,
has_bias: tl.constexpr,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
SPLIT_K: tl.constexpr,
EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
):
# matrix multiplication # matrix multiplication
pid = tl.program_id(0) pid = tl.program_id(0)
pid_z = tl.program_id(1) pid_z = tl.program_id(1)
...@@ -115,13 +140,13 @@ else: ...@@ -115,13 +140,13 @@ else:
b = tl.load(B) b = tl.load(B)
else: else:
k_remaining = K - k * (BLOCK_K * SPLIT_K) k_remaining = K - k * (BLOCK_K * SPLIT_K)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0)
acc += tl.dot(a, b) acc += tl.dot(a, b)
A += BLOCK_K * SPLIT_K * stride_ak A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk B += BLOCK_K * SPLIT_K * stride_bk
acc = (w_factor * (x_factor * (acc * divfactor))) acc = w_factor * (x_factor * (acc * divfactor))
acc = acc.to(C.dtype.element_ty) acc = acc.to(C.dtype.element_ty)
# conditionally add bias # conditionally add bias
...@@ -137,10 +162,9 @@ else: ...@@ -137,10 +162,9 @@ else:
else: else:
tl.atomic_add(C, acc, mask=mask) tl.atomic_add(C, acc, mask=mask)
def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias):
device = a.device device = a.device
divfactor = 1. / (127. * 127.) divfactor = 1.0 / (127.0 * 127.0)
has_bias = 0 if bias is None else 1 has_bias = 0 if bias is None else 1
# handle non-contiguous inputs if necessary # handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1: if a.stride(0) > 1 and a.stride(1) > 1:
...@@ -154,12 +178,28 @@ else: ...@@ -154,12 +178,28 @@ else:
# allocates output # allocates output
c = torch.empty((M, N), device=device, dtype=torch.float16) c = torch.empty((M, N), device=device, dtype=torch.float16)
# accumulator types # accumulator types
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 ACC_TYPE = tl.float32 # if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch int8_matmul_mixed_dequantize kernel # launch int8_matmul_mixed_dequantize kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), META["SPLIT_K"])
_int8_matmul_mixed_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, _int8_matmul_mixed_dequantize[grid](
a.stride(0), a.stride(1), a,
b.stride(0), b.stride(1), b,
c.stride(0), c.stride(1), c,
GROUP_M=8, ACC_TYPE=ACC_TYPE) bias,
state_x,
state_w,
M,
N,
K,
divfactor,
has_bias,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
GROUP_M=8,
ACC_TYPE=ACC_TYPE,
)
return c return c
...@@ -3,7 +3,9 @@ import torch ...@@ -3,7 +3,9 @@ import torch
from bitsandbytes.triton.triton_utils import is_triton_available from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available(): if not is_triton_available():
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): return None
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):
return None
else: else:
import triton import triton
import triton.language as tl import triton.language as tl
...@@ -17,7 +19,6 @@ else: ...@@ -17,7 +19,6 @@ else:
def init_to_zero(name): def init_to_zero(name):
return lambda nargs: nargs[name].zero_() return lambda nargs: nargs[name].zero_()
def get_configs_io_bound(): def get_configs_io_bound():
configs = [] configs = []
for num_stages in [2, 3, 4, 5, 6]: for num_stages in [2, 3, 4, 5, 6]:
...@@ -26,58 +27,83 @@ else: ...@@ -26,58 +27,83 @@ else:
for block_n in [32, 64, 128, 256]: for block_n in [32, 64, 128, 256]:
num_warps = 2 if block_n <= 64 else 4 num_warps = 2 if block_n <= 64 else 4
configs.append( configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, triton.Config(
num_stages=num_stages, num_warps=num_warps)) {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": 1},
num_stages=num_stages,
num_warps=num_warps,
),
)
# split_k # split_k
for split_k in [2, 4, 8, 16]: for split_k in [2, 4, 8, 16]:
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, configs.append(
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) triton.Config(
{"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": split_k},
num_stages=num_stages,
num_warps=num_warps,
pre_hook=init_to_zero("C"),
),
)
return configs return configs
@triton.autotune( @triton.autotune(
configs=[ configs=[
# basic configs for compute-bound matmuls # basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2),
# good for int8 # good for int8
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2),
*get_configs_io_bound(), *get_configs_io_bound(),
], ],
key=['M', 'N', 'K'], key=["M", "N", "K"],
prune_configs_by={ prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10},
'early_config_prune': early_config_prune, )
'perf_model': estimate_matmul_time, @triton.heuristics(
'top_k': 10 {
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
}, },
) )
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.jit @triton.jit
def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr, def _int8_matmul_rowwise_dequantize(
stride_am, stride_ak, A,
stride_bk, stride_bn, B,
stride_cm, stride_cn, C,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, bias,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, state_x_ptr,
ACC_TYPE: tl.constexpr state_w_ptr,
): M,
N,
K,
divfactor,
has_bias: tl.constexpr,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
SPLIT_K: tl.constexpr,
EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
):
# matrix multiplication # matrix multiplication
pid = tl.program_id(0) pid = tl.program_id(0)
pid_z = tl.program_id(1) pid_z = tl.program_id(1)
...@@ -114,13 +140,13 @@ else: ...@@ -114,13 +140,13 @@ else:
b = tl.load(B) b = tl.load(B)
else: else:
k_remaining = K - k * (BLOCK_K * SPLIT_K) k_remaining = K - k * (BLOCK_K * SPLIT_K)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0)
acc += tl.dot(a, b) acc += tl.dot(a, b)
A += BLOCK_K * SPLIT_K * stride_ak A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk B += BLOCK_K * SPLIT_K * stride_bk
acc = (w_factor * (x_factor * (acc * divfactor))) acc = w_factor * (x_factor * (acc * divfactor))
acc = acc.to(C.dtype.element_ty) acc = acc.to(C.dtype.element_ty)
if has_bias: if has_bias:
...@@ -135,9 +161,8 @@ else: ...@@ -135,9 +161,8 @@ else:
else: else:
tl.atomic_add(C, acc, mask=mask) tl.atomic_add(C, acc, mask=mask)
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):
divfactor = 1. / (127. * 127.) divfactor = 1.0 / (127.0 * 127.0)
has_bias = 0 if bias is None else 1 has_bias = 0 if bias is None else 1
...@@ -154,12 +179,28 @@ else: ...@@ -154,12 +179,28 @@ else:
# allocates output # allocates output
c = torch.empty((M, N), device=device, dtype=torch.float16) c = torch.empty((M, N), device=device, dtype=torch.float16)
# accumulator types # accumulator types
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 ACC_TYPE = tl.float32 # if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch int8_matmul_rowwise_dequantize kernel # launch int8_matmul_rowwise_dequantize kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), META["SPLIT_K"])
_int8_matmul_rowwise_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, _int8_matmul_rowwise_dequantize[grid](
a.stride(0), a.stride(1), a,
b.stride(0), b.stride(1), b,
c.stride(0), c.stride(1), c,
GROUP_M=8, ACC_TYPE=ACC_TYPE) bias,
state_x,
state_w,
M,
N,
K,
divfactor,
has_bias,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
GROUP_M=8,
ACC_TYPE=ACC_TYPE,
)
return c return c
...@@ -5,9 +5,10 @@ import torch ...@@ -5,9 +5,10 @@ import torch
from bitsandbytes.triton.triton_utils import is_triton_available from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available(): if not is_triton_available():
def quantize_columnwise_and_transpose(x: torch.Tensor): return None
else:
def quantize_columnwise_and_transpose(x: torch.Tensor):
return None
else:
import triton import triton
import triton.language as tl import triton.language as tl
...@@ -15,23 +16,23 @@ else: ...@@ -15,23 +16,23 @@ else:
# TODO: autotune this better. # TODO: autotune this better.
@triton.autotune( @triton.autotune(
configs=[ configs=[
triton.Config({}, num_stages=1), triton.Config({}, num_stages=1),
triton.Config({}, num_stages=2), triton.Config({}, num_stages=2),
triton.Config({}, num_stages=4), triton.Config({}, num_stages=4),
triton.Config({}, num_stages=8), triton.Config({}, num_stages=8),
triton.Config({}, num_stages=16), triton.Config({}, num_stages=16),
triton.Config({}, num_stages=1, num_warps=8), triton.Config({}, num_stages=1, num_warps=8),
triton.Config({}, num_stages=2, num_warps=8), triton.Config({}, num_stages=2, num_warps=8),
triton.Config({}, num_stages=4, num_warps=8), triton.Config({}, num_stages=4, num_warps=8),
triton.Config({}, num_stages=8, num_warps=8), triton.Config({}, num_stages=8, num_warps=8),
triton.Config({}, num_stages=16, num_warps=8), triton.Config({}, num_stages=16, num_warps=8),
triton.Config({}, num_warps=1), triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2), triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4), triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8), triton.Config({}, num_warps=8),
], ],
key=['n_elements'] key=["n_elements"],
) )
@triton.jit @triton.jit
def _quantize_columnwise_and_transpose( def _quantize_columnwise_and_transpose(
...@@ -39,7 +40,8 @@ else: ...@@ -39,7 +40,8 @@ else:
output_ptr, output_ptr,
output_maxs, output_maxs,
n_elements, n_elements,
M : tl.constexpr, N : tl.constexpr, M: tl.constexpr,
N: tl.constexpr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
P2: tl.constexpr, P2: tl.constexpr,
): ):
...@@ -47,12 +49,12 @@ else: ...@@ -47,12 +49,12 @@ else:
block_start = pid block_start = pid
p2_arange = tl.arange(0, P2) p2_arange = tl.arange(0, P2)
p2_arange_mask = p2_arange < M p2_arange_mask = p2_arange < M
arange = p2_arange * N arange = p2_arange * N
offsets = block_start + arange offsets = block_start + arange
x = tl.load(x_ptr + offsets, mask=p2_arange_mask) x = tl.load(x_ptr + offsets, mask=p2_arange_mask)
abs_x = tl.abs(x) abs_x = tl.abs(x)
max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0) max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0)
output = tl.libdevice.llrint(127. * (x / max_val)) output = tl.libdevice.llrint(127.0 * (x / max_val))
new_start = pid * M new_start = pid * M
new_offsets = new_start + p2_arange new_offsets = new_start + p2_arange
...@@ -68,6 +70,6 @@ else: ...@@ -68,6 +70,6 @@ else:
assert x.is_cuda and output.is_cuda assert x.is_cuda and output.is_cuda
n_elements = output.numel() n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
_quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2) _quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2)
return output, output_maxs return output, output_maxs
This diff is collapsed.
...@@ -5,9 +5,10 @@ import torch ...@@ -5,9 +5,10 @@ import torch
from bitsandbytes.triton.triton_utils import is_triton_available from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available(): if not is_triton_available():
def quantize_rowwise(x: torch.Tensor): return None
else:
def quantize_rowwise(x: torch.Tensor):
return None
else:
import triton import triton
import triton.language as tl import triton.language as tl
...@@ -15,21 +16,21 @@ else: ...@@ -15,21 +16,21 @@ else:
# TODO: autotune this better. # TODO: autotune this better.
@triton.autotune( @triton.autotune(
configs=[ configs=[
triton.Config({}, num_stages=1, num_warps=8), triton.Config({}, num_stages=1, num_warps=8),
triton.Config({}, num_stages=2, num_warps=8), triton.Config({}, num_stages=2, num_warps=8),
triton.Config({}, num_stages=4, num_warps=8), triton.Config({}, num_stages=4, num_warps=8),
triton.Config({}, num_stages=8, num_warps=8), triton.Config({}, num_stages=8, num_warps=8),
triton.Config({}, num_stages=1), triton.Config({}, num_stages=1),
triton.Config({}, num_stages=2), triton.Config({}, num_stages=2),
triton.Config({}, num_stages=4), triton.Config({}, num_stages=4),
triton.Config({}, num_stages=8), triton.Config({}, num_stages=8),
triton.Config({}, num_warps=1), triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2), triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4), triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8), triton.Config({}, num_warps=8),
], ],
key=['n_elements'] key=["n_elements"],
) )
@triton.jit @triton.jit
def _quantize_rowwise( def _quantize_rowwise(
...@@ -49,7 +50,7 @@ else: ...@@ -49,7 +50,7 @@ else:
abs_x = tl.abs(x) abs_x = tl.abs(x)
max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0) max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0)
output = tl.libdevice.llrint(127. * (x / max_val)) output = tl.libdevice.llrint(127.0 * (x / max_val))
tl.store(output_ptr + offsets, output, mask=row_mask) tl.store(output_ptr + offsets, output, mask=row_mask)
tl.store(output_maxs + pid, max_val) tl.store(output_maxs + pid, max_val)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -35,7 +35,4 @@ def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog): ...@@ -35,7 +35,4 @@ def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog):
def test_get_cuda_bnb_library_path_nocublaslt(monkeypatch, cuda111_noblas_spec): def test_get_cuda_bnb_library_path_nocublaslt(monkeypatch, cuda111_noblas_spec):
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
assert ( assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda111_nocublaslt"
get_cuda_bnb_library_path(cuda111_noblas_spec).stem
== "libbitsandbytes_cuda111_nocublaslt"
)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment