Commit c3d87e44 authored by Tim Dettmers's avatar Tim Dettmers
Browse files

Added is_available_triton guard.

parent 7140c014
...@@ -184,7 +184,7 @@ class MatMulFP8Global(torch.autograd.Function): ...@@ -184,7 +184,7 @@ class MatMulFP8Global(torch.autograd.Function):
return grad_A, grad_B, None, None, None, None, None return grad_A, grad_B, None, None, None, None, None
class MatMul8bitMixed(torch.autograd.Function): class SwitchBackBnb(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
# default to pytorch behavior if inputs are empty # default to pytorch behavior if inputs are empty
...@@ -408,4 +408,4 @@ def switchback_bnb( ...@@ -408,4 +408,4 @@ def switchback_bnb(
state = state or MatmulLtState() state = state or MatmulLtState()
if threshold > 0.0: if threshold > 0.0:
state.threshold = threshold state.threshold = threshold
return MatMul8bitMixed.apply(A, B, out, bias, state) return SwitchBackBnb.apply(A, B, out, bias, state)
import math import math
import torch import torch
import time import time
import triton from bitsandbytes.triton.triton_utils import is_triton_available
import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# rowwise quantize if not is_triton_available():
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): return None
else:
# TODO: autotune this better. import triton
@triton.autotune( import triton.language as tl
configs=[ from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
triton.Config({}, num_stages=1, num_warps=8),
triton.Config({}, num_stages=2, num_warps=8),
triton.Config({}, num_stages=4, num_warps=8),
triton.Config({}, num_stages=8, num_warps=8),
triton.Config({}, num_stages=1),
triton.Config({}, num_stages=2),
triton.Config({}, num_stages=4),
triton.Config({}, num_stages=8),
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=['n_elements']
)
@triton.jit
def _dequantize_rowwise(
x_ptr,
state_x,
output_ptr,
inv_127,
n_elements,
BLOCK_SIZE: tl.constexpr,
P2: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
arange = tl.arange(0, P2)
offsets = block_start + arange
row_mask = arange < BLOCK_SIZE
x = tl.load(x_ptr + offsets, mask=row_mask)
max_val = tl.load(state_x + pid)
output = max_val * x * inv_127
tl.store(output_ptr + offsets, output, mask=row_mask)
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): # rowwise quantize
output = torch.empty(*x.shape, device=x.device, dtype=torch.float16)
P2 = int(2 ** (math.ceil(math.log2(x.shape[1])))) # TODO: autotune this better.
@triton.autotune(
configs=[
triton.Config({}, num_stages=1, num_warps=8),
triton.Config({}, num_stages=2, num_warps=8),
triton.Config({}, num_stages=4, num_warps=8),
triton.Config({}, num_stages=8, num_warps=8),
triton.Config({}, num_stages=1),
triton.Config({}, num_stages=2),
triton.Config({}, num_stages=4),
triton.Config({}, num_stages=8),
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=['n_elements']
)
@triton.jit
def _dequantize_rowwise(
x_ptr,
state_x,
output_ptr,
inv_127,
n_elements,
BLOCK_SIZE: tl.constexpr,
P2: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
arange = tl.arange(0, P2)
offsets = block_start + arange
row_mask = arange < BLOCK_SIZE
x = tl.load(x_ptr + offsets, mask=row_mask)
max_val = tl.load(state_x + pid)
output = max_val * x * inv_127
tl.store(output_ptr + offsets, output, mask=row_mask)
assert x.is_cuda and output.is_cuda def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
n_elements = output.numel() output = torch.empty(*x.shape, device=x.device, dtype=torch.float16)
grid = lambda meta: (x.shape[0],)
_dequantize_rowwise[grid](x, state_x, output, 1./127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
return output
assert x.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (x.shape[0],)
_dequantize_rowwise[grid](x, state_x, output, 1./127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
return output
import torch import torch
from bitsandbytes.triton.triton_utils import is_triton_available
import triton if not is_triton_available():
import triton.language as tl def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias): return None
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time else:
import triton
import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# This is a matmul kernel based on triton.ops.matmul
# It is modified to support rowwise quantized input and global quantized weight
# It's purpose is fused matmul then dequantize
# It does support bias.
def init_to_zero(name): # This is a matmul kernel based on triton.ops.matmul
return lambda nargs: nargs[name].zero_() # It is modified to support rowwise quantized input and global quantized weight
# It's purpose is fused matmul then dequantize
# It does support bias.
def get_configs_io_bound(): def init_to_zero(name):
configs = [] return lambda nargs: nargs[name].zero_()
for num_stages in [2, 3, 4, 5, 6]:
for block_m in [16, 32]:
for block_k in [32, 64]:
for block_n in [32, 64, 128, 256]:
num_warps = 2 if block_n <= 64 else 4
configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
num_stages=num_stages, num_warps=num_warps))
# split_k
for split_k in [2, 4, 8, 16]:
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
return configs
def get_configs_io_bound():
configs = []
for num_stages in [2, 3, 4, 5, 6]:
for block_m in [16, 32]:
for block_k in [32, 64]:
for block_n in [32, 64, 128, 256]:
num_warps = 2 if block_n <= 64 else 4
configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
num_stages=num_stages, num_warps=num_warps))
# split_k
for split_k in [2, 4, 8, 16]:
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
return configs
@triton.autotune(
configs=[
# basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
# good for int8
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
] + get_configs_io_bound(),
key=['M', 'N', 'K'],
prune_configs_by={
'early_config_prune': early_config_prune,
'perf_model': estimate_matmul_time,
'top_k': 10
},
)
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.jit
def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr
):
# matrix multiplication
pid = tl.program_id(0)
pid_z = tl.program_id(1)
grid_m = tl.cdiv(M, BLOCK_M)
grid_n = tl.cdiv(N, BLOCK_N)
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
# pointers
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
# rematerialize rm and rn to save registers @triton.autotune(
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) configs=[
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
# good for int8
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
] + get_configs_io_bound(),
key=['M', 'N', 'K'],
prune_configs_by={
'early_config_prune': early_config_prune,
'perf_model': estimate_matmul_time,
'top_k': 10
},
)
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.jit
def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr
):
# matrix multiplication
pid = tl.program_id(0)
pid_z = tl.program_id(1)
grid_m = tl.cdiv(M, BLOCK_M)
grid_n = tl.cdiv(N, BLOCK_N)
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
# pointers
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
w_factor = tl.load(state_w_ptr) # rematerialize rm and rn to save registers
x_factor = tl.load(state_x_ptr + ram)[:, None] rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
# acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) w_factor = tl.load(state_w_ptr)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) x_factor = tl.load(state_x_ptr + ram)[:, None]
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
if EVEN_K: # acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
a = tl.load(A) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
b = tl.load(B) for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
else: if EVEN_K:
k_remaining = K - k * (BLOCK_K * SPLIT_K) a = tl.load(A)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) b = tl.load(B)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) else:
acc += tl.dot(a, b) k_remaining = K - k * (BLOCK_K * SPLIT_K)
A += BLOCK_K * SPLIT_K * stride_ak a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
B += BLOCK_K * SPLIT_K * stride_bk b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
acc += tl.dot(a, b)
acc = (w_factor * (x_factor * (acc * divfactor))) A += BLOCK_K * SPLIT_K * stride_ak
acc = acc.to(C.dtype.element_ty) B += BLOCK_K * SPLIT_K * stride_bk
acc = (w_factor * (x_factor * (acc * divfactor)))
acc = acc.to(C.dtype.element_ty)
# conditionally add bias # conditionally add bias
if has_bias: if has_bias:
bias = tl.load(bias + rn).to(C.dtype.element_ty) bias = tl.load(bias + rn).to(C.dtype.element_ty)
acc = acc + bias[None, :] acc = acc + bias[None, :]
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :] mask = (rm < M)[:, None] & (rn < N)[None, :]
# handles write-back with reduction-splitting # handles write-back with reduction-splitting
if SPLIT_K == 1: if SPLIT_K == 1:
tl.store(C, acc, mask=mask) tl.store(C, acc, mask=mask)
else: else:
tl.atomic_add(C, acc, mask=mask) tl.atomic_add(C, acc, mask=mask)
def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias): def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias):
device = a.device device = a.device
divfactor = 1. / (127. * 127.) divfactor = 1. / (127. * 127.)
has_bias = 0 if bias is None else 1 has_bias = 0 if bias is None else 1
# handle non-contiguous inputs if necessary # handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1: if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous() a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1: if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous() b = b.contiguous()
# checks constraints # checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions" assert a.shape[1] == b.shape[0], "incompatible dimensions"
M, K = a.shape M, K = a.shape
_, N = b.shape _, N = b.shape
# allocates output # allocates output
c = torch.empty((M, N), device=device, dtype=torch.float16) c = torch.empty((M, N), device=device, dtype=torch.float16)
# accumulator types # accumulator types
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch int8_matmul_mixed_dequantize kernel # launch int8_matmul_mixed_dequantize kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_int8_matmul_mixed_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, _int8_matmul_mixed_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,
a.stride(0), a.stride(1), a.stride(0), a.stride(1),
b.stride(0), b.stride(1), b.stride(0), b.stride(1),
c.stride(0), c.stride(1), c.stride(0), c.stride(1),
GROUP_M=8, ACC_TYPE=ACC_TYPE) GROUP_M=8, ACC_TYPE=ACC_TYPE)
return c return c
import torch import torch
import triton from bitsandbytes.triton.triton_utils import is_triton_available
import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time if not is_triton_available():
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): return None
# This is a matmul kernel based on triton.ops.matmul else:
# It is modified to support rowwise quantized input and columnwise quantized weight import triton
# It's purpose is fused matmul then dequantize import triton.language as tl
# It does support bias. from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
def init_to_zero(name): # This is a matmul kernel based on triton.ops.matmul
return lambda nargs: nargs[name].zero_() # It is modified to support rowwise quantized input and columnwise quantized weight
# It's purpose is fused matmul then dequantize
# It does support bias.
def get_configs_io_bound():
configs = [] def init_to_zero(name):
for num_stages in [2, 3, 4, 5, 6]: return lambda nargs: nargs[name].zero_()
for block_m in [16, 32]:
for block_k in [32, 64]:
for block_n in [32, 64, 128, 256]: def get_configs_io_bound():
num_warps = 2 if block_n <= 64 else 4 configs = []
configs.append( for num_stages in [2, 3, 4, 5, 6]:
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, for block_m in [16, 32]:
num_stages=num_stages, num_warps=num_warps)) for block_k in [32, 64]:
# split_k for block_n in [32, 64, 128, 256]:
for split_k in [2, 4, 8, 16]: num_warps = 2 if block_n <= 64 else 4
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, configs.append(
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
return configs num_stages=num_stages, num_warps=num_warps))
# split_k
for split_k in [2, 4, 8, 16]:
@triton.autotune( configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
configs=[ num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
# basic configs for compute-bound matmuls return configs
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), @triton.autotune(
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), # basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
# good for int8 triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), # good for int8
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
] + get_configs_io_bound(), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
key=['M', 'N', 'K'], triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
prune_configs_by={ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
'early_config_prune': early_config_prune, triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
'perf_model': estimate_matmul_time, triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
'top_k': 10 ] + get_configs_io_bound(),
}, key=['M', 'N', 'K'],
) prune_configs_by={
@triton.heuristics({ 'early_config_prune': early_config_prune,
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, 'perf_model': estimate_matmul_time,
}) 'top_k': 10
@triton.jit },
def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr, )
stride_am, stride_ak, @triton.heuristics({
stride_bk, stride_bn, 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
stride_cm, stride_cn, })
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, @triton.jit
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr,
ACC_TYPE: tl.constexpr stride_am, stride_ak,
): stride_bk, stride_bn,
# matrix multiplication stride_cm, stride_cn,
pid = tl.program_id(0) BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
pid_z = tl.program_id(1) GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
grid_m = tl.cdiv(M, BLOCK_M) ACC_TYPE: tl.constexpr
grid_n = tl.cdiv(N, BLOCK_N) ):
# re-order program ID for better L2 performance # matrix multiplication
width = GROUP_M * grid_n pid = tl.program_id(0)
group_id = pid // width pid_z = tl.program_id(1)
group_size = min(grid_m - group_id * GROUP_M, GROUP_M) grid_m = tl.cdiv(M, BLOCK_M)
pid_m = group_id * GROUP_M + (pid % group_size) grid_n = tl.cdiv(N, BLOCK_N)
pid_n = (pid % width) // (group_size) # re-order program ID for better L2 performance
# do matrix multiplication width = GROUP_M * grid_n
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) group_id = pid // width
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) pid_m = group_id * GROUP_M + (pid % group_size)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) pid_n = (pid % width) // (group_size)
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) # do matrix multiplication
# pointers rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
# rematerialize rm and rn to save registers rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # pointers
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
w_factor = tl.load(state_w_ptr + rbn)[None, :]
x_factor = tl.load(state_x_ptr + ram)[:, None] # rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
# acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): w_factor = tl.load(state_w_ptr + rbn)[None, :]
if EVEN_K: x_factor = tl.load(state_x_ptr + ram)[:, None]
a = tl.load(A)
b = tl.load(B) # acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
k_remaining = K - k * (BLOCK_K * SPLIT_K)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
acc += tl.dot(a, b)
A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk
acc = (w_factor * (x_factor * (acc * divfactor)))
acc = acc.to(C.dtype.element_ty)
if has_bias:
bias = tl.load(bias + rn).to(C.dtype.element_ty)
acc = acc + bias[None, :]
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(C, acc, mask=mask)
else: else:
k_remaining = K - k * (BLOCK_K * SPLIT_K) tl.atomic_add(C, acc, mask=mask)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
acc += tl.dot(a, b) def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):
A += BLOCK_K * SPLIT_K * stride_ak divfactor = 1. / (127. * 127.)
B += BLOCK_K * SPLIT_K * stride_bk
has_bias = 0 if bias is None else 1
acc = (w_factor * (x_factor * (acc * divfactor)))
acc = acc.to(C.dtype.element_ty) device = a.device
# handle non-contiguous inputs if necessary
if has_bias: if a.stride(0) > 1 and a.stride(1) > 1:
bias = tl.load(bias + rn).to(C.dtype.element_ty) a = a.contiguous()
acc = acc + bias[None, :] if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) # checks constraints
mask = (rm < M)[:, None] & (rn < N)[None, :] assert a.shape[1] == b.shape[0], "incompatible dimensions"
# handles write-back with reduction-splitting M, K = a.shape
if SPLIT_K == 1: _, N = b.shape
tl.store(C, acc, mask=mask) # allocates output
else: c = torch.empty((M, N), device=device, dtype=torch.float16)
tl.atomic_add(C, acc, mask=mask) # accumulator types
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch int8_matmul_rowwise_dequantize kernel
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
divfactor = 1. / (127. * 127.) _int8_matmul_rowwise_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,
a.stride(0), a.stride(1),
has_bias = 0 if bias is None else 1 b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
device = a.device GROUP_M=8, ACC_TYPE=ACC_TYPE)
# handle non-contiguous inputs if necessary return c
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
M, K = a.shape
_, N = b.shape
# allocates output
c = torch.empty((M, N), device=device, dtype=torch.float16)
# accumulator types
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch int8_matmul_rowwise_dequantize kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_int8_matmul_rowwise_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
GROUP_M=8, ACC_TYPE=ACC_TYPE)
return c
import math import math
import torch import torch
import time import time
import triton from bitsandbytes.triton.triton_utils import is_triton_available
import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# This kernel does fused columnwise quantization and transpose. if not is_triton_available():
def quantize_columnwise_and_transpose(x: torch.Tensor): return None
else:
# TODO: autotune this better. import triton
@triton.autotune( import triton.language as tl
configs=[ from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
triton.Config({}, num_stages=1),
triton.Config({}, num_stages=2),
triton.Config({}, num_stages=4),
triton.Config({}, num_stages=8),
triton.Config({}, num_stages=16),
triton.Config({}, num_stages=1, num_warps=8),
triton.Config({}, num_stages=2, num_warps=8),
triton.Config({}, num_stages=4, num_warps=8),
triton.Config({}, num_stages=8, num_warps=8),
triton.Config({}, num_stages=16, num_warps=8),
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=['n_elements']
)
@triton.jit
def _quantize_columnwise_and_transpose(
x_ptr,
output_ptr,
output_maxs,
n_elements,
M : tl.constexpr, N : tl.constexpr,
BLOCK_SIZE: tl.constexpr,
P2: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid
p2_arange = tl.arange(0, P2)
p2_arange_mask = p2_arange < M
arange = p2_arange * N
offsets = block_start + arange
x = tl.load(x_ptr + offsets, mask=p2_arange_mask)
abs_x = tl.abs(x)
max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0)
output = tl.libdevice.llrint(127. * (x / max_val))
new_start = pid * M # This kernel does fused columnwise quantization and transpose.
new_offsets = new_start + p2_arange
tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask)
tl.store(output_maxs + pid, max_val)
def quantize_columnwise_and_transpose(x: torch.Tensor): # TODO: autotune this better.
M, N = x.shape @triton.autotune(
output = torch.empty(N, M, device=x.device, dtype=torch.int8) configs=[
output_maxs = torch.empty(x.shape[1], device=x.device, dtype=torch.float16) triton.Config({}, num_stages=1),
triton.Config({}, num_stages=2),
triton.Config({}, num_stages=4),
triton.Config({}, num_stages=8),
triton.Config({}, num_stages=16),
triton.Config({}, num_stages=1, num_warps=8),
triton.Config({}, num_stages=2, num_warps=8),
triton.Config({}, num_stages=4, num_warps=8),
triton.Config({}, num_stages=8, num_warps=8),
triton.Config({}, num_stages=16, num_warps=8),
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=['n_elements']
)
@triton.jit
def _quantize_columnwise_and_transpose(
x_ptr,
output_ptr,
output_maxs,
n_elements,
M : tl.constexpr, N : tl.constexpr,
BLOCK_SIZE: tl.constexpr,
P2: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid
p2_arange = tl.arange(0, P2)
p2_arange_mask = p2_arange < M
arange = p2_arange * N
offsets = block_start + arange
x = tl.load(x_ptr + offsets, mask=p2_arange_mask)
abs_x = tl.abs(x)
max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0)
output = tl.libdevice.llrint(127. * (x / max_val))
P2 = int(2 ** (math.ceil(math.log2(M)))) new_start = pid * M
new_offsets = new_start + p2_arange
tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask)
tl.store(output_maxs + pid, max_val)
assert x.is_cuda and output.is_cuda def quantize_columnwise_and_transpose(x: torch.Tensor):
n_elements = output.numel() M, N = x.shape
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) output = torch.empty(N, M, device=x.device, dtype=torch.int8)
_quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2) output_maxs = torch.empty(x.shape[1], device=x.device, dtype=torch.float16)
return output, output_maxs
P2 = int(2 ** (math.ceil(math.log2(M))))
assert x.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2)
return output, output_maxs
import math import math
import torch import torch
import time import time
import triton from bitsandbytes.triton.triton_utils import is_triton_available
import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# global quantize if not is_triton_available():
@triton.autotune( def quantize_global_transpose(input): return None
configs=[ def quantize_global(x: torch.Tensor): return None
triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4), else:
triton.Config({'BLOCK_SIZE': 2048,}, num_stages=1),
], import triton
key=['n_elements'] import triton.language as tl
) from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
@triton.jit
def _quantize_global(
x_ptr,
absmax_inv_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
absmax_inv = tl.load(absmax_inv_ptr)
output = tl.libdevice.llrint(127. * (x * absmax_inv))
tl.store(output_ptr + offsets, output, mask=mask)
def quantize_global(x: torch.Tensor): # global quantize
absmax = x.abs().max().unsqueeze(0) @triton.autotune(
absmax_inv = 1./ absmax configs=[
output = torch.empty(*x.shape, device='cuda', dtype=torch.int8) triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4),
assert x.is_cuda and output.is_cuda triton.Config({'BLOCK_SIZE': 2048,}, num_stages=1),
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_quantize_global[grid](x, absmax_inv, output, n_elements)
return output, absmax
],
key=['n_elements']
)
@triton.jit
def _quantize_global(
x_ptr,
absmax_inv_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
absmax_inv = tl.load(absmax_inv_ptr)
output = tl.libdevice.llrint(127. * (x * absmax_inv))
tl.store(output_ptr + offsets, output, mask=mask)
# global quantize and transpose def quantize_global(x: torch.Tensor):
@triton.autotune( absmax = x.abs().max().unsqueeze(0)
configs=[ absmax_inv = 1./ absmax
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4), output = torch.empty(*x.shape, device='cuda', dtype=torch.int8)
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4), assert x.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_quantize_global[grid](x, absmax_inv, output, n_elements)
return output, absmax
# ...
],
key=['M', 'N']
)
@triton.jit
def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, stride_bn, stride_bm, M, N,
BLOCK_M : tl.constexpr,
BLOCK_N : tl.constexpr,
GROUP_M : tl.constexpr):
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // group_size
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
A = A + (rm[:, None] * stride_am + rn[None, :] * stride_an)
mask = (rm < M)[:, None] & (rn < N)[None, :]
a = tl.load(A, mask=mask)
absmax_inv = tl.load(absmax_inv_ptr)
# rematerialize to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
output = tl.libdevice.llrint(127. * (a * absmax_inv)) # global quantize and transpose
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),
tl.store(B, output, mask=mask) # ...
],
key=['M', 'N']
)
@triton.jit
def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, stride_bn, stride_bm, M, N,
BLOCK_M : tl.constexpr,
BLOCK_N : tl.constexpr,
GROUP_M : tl.constexpr):
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // group_size
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
A = A + (rm[:, None] * stride_am + rn[None, :] * stride_an)
mask = (rm < M)[:, None] & (rn < N)[None, :]
a = tl.load(A, mask=mask)
absmax_inv = tl.load(absmax_inv_ptr)
# rematerialize to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
def quantize_global_transpose(input): output = tl.libdevice.llrint(127. * (a * absmax_inv))
absmax = input.abs().max().unsqueeze(0)
absmax_inv = 1./ absmax tl.store(B, output, mask=mask)
M, N = input.shape
out = torch.empty(N, M, device='cuda', dtype=torch.int8) def quantize_global_transpose(input):
absmax = input.abs().max().unsqueeze(0)
assert out.size(0) == N and out.size(1) == M absmax_inv = 1./ absmax
assert input.stride(0) == 1 or input.stride(1) == 1 M, N = input.shape
assert out.stride(0) == 1 or out.stride(1) == 1 out = torch.empty(N, M, device='cuda', dtype=torch.int8)
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),) assert out.size(0) == N and out.size(1) == M
_quantize_global_transpose[grid](input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N) assert input.stride(0) == 1 or input.stride(1) == 1
return out, absmax assert out.stride(0) == 1 or out.stride(1) == 1
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
_quantize_global_transpose[grid](input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N)
return out, absmax
import math import math
import torch import torch
import time import time
import triton
import triton.language as tl from bitsandbytes.triton.triton_utils import is_triton_available
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
if not is_triton_available():
# rowwise quantize def quantize_rowwise(x: torch.Tensor): return None
else:
# TODO: autotune this better.
@triton.autotune( import triton
configs=[ import triton.language as tl
triton.Config({}, num_stages=1, num_warps=8), from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
triton.Config({}, num_stages=2, num_warps=8),
triton.Config({}, num_stages=4, num_warps=8), # rowwise quantize
triton.Config({}, num_stages=8, num_warps=8),
triton.Config({}, num_stages=1), # TODO: autotune this better.
triton.Config({}, num_stages=2), @triton.autotune(
triton.Config({}, num_stages=4), configs=[
triton.Config({}, num_stages=8), triton.Config({}, num_stages=1, num_warps=8),
triton.Config({}, num_warps=1), triton.Config({}, num_stages=2, num_warps=8),
triton.Config({}, num_warps=2), triton.Config({}, num_stages=4, num_warps=8),
triton.Config({}, num_warps=4), triton.Config({}, num_stages=8, num_warps=8),
triton.Config({}, num_warps=8), triton.Config({}, num_stages=1),
], triton.Config({}, num_stages=2),
key=['n_elements'] triton.Config({}, num_stages=4),
) triton.Config({}, num_stages=8),
@triton.jit triton.Config({}, num_warps=1),
def _quantize_rowwise( triton.Config({}, num_warps=2),
x_ptr, triton.Config({}, num_warps=4),
output_ptr, triton.Config({}, num_warps=8),
output_maxs, ],
n_elements, key=['n_elements']
BLOCK_SIZE: tl.constexpr, )
P2: tl.constexpr, @triton.jit
): def _quantize_rowwise(
pid = tl.program_id(axis=0) x_ptr,
block_start = pid * BLOCK_SIZE output_ptr,
arange = tl.arange(0, P2) output_maxs,
offsets = block_start + arange n_elements,
row_mask = arange < BLOCK_SIZE BLOCK_SIZE: tl.constexpr,
x = tl.load(x_ptr + offsets, mask=row_mask) P2: tl.constexpr,
):
abs_x = tl.abs(x) pid = tl.program_id(axis=0)
max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0) block_start = pid * BLOCK_SIZE
output = tl.libdevice.llrint(127. * (x / max_val)) arange = tl.arange(0, P2)
tl.store(output_ptr + offsets, output, mask=row_mask) offsets = block_start + arange
tl.store(output_maxs + pid, max_val) row_mask = arange < BLOCK_SIZE
x = tl.load(x_ptr + offsets, mask=row_mask)
def quantize_rowwise(x: torch.Tensor):
output = torch.empty(*x.shape, device=x.device, dtype=torch.int8) abs_x = tl.abs(x)
output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16) max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0)
output = tl.libdevice.llrint(127. * (x / max_val))
P2 = int(2 ** (math.ceil(math.log2(x.shape[1])))) tl.store(output_ptr + offsets, output, mask=row_mask)
tl.store(output_maxs + pid, max_val)
assert x.is_cuda and output.is_cuda
n_elements = output.numel() def quantize_rowwise(x: torch.Tensor):
grid = lambda meta: (x.shape[0],) output = torch.empty(*x.shape, device=x.device, dtype=torch.int8)
_quantize_rowwise[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16)
return output, output_maxs
P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
assert x.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (x.shape[0],)
_quantize_rowwise[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
return output, output_maxs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment