Commit 5f3d9ada authored by Mitchell Wortsman's avatar Mitchell Wortsman
Browse files

triton-v1

parent 51f8bb71
......@@ -3,3 +3,4 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8, LinearInt8, Linear8bitLtThresh, LinearInt8Cast, Linear8bitLt2, Linear8bitLtMixed, LinearFP8Global, LinearFP4, LinearFP8Mixed
from .triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear
import torch
import torch.nn as nn
import time
from .triton_utils.v0.quantize_rowwise_nogroup import quantize_rowwise_nogroup
from .triton_utils.v0.quantize_columnwise_nogroup_transpose import quantize_columnwise_nogroup_transpose
from .triton_utils.v0.int8_matmul_rowwise_dequantize_bias import int8_matmul_rowwise_dequantize_bias
from .triton_utils.v0.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
from .triton_utils.v0.quantize_global import quantize_global, quantize_global_transpose
from .triton_utils.v0.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze, int8_matmul_mixed_dequanitze_bias
from .triton_utils.v0.fused_gelu_quantize import quantize_rowwise_nogroup_gelu, quantize_rowwise_nogroup_back_gelu
class _switchback(torch.autograd.Function):
@staticmethod
def forward(ctx, X_3D, W, bias):
X = X_3D.view(-1, X_3D.size(-1))
ctx.save_for_backward = X, W
X_int8, state_X = quantize_rowwise_nogroup(X)
W_int8, state_W = quantize_rowwise_nogroup(W)
return int8_matmul_rowwise_dequantize_bias(
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D.size()[:-1], -1)
@staticmethod
def backward(ctx, G_3D):
X, W = ctx.save_for_backward
G = G_3D.reshape(-1, G_3D.size(-1))
grad_X = grad_W = grad_bias = None
if ctx.needs_input_grad[0]:
G_int8, state_G = quantize_rowwise_nogroup(G)
W_int8, state_W = quantize_columnwise_nogroup_transpose(W)
grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W).view(
*G_3D.size()[:-1], -1
)
if ctx.needs_input_grad[1]:
grad_W = torch.matmul(G.t(), X.to(G.dtype))
if ctx.needs_input_grad[2]:
grad_bias = G.sum(dim=0)
return grad_X, grad_W, grad_bias
class SwitchBackLinear(nn.Linear):
def prepare_for_eval(self):
state_W = self.weight.abs().max(dim=1, keepdim=True)[0]
W_int8 = (127 * self.weight.float() / state_W).round().to(torch.int8)
state_W = state_W.squeeze()
self.register_buffer("W_int8", W_int8)
self.register_buffer("state_W", state_W)
del self.weight
def forward(self, x):
if self.training:
return _switchback.apply(x, self.weight, self.bias)
else:
if not hasattr(self, "state_W"):
self.prepare_for_eval()
X = x.view(-1, x.size(-1))
X_int8, state_X = quantize_rowwise_nogroup(X)
return int8_matmul_rowwise_dequantize_bias(
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
).view(*x.size()[:-1], -1)
class _switchback_global(torch.autograd.Function):
@staticmethod
def forward(ctx, X_3D, W, bias):
X = X_3D.view(-1, X_3D.size(-1))
X_int8, state_X = quantize_rowwise_nogroup(X)
W_int8, state_W = quantize_global(W)
ctx.save_for_backward = X, W
return int8_matmul_mixed_dequanitze_bias(
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D.size()[:-1], -1)
@staticmethod
def backward(ctx, G_3D):
G = G_3D.reshape(-1, G_3D.size(-1))
grad_X = grad_W = grad_bias = None
X, W = ctx.save_for_backward
if ctx.needs_input_grad[0]:
G_int8, state_G = quantize_rowwise_nogroup(G)
W_int8, state_W = quantize_global_transpose(W)
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W).view(
*G_3D.size()[:-1], -1
)
if ctx.needs_input_grad[1]:
grad_W = torch.matmul(G.t(), X.to(G.dtype))
if ctx.needs_input_grad[2]:
grad_bias = G.sum(dim=0)
return grad_X, grad_W, grad_bias
class SwitchBackGlobalLinear(nn.Linear):
def prepare_for_eval(self):
state_W = self.weight.abs().max()
W_int8 = (127 * self.weight.float() / state_W).round().to(torch.int8)
self.register_buffer("W_int8", W_int8)
self.register_buffer("state_W", state_W)
del self.weight
def forward(self, x):
if self.training:
return _switchback_global.apply(x, self.weight, self.bias)
else:
if not hasattr(self, "state_W"):
self.prepare_for_eval()
X = x.view(-1, x.size(-1))
X_int8, state_X = quantize_rowwise_nogroup(X)
return int8_matmul_mixed_dequanitze_bias(
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
).view(*x.size()[:-1], -1)
class LinearFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias=None):
X = input.view(-1, input.size(-1))
ctx.save_for_backward(X, weight, bias)
output = input.matmul(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output.view(*input.size()[:-1], -1)
@staticmethod
def backward(ctx, grad_output_3D):
input, weight, bias = ctx.saved_tensors
grad_output = grad_output_3D.reshape(-1, grad_output_3D.size(-1))
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.matmul(weight.to(grad_output.dtype)).view(*grad_output_3D.size()[:-1], -1)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().matmul(input.to(grad_output.dtype))
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_bias
class MyLinear(nn.Linear):
def forward(self, x):
return LinearFunction.apply(x, self.weight, self.bias)
class _switchback_mlp(torch.autograd.Function):
@staticmethod
def forward(ctx, X_3D, W1, B1, W2, B2):
X1 = X_3D.view(-1, X_3D.size(-1))
X1_int8, state_X1 = quantize_rowwise_nogroup(X1)
W1_int8, state_W1 = quantize_global(W1)
X2_pre = int8_matmul_mixed_dequanitze_bias(
X1_int8, W1_int8.t(), state_X1, state_W1, B1
)
# X2_v1 = torch.nn.functional.gelu(X2)
# X2_int8, state_X2, = quantize_rowwise_nogroup(X2_v1)
X2_int8, state_X2, X2 = quantize_rowwise_nogroup_gelu(X2_pre)
W2_int8, state_W2 = quantize_global(W2)
out = int8_matmul_mixed_dequanitze_bias(
X2_int8, W2_int8.t(), state_X2, state_W2, B2
)
ctx.save_for_backward = X1, W1, X2, X2_pre, W2
return out.view(*X_3D.size()[:-1], -1)
@staticmethod
def backward(ctx, G_3D):
G2 = G_3D.reshape(-1, G_3D.size(-1))
grad_X1 = grad_W1 = grad_B1 = grad_W2 = grad_B2 = None
X1, W1, X2, X2_pre, W2 = ctx.save_for_backward
G2_int8, state_G2 = quantize_rowwise_nogroup(G2)
W2_int8, state_W2 = quantize_global_transpose(W2)
G1 = int8_matmul_mixed_dequanitze(G2_int8, W2_int8.t(), state_G2, state_W2).view(
*G_3D.size()[:-1], -1
)
grad_W2 = torch.matmul(G2.t(), X2.to(G2.dtype))
grad_B2 = G2.sum(dim=0)
G1_int8, state_G1, G1 = quantize_rowwise_nogroup_back_gelu(G1, X2_pre)
if ctx.needs_input_grad[0]:
W1_int8, state_W1 = quantize_global_transpose(W1)
grad_X1 = int8_matmul_mixed_dequanitze(G1_int8, W1_int8.t(), state_G1, state_W1).view(
*G_3D.size()[:-1], -1
)
if ctx.needs_input_grad[1]:
grad_W1 = torch.matmul(G1.t(), X1.to(G1.dtype))
if ctx.needs_input_grad[2]:
grad_B1 = G1.sum(dim=0)
return grad_X1, grad_W1, grad_B1, grad_W2, grad_B2
class SwitchBackGlobalMLP(nn.Module):
def __init__(self, dim_in, dim_hidden):
super().__init__()
self.linear1 = nn.Linear(dim_in, dim_hidden)
self.linear2 = nn.Linear(dim_hidden, dim_in)
def forward(self, x):
return _switchback_mlp.apply(x, self.linear1.weight, self.linear1.bias, self.linear2.weight, self.linear2.bias)
\ No newline at end of file
import math
import torch
import time
import triton
import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
tl.libdevice
# TODO: autotune this better.
@triton.autotune(
configs=[
triton.Config({}, num_stages=1, num_warps=8),
triton.Config({}, num_stages=2, num_warps=8),
triton.Config({}, num_stages=4, num_warps=8),
triton.Config({}, num_stages=8, num_warps=8),
triton.Config({}, num_stages=1),
triton.Config({}, num_stages=2),
triton.Config({}, num_stages=4),
triton.Config({}, num_stages=8),
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=['n_elements']
)
@triton.jit
def _quantize_rowwise_nogroup_gelu(
x_ptr,
output_ptr,
output_maxs,
output_fp16,
n_elements,
BLOCK_SIZE: tl.constexpr,
P2: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
arange = tl.arange(0, P2)
offsets = block_start + arange
row_mask = arange < BLOCK_SIZE
x = tl.load(x_ptr + offsets, mask=row_mask)
cdf = 0.5 * (1.0 + tl.libdevice.tanh(x * 0.7978845608 * (1 + 0.044715 * x * x)))
x_new = x * cdf
tl.store(output_fp16 + offsets, x_new, mask=row_mask)
abs_x = tl.abs(x_new)
max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0)
output = tl.libdevice.llrint(127. * (x_new / max_val))
tl.store(output_ptr + offsets, output, mask=row_mask)
tl.store(output_maxs + pid, max_val)
def quantize_rowwise_nogroup_gelu(x: torch.Tensor):
output = torch.empty(*x.shape, device=x.device, dtype=torch.int8)
output_fp16 = torch.empty(*x.shape, device=x.device, dtype=torch.float16)
output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16)
P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
assert x.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (x.shape[0],)
_quantize_rowwise_nogroup_gelu[grid](x, output, output_maxs, output_fp16, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
return output, output_maxs, output_fp16
# TODO: autotune this better.
@triton.autotune(
configs=[
triton.Config({}, num_stages=1, num_warps=8),
triton.Config({}, num_stages=2, num_warps=8),
triton.Config({}, num_stages=4, num_warps=8),
triton.Config({}, num_stages=8, num_warps=8),
triton.Config({}, num_stages=1),
triton.Config({}, num_stages=2),
triton.Config({}, num_stages=4),
triton.Config({}, num_stages=8),
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=['n_elements']
)
@triton.jit
def _quantize_rowwise_nogroup_back_gelu(
x_ptr,
in_ptr,
output_ptr,
output_maxs,
output_fp16,
n_elements,
BLOCK_SIZE: tl.constexpr,
P2: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
arange = tl.arange(0, P2)
offsets = block_start + arange
row_mask = arange < BLOCK_SIZE
x_out = tl.load(x_ptr + offsets, mask=row_mask)
x_in = tl.load(in_ptr + offsets, mask=row_mask)
cdf = 0.5 * (1.0 + tl.libdevice.tanh(x_in * 0.7978845608 * (1 + 0.044715 * x_in * x_in)))
intermediate = tl.libdevice.tanh(x_in * 0.7978845608 * (1 + 0.044715 * x_in * x_in))
dcdf = 0.5 * (0.7978845608 + 0.1070322243 * x_in * x_in) * (1 - intermediate * intermediate)
x = x_out * (cdf + x_in * dcdf)
tl.store(output_fp16 + offsets, x, mask=row_mask)
abs_x = tl.abs(x)
max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0)
output = tl.libdevice.llrint(127. * (x / max_val))
tl.store(output_ptr + offsets, output, mask=row_mask)
tl.store(output_maxs + pid, max_val)
def quantize_rowwise_nogroup_back_gelu(x: torch.Tensor, y : torch.Tensor):
output = torch.empty(*x.shape, device=x.device, dtype=torch.int8)
output_fp16 = torch.empty(*x.shape, device=x.device, dtype=torch.float16)
output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16)
P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
assert x.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (x.shape[0],)
_quantize_rowwise_nogroup_back_gelu[grid](x, y, output, output_maxs, output_fp16, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
return output, output_maxs, output_fp16
# if __name__ == '__main__':
# torch.manual_seed(0)
# x = torch.randn(1280, 768).cuda().to(torch.float16)
# out = quantize_rowwise_nogroup(x)
# x_real = (127 * x.float() / x.abs().max(dim=1, keepdim=True)[0]).round().to(torch.int8)
# max2 = x.abs().max(1)[0]
# print(torch.allclose(out[1], max2))
# print( (x_real == out[0]).float().mean() )
# # for i in range(x.shape[0]):
# # print( (x_real[i, :] == out[0][i, :]).float().mean() )
# # print(out[0])
# # print(x_real)
# # import pdb; pdb.set_trace()
# # print(out[2])
# # print(out[2][:10])
# sums = x.sum(dim=0)
# #print(sums[:10])
# #print( (sums == out[2]).float().mean() )
# import pdb; pdb.set_trace()
# # import pdb; pdb.set_trace()
# # exit()
# # repeat = 16
# # for _ in range(8):
# # out = quantize_rowwise_nogroup(x)
# # triton_graph = torch.cuda.CUDAGraph()
# # with torch.cuda.graph(triton_graph):
# # out = quantize_rowwise_nogroup(x)
# # triton_graph.replay()
# # torch.cuda.synchronize()
# # start = time.time()
# # for _ in range(repeat):
# # triton_graph.replay()
# # torch.cuda.synchronize()
# # end = time.time()
# # print(out[0])
# # print(out[1])
# # print(x / x.abs().max(dim=1, keepdim=True)[0])
# # max1 = out[1]
# # max2 = x.abs().max(1)[0]
# # print(max1, max2)
# # print(torch.allclose(max1, max2))
# #print(f"time: {(end - start) / repeat * 1000:.3f} ms")
import torch
import triton
import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
def init_to_zero(name):
return lambda nargs: nargs[name].zero_()
def get_configs_io_bound():
configs = []
for num_stages in [2, 3, 4, 5, 6]:
for block_m in [16, 32]:
for block_k in [32, 64]:
for block_n in [32, 64, 128, 256]:
num_warps = 2 if block_n <= 64 else 4
configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
num_stages=num_stages, num_warps=num_warps))
# split_k
for split_k in [2, 4, 8, 16]:
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
return configs
@triton.autotune(
configs=[
# basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
# good for int8
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
] + get_configs_io_bound(),
key=['M', 'N', 'K'],
prune_configs_by={
'early_config_prune': early_config_prune,
'perf_model': estimate_matmul_time,
'top_k': 10
},
)
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.jit
def _kernel(A, B, C, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr
):
# matrix multiplication
pid = tl.program_id(0)
pid_z = tl.program_id(1)
grid_m = tl.cdiv(M, BLOCK_M)
grid_n = tl.cdiv(N, BLOCK_N)
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
# pointers
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
w_factor = tl.load(state_w_ptr)
x_factor = tl.load(state_x_ptr + ram)[:, None]
# acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
k_remaining = K - k * (BLOCK_K * SPLIT_K)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
acc += tl.dot(a, b)
A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk
acc = (w_factor * (x_factor * (acc * divfactor)))
acc = acc.to(C.dtype.element_ty)
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(C, acc, mask=mask)
else:
tl.atomic_add(C, acc, mask=mask)
def int8_matmul_mixed_dequanitze(a, b, state_x, state_w):
device = a.device
divfactor = 1. / (127. * 127.)
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
M, K = a.shape
_, N = b.shape
# allocates output
c = torch.empty((M, N), device=device, dtype=torch.float16)
# accumulator types
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_kernel[grid](a, b, c, state_x, state_w, M, N, K, divfactor,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
GROUP_M=8, ACC_TYPE=ACC_TYPE)
return c
@triton.autotune(
configs=[
# basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
# good for int8
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
] + get_configs_io_bound(),
key=['M', 'N', 'K'],
prune_configs_by={
'early_config_prune': early_config_prune,
'perf_model': estimate_matmul_time,
'top_k': 10
},
)
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.jit
def _kernel_bias(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr
):
# matrix multiplication
pid = tl.program_id(0)
pid_z = tl.program_id(1)
grid_m = tl.cdiv(M, BLOCK_M)
grid_n = tl.cdiv(N, BLOCK_N)
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
# pointers
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
w_factor = tl.load(state_w_ptr)
x_factor = tl.load(state_x_ptr + ram)[:, None]
# acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
k_remaining = K - k * (BLOCK_K * SPLIT_K)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
acc += tl.dot(a, b)
A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk
acc = (w_factor * (x_factor * (acc * divfactor)))
acc = acc.to(C.dtype.element_ty)
if has_bias:
bias = tl.load(bias + rn).to(C.dtype.element_ty)
acc = acc + bias[None, :]
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(C, acc, mask=mask)
else:
tl.atomic_add(C, acc, mask=mask)
def int8_matmul_mixed_dequanitze_bias(a, b, state_x, state_w, bias):
device = a.device
divfactor = 1. / (127. * 127.)
has_bias = 0 if bias is None else 1
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
M, K = a.shape
_, N = b.shape
# allocates output
c = torch.empty((M, N), device=device, dtype=torch.float16)
# accumulator types
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_kernel_bias[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
GROUP_M=8, ACC_TYPE=ACC_TYPE)
return c
import torch
import triton
import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
def init_to_zero(name):
return lambda nargs: nargs[name].zero_()
def get_configs_io_bound():
configs = []
for num_stages in [2, 3, 4, 5, 6]:
for block_m in [16, 32]:
for block_k in [32, 64]:
for block_n in [32, 64, 128, 256]:
num_warps = 2 if block_n <= 64 else 4
configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
num_stages=num_stages, num_warps=num_warps))
# split_k
for split_k in [2, 4, 8, 16]:
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
return configs
@triton.autotune(
configs=[
# basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
# good for int8
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
] + get_configs_io_bound(),
key=['M', 'N', 'K'],
prune_configs_by={
'early_config_prune': early_config_prune,
'perf_model': estimate_matmul_time,
'top_k': 10
},
)
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.jit
def _kernel(A, B, C, state_x_ptr, state_w_ptr, M, N, K, divfactor,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr
):
# matrix multiplication
pid = tl.program_id(0)
pid_z = tl.program_id(1)
grid_m = tl.cdiv(M, BLOCK_M)
grid_n = tl.cdiv(N, BLOCK_N)
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
# pointers
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
w_factor = tl.load(state_w_ptr + rbn)[None, :]
x_factor = tl.load(state_x_ptr + ram)[:, None]
# acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
k_remaining = K - k * (BLOCK_K * SPLIT_K)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
acc += tl.dot(a, b)
A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk
acc = (w_factor * (x_factor * (acc * divfactor)))
acc = acc.to(C.dtype.element_ty)
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(C, acc, mask=mask)
else:
tl.atomic_add(C, acc, mask=mask)
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w):
divfactor = 1. / (127. * 127.)
device = a.device
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
M, K = a.shape
_, N = b.shape
# allocates output
c = torch.empty((M, N), device=device, dtype=torch.float16)
# accumulator types
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_kernel[grid](a, b, c, state_x, state_w, M, N, K, divfactor,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
GROUP_M=8, ACC_TYPE=ACC_TYPE)
return c
import torch
import triton
import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
def init_to_zero(name):
return lambda nargs: nargs[name].zero_()
def get_configs_io_bound():
configs = []
for num_stages in [2, 3, 4, 5, 6]:
for block_m in [16, 32]:
for block_k in [32, 64]:
for block_n in [32, 64, 128, 256]:
num_warps = 2 if block_n <= 64 else 4
configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
num_stages=num_stages, num_warps=num_warps))
# split_k
for split_k in [2, 4, 8, 16]:
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
return configs
@triton.autotune(
configs=[
# basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
# good for int8
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
] + get_configs_io_bound(),
key=['M', 'N', 'K'],
prune_configs_by={
'early_config_prune': early_config_prune,
'perf_model': estimate_matmul_time,
'top_k': 10
},
)
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.jit
def _kernel(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr
):
# matrix multiplication
pid = tl.program_id(0)
pid_z = tl.program_id(1)
grid_m = tl.cdiv(M, BLOCK_M)
grid_n = tl.cdiv(N, BLOCK_N)
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
# pointers
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
w_factor = tl.load(state_w_ptr + rbn)[None, :]
x_factor = tl.load(state_x_ptr + ram)[:, None]
# acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
k_remaining = K - k * (BLOCK_K * SPLIT_K)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
acc += tl.dot(a, b)
A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk
acc = (w_factor * (x_factor * (acc * divfactor)))
acc = acc.to(C.dtype.element_ty)
if has_bias:
bias = tl.load(bias + rn).to(C.dtype.element_ty)
acc = acc + bias[None, :]
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(C, acc, mask=mask)
else:
tl.atomic_add(C, acc, mask=mask)
def int8_matmul_rowwise_dequantize_bias(a, b, state_x, state_w, bias):
#print(bias)
divfactor = 1. / (127. * 127.)
has_bias = 0 if bias is None else 1
if bias is not None:
bias = bias.contiguous()
device = a.device
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
M, K = a.shape
_, N = b.shape
# allocates output
c = torch.empty((M, N), device=device, dtype=torch.float16)
# accumulator types
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_kernel[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
GROUP_M=8, ACC_TYPE=ACC_TYPE)
return c
import math
import torch
import time
import triton
import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# TODO: autotune this better.
@triton.autotune(
configs=[
triton.Config({}, num_stages=1),
triton.Config({}, num_stages=2),
triton.Config({}, num_stages=4),
triton.Config({}, num_stages=8),
triton.Config({}, num_stages=16),
triton.Config({}, num_stages=1, num_warps=8),
triton.Config({}, num_stages=2, num_warps=8),
triton.Config({}, num_stages=4, num_warps=8),
triton.Config({}, num_stages=8, num_warps=8),
triton.Config({}, num_stages=16, num_warps=8),
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=['n_elements']
)
@triton.jit
def _quantize_columnwise_nogroup_transpose(
x_ptr,
output_ptr,
output_maxs,
n_elements,
M : tl.constexpr, N : tl.constexpr,
BLOCK_SIZE: tl.constexpr,
P2: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid
p2_arange = tl.arange(0, P2)
p2_arange_mask = p2_arange < M
arange = p2_arange * N
offsets = block_start + arange
x = tl.load(x_ptr + offsets, mask=p2_arange_mask)
abs_x = tl.abs(x)
max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0)
output = tl.libdevice.llrint(127. * (x / max_val))
new_start = pid * M
new_offsets = new_start + p2_arange
tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask)
tl.store(output_maxs + pid, max_val)
def quantize_columnwise_nogroup_transpose(x: torch.Tensor):
M, N = x.shape
output = torch.empty(N, M, device=x.device, dtype=torch.int8)
output_maxs = torch.empty(x.shape[1], device=x.device, dtype=torch.float16)
P2 = int(2 ** (math.ceil(math.log2(M))))
assert x.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_quantize_columnwise_nogroup_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2)
return output, output_maxs
if __name__ == '__main__':
torch.manual_seed(0)
x = torch.randn(1280, 768).cuda().to(torch.float16)
out = quantize_columnwise_nogroup_transpose(x)
x_real = x.t().float()
x_real_int8 = (127. * x_real / x_real.abs().max(dim=1, keepdim=True)[0]).round().to(torch.int8)
maxs = x_real.abs().max(dim=1, keepdim=True)[0].half()
#print(out[0][2,:])
print((out[0] == x_real_int8).float().mean())
print((out[1] == maxs[:, 0]).float().mean())
# print(out[0])
# print(out[1])
# print(out[0][2,:])
# print(x_real[2, :])
# print((out[0] != x_real).nonzero())
#import pdb; pdb.set_trace()
# repeat = 16
# for _ in range(8):
# out = quantize_columnwise_nogroup_transpose(x)
# triton_graph = torch.cuda.CUDAGraph()
# with torch.cuda.graph(triton_graph):
# out = quantize_columnwise_nogroup_transpose(x)
# triton_graph.replay()
# torch.cuda.synchronize()
# start = time.time()
# for _ in range(repeat):
# triton_graph.replay()
# torch.cuda.synchronize()
# end = time.time()
# print(out[0])
# print(out[1])
# print(x / x.abs().max(dim=0, keepdim=True)[0])
# x_real = (127 * (x / x.abs().max(dim=0, keepdim=True)[0])).round().to(torch.int8)
# max1 = out[1]
# max2 = x.abs().max(0)[0]
# print(max1, max2)
# import pdb; pdb.set_trace()
# print(torch.allclose(max1, max2))
# print(f"time: {(end - start) / repeat * 1000:.3f} ms")
import math
import torch
import time
import triton
import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# TODO: autotune this better.
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4),
triton.Config({'BLOCK_SIZE': 2048,}, num_stages=1),
],
key=['n_elements']
)
@triton.jit
def _quantize_global(
x_ptr,
absmax_inv_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
absmax_inv = tl.load(absmax_inv_ptr)
output = tl.libdevice.llrint(127. * (x * absmax_inv))
tl.store(output_ptr + offsets, output, mask=mask)
def quantize_global(x: torch.Tensor):
absmax = x.abs().max().unsqueeze(0)
absmax_inv = 1./ absmax
output = torch.empty(*x.shape, device='cuda', dtype=torch.int8)
assert x.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_quantize_global[grid](x, absmax_inv, output, n_elements)
return output, absmax
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),
# ...
],
key=['M', 'N']
)
@triton.jit
def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, stride_bn, stride_bm, M, N,
BLOCK_M : tl.constexpr,
BLOCK_N : tl.constexpr,
GROUP_M : tl.constexpr):
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // group_size
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
A = A + (rm[:, None] * stride_am + rn[None, :] * stride_an)
mask = (rm < M)[:, None] & (rn < N)[None, :]
a = tl.load(A, mask=mask)
absmax_inv = tl.load(absmax_inv_ptr)
# rematerialize to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
output = tl.libdevice.llrint(127. * (a * absmax_inv))
tl.store(B, output, mask=mask)
def quantize_global_transpose(input):
absmax = input.abs().max().unsqueeze(0)
absmax_inv = 1./ absmax
M, N = input.shape
out = torch.empty(N, M, device='cuda', dtype=torch.int8)
assert out.size(0) == N and out.size(1) == M
assert input.stride(0) == 1 or input.stride(1) == 1
assert out.stride(0) == 1 or out.stride(1) == 1
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
_quantize_global_transpose[grid](input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N)
return out, absmax
if __name__ == '__main__':
w = torch.randn(768, 1280).cuda().to(torch.float16)
W_int8, state_w = quantize_global(w)
r_state_w = w.abs().max()
r_W_int8 = ((127 * w.float()) / state_w).round().to(torch.int8)
print((r_W_int8 == W_int8).float().mean())
# print(r_W_int8)
# print(W_int8)
exit()
repeat = 16
for _ in range(8):
out = quantize_global(w)
triton_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(triton_graph):
out = quantize_global(w)
triton_graph.replay()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
triton_graph.replay()
torch.cuda.synchronize()
end = time.time()
print(f"time: {(end - start) / repeat * 1000:.3f} ms")
import math
import torch
import time
import triton
import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# TODO: autotune this better.
@triton.autotune(
configs=[
triton.Config({}, num_stages=1, num_warps=8),
triton.Config({}, num_stages=2, num_warps=8),
triton.Config({}, num_stages=4, num_warps=8),
triton.Config({}, num_stages=8, num_warps=8),
triton.Config({}, num_stages=1),
triton.Config({}, num_stages=2),
triton.Config({}, num_stages=4),
triton.Config({}, num_stages=8),
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=['n_elements']
)
@triton.jit
def _quantize_rowwise_nogroup(
x_ptr,
output_ptr,
output_maxs,
n_elements,
BLOCK_SIZE: tl.constexpr,
P2: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
arange = tl.arange(0, P2)
offsets = block_start + arange
row_mask = arange < BLOCK_SIZE
x = tl.load(x_ptr + offsets, mask=row_mask)
abs_x = tl.abs(x)
max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0)
output = tl.libdevice.llrint(127. * (x / max_val))
tl.store(output_ptr + offsets, output, mask=row_mask)
tl.store(output_maxs + pid, max_val)
def quantize_rowwise_nogroup(x: torch.Tensor):
output = torch.empty(*x.shape, device=x.device, dtype=torch.int8)
output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16)
P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
assert x.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (x.shape[0],)
_quantize_rowwise_nogroup[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
return output, output_maxs
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=['n_elements']
)
@triton.jit
def _experimental_quantize_rowwise_nogroup(
x_ptr,
output_ptr,
bias_grad_ptr,
output_maxs,
n_elements,
M: tl.constexpr, N: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
P2: tl.constexpr,
P2M: tl.constexpr,
):
pid = tl.program_id(axis=0)
if pid < M:
block_start = pid * BLOCK_SIZE
arange = tl.arange(0, P2)
offsets = block_start + arange
row_mask = arange < BLOCK_SIZE
x = tl.load(x_ptr + offsets, mask=row_mask)
abs_x = tl.abs(x)
max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0)
output = tl.libdevice.llrint(127. * (x / max_val))
tl.store(output_ptr + offsets, output, mask=row_mask)
tl.store(output_maxs + pid, max_val)
else:
real_pid = pid - M
arange_new = tl.arange(0, P2M)
mask_new = arange_new < M
offsets_new = real_pid + arange_new * N
new_x = tl.load(x_ptr + offsets_new, mask=mask_new)
s = tl.sum(tl.where(mask_new, new_x, 0).to(tl.float32), axis=0)
tl.store(bias_grad_ptr + real_pid, s)
def experimental_quantize_rowwise_nogroup(x: torch.Tensor):
M, N = x.shape
output = torch.empty(*x.shape, device=x.device, dtype=torch.int8)
output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16)
bias_grad = torch.empty(x.shape[1], device=x.device, dtype=torch.float16)
P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
P2M = int(2 ** (math.ceil(math.log2(x.shape[0]))))
assert x.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (x.shape[0] + x.shape[1],)
_experimental_quantize_rowwise_nogroup[grid](x, output, bias_grad, output_maxs, n_elements, M, N, BLOCK_SIZE=x.shape[1], P2=P2, P2M=P2M)
return output, output_maxs, bias_grad
if __name__ == '__main__':
torch.manual_seed(0)
x = torch.randn(1280, 768).cuda().to(torch.float16)
out = quantize_rowwise_nogroup(x)
x_real = (127 * x.float() / x.abs().max(dim=1, keepdim=True)[0]).round().to(torch.int8)
max2 = x.abs().max(1)[0]
print(torch.allclose(out[1], max2))
print( (x_real == out[0]).float().mean() )
# for i in range(x.shape[0]):
# print( (x_real[i, :] == out[0][i, :]).float().mean() )
# print(out[0])
# print(x_real)
# import pdb; pdb.set_trace()
# print(out[2])
# print(out[2][:10])
sums = x.sum(dim=0)
#print(sums[:10])
#print( (sums == out[2]).float().mean() )
import pdb; pdb.set_trace()
# import pdb; pdb.set_trace()
# exit()
# repeat = 16
# for _ in range(8):
# out = quantize_rowwise_nogroup(x)
# triton_graph = torch.cuda.CUDAGraph()
# with torch.cuda.graph(triton_graph):
# out = quantize_rowwise_nogroup(x)
# triton_graph.replay()
# torch.cuda.synchronize()
# start = time.time()
# for _ in range(repeat):
# triton_graph.replay()
# torch.cuda.synchronize()
# end = time.time()
# print(out[0])
# print(out[1])
# print(x / x.abs().max(dim=1, keepdim=True)[0])
# max1 = out[1]
# max2 = x.abs().max(1)[0]
# print(max1, max2)
# print(torch.allclose(max1, max2))
#print(f"time: {(end - start) / repeat * 1000:.3f} ms")
import torch
import json
from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, MyLinear
import time
# class AttentionOld(torch.nn.Module):
# def __init__(
# self,
# dim,
# num_heads=8,
# qkv_bias=True,
# scaled_cosine=False,
# scale_heads=False,
# attn_drop=0.,
# proj_drop=0.,
# linear_module=torch.nn.Linear,
# ):
# super().__init__()
# self.scaled_cosine = scaled_cosine
# self.scale_heads = scale_heads
# assert dim % num_heads == 0, 'dim should be divisible by num_heads'
# self.num_heads = num_heads
# self.head_dim = dim // num_heads
# self.scale = self.head_dim ** -0.5
# self.in_proj_linear = linear_module(dim, 3 * dim, bias = qkv_bias)
# self.attn_drop = torch.nn.Dropout(attn_drop)
# if self.scale_heads:
# self.head_scale = torch.nn.Parameter(torch.ones((num_heads, 1, 1)))
# else:
# self.head_scale = None
# self.out_proj = linear_module(dim, dim)
# self.out_drop = torch.nn.Dropout(proj_drop)
# def forward(self, x, attn_mask = None):
# L, N, C = x.shape
# q, k, v = self.in_proj_linear(x).chunk(3, dim=-1)
# q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
# k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
# v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
# q = q * self.scale
# attn = torch.bmm(q, k.transpose(-1, -2))
# if attn_mask is not None:
# if attn_mask.dtype == torch.bool:
# new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
# new_attn_mask.masked_fill_(attn_mask, float("-inf"))
# attn_mask = new_attn_mask
# attn += attn_mask
# attn = attn.softmax(dim=-1)
# attn = self.attn_drop(attn)
# x = torch.bmm(attn, v)
# x = x.transpose(0, 1).reshape(L, N, C)
# x = self.out_proj(x)
# x = self.out_drop(x)
# return x
class Attention(torch.nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=True,
scaled_cosine=False,
scale_heads=False,
attn_drop=0.,
proj_drop=0.,
linear_module=torch.nn.Linear,
):
super().__init__()
self.scaled_cosine = scaled_cosine
self.scale_heads = scale_heads
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.ln = torch.nn.LayerNorm(dim)
self.in_proj_linear = linear_module(dim, 3 * dim, bias = qkv_bias)
self.attn_drop = torch.nn.Dropout(attn_drop)
if self.scale_heads:
self.head_scale = torch.nn.Parameter(torch.ones((num_heads, 1, 1)))
else:
self.head_scale = None
self.out_proj = linear_module(dim, dim)
self.out_drop = torch.nn.Dropout(proj_drop)
def forward(self, x, attn_mask = None):
q, k, v = self.in_proj_linear(self.ln(x)).chunk(3, dim=-1)
x = torch.compile(torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask))
x = self.out_proj(x)
return x
if __name__ == '__main__':
for dim in [1024, 1280, 1408, 1664, 2048]:
for batch in [2**14, 2**15, 2**16, 2**17]:
# if dim != 4096 or batch != 2**17:
# continue
x1 = torch.randn( batch // 256, 256, dim ).cuda().requires_grad_(True)
qu = torch.randn( batch // 256, 256, dim ).cuda().requires_grad_(True)
ke = torch.randn( batch // 256, 256, dim ).cuda().requires_grad_(True)
va = torch.randn( batch // 256, 256, dim ).cuda().requires_grad_(True)
standard = Attention(dim).cuda()
my_standard = Attention(dim, linear_module=MyLinear).cuda()
sb = Attention(dim, linear_module=SwitchBackGlobalLinear).cuda()
standard_compiled = torch.compile(standard)
ln_model = torch.nn.Sequential(
torch.nn.LayerNorm(dim),
torch.nn.LayerNorm(dim),
).cuda()
ln_model_compiled = torch.compile(
ln_model
)
gelu_model = torch.nn.Sequential(
torch.nn.GELU(),
).cuda()
gelu_model_compiled = torch.compile(
gelu_model
)
print('Model part 2')
repeat = 32
info = {'repeat' : repeat, 'batch_size' : batch, 'dim' : dim}
k = 'attn'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out_attn = torch.nn.functional.scaled_dot_product_attention(qu, ke, va)
((2 ** 16) * out_attn).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out_attn = torch.nn.functional.scaled_dot_product_attention(qu, ke, va)
((2 ** 16) * out_attn).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'ln'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out = ln_model(x1)
((2 ** 16) * out).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out = ln_model(x1)
((2 ** 16) * out).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
x1.grad.zero_()
k = 'ln_compiled'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out = ln_model_compiled(x1)
((2 ** 16) * out).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out = ln_model_compiled(x1)
((2 ** 16) * out).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'gelu'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out = gelu_model(x1)
((2 ** 16) * out).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out = gelu_model(x1)
((2 ** 16) * out).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
x1.grad.zero_()
k = 'gelu_compiled'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out = gelu_model_compiled(x1)
((2 ** 16) * out).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out = gelu_model_compiled(x1)
((2 ** 16) * out).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
x1.grad.zero_()
k = 'standard'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out_standard = standard(x1)
((2 ** 16) * out_standard).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out_standard = standard(x1)
((2 ** 16) * out_standard).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
x1.grad.zero_()
k = 'my_standard'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out_my_standard = my_standard(x1)
((2 ** 16) * out_my_standard).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out_my_standard = my_standard(x1)
((2 ** 16) * out_my_standard).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
#
#
x1.grad.zero_()
k = 'standard_compiled'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out_standard_compiled = standard_compiled(x1)
((2 ** 16) * out_standard_compiled).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out_standard_compiled = standard_compiled(x1)
((2 ** 16) * out_standard_compiled).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
x1.grad.zero_()
k = 'sb'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out_sb = sb(x1)
((2 ** 16) * out_sb).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out_sb = sb(x1)
((2 ** 16) * out_sb).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
info_json = json.dumps(info)
with open("tests/triton_tests/attn_info_ln.jsonl", "a") as file:
file.write(info_json + "\n")
#exit()
# err_fused = (out_standard - out_fused).abs().mean()
# err_sb = (out_standard - out_sb).abs().mean()
# print('OUT', err_fused, err_sb)
# err_fused = (standard[d].weight.grad - fused_mlp.linear2.weight.grad).abs().mean()
# err_sb = (standard[d].weight.grad - sb[d].weight.grad).abs().mean()
# print('GW2', err_fused, err_sb)
# err_fused = (standard[0].weight.grad - fused_mlp.linear1.weight.grad).abs().mean()
# err_sb = (standard[0].weight.grad - sb[0].weight.grad).abs().mean()
# print('GW1', err_fused, err_sb)
# err_fused = (x1.grad - x2.grad).abs().mean()
# err_sb = (x1.grad - x3.grad).abs().mean()
# print('GX1', err_fused, err_sb)
# import pdb; pdb.set_trace()
# # NO GELU, ST GRADIENTS, EVERYTHING FINE.
\ No newline at end of file
{"repeat": 32, "batch_size": 16384, "dim": 1024, "attn": 2.1414458751678467, "ln": 1.6365647315979004, "ln_compiled": 1.799367368221283, "gelu": 1.0930374264717102, "gelu_compiled": 1.094818115234375, "standard": 4.159651696681976, "my_standard": 4.696495831012726, "standard_compiled": 3.675594925880432, "sb": 4.1465312242507935}
{"repeat": 32, "batch_size": 32768, "dim": 1024, "attn": 4.100345075130463, "ln": 3.1594187021255493, "ln_compiled": 3.437422215938568, "gelu": 2.109348773956299, "gelu_compiled": 2.11450457572937, "standard": 7.706902921199799, "my_standard": 8.799396455287933, "standard_compiled": 6.735652685165405, "sb": 7.66376405954361}
{"repeat": 32, "batch_size": 65536, "dim": 1024, "attn": 7.953710854053497, "ln": 6.236426532268524, "ln_compiled": 6.746955215930939, "gelu": 4.164382815361023, "gelu_compiled": 4.171714186668396, "standard": 14.894917607307434, "my_standard": 17.042435705661774, "standard_compiled": 12.985721230506897, "sb": 14.6140456199646}
{"repeat": 32, "batch_size": 131072, "dim": 1024, "attn": 15.638880431652069, "ln": 12.333884835243225, "ln_compiled": 13.272866606712341, "gelu": 8.228793740272522, "gelu_compiled": 8.243747055530548, "standard": 29.425136744976044, "my_standard": 35.08377820253372, "standard_compiled": 25.69487690925598, "sb": 28.760001063346863}
{"repeat": 32, "batch_size": 16384, "dim": 1280, "attn": 2.627238631248474, "ln": 2.0098239183425903, "ln_compiled": 2.4197474122047424, "gelu": 1.3455823063850403, "gelu_compiled": 1.35069340467453, "standard": 5.554787814617157, "my_standard": 6.2290579080581665, "standard_compiled": 5.132324993610382, "sb": 5.4178386926651}
{"repeat": 32, "batch_size": 32768, "dim": 1280, "attn": 5.0596073269844055, "ln": 3.903590142726898, "ln_compiled": 4.719957709312439, "gelu": 2.6203468441963196, "gelu_compiled": 2.627365291118622, "standard": 10.546617209911346, "my_standard": 11.850126087665558, "standard_compiled": 9.685918688774109, "sb": 10.088451206684113}
{"repeat": 32, "batch_size": 65536, "dim": 1280, "attn": 9.845800697803497, "ln": 7.711298763751984, "ln_compiled": 9.292080998420715, "gelu": 5.172915756702423, "gelu_compiled": 5.180932581424713, "standard": 21.371990442276, "my_standard": 23.921720683574677, "standard_compiled": 19.669152796268463, "sb": 20.267993211746216}
{"repeat": 32, "batch_size": 131072, "dim": 1280, "attn": 19.375711679458618, "ln": 15.333592891693115, "ln_compiled": 18.245264887809753, "gelu": 10.264746844768524, "gelu_compiled": 10.283775627613068, "standard": 41.79700464010239, "my_standard": 45.84744572639465, "standard_compiled": 38.35208714008331, "sb": 38.35364431142807}
{"repeat": 32, "batch_size": 16384, "dim": 1408, "attn": 2.9110386967658997, "ln": 2.1998360753059387, "ln_compiled": 2.581551671028137, "gelu": 1.4731436967849731, "gelu_compiled": 1.478634774684906, "standard": 6.764143705368042, "my_standard": 7.331632077693939, "standard_compiled": 6.24605268239975, "sb": 6.325609982013702}
{"repeat": 32, "batch_size": 32768, "dim": 1408, "attn": 5.542516708374023, "ln": 4.289716482162476, "ln_compiled": 5.065307021141052, "gelu": 2.8742849826812744, "gelu_compiled": 2.882353961467743, "standard": 12.749537825584412, "my_standard": 13.79828155040741, "standard_compiled": 11.728867888450623, "sb": 11.642806231975555}
{"repeat": 32, "batch_size": 65536, "dim": 1408, "attn": 10.80312579870224, "ln": 8.471302688121796, "ln_compiled": 9.96796041727066, "gelu": 5.681410431861877, "gelu_compiled": 5.6905597448349, "standard": 25.19702911376953, "my_standard": 27.226239442825317, "standard_compiled": 23.22910726070404, "sb": 22.682294249534607}
{"repeat": 32, "batch_size": 131072, "dim": 1408, "attn": 21.284908056259155, "ln": 16.85701310634613, "ln_compiled": 19.643358886241913, "gelu": 11.292420327663422, "gelu_compiled": 11.314474046230316, "standard": 50.06787180900574, "my_standard": 54.29378151893616, "standard_compiled": 44.58653926849365, "sb": 45.359253883361816}
{"repeat": 32, "batch_size": 16384, "dim": 1664, "attn": 3.382459282875061, "ln": 2.6206374168395996, "ln_compiled": 2.9666870832443237, "gelu": 1.7263293266296387, "gelu_compiled": 1.7317384481430054, "standard": 8.414775133132935, "my_standard": 9.117811918258667, "standard_compiled": 7.7542513608932495, "sb": 7.70898163318634}
{"repeat": 32, "batch_size": 32768, "dim": 1664, "attn": 6.468378007411957, "ln": 5.125559866428375, "ln_compiled": 5.791269242763519, "gelu": 3.3864825963974, "gelu_compiled": 3.3920034766197205, "standard": 16.016244888305664, "my_standard": 17.25083589553833, "standard_compiled": 14.60808515548706, "sb": 14.347739517688751}
{"repeat": 32, "batch_size": 65536, "dim": 1664, "attn": 12.645229697227478, "ln": 10.13532280921936, "ln_compiled": 11.427387595176697, "gelu": 6.6957250237464905, "gelu_compiled": 6.711684167385101, "standard": 31.792201101779938, "my_standard": 34.31189805269241, "standard_compiled": 29.10037338733673, "sb": 28.3128023147583}
{"repeat": 32, "batch_size": 131072, "dim": 1664, "attn": 24.970605969429016, "ln": 20.182937383651733, "ln_compiled": 22.7489173412323, "gelu": 13.326868414878845, "gelu_compiled": 13.345755636692047, "standard": 63.46555054187775, "my_standard": 70.19880414009094, "standard_compiled": 56.40875548124313, "sb": 56.22846633195877}
{"repeat": 32, "batch_size": 16384, "dim": 2048, "attn": 4.080049693584442, "ln": 3.2655522227287292, "ln_compiled": 3.3329352736473083, "gelu": 2.108432352542877, "gelu_compiled": 2.114713191986084, "standard": 11.370822787284851, "my_standard": 12.234866619110107, "standard_compiled": 10.377615690231323, "sb": 10.209612548351288}
{"repeat": 32, "batch_size": 32768, "dim": 2048, "attn": 7.74645060300827, "ln": 6.418220698833466, "ln_compiled": 6.55733048915863, "gelu": 4.163652658462524, "gelu_compiled": 4.171028733253479, "standard": 21.39316499233246, "my_standard": 23.04024249315262, "standard_compiled": 19.431106746196747, "sb": 18.732361495494843}
{"repeat": 32, "batch_size": 65536, "dim": 2048, "attn": 15.235155820846558, "ln": 12.684382498264313, "ln_compiled": 12.895286083221436, "gelu": 8.228868246078491, "gelu_compiled": 8.242718875408173, "standard": 42.55136102437973, "my_standard": 45.82635313272476, "standard_compiled": 38.663335144519806, "sb": 36.76284849643707}
{"repeat": 32, "batch_size": 131072, "dim": 2048, "attn": 30.24454414844513, "ln": 25.25731921195984, "ln_compiled": 25.67601203918457, "gelu": 16.384944319725037, "gelu_compiled": 16.409948468208313, "standard": 84.26841348409653, "my_standard": 91.10662341117859, "standard_compiled": 76.89539343118668, "sb": 71.73164188861847}
import json
import time
import torch
import torch.nn as nn
import bitsandbytes.nn as bnn
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear, MyLinear
from bitsandbytes.nn.triton_utils.v0.quantize_rowwise_nogroup import quantize_rowwise_nogroup
from bitsandbytes.nn.triton_utils.v0.quantize_columnwise_nogroup_transpose import quantize_columnwise_nogroup_transpose
from bitsandbytes.nn.triton_utils.v0.int8_matmul_rowwise_dequantize_bias import int8_matmul_rowwise_dequantize_bias
from bitsandbytes.nn.triton_utils.v0.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
from bitsandbytes.nn.triton_utils.v0.quantize_global import quantize_global, quantize_global_transpose
from bitsandbytes.nn.triton_utils.v0.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze, int8_matmul_mixed_dequanitze_bias
# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large.
# not that big of an issue.
def get_time_standard_fwd(k, v):
x = torch.randn(batch_size, dim_in, dtype=torch.float16).cuda()
g = torch.randn(batch_size, dim_out, dtype=torch.float16).cuda()
##### time matmul 1
for _ in range(repeat // 2):
g.t().matmul(x)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
g.t().matmul(x)
torch.cuda.synchronize()
end = time.time()
print(f"time {k}: {(end - start) / repeat * 1000:.3f} ms")
return (end - start) / repeat * 1000
if __name__ == '__main__':
torch.manual_seed(0)
#for (dim, wm) in [(1024, 4), (1280, 4), (1408, 4.3637), (1664, 4.9231), (2048, 4), (4096, 4), (8096, 4)]
for (dim, wm) in [(1408, 4), (1664, 4),]:
for batch_size in [256*32, 256*64, 256*128, 256*256, 256*512]:
#for batch_size in [256*256, 256*512]:
for switch in [False, True]:
# hparams
repeat = 64
batch_size = batch_size
dim_out = dim * wm
dim_in = dim
if switch:
dim_out = dim
dim_in = wm * dim
dim_in = round(dim_in)
dim_out = round(dim_out)
# simulate forward pass
x = torch.randn(batch_size, dim_in, dtype=torch.float16).cuda()
g = torch.randn(batch_size, dim_out, dtype=torch.float16).cuda()
w = torch.randn(dim_out, dim_in, dtype=torch.float16).cuda()
x_int8 = x.clone().to(torch.int8)
g_int8 = g.clone().to(torch.int8)
w_int8 = w.clone().to(torch.int8)
wt_int8 = w.t().contiguous().clone().to(torch.int8)
state_x_rowwise = x.max(dim=1)[0]
state_g_rowwise = g.max(dim=1)[0]
state_w_columnwise = w.max(dim=0)[0]
state_w_rowwise = w.max(dim=1)[0]
state_w_global = w.max()
info = {'repeat' : repeat, 'batch_size' : batch_size, 'dim_out' : dim_out, 'dim_in' : dim_in, 'wm' : wm, 'switch' : switch}
k = 'standard_fwd'
for _ in range(repeat // 2):
x.matmul(w.t())
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
x.matmul(w.t())
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'standard_gw'
for _ in range(repeat // 2):
g.t().matmul(x)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
g.t().matmul(x)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'standard_gx'
for _ in range(repeat // 2):
g.matmul(w)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
g.matmul(w)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'rowwise_fwd'
for _ in range(repeat // 2):
int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'rowwise_bwd'
for _ in range(repeat // 2):
int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'global_fwd'
for _ in range(repeat // 2):
int8_matmul_mixed_dequanitze(x_int8, w_int8.t(), state_x_rowwise, state_w_global)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
int8_matmul_mixed_dequanitze(x_int8, w_int8.t(), state_x_rowwise, state_w_global)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'global_bwd'
for _ in range(repeat // 2):
int8_matmul_mixed_dequanitze(g_int8, wt_int8.t(), state_x_rowwise, state_w_global)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
int8_matmul_mixed_dequanitze(g_int8, wt_int8.t(), state_x_rowwise, state_w_global)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'x_quantize_rowwise'
for _ in range(repeat // 2):
quantize_rowwise_nogroup(x)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
quantize_rowwise_nogroup(x)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'g_quantize_rowwise'
for _ in range(repeat // 2):
quantize_rowwise_nogroup(g)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
quantize_rowwise_nogroup(g)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'w_quantize_rowwise'
for _ in range(repeat // 2):
quantize_rowwise_nogroup(w)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
quantize_rowwise_nogroup(w)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'w_quantize_colwise_transpose'
for _ in range(repeat // 2):
quantize_columnwise_nogroup_transpose(w)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
quantize_columnwise_nogroup_transpose(w)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'w_quantize_global'
for _ in range(repeat // 2):
quantize_global(w)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
quantize_global(w)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'w_quantize_global_transpose'
for _ in range(repeat // 2):
quantize_global_transpose(w)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
quantize_global_transpose(w)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'cast_x'
for _ in range(repeat // 2):
newx = x.to(torch.int8)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
newx = x.to(torch.int8)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'cast_g'
for _ in range(repeat // 2):
newx = g.to(torch.int8)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
newx = g.to(torch.int8)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'cast_w'
for _ in range(repeat // 2):
newx = w.to(torch.int8)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
newx = w.to(torch.int8)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
time_standard = info['standard_fwd'] + info['standard_gx'] + info['standard_gw']
time_rowwise = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_colwise_transpose'] + info['w_quantize_rowwise'] + info['standard_gw'] + info['rowwise_fwd'] + info['rowwise_bwd']
time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd']
print('TOTAL STANDARD', time_standard)
print('TOTAL ROWWISE', time_rowwise)
print('TOTAL GLOBAL', time_global)
print('speedup', -100*(time_global - time_standard)/time_standard)
info['time_standard'] = time_standard
info['time_rowwise'] = time_rowwise
info['time_global'] = time_global
info_json = json.dumps(info)
with open("tests/triton_tests/info.jsonl", "a") as file:
file.write(info_json + "\n")
\ No newline at end of file
This diff is collapsed.
{"repeat": 32, "batch_size": 16384, "dim": 1024, "standard": 3.807276487350464, "my_standard": 4.196919500827789, "standard_compiled": 3.771558403968811, "sb": 3.5132691264152527}
{"repeat": 32, "batch_size": 32768, "dim": 1024, "standard": 7.215872406959534, "my_standard": 7.991522550582886, "standard_compiled": 7.241688668727875, "sb": 6.581142544746399}
{"repeat": 32, "batch_size": 65536, "dim": 1024, "standard": 14.26444947719574, "my_standard": 15.685759484767914, "standard_compiled": 14.251746237277985, "sb": 12.735314667224884}
{"repeat": 32, "batch_size": 131072, "dim": 1024, "standard": 28.49559485912323, "my_standard": 31.26966953277588, "standard_compiled": 28.414390981197357, "sb": 25.319166481494904}
{"repeat": 32, "batch_size": 16384, "dim": 1280, "standard": 5.887262523174286, "my_standard": 6.132654845714569, "standard_compiled": 5.902409553527832, "sb": 4.947789013385773}
{"repeat": 32, "batch_size": 32768, "dim": 1280, "standard": 11.14131510257721, "my_standard": 12.859955430030823, "standard_compiled": 11.133037507534027, "sb": 9.303092956542969}
{"repeat": 32, "batch_size": 65536, "dim": 1280, "standard": 22.193141281604767, "my_standard": 25.66336840391159, "standard_compiled": 22.22583442926407, "sb": 18.285617232322693}
{"repeat": 32, "batch_size": 131072, "dim": 1280, "standard": 44.23898458480835, "my_standard": 51.30268633365631, "standard_compiled": 44.08355802297592, "sb": 35.999126732349396}
{"repeat": 32, "batch_size": 16384, "dim": 1408, "standard": 6.938718259334564, "my_standard": 7.269218564033508, "standard_compiled": 6.94604218006134, "sb": 5.764961242675781}
{"repeat": 32, "batch_size": 32768, "dim": 1408, "standard": 13.04878294467926, "my_standard": 13.742901384830475, "standard_compiled": 13.011425733566284, "sb": 10.774023830890656}
{"repeat": 32, "batch_size": 65536, "dim": 1408, "standard": 26.738539338111877, "my_standard": 27.739346027374268, "standard_compiled": 26.75659954547882, "sb": 21.882005035877228}
{"repeat": 32, "batch_size": 131072, "dim": 1408, "standard": 51.905401051044464, "my_standard": 53.98637801408768, "standard_compiled": 51.8316924571991, "sb": 41.67725890874863}
{"repeat": 32, "batch_size": 16384, "dim": 1664, "standard": 9.233824908733368, "my_standard": 9.619377553462982, "standard_compiled": 9.214423596858978, "sb": 7.557623088359833}
{"repeat": 32, "batch_size": 32768, "dim": 1664, "standard": 17.324909567832947, "my_standard": 17.996780574321747, "standard_compiled": 17.29544997215271, "sb": 14.035224914550781}
{"repeat": 32, "batch_size": 65536, "dim": 1664, "standard": 35.51657497882843, "my_standard": 36.674730479717255, "standard_compiled": 35.43049842119217, "sb": 28.38330715894699}
{"repeat": 32, "batch_size": 131072, "dim": 1664, "standard": 69.0087378025055, "my_standard": 71.56594842672348, "standard_compiled": 68.82885098457336, "sb": 54.01633679866791}
{"repeat": 32, "batch_size": 16384, "dim": 2048, "standard": 12.590140104293823, "my_standard": 13.106442987918854, "standard_compiled": 12.606985867023468, "sb": 10.286301374435425}
{"repeat": 32, "batch_size": 32768, "dim": 2048, "standard": 24.830535054206848, "my_standard": 25.563716888427734, "standard_compiled": 24.895809590816498, "sb": 19.559212028980255}
{"repeat": 32, "batch_size": 65536, "dim": 2048, "standard": 49.55078661441803, "my_standard": 51.16480588912964, "standard_compiled": 49.739621579647064, "sb": 38.29141706228256}
{"repeat": 32, "batch_size": 131072, "dim": 2048, "standard": 98.36294502019882, "my_standard": 102.69322991371155, "standard_compiled": 98.76712411642075, "sb": 75.88706165552139}
{"repeat": 32, "batch_size": 16384, "dim": 1024, "standard": 4.91420179605484, "my_standard": 5.577877163887024, "standard_compiled": 4.810944199562073, "sb": 4.512995481491089}
{"repeat": 32, "batch_size": 32768, "dim": 1024, "standard": 8.876129984855652, "my_standard": 10.154612362384796, "standard_compiled": 8.820965886116028, "sb": 8.367843925952911}
{"repeat": 32, "batch_size": 65536, "dim": 1024, "standard": 17.47015118598938, "my_standard": 19.857674837112427, "standard_compiled": 17.338842153549194, "sb": 15.992552042007446}
{"repeat": 32, "batch_size": 131072, "dim": 1024, "standard": 34.824438393116, "my_standard": 39.499424397945404, "standard_compiled": 34.56207364797592, "sb": 31.573951244354248}
{"repeat": 32, "batch_size": 16384, "dim": 1280, "standard": 7.342606782913208, "my_standard": 7.9323723912239075, "standard_compiled": 7.279552519321442, "sb": 6.395488977432251}
{"repeat": 32, "batch_size": 32768, "dim": 1280, "standard": 13.69999349117279, "my_standard": 16.0503089427948, "standard_compiled": 13.603456318378448, "sb": 11.813104152679443}
{"repeat": 32, "batch_size": 65536, "dim": 1280, "standard": 29.557034373283386, "my_standard": 34.2303067445755, "standard_compiled": 29.382556676864624, "sb": 22.882774472236633}
{"repeat": 32, "batch_size": 131072, "dim": 1280, "standard": 53.629085421562195, "my_standard": 63.07622790336609, "standard_compiled": 53.33048850297928, "sb": 44.76426541805267}
{"repeat": 32, "batch_size": 16384, "dim": 1408, "standard": 8.81417840719223, "my_standard": 9.477965533733368, "standard_compiled": 8.73943418264389, "sb": 7.479414343833923}
{"repeat": 32, "batch_size": 32768, "dim": 1408, "standard": 16.242466866970062, "my_standard": 17.616644501686096, "standard_compiled": 16.14125818014145, "sb": 13.665586709976196}
{"repeat": 32, "batch_size": 65536, "dim": 1408, "standard": 32.429613173007965, "my_standard": 34.80646014213562, "standard_compiled": 32.319076359272, "sb": 27.123987674713135}
{"repeat": 32, "batch_size": 131072, "dim": 1408, "standard": 62.85770237445831, "my_standard": 67.55391508340836, "standard_compiled": 62.453076243400574, "sb": 51.53566598892212}
{"repeat": 32, "batch_size": 16384, "dim": 1664, "standard": 11.585861444473267, "my_standard": 12.565858662128448, "standard_compiled": 11.504307389259338, "sb": 9.657211601734161}
{"repeat": 32, "batch_size": 32768, "dim": 1664, "standard": 21.261662244796753, "my_standard": 22.771358489990234, "standard_compiled": 21.12410217523575, "sb": 17.64291524887085}
{"repeat": 32, "batch_size": 65536, "dim": 1664, "standard": 42.85307973623276, "my_standard": 45.70870101451874, "standard_compiled": 42.57970303297043, "sb": 34.918561577796936}
{"repeat": 32, "batch_size": 131072, "dim": 1664, "standard": 83.56057852506638, "my_standard": 89.11971747875214, "standard_compiled": 83.05662125349045, "sb": 66.32210314273834}
{"repeat": 32, "batch_size": 16384, "dim": 2048, "standard": 15.7279372215271, "my_standard": 16.854502260684967, "standard_compiled": 15.655294060707092, "sb": 13.228952884674072}
{"repeat": 32, "batch_size": 32768, "dim": 2048, "standard": 30.42648732662201, "my_standard": 32.26502239704132, "standard_compiled": 30.239209532737732, "sb": 24.354808032512665}
{"repeat": 32, "batch_size": 65536, "dim": 2048, "standard": 60.779355466365814, "my_standard": 64.11923468112946, "standard_compiled": 60.89268624782562, "sb": 46.91776633262634}
{"repeat": 32, "batch_size": 131072, "dim": 2048, "standard": 119.93677169084549, "my_standard": 128.19699943065643, "standard_compiled": 120.20225822925568, "sb": 92.3452153801918}
{"repeat": 32, "batch_size": 16384, "dim": 1024, "standard": 5.171686410903931, "my_standard": 5.839601159095764, "standard_compiled": 5.032263696193695, "sb": 4.89344447851181}
{"repeat": 32, "batch_size": 32768, "dim": 1024, "standard": 9.605035185813904, "my_standard": 10.910414159297943, "standard_compiled": 9.230785071849823, "sb": 9.128175675868988}
{"repeat": 32, "batch_size": 65536, "dim": 1024, "standard": 18.802084028720856, "my_standard": 21.311581134796143, "standard_compiled": 18.105976283550262, "sb": 17.489850521087646}
{"repeat": 32, "batch_size": 131072, "dim": 1024, "standard": 37.49683499336243, "my_standard": 42.40527004003525, "standard_compiled": 36.13145649433136, "sb": 34.58733111619949}
{"repeat": 32, "batch_size": 16384, "dim": 1280, "standard": 7.709823548793793, "my_standard": 8.290477097034454, "standard_compiled": 7.564418017864227, "sb": 6.8823546171188354}
{"repeat": 32, "batch_size": 32768, "dim": 1280, "standard": 14.64156061410904, "my_standard": 16.996942460536957, "standard_compiled": 14.4081711769104, "sb": 12.761622667312622}
{"repeat": 32, "batch_size": 65536, "dim": 1280, "standard": 31.40200674533844, "my_standard": 36.074504256248474, "standard_compiled": 30.981406569480896, "sb": 24.76389706134796}
{"repeat": 32, "batch_size": 131072, "dim": 1280, "standard": 56.93405121564865, "my_standard": 66.35250151157379, "standard_compiled": 56.07586354017258, "sb": 48.49743843078613}
{"repeat": 32, "batch_size": 16384, "dim": 1408, "standard": 9.188003838062286, "my_standard": 9.84550267457962, "standard_compiled": 9.006097912788391, "sb": 7.9473331570625305}
{"repeat": 32, "batch_size": 32768, "dim": 1408, "standard": 17.268165946006775, "my_standard": 18.64910125732422, "standard_compiled": 16.983114182949066, "sb": 14.70106840133667}
{"repeat": 32, "batch_size": 65536, "dim": 1408, "standard": 34.39047932624817, "my_standard": 36.69705241918564, "standard_compiled": 33.8401272892952, "sb": 29.188089072704315}
{"repeat": 32, "batch_size": 131072, "dim": 1408, "standard": 66.70494377613068, "my_standard": 71.27603143453598, "standard_compiled": 65.56134670972824, "sb": 55.6538850069046}
{"repeat": 32, "batch_size": 16384, "dim": 1664, "standard": 12.10707426071167, "my_standard": 12.931793928146362, "standard_compiled": 11.76995038986206, "sb": 10.228671133518219}
{"repeat": 32, "batch_size": 32768, "dim": 1664, "standard": 22.5130096077919, "my_standard": 23.962542414665222, "standard_compiled": 21.997176110744476, "sb": 18.89890432357788}
{"repeat": 32, "batch_size": 65536, "dim": 1664, "standard": 45.210108160972595, "my_standard": 47.94136434793472, "standard_compiled": 44.2262664437294, "sb": 37.37735003232956}
{"repeat": 32, "batch_size": 131072, "dim": 1664, "standard": 88.1955549120903, "my_standard": 93.6831533908844, "standard_compiled": 86.33609116077423, "sb": 71.23208791017532}
{"repeat": 32, "batch_size": 16384, "dim": 2048, "standard": 16.538940370082855, "my_standard": 17.607316374778748, "standard_compiled": 16.108587384223938, "sb": 14.030493795871735}
{"repeat": 32, "batch_size": 32768, "dim": 2048, "standard": 31.795650720596313, "my_standard": 33.57230871915817, "standard_compiled": 31.04180097579956, "sb": 25.971196591854095}
{"repeat": 32, "batch_size": 65536, "dim": 2048, "standard": 63.021354377269745, "my_standard": 66.8477788567543, "standard_compiled": 61.682507395744324, "sb": 50.138771533966064}
{"repeat": 32, "batch_size": 131072, "dim": 2048, "standard": 125.17062574625015, "my_standard": 133.60925763845444, "standard_compiled": 122.21191823482513, "sb": 98.40084612369537}
{"repeat": 32, "batch_size": 16384, "dim": 4096, "standard": 57.31645971536636, "my_standard": 60.84543466567993, "standard_compiled": 55.78199774026871, "sb": 45.43223977088928}
{"repeat": 32, "batch_size": 32768, "dim": 4096, "standard": 111.80306226015091, "my_standard": 119.0284714102745, "standard_compiled": 108.91905426979065, "sb": 85.4572057723999}
{"repeat": 32, "batch_size": 65536, "dim": 4096, "standard": 220.4471081495285, "my_standard": 233.0927476286888, "standard_compiled": 214.26431089639664, "sb": 163.30372542142868}
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
import matplotlib.gridspec as gridspec
cmap=plt.get_cmap('cool')
if __name__ == '__main__':
fig = plt.figure(tight_layout=True, figsize=(12,3.5))
gs = gridspec.GridSpec(1, 2)
ax = fig.add_subplot(gs[0, 0])
rdf = pd.read_json('tests/triton_tests/info.jsonl', lines=True)
df = rdf[rdf.batch_size == 32768]
for k, marker, ls, color, name in [
('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (sum of parts)'),
('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (sum of parts)'),
('standard_fwd', '^', '--', 'C2', 'Matmul XW (standard)'),
('standard_gw', '^', '-.', 'C2', 'Matmul GW (standard)'),
('standard_gx', '^', ':', 'gray', 'Matmul GX (both)'),
('global_fwd', '^', '--', 'C4', 'Int8 Matmul XW (switchback)'),
('global_bwd', '^', '-.', 'C4', 'Int8 Matmul GW (switchback)'),
#### time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd']
('x_quantize_rowwise', 'P', '--', 'C4', 'Quantize rowwise X (switchback)'),
('g_quantize_rowwise', 'P', '-.', 'C4', 'Quantize rowwise G (switchback)'),
('w_quantize_global', '.', '--', 'C4', 'Quatnize global W (switchback)'),
('w_quantize_global_transpose', '.', '-.', 'C4', 'Quantize gloabl and\ntranspose W (switchback)'),
#('standard_gw', '.', '--', 'C1', 'standard_gw'),
]:
xs = []
ys = []
for embed_dim in [1024, 1280, 1408, 1664, 2048, 4096]:
df_ = df[df.dim_in == embed_dim]
df_ = df_[df_.dim_out == embed_dim * 4]
xs.append(embed_dim)
y_ = 0
for k_ in k.split('+'):
y_ += df_[k_].values[0]
df_ = df[df.dim_in == embed_dim * 4]
df_ = df_[df_.dim_out == embed_dim]
for k_ in k.split('+'):
y_ += df_[k_].values[0]
ys.append(y_ * 0.5)
ax.plot(xs, ys, color=color, label=name, marker=marker, markersize=5 if marker=='s' else 5, linestyle=ls, linewidth=2 if '+' in k else 1.)
ax.set_xlabel('dim', fontsize=13)
ax.set_ylabel('time (ms)', fontsize=13)
# make a legend which is below the plot
ax.grid()
ax.set_xscale('log')
#ax.set_yscale('log')
ax.tick_params(axis='x', labelsize=11)
ax.tick_params(axis='y', labelsize=11)
ax.set_xticks([1024, 2048, 4096])
ax.set_xticklabels([1024, 2048, 4096])
ax.set_xticks([], minor=True)
leg = ax.legend(loc='upper center', bbox_to_anchor=(-0.64, 1.), ncol=1, fontsize=10)
leg.get_texts()[0].set_fontweight('bold')
leg.get_texts()[1].set_fontweight('bold')
plt.subplots_adjust(left=0.1)
ax.set_title(' Linear layer, batch * sequence length = 32k', fontsize=10, loc='left', y=1.05, pad=-20)
ax = fig.add_subplot(gs[0, 1])
# now plot the % speedup for different batch sizes
for j, batch_size in enumerate([2**14, 2**15, 2**16, 2**17]):
all_xs, all_ys = [], []
for k, marker, ls, color, name in [
('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (total time)'),
('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (total time)'),
]:
xs, ys = [], []
df = rdf[rdf.batch_size == batch_size]
for embed_dim in [1024, 1280, 1408, 1664, 2048, 4096]:
df_ = df[df.dim_in == embed_dim]
df_ = df_[df_.dim_out == embed_dim * 4]
xs.append(embed_dim)
y_ = 0
for k_ in k.split('+'):
y_ += df_[k_].values[0]
df_ = df[df.dim_in == embed_dim * 4]
df_ = df_[df_.dim_out == embed_dim]
for k_ in k.split('+'):
y_ += df_[k_].values[0]
ys.append(y_ * 0.5)
all_xs.append(xs)
all_ys.append(ys)
color = cmap(j * 0.25)
real_ys = [-((all_ys[1][i] - all_ys[0][i]) / all_ys[0][i]) * 100 for i in range(len(all_ys[0]))]
markers = ['^', 'v', 'P', 'o']
ax.plot(all_xs[0], real_ys, color=color, label=f'batch * sequence length = {batch_size}', marker=markers[j], markersize=5 if marker=='s' else 5)
ax.legend()
ax.set_xlabel('dim', fontsize=13)
ax.set_xscale('log')
ax.grid()
ax.set_ylabel(r'% speedup', fontsize=13)
ax.tick_params(axis='x', labelsize=11)
ax.tick_params(axis='y', labelsize=11)
ax.set_xticks([1024, 2048, 4096])
ax.set_xticklabels([1024, 2048, 4096])
ax.set_xticks([], minor=True)
ax.set_title(' Linear layer summary, varying dimensions', fontsize=10, loc='left', y=1.05, pad=-20)
plt.savefig('tests/triton_tests/plot1.pdf', bbox_inches='tight')
import time
import torch
import torch.nn as nn
import bitsandbytes.nn as bnn
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear, MyLinear
def construct_model(dim, layers, module):
modules = []
for _ in range(layers):
modules.append(module(dim, 4*dim))
modules.append(module(4*dim, dim))
return nn.Sequential(*modules).cuda().train()
def get_time(model, x, name):
for _ in range(repeat // 2):
#with torch.cuda.amp.autocast():
out = model(x)
#(2**16 * out.pow(2).mean()).backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
# with torch.cuda.amp.autocast():
out = model(x)
#(2**16 * out.pow(2).mean()).backward()
torch.cuda.synchronize()
end = time.time()
print(f"time {name}: {(end - start) / repeat * 1000:.3f} ms")
if __name__ == '__main__':
torch.manual_seed(0)
# hparams
repeat = 16
dim=2048
layers =4
batch_size = 2
sequence_length = 2**15
# construct models
standard = construct_model(dim, layers, nn.Linear).half()
my_standard = construct_model(dim, layers, MyLinear).half()
switchback = construct_model(dim, layers, SwitchBackLinear).half()
switchback_global = construct_model(dim, layers, SwitchBackGlobalLinear).half()
#bnb_8bitmixed = construct_model(dim, layers, bnn.Linear8bitLt)
# simulate forward pass
x = torch.randn(batch_size * sequence_length, dim, dtype=torch.float16).cuda()
# get time for forward and backward
get_time(standard, x, "standard")
get_time(my_standard, x, "my_standard")
get_time(switchback, x, "switchback")
get_time(switchback_global, x, "switchback_global")
#get_time(bnb_8bitmixed, x, "bnb_8bitmixed")
\ No newline at end of file
import torch
import json
from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, MyLinear
import time
if __name__ == '__main__':
print('Startin')
for dim in [1024, 1280, 1408, 1664, 2048]:
for batch in [2**14, 2**15, 2**16, 2**17]:
if dim != 4096 or batch != 2**17:
continue
x1 = torch.randn(batch, dim).cuda().requires_grad_(True)
d = 2
standard = torch.nn.Sequential(
torch.nn.Linear(dim, 4 * dim),
torch.nn.GELU(),
torch.nn.Linear(4 * dim, dim),
).cuda()
my_standard = torch.nn.Sequential(
MyLinear(dim, 4 * dim),
torch.nn.GELU(),
MyLinear(4 * dim, dim),
).cuda()
fused_mlp = SwitchBackGlobalMLP(dim, 4 * dim).cuda()
sb = torch.nn.Sequential(
SwitchBackGlobalLinear(dim, 4 * dim),
torch.nn.GELU(),
SwitchBackGlobalLinear(4 * dim, dim),
).cuda()
standard_compiled = torch.compile(standard)
print('Model part 2')
repeat = 32
info = {'repeat' : repeat, 'batch_size' : batch, 'dim' : dim}
# k = 'standard'
# for _ in range(repeat // 2):
# with torch.cuda.amp.autocast():
# out_standard = standard(x1)
# ((2 ** 16) * out_standard).abs().mean().backward()
# torch.cuda.synchronize()
# start = time.time()
# for _ in range(repeat):
# with torch.cuda.amp.autocast():
# out_standard = standard(x1)
# ((2 ** 16) * out_standard).abs().mean().backward()
# torch.cuda.synchronize()
# end = time.time()
# ms = (end - start) / repeat * 1000
# print(f"time {k}: {ms:.3f} ms")
# info[k] = ms
# x1.grad.zero_()
# k = 'my_standard'
# for _ in range(repeat // 2):
# with torch.cuda.amp.autocast():
# out_my_standard = my_standard(x1)
# ((2 ** 16) * out_my_standard).abs().mean().backward()
# torch.cuda.synchronize()
# start = time.time()
# for _ in range(repeat):
# with torch.cuda.amp.autocast():
# out_my_standard = my_standard(x1)
# ((2 ** 16) * out_my_standard).abs().mean().backward()
# torch.cuda.synchronize()
# end = time.time()
# ms = (end - start) / repeat * 1000
# print(f"time {k}: {ms:.3f} ms")
# info[k] = ms
# x1.grad.zero_()
# k = 'standard_compiled'
# for _ in range(repeat // 2):
# with torch.cuda.amp.autocast():
# out_standard_compiled = standard_compiled(x1)
# ((2 ** 16) * out_standard_compiled).abs().mean().backward()
# torch.cuda.synchronize()
# start = time.time()
# for _ in range(repeat):
# with torch.cuda.amp.autocast():
# out_standard_compiled = standard_compiled(x1)
# ((2 ** 16) * out_standard_compiled).abs().mean().backward()
# torch.cuda.synchronize()
# end = time.time()
# ms = (end - start) / repeat * 1000
# print(f"time {k}: {ms:.3f} ms")
# info[k] = ms
# x1.grad.zero_()
k = 'sb'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out_sb = sb(x1)
((2 ** 16) * out_sb).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out_sb = sb(x1)
((2 ** 16) * out_sb).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
info_json = json.dumps(info)
with open("tests/triton_tests/info_mlp_autocast.jsonl", "a") as file:
file.write(info_json + "\n")
#exit()
# err_fused = (out_standard - out_fused).abs().mean()
# err_sb = (out_standard - out_sb).abs().mean()
# print('OUT', err_fused, err_sb)
# err_fused = (standard[d].weight.grad - fused_mlp.linear2.weight.grad).abs().mean()
# err_sb = (standard[d].weight.grad - sb[d].weight.grad).abs().mean()
# print('GW2', err_fused, err_sb)
# err_fused = (standard[0].weight.grad - fused_mlp.linear1.weight.grad).abs().mean()
# err_sb = (standard[0].weight.grad - sb[0].weight.grad).abs().mean()
# print('GW1', err_fused, err_sb)
# err_fused = (x1.grad - x2.grad).abs().mean()
# err_sb = (x1.grad - x3.grad).abs().mean()
# print('GX1', err_fused, err_sb)
# import pdb; pdb.set_trace()
# # NO GELU, ST GRADIENTS, EVERYTHING FINE.
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment