Commit 7f87ba83 authored by Mitchell Wortsman's avatar Mitchell Wortsman
Browse files

cleaning and refactor

parent 30d21d58
...@@ -3,4 +3,4 @@ ...@@ -3,4 +3,4 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8, LinearInt8, Linear8bitLtThresh, LinearInt8Cast, Linear8bitLt2, Linear8bitLtMixed, LinearFP8Global, LinearFP4, LinearFP8Mixed from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8, LinearInt8, Linear8bitLtThresh, LinearInt8Cast, Linear8bitLt2, Linear8bitLtMixed, LinearFP8Global, LinearFP4, LinearFP8Mixed
from .triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear from .triton_based_modules import SwitchBackLinear, SwitchBackLinearGlobal, SwitchBackLinearVectorized, StandardLinear
import torch import torch
import torch.nn as nn import torch.nn as nn
import time import time
from functools import partial
from .triton_utils.v0.quantize_rowwise_nogroup import quantize_rowwise_nogroup from .triton_utils.v0.quantize_rowwise import quantize_rowwise
from .triton_utils.v0.quantize_columnwise_nogroup_transpose import quantize_columnwise_nogroup_transpose from .triton_utils.v0.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
from .triton_utils.v0.int8_matmul_rowwise_dequantize_bias import int8_matmul_rowwise_dequantize_bias
from .triton_utils.v0.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize from .triton_utils.v0.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
from .triton_utils.v0.quantize_global import quantize_global, quantize_global_transpose from .triton_utils.v0.quantize_global import quantize_global, quantize_global_transpose
from .triton_utils.v0.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze, int8_matmul_mixed_dequanitze_bias from .triton_utils.v0.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze
from .triton_utils.v0.fused_gelu_quantize import quantize_rowwise_nogroup_gelu, quantize_rowwise_nogroup_back_gelu
class _switchback(torch.autograd.Function):
class _switchback_global(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, X_3D, W, bias): def forward(ctx, X_3D, W, bias):
# reshape input to [N * L, D]
X = X_3D.view(-1, X_3D.size(-1)) X = X_3D.view(-1, X_3D.size(-1))
# rowwise quantize for X, global quantize for W
X_int8, state_X = quantize_rowwise(X)
W_int8, state_W = quantize_global(W)
# save for backward.
ctx.save_for_backward = X, W ctx.save_for_backward = X, W
X_int8, state_X = quantize_rowwise_nogroup(X)
W_int8, state_W = quantize_rowwise_nogroup(W) # matmult, fused dequant and add bias
return int8_matmul_rowwise_dequantize_bias( # call "mixed" because we are mixing rowwise quantized and global quantized
return int8_matmul_mixed_dequanitze(
X_int8, W_int8.t(), state_X, state_W, bias X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D.size()[:-1], -1) ).view(*X_3D.size()[:-1], -1)
@staticmethod @staticmethod
def backward(ctx, G_3D): def backward(ctx, G_3D):
X, W = ctx.save_for_backward # reshape input to [N_out * L, D]
G = G_3D.reshape(-1, G_3D.size(-1)) G = G_3D.reshape(-1, G_3D.size(-1))
grad_X = grad_W = grad_bias = None grad_X = grad_W = grad_bias = None
X, W = ctx.save_for_backward
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
G_int8, state_G = quantize_rowwise_nogroup(G) # rowwise quantize for G, global quantize for W
W_int8, state_W = quantize_columnwise_nogroup_transpose(W) # for W, we also fuse the transpose operation because only A @ B^T is supported
grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W).view( # so we transpose once then call .t() in the matmul
G_int8, state_G = quantize_rowwise(G)
W_int8, state_W = quantize_global_transpose(W)
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D.size()[:-1], -1 *G_3D.size()[:-1], -1
) )
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
# backward pass uses standard weight grad
grad_W = torch.matmul(G.t(), X.to(G.dtype)) grad_W = torch.matmul(G.t(), X.to(G.dtype))
if ctx.needs_input_grad[2]: if ctx.needs_input_grad[2]:
grad_bias = G.sum(dim=0) grad_bias = G.sum(dim=0)
return grad_X, grad_W, grad_bias return grad_X, grad_W, grad_bias
class SwitchBackLinear(nn.Linear): class _switchback_vectorrize(torch.autograd.Function):
def prepare_for_eval(self):
state_W = self.weight.abs().max(dim=1, keepdim=True)[0]
W_int8 = (127 * self.weight.float() / state_W).round().to(torch.int8)
state_W = state_W.squeeze()
self.register_buffer("W_int8", W_int8)
self.register_buffer("state_W", state_W)
del self.weight
def forward(self, x):
if self.training:
return _switchback.apply(x, self.weight, self.bias)
else:
if not hasattr(self, "state_W"):
self.prepare_for_eval()
X = x.view(-1, x.size(-1))
X_int8, state_X = quantize_rowwise_nogroup(X)
return int8_matmul_rowwise_dequantize_bias(
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
).view(*x.size()[:-1], -1)
class _switchback_global(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, X_3D, W, bias): def forward(ctx, X_3D, W, bias):
# reshape input to [N * L, D]
X = X_3D.view(-1, X_3D.size(-1)) X = X_3D.view(-1, X_3D.size(-1))
X_int8, state_X = quantize_rowwise_nogroup(X)
W_int8, state_W = quantize_global(W)
ctx.save_for_backward = X, W ctx.save_for_backward = X, W
return int8_matmul_mixed_dequanitze_bias( # rowwise quantize for X
# columnwise quantize for W (first rowwise, transpose later)
X_int8, state_X = quantize_rowwise(X)
W_int8, state_W = quantize_rowwise(W)
# matmult, fused dequant and add bias
# call kernel which expects rowwise quantized X and W
return int8_matmul_rowwise_dequantize(
X_int8, W_int8.t(), state_X, state_W, bias X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D.size()[:-1], -1) ).view(*X_3D.size()[:-1], -1)
@staticmethod @staticmethod
def backward(ctx, G_3D): def backward(ctx, G_3D):
X, W = ctx.save_for_backward
G = G_3D.reshape(-1, G_3D.size(-1)) G = G_3D.reshape(-1, G_3D.size(-1))
grad_X = grad_W = grad_bias = None grad_X = grad_W = grad_bias = None
X, W = ctx.save_for_backward
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
G_int8, state_G = quantize_rowwise_nogroup(G) # rowwise quantize for G, columnwise quantize for W and fused transpose
W_int8, state_W = quantize_global_transpose(W) # we call .t() for weight later because only A @ B^T is supported
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W).view( G_int8, state_G = quantize_rowwise(G)
W_int8, state_W = quantize_columnwise_and_transpose(W)
grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D.size()[:-1], -1 *G_3D.size()[:-1], -1
) )
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
# backward pass uses standard weight grad
grad_W = torch.matmul(G.t(), X.to(G.dtype)) grad_W = torch.matmul(G.t(), X.to(G.dtype))
if ctx.needs_input_grad[2]: if ctx.needs_input_grad[2]:
grad_bias = G.sum(dim=0) grad_bias = G.sum(dim=0)
return grad_X, grad_W, grad_bias return grad_X, grad_W, grad_bias
class SwitchBackLinear(nn.Linear):
def __init__(
class SwitchBackGlobalLinear(nn.Linear): self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
vectorize: bool = False
):
super().__init__(in_features, out_features, bias, device, dtype)
# By default, we use the global quantization.
self.vectorize = vectorize
if self.vectorize:
self._fn = _switchback_vectorrize
else:
self._fn = _switchback_global
def prepare_for_eval(self): def prepare_for_eval(self):
state_W = self.weight.abs().max() # If we just want to do eval, we can pre-quantize the weights instead of doing it on the forward pass.
W_int8 = (127 * self.weight.float() / state_W).round().to(torch.int8) # Note this is experimental and not tested thoroughly.
# Note this needs to be explicitly called with something like
# def cond_prepare(m):
# if hasattr(m, "prepare_for_eval"):
# m.prepare_for_eval()
# model.apply(cond_prepare)
print('=> preparing for eval.')
if self.vectorize:
W_int8, state_W = quantize_rowwise(self.weight)
else:
W_int8, state_W = quantize_global(self.weight)
self.register_buffer("W_int8", W_int8) self.register_buffer("W_int8", W_int8)
self.register_buffer("state_W", state_W) self.register_buffer("state_W", state_W)
...@@ -120,19 +138,29 @@ class SwitchBackGlobalLinear(nn.Linear): ...@@ -120,19 +138,29 @@ class SwitchBackGlobalLinear(nn.Linear):
def forward(self, x): def forward(self, x):
if self.training: if self.training:
return _switchback_global.apply(x, self.weight, self.bias) return self._fn.apply(x, self.weight, self.bias)
else: else:
if not hasattr(self, "state_W"): # If it hasn't been "prepared for eval", run the standard forward pass.
self.prepare_for_eval() if not hasattr(self, "W_int8"):
return self._fn.apply(x, self.weight, self.bias)
# Otherwise, use pre-computed weights.
X = x.view(-1, x.size(-1)) X = x.view(-1, x.size(-1))
X_int8, state_X = quantize_rowwise_nogroup(X) X_int8, state_X = quantize_rowwise(X)
return int8_matmul_mixed_dequanitze_bias(
if self.vectorize:
return int8_matmul_rowwise_dequantize(
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
).view(*x.size()[:-1], -1)
else:
return int8_matmul_mixed_dequanitze(
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
).view(*x.size()[:-1], -1) ).view(*x.size()[:-1], -1)
SwitchBackLinearGlobal = partial(SwitchBackLinear, vectorize=False)
SwitchBackLinearVectorized = partial(SwitchBackLinear, vectorize=True)
# This is just the standard linear function.
class StandardLinearFunction(torch.autograd.Function): class StandardLinearFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input, weight, bias=None): def forward(ctx, input, weight, bias=None):
......
import math
import torch
import time
import triton
import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
tl.libdevice
# TODO: autotune this better.
@triton.autotune(
configs=[
triton.Config({}, num_stages=1, num_warps=8),
triton.Config({}, num_stages=2, num_warps=8),
triton.Config({}, num_stages=4, num_warps=8),
triton.Config({}, num_stages=8, num_warps=8),
triton.Config({}, num_stages=1),
triton.Config({}, num_stages=2),
triton.Config({}, num_stages=4),
triton.Config({}, num_stages=8),
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=['n_elements']
)
@triton.jit
def _quantize_rowwise_nogroup_gelu(
x_ptr,
output_ptr,
output_maxs,
output_fp16,
n_elements,
BLOCK_SIZE: tl.constexpr,
P2: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
arange = tl.arange(0, P2)
offsets = block_start + arange
row_mask = arange < BLOCK_SIZE
x = tl.load(x_ptr + offsets, mask=row_mask)
cdf = 0.5 * (1.0 + tl.libdevice.tanh(x * 0.7978845608 * (1 + 0.044715 * x * x)))
x_new = x * cdf
tl.store(output_fp16 + offsets, x_new, mask=row_mask)
abs_x = tl.abs(x_new)
max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0)
output = tl.libdevice.llrint(127. * (x_new / max_val))
tl.store(output_ptr + offsets, output, mask=row_mask)
tl.store(output_maxs + pid, max_val)
def quantize_rowwise_nogroup_gelu(x: torch.Tensor):
output = torch.empty(*x.shape, device=x.device, dtype=torch.int8)
output_fp16 = torch.empty(*x.shape, device=x.device, dtype=torch.float16)
output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16)
P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
assert x.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (x.shape[0],)
_quantize_rowwise_nogroup_gelu[grid](x, output, output_maxs, output_fp16, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
return output, output_maxs, output_fp16
# TODO: autotune this better.
@triton.autotune(
configs=[
triton.Config({}, num_stages=1, num_warps=8),
triton.Config({}, num_stages=2, num_warps=8),
triton.Config({}, num_stages=4, num_warps=8),
triton.Config({}, num_stages=8, num_warps=8),
triton.Config({}, num_stages=1),
triton.Config({}, num_stages=2),
triton.Config({}, num_stages=4),
triton.Config({}, num_stages=8),
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=['n_elements']
)
@triton.jit
def _quantize_rowwise_nogroup_back_gelu(
x_ptr,
in_ptr,
output_ptr,
output_maxs,
output_fp16,
n_elements,
BLOCK_SIZE: tl.constexpr,
P2: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
arange = tl.arange(0, P2)
offsets = block_start + arange
row_mask = arange < BLOCK_SIZE
x_out = tl.load(x_ptr + offsets, mask=row_mask)
x_in = tl.load(in_ptr + offsets, mask=row_mask)
cdf = 0.5 * (1.0 + tl.libdevice.tanh(x_in * 0.7978845608 * (1 + 0.044715 * x_in * x_in)))
intermediate = tl.libdevice.tanh(x_in * 0.7978845608 * (1 + 0.044715 * x_in * x_in))
dcdf = 0.5 * (0.7978845608 + 0.1070322243 * x_in * x_in) * (1 - intermediate * intermediate)
x = x_out * (cdf + x_in * dcdf)
tl.store(output_fp16 + offsets, x, mask=row_mask)
abs_x = tl.abs(x)
max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0)
output = tl.libdevice.llrint(127. * (x / max_val))
tl.store(output_ptr + offsets, output, mask=row_mask)
tl.store(output_maxs + pid, max_val)
def quantize_rowwise_nogroup_back_gelu(x: torch.Tensor, y : torch.Tensor):
output = torch.empty(*x.shape, device=x.device, dtype=torch.int8)
output_fp16 = torch.empty(*x.shape, device=x.device, dtype=torch.float16)
output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16)
P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
assert x.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (x.shape[0],)
_quantize_rowwise_nogroup_back_gelu[grid](x, y, output, output_maxs, output_fp16, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
return output, output_maxs, output_fp16
# if __name__ == '__main__':
# torch.manual_seed(0)
# x = torch.randn(1280, 768).cuda().to(torch.float16)
# out = quantize_rowwise_nogroup(x)
# x_real = (127 * x.float() / x.abs().max(dim=1, keepdim=True)[0]).round().to(torch.int8)
# max2 = x.abs().max(1)[0]
# print(torch.allclose(out[1], max2))
# print( (x_real == out[0]).float().mean() )
# # for i in range(x.shape[0]):
# # print( (x_real[i, :] == out[0][i, :]).float().mean() )
# # print(out[0])
# # print(x_real)
# # import pdb; pdb.set_trace()
# # print(out[2])
# # print(out[2][:10])
# sums = x.sum(dim=0)
# #print(sums[:10])
# #print( (sums == out[2]).float().mean() )
# import pdb; pdb.set_trace()
# # import pdb; pdb.set_trace()
# # exit()
# # repeat = 16
# # for _ in range(8):
# # out = quantize_rowwise_nogroup(x)
# # triton_graph = torch.cuda.CUDAGraph()
# # with torch.cuda.graph(triton_graph):
# # out = quantize_rowwise_nogroup(x)
# # triton_graph.replay()
# # torch.cuda.synchronize()
# # start = time.time()
# # for _ in range(repeat):
# # triton_graph.replay()
# # torch.cuda.synchronize()
# # end = time.time()
# # print(out[0])
# # print(out[1])
# # print(x / x.abs().max(dim=1, keepdim=True)[0])
# # max1 = out[1]
# # max2 = x.abs().max(1)[0]
# # print(max1, max2)
# # print(torch.allclose(max1, max2))
# #print(f"time: {(end - start) / repeat * 1000:.3f} ms")
...@@ -5,10 +5,14 @@ import triton.language as tl ...@@ -5,10 +5,14 @@ import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# This is a matmul kernel based on triton.ops.matmul
# It is modified to support rowwise quantized input and global quantized weight
# It's purpose is fused matmul then dequantize
# It does support bias.
def init_to_zero(name): def init_to_zero(name):
return lambda nargs: nargs[name].zero_() return lambda nargs: nargs[name].zero_()
def get_configs_io_bound(): def get_configs_io_bound():
configs = [] configs = []
for num_stages in [2, 3, 4, 5, 6]: for num_stages in [2, 3, 4, 5, 6]:
...@@ -60,130 +64,7 @@ def get_configs_io_bound(): ...@@ -60,130 +64,7 @@ def get_configs_io_bound():
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
}) })
@triton.jit @triton.jit
def _kernel(A, B, C, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr
):
# matrix multiplication
pid = tl.program_id(0)
pid_z = tl.program_id(1)
grid_m = tl.cdiv(M, BLOCK_M)
grid_n = tl.cdiv(N, BLOCK_N)
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
# pointers
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
w_factor = tl.load(state_w_ptr)
x_factor = tl.load(state_x_ptr + ram)[:, None]
# acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
k_remaining = K - k * (BLOCK_K * SPLIT_K)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
acc += tl.dot(a, b)
A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk
acc = (w_factor * (x_factor * (acc * divfactor)))
acc = acc.to(C.dtype.element_ty)
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(C, acc, mask=mask)
else:
tl.atomic_add(C, acc, mask=mask)
def int8_matmul_mixed_dequanitze(a, b, state_x, state_w):
device = a.device
divfactor = 1. / (127. * 127.)
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
M, K = a.shape
_, N = b.shape
# allocates output
c = torch.empty((M, N), device=device, dtype=torch.float16)
# accumulator types
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_kernel[grid](a, b, c, state_x, state_w, M, N, K, divfactor,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
GROUP_M=8, ACC_TYPE=ACC_TYPE)
return c
@triton.autotune(
configs=[
# basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
# good for int8
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
] + get_configs_io_bound(),
key=['M', 'N', 'K'],
prune_configs_by={
'early_config_prune': early_config_prune,
'perf_model': estimate_matmul_time,
'top_k': 10
},
)
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.jit
def _kernel_bias(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr,
stride_am, stride_ak, stride_am, stride_ak,
stride_bk, stride_bn, stride_bk, stride_bn,
stride_cm, stride_cn, stride_cm, stride_cn,
...@@ -236,6 +117,7 @@ def _kernel_bias(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl ...@@ -236,6 +117,7 @@ def _kernel_bias(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl
acc = (w_factor * (x_factor * (acc * divfactor))) acc = (w_factor * (x_factor * (acc * divfactor)))
acc = acc.to(C.dtype.element_ty) acc = acc.to(C.dtype.element_ty)
# conditionally add bias
if has_bias: if has_bias:
bias = tl.load(bias + rn).to(C.dtype.element_ty) bias = tl.load(bias + rn).to(C.dtype.element_ty)
acc = acc + bias[None, :] acc = acc + bias[None, :]
...@@ -249,7 +131,7 @@ def _kernel_bias(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl ...@@ -249,7 +131,7 @@ def _kernel_bias(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl
tl.atomic_add(C, acc, mask=mask) tl.atomic_add(C, acc, mask=mask)
def int8_matmul_mixed_dequanitze_bias(a, b, state_x, state_w, bias): def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias):
device = a.device device = a.device
divfactor = 1. / (127. * 127.) divfactor = 1. / (127. * 127.)
has_bias = 0 if bias is None else 1 has_bias = 0 if bias is None else 1
...@@ -266,9 +148,9 @@ def int8_matmul_mixed_dequanitze_bias(a, b, state_x, state_w, bias): ...@@ -266,9 +148,9 @@ def int8_matmul_mixed_dequanitze_bias(a, b, state_x, state_w, bias):
c = torch.empty((M, N), device=device, dtype=torch.float16) c = torch.empty((M, N), device=device, dtype=torch.float16)
# accumulator types # accumulator types
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch kernel # launch int8_matmul_mixed_dequantize kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_kernel_bias[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, _int8_matmul_mixed_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,
a.stride(0), a.stride(1), a.stride(0), a.stride(1),
b.stride(0), b.stride(1), b.stride(0), b.stride(1),
c.stride(0), c.stride(1), c.stride(0), c.stride(1),
......
...@@ -4,6 +4,10 @@ import triton ...@@ -4,6 +4,10 @@ import triton
import triton.language as tl import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# This is a matmul kernel based on triton.ops.matmul
# It is modified to support rowwise quantized input and columnwise quantized weight
# It's purpose is fused matmul then dequantize
# It does support bias.
def init_to_zero(name): def init_to_zero(name):
return lambda nargs: nargs[name].zero_() return lambda nargs: nargs[name].zero_()
...@@ -60,7 +64,7 @@ def get_configs_io_bound(): ...@@ -60,7 +64,7 @@ def get_configs_io_bound():
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
}) })
@triton.jit @triton.jit
def _kernel(A, B, C, state_x_ptr, state_w_ptr, M, N, K, divfactor, def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr,
stride_am, stride_ak, stride_am, stride_ak,
stride_bk, stride_bn, stride_bk, stride_bn,
stride_cm, stride_cn, stride_cm, stride_cn,
...@@ -113,6 +117,10 @@ def _kernel(A, B, C, state_x_ptr, state_w_ptr, M, N, K, divfactor, ...@@ -113,6 +117,10 @@ def _kernel(A, B, C, state_x_ptr, state_w_ptr, M, N, K, divfactor,
acc = (w_factor * (x_factor * (acc * divfactor))) acc = (w_factor * (x_factor * (acc * divfactor)))
acc = acc.to(C.dtype.element_ty) acc = acc.to(C.dtype.element_ty)
if has_bias:
bias = tl.load(bias + rn).to(C.dtype.element_ty)
acc = acc + bias[None, :]
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :] mask = (rm < M)[:, None] & (rn < N)[None, :]
# handles write-back with reduction-splitting # handles write-back with reduction-splitting
...@@ -122,9 +130,11 @@ def _kernel(A, B, C, state_x_ptr, state_w_ptr, M, N, K, divfactor, ...@@ -122,9 +130,11 @@ def _kernel(A, B, C, state_x_ptr, state_w_ptr, M, N, K, divfactor,
tl.atomic_add(C, acc, mask=mask) tl.atomic_add(C, acc, mask=mask)
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w): def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):
divfactor = 1. / (127. * 127.) divfactor = 1. / (127. * 127.)
has_bias = 0 if bias is None else 1
device = a.device device = a.device
# handle non-contiguous inputs if necessary # handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1: if a.stride(0) > 1 and a.stride(1) > 1:
...@@ -139,9 +149,9 @@ def int8_matmul_rowwise_dequantize(a, b, state_x, state_w): ...@@ -139,9 +149,9 @@ def int8_matmul_rowwise_dequantize(a, b, state_x, state_w):
c = torch.empty((M, N), device=device, dtype=torch.float16) c = torch.empty((M, N), device=device, dtype=torch.float16)
# accumulator types # accumulator types
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch kernel # launch int8_matmul_rowwise_dequantize kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_kernel[grid](a, b, c, state_x, state_w, M, N, K, divfactor, _int8_matmul_rowwise_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,
a.stride(0), a.stride(1), a.stride(0), a.stride(1),
b.stride(0), b.stride(1), b.stride(0), b.stride(1),
c.stride(0), c.stride(1), c.stride(0), c.stride(1),
......
import torch
import triton
import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
def init_to_zero(name):
return lambda nargs: nargs[name].zero_()
def get_configs_io_bound():
configs = []
for num_stages in [2, 3, 4, 5, 6]:
for block_m in [16, 32]:
for block_k in [32, 64]:
for block_n in [32, 64, 128, 256]:
num_warps = 2 if block_n <= 64 else 4
configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
num_stages=num_stages, num_warps=num_warps))
# split_k
for split_k in [2, 4, 8, 16]:
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
return configs
@triton.autotune(
configs=[
# basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
# good for int8
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
] + get_configs_io_bound(),
key=['M', 'N', 'K'],
prune_configs_by={
'early_config_prune': early_config_prune,
'perf_model': estimate_matmul_time,
'top_k': 10
},
)
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.jit
def _kernel(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr
):
# matrix multiplication
pid = tl.program_id(0)
pid_z = tl.program_id(1)
grid_m = tl.cdiv(M, BLOCK_M)
grid_n = tl.cdiv(N, BLOCK_N)
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
# pointers
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
w_factor = tl.load(state_w_ptr + rbn)[None, :]
x_factor = tl.load(state_x_ptr + ram)[:, None]
# acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
k_remaining = K - k * (BLOCK_K * SPLIT_K)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
acc += tl.dot(a, b)
A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk
acc = (w_factor * (x_factor * (acc * divfactor)))
acc = acc.to(C.dtype.element_ty)
if has_bias:
bias = tl.load(bias + rn).to(C.dtype.element_ty)
acc = acc + bias[None, :]
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(C, acc, mask=mask)
else:
tl.atomic_add(C, acc, mask=mask)
def int8_matmul_rowwise_dequantize_bias(a, b, state_x, state_w, bias):
#print(bias)
divfactor = 1. / (127. * 127.)
has_bias = 0 if bias is None else 1
if bias is not None:
bias = bias.contiguous()
device = a.device
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
M, K = a.shape
_, N = b.shape
# allocates output
c = torch.empty((M, N), device=device, dtype=torch.float16)
# accumulator types
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_kernel[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
GROUP_M=8, ACC_TYPE=ACC_TYPE)
return c
...@@ -5,6 +5,8 @@ import triton ...@@ -5,6 +5,8 @@ import triton
import triton.language as tl import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# This kernel does fused columnwise quantization and transpose.
# TODO: autotune this better. # TODO: autotune this better.
@triton.autotune( @triton.autotune(
configs=[ configs=[
...@@ -26,7 +28,7 @@ from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_tim ...@@ -26,7 +28,7 @@ from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_tim
key=['n_elements'] key=['n_elements']
) )
@triton.jit @triton.jit
def _quantize_columnwise_nogroup_transpose( def _quantize_columnwise_and_transpose(
x_ptr, x_ptr,
output_ptr, output_ptr,
output_maxs, output_maxs,
...@@ -51,7 +53,7 @@ def _quantize_columnwise_nogroup_transpose( ...@@ -51,7 +53,7 @@ def _quantize_columnwise_nogroup_transpose(
tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask) tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask)
tl.store(output_maxs + pid, max_val) tl.store(output_maxs + pid, max_val)
def quantize_columnwise_nogroup_transpose(x: torch.Tensor): def quantize_columnwise_and_transpose(x: torch.Tensor):
M, N = x.shape M, N = x.shape
output = torch.empty(N, M, device=x.device, dtype=torch.int8) output = torch.empty(N, M, device=x.device, dtype=torch.int8)
output_maxs = torch.empty(x.shape[1], device=x.device, dtype=torch.float16) output_maxs = torch.empty(x.shape[1], device=x.device, dtype=torch.float16)
...@@ -61,62 +63,6 @@ def quantize_columnwise_nogroup_transpose(x: torch.Tensor): ...@@ -61,62 +63,6 @@ def quantize_columnwise_nogroup_transpose(x: torch.Tensor):
assert x.is_cuda and output.is_cuda assert x.is_cuda and output.is_cuda
n_elements = output.numel() n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_quantize_columnwise_nogroup_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2) _quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2)
return output, output_maxs return output, output_maxs
if __name__ == '__main__':
torch.manual_seed(0)
x = torch.randn(1280, 768).cuda().to(torch.float16)
out = quantize_columnwise_nogroup_transpose(x)
x_real = x.t().float()
x_real_int8 = (127. * x_real / x_real.abs().max(dim=1, keepdim=True)[0]).round().to(torch.int8)
maxs = x_real.abs().max(dim=1, keepdim=True)[0].half()
#print(out[0][2,:])
print((out[0] == x_real_int8).float().mean())
print((out[1] == maxs[:, 0]).float().mean())
# print(out[0])
# print(out[1])
# print(out[0][2,:])
# print(x_real[2, :])
# print((out[0] != x_real).nonzero())
#import pdb; pdb.set_trace()
# repeat = 16
# for _ in range(8):
# out = quantize_columnwise_nogroup_transpose(x)
# triton_graph = torch.cuda.CUDAGraph()
# with torch.cuda.graph(triton_graph):
# out = quantize_columnwise_nogroup_transpose(x)
# triton_graph.replay()
# torch.cuda.synchronize()
# start = time.time()
# for _ in range(repeat):
# triton_graph.replay()
# torch.cuda.synchronize()
# end = time.time()
# print(out[0])
# print(out[1])
# print(x / x.abs().max(dim=0, keepdim=True)[0])
# x_real = (127 * (x / x.abs().max(dim=0, keepdim=True)[0])).round().to(torch.int8)
# max1 = out[1]
# max2 = x.abs().max(0)[0]
# print(max1, max2)
# import pdb; pdb.set_trace()
# print(torch.allclose(max1, max2))
# print(f"time: {(end - start) / repeat * 1000:.3f} ms")
...@@ -5,7 +5,7 @@ import triton ...@@ -5,7 +5,7 @@ import triton
import triton.language as tl import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# TODO: autotune this better. # global quantize
@triton.autotune( @triton.autotune(
configs=[ configs=[
triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4), triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4),
...@@ -42,6 +42,7 @@ def quantize_global(x: torch.Tensor): ...@@ -42,6 +42,7 @@ def quantize_global(x: torch.Tensor):
return output, absmax return output, absmax
# global quantize and transpose
@triton.autotune( @triton.autotune(
configs=[ configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),
...@@ -97,34 +98,3 @@ def quantize_global_transpose(input): ...@@ -97,34 +98,3 @@ def quantize_global_transpose(input):
_quantize_global_transpose[grid](input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N) _quantize_global_transpose[grid](input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N)
return out, absmax return out, absmax
if __name__ == '__main__':
w = torch.randn(768, 1280).cuda().to(torch.float16)
W_int8, state_w = quantize_global(w)
r_state_w = w.abs().max()
r_W_int8 = ((127 * w.float()) / state_w).round().to(torch.int8)
print((r_W_int8 == W_int8).float().mean())
# print(r_W_int8)
# print(W_int8)
exit()
repeat = 16
for _ in range(8):
out = quantize_global(w)
triton_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(triton_graph):
out = quantize_global(w)
triton_graph.replay()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
triton_graph.replay()
torch.cuda.synchronize()
end = time.time()
print(f"time: {(end - start) / repeat * 1000:.3f} ms")
...@@ -5,6 +5,8 @@ import triton ...@@ -5,6 +5,8 @@ import triton
import triton.language as tl import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# rowwise quantize
# TODO: autotune this better. # TODO: autotune this better.
@triton.autotune( @triton.autotune(
configs=[ configs=[
...@@ -24,7 +26,7 @@ from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_tim ...@@ -24,7 +26,7 @@ from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_tim
key=['n_elements'] key=['n_elements']
) )
@triton.jit @triton.jit
def _quantize_rowwise_nogroup( def _quantize_rowwise(
x_ptr, x_ptr,
output_ptr, output_ptr,
output_maxs, output_maxs,
...@@ -45,7 +47,7 @@ def _quantize_rowwise_nogroup( ...@@ -45,7 +47,7 @@ def _quantize_rowwise_nogroup(
tl.store(output_ptr + offsets, output, mask=row_mask) tl.store(output_ptr + offsets, output, mask=row_mask)
tl.store(output_maxs + pid, max_val) tl.store(output_maxs + pid, max_val)
def quantize_rowwise_nogroup(x: torch.Tensor): def quantize_rowwise(x: torch.Tensor):
output = torch.empty(*x.shape, device=x.device, dtype=torch.int8) output = torch.empty(*x.shape, device=x.device, dtype=torch.int8)
output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16) output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16)
...@@ -54,121 +56,6 @@ def quantize_rowwise_nogroup(x: torch.Tensor): ...@@ -54,121 +56,6 @@ def quantize_rowwise_nogroup(x: torch.Tensor):
assert x.is_cuda and output.is_cuda assert x.is_cuda and output.is_cuda
n_elements = output.numel() n_elements = output.numel()
grid = lambda meta: (x.shape[0],) grid = lambda meta: (x.shape[0],)
_quantize_rowwise_nogroup[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) _quantize_rowwise[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
return output, output_maxs return output, output_maxs
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=['n_elements']
)
@triton.jit
def _experimental_quantize_rowwise_nogroup(
x_ptr,
output_ptr,
bias_grad_ptr,
output_maxs,
n_elements,
M: tl.constexpr, N: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
P2: tl.constexpr,
P2M: tl.constexpr,
):
pid = tl.program_id(axis=0)
if pid < M:
block_start = pid * BLOCK_SIZE
arange = tl.arange(0, P2)
offsets = block_start + arange
row_mask = arange < BLOCK_SIZE
x = tl.load(x_ptr + offsets, mask=row_mask)
abs_x = tl.abs(x)
max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0)
output = tl.libdevice.llrint(127. * (x / max_val))
tl.store(output_ptr + offsets, output, mask=row_mask)
tl.store(output_maxs + pid, max_val)
else:
real_pid = pid - M
arange_new = tl.arange(0, P2M)
mask_new = arange_new < M
offsets_new = real_pid + arange_new * N
new_x = tl.load(x_ptr + offsets_new, mask=mask_new)
s = tl.sum(tl.where(mask_new, new_x, 0).to(tl.float32), axis=0)
tl.store(bias_grad_ptr + real_pid, s)
def experimental_quantize_rowwise_nogroup(x: torch.Tensor):
M, N = x.shape
output = torch.empty(*x.shape, device=x.device, dtype=torch.int8)
output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16)
bias_grad = torch.empty(x.shape[1], device=x.device, dtype=torch.float16)
P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
P2M = int(2 ** (math.ceil(math.log2(x.shape[0]))))
assert x.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (x.shape[0] + x.shape[1],)
_experimental_quantize_rowwise_nogroup[grid](x, output, bias_grad, output_maxs, n_elements, M, N, BLOCK_SIZE=x.shape[1], P2=P2, P2M=P2M)
return output, output_maxs, bias_grad
if __name__ == '__main__':
torch.manual_seed(0)
x = torch.randn(1280, 768).cuda().to(torch.float16)
out = quantize_rowwise_nogroup(x)
x_real = (127 * x.float() / x.abs().max(dim=1, keepdim=True)[0]).round().to(torch.int8)
max2 = x.abs().max(1)[0]
print(torch.allclose(out[1], max2))
print( (x_real == out[0]).float().mean() )
# for i in range(x.shape[0]):
# print( (x_real[i, :] == out[0][i, :]).float().mean() )
# print(out[0])
# print(x_real)
# import pdb; pdb.set_trace()
# print(out[2])
# print(out[2][:10])
sums = x.sum(dim=0)
#print(sums[:10])
#print( (sums == out[2]).float().mean() )
import pdb; pdb.set_trace()
# import pdb; pdb.set_trace()
# exit()
# repeat = 16
# for _ in range(8):
# out = quantize_rowwise_nogroup(x)
# triton_graph = torch.cuda.CUDAGraph()
# with torch.cuda.graph(triton_graph):
# out = quantize_rowwise_nogroup(x)
# triton_graph.replay()
# torch.cuda.synchronize()
# start = time.time()
# for _ in range(repeat):
# triton_graph.replay()
# torch.cuda.synchronize()
# end = time.time()
# print(out[0])
# print(out[1])
# print(x / x.abs().max(dim=1, keepdim=True)[0])
# max1 = out[1]
# max2 = x.abs().max(1)[0]
# print(max1, max2)
# print(torch.allclose(max1, max2))
#print(f"time: {(end - start) / repeat * 1000:.3f} ms")
This diff is collapsed.
...@@ -12,12 +12,18 @@ if __name__ == '__main__': ...@@ -12,12 +12,18 @@ if __name__ == '__main__':
fig = plt.figure(tight_layout=True, figsize=(12,3.5)) fig = plt.figure(tight_layout=True, figsize=(12,3.5))
gs = gridspec.GridSpec(1, 2) gs = gridspec.GridSpec(1, 2)
dims_to_consider = [1024, 1280, 1408, 1664, 2048, 4096]
batch_size_for_plot1 = 32768
batch_sizes_for_plot2 = [2**14, 2**15, 2**16, 2**17]
dims_to_xtick = [1024, 2048, 4096]
logscale_plot1 = True
ax = fig.add_subplot(gs[0, 0]) ax = fig.add_subplot(gs[0, 0])
rdf = pd.read_json('tests/triton_tests/info.jsonl', lines=True) rdf = pd.read_json('speed_benchmark/info_a100_py2.jsonl', lines=True)
df = rdf[rdf.batch_size == 32768] df = rdf[rdf.batch_size == batch_size_for_plot1]
# first plot the time occupied by different operations
for k, marker, ls, color, name in [ for k, marker, ls, color, name in [
('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (sum of parts)'), ('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (sum of parts)'),
('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (sum of parts)'), ('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (sum of parts)'),
...@@ -29,17 +35,15 @@ if __name__ == '__main__': ...@@ -29,17 +35,15 @@ if __name__ == '__main__':
('global_fwd', '^', '--', 'C4', 'Int8 Matmul XW (switchback)'), ('global_fwd', '^', '--', 'C4', 'Int8 Matmul XW (switchback)'),
('global_bwd', '^', '-.', 'C4', 'Int8 Matmul GW (switchback)'), ('global_bwd', '^', '-.', 'C4', 'Int8 Matmul GW (switchback)'),
#### time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd']
('x_quantize_rowwise', 'P', '--', 'C4', 'Quantize rowwise X (switchback)'), ('x_quantize_rowwise', 'P', '--', 'C4', 'Quantize rowwise X (switchback)'),
('g_quantize_rowwise', 'P', '-.', 'C4', 'Quantize rowwise G (switchback)'), ('g_quantize_rowwise', 'P', '-.', 'C4', 'Quantize rowwise G (switchback)'),
('w_quantize_global', '.', '--', 'C4', 'Quatnize global W (switchback)'), ('w_quantize_global', '.', '--', 'C4', 'Quatnize global W (switchback)'),
('w_quantize_global_transpose', '.', '-.', 'C4', 'Quantize gloabl and\ntranspose W (switchback)'), ('w_quantize_global_transpose', '.', '-.', 'C4', 'Quantize gloabl and\ntranspose W (switchback)'),
#('standard_gw', '.', '--', 'C1', 'standard_gw'),
]: ]:
xs = [] xs = []
ys = [] ys = []
for embed_dim in [1024, 1280, 1408, 1664, 2048, 4096]: for embed_dim in dims_to_consider:
# average over dim -> 4*dim and 4*dim -> dim
df_ = df[df.dim_in == embed_dim] df_ = df[df.dim_in == embed_dim]
df_ = df_[df_.dim_out == embed_dim * 4] df_ = df_[df_.dim_out == embed_dim * 4]
xs.append(embed_dim) xs.append(embed_dim)
...@@ -56,24 +60,20 @@ if __name__ == '__main__': ...@@ -56,24 +60,20 @@ if __name__ == '__main__':
ax.plot(xs, ys, color=color, label=name, marker=marker, markersize=5 if marker=='s' else 5, linestyle=ls, linewidth=2 if '+' in k else 1.) ax.plot(xs, ys, color=color, label=name, marker=marker, markersize=5 if marker=='s' else 5, linestyle=ls, linewidth=2 if '+' in k else 1.)
ax.set_xlabel('dim', fontsize=13) ax.set_xlabel('dim', fontsize=13)
ax.set_ylabel('time (ms)', fontsize=13) ax.set_ylabel('time (ms)', fontsize=13)
# make a legend which is below the plot
ax.grid() ax.grid()
ax.set_xscale('log') ax.set_xscale('log')
#ax.set_yscale('log') if logscale_plot1:
ax.set_yscale('log')
ax.tick_params(axis='x', labelsize=11) ax.tick_params(axis='x', labelsize=11)
ax.tick_params(axis='y', labelsize=11) ax.tick_params(axis='y', labelsize=11)
ax.set_xticks([1024, 2048, 4096]) ax.set_xticks(dims_to_xtick)
ax.set_xticklabels([1024, 2048, 4096]) ax.set_xticklabels(dims_to_xtick)
ax.set_xticks([], minor=True) ax.set_xticks([], minor=True)
leg = ax.legend(loc='upper center', bbox_to_anchor=(-0.64, 1.), ncol=1, fontsize=10) leg = ax.legend(loc='upper center', bbox_to_anchor=(-0.64, 1.), ncol=1, fontsize=10)
...@@ -86,7 +86,7 @@ if __name__ == '__main__': ...@@ -86,7 +86,7 @@ if __name__ == '__main__':
ax = fig.add_subplot(gs[0, 1]) ax = fig.add_subplot(gs[0, 1])
# now plot the % speedup for different batch sizes # now plot the % speedup for different batch sizes
for j, batch_size in enumerate([2**14, 2**15, 2**16, 2**17]): for j, batch_size in enumerate(batch_sizes_for_plot2):
all_xs, all_ys = [], [] all_xs, all_ys = [], []
for k, marker, ls, color, name in [ for k, marker, ls, color, name in [
('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (total time)'), ('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (total time)'),
...@@ -95,7 +95,7 @@ if __name__ == '__main__': ...@@ -95,7 +95,7 @@ if __name__ == '__main__':
xs, ys = [], [] xs, ys = [], []
df = rdf[rdf.batch_size == batch_size] df = rdf[rdf.batch_size == batch_size]
for embed_dim in [1024, 1280, 1408, 1664, 2048, 4096]: for embed_dim in dims_to_consider:
df_ = df[df.dim_in == embed_dim] df_ = df[df.dim_in == embed_dim]
df_ = df_[df_.dim_out == embed_dim * 4] df_ = df_[df_.dim_out == embed_dim * 4]
xs.append(embed_dim) xs.append(embed_dim)
...@@ -125,13 +125,13 @@ if __name__ == '__main__': ...@@ -125,13 +125,13 @@ if __name__ == '__main__':
ax.tick_params(axis='x', labelsize=11) ax.tick_params(axis='x', labelsize=11)
ax.tick_params(axis='y', labelsize=11) ax.tick_params(axis='y', labelsize=11)
ax.set_xticks([1024, 2048, 4096]) ax.set_xticks(dims_to_xtick)
ax.set_xticklabels([1024, 2048, 4096]) ax.set_xticklabels(dims_to_xtick)
ax.set_xticks([], minor=True) ax.set_xticks([], minor=True)
ax.set_title(' Linear layer summary, varying dimensions', fontsize=10, loc='left', y=1.05, pad=-20) ax.set_title(' Linear layer summary, varying dimensions', fontsize=10, loc='left', y=1.05, pad=-20)
plt.savefig('tests/triton_tests/plot1.pdf', bbox_inches='tight') plt.savefig('speed_benchmark/plot_with_info.pdf', bbox_inches='tight')
import json
import time
import torch
import torch.nn as nn
from bitsandbytes.nn.triton_utils.v0.quantize_rowwise import quantize_rowwise
from bitsandbytes.nn.triton_utils.v0.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
from bitsandbytes.nn.triton_utils.v0.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
from bitsandbytes.nn.triton_utils.v0.quantize_global import quantize_global, quantize_global_transpose
from bitsandbytes.nn.triton_utils.v0.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze
# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large.
def get_time(k, fn, info_dict):
for _ in range(repeat // 2):
fn()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
fn()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info_dict[k] = ms
if __name__ == '__main__':
torch.manual_seed(0)
wm = 4
for dim in [1024, 1280, 1408, 1664, 2048, 4096]:
# note "batch_size" is actually "batch_size * embed_dim", which is why it's large
for batch_size in [256*32, 256*64, 256*128, 256*256, 256*512]:
# switch switches dim_in and dim_out
for switch in [False, True]:
# hparams
repeat = 64
batch_size = batch_size
dim_out = dim * wm
dim_in = dim
if switch:
dim_out = dim
dim_in = wm * dim
dim_in = round(dim_in)
dim_out = round(dim_out)
# simulate forward pass
x = torch.randn(batch_size, dim_in, dtype=torch.float16).cuda()
g = torch.randn(batch_size, dim_out, dtype=torch.float16).cuda()
w = torch.randn(dim_out, dim_in, dtype=torch.float16).cuda()
x_int8 = x.clone().to(torch.int8)
g_int8 = g.clone().to(torch.int8)
w_int8 = w.clone().to(torch.int8)
wt_int8 = w.t().contiguous().clone().to(torch.int8)
state_x_rowwise = x.max(dim=1)[0]
state_g_rowwise = g.max(dim=1)[0]
state_w_columnwise = w.max(dim=0)[0]
state_w_rowwise = w.max(dim=1)[0]
state_w_global = w.max()
info = {'repeat' : repeat, 'batch_size' : batch_size, 'dim_out' : dim_out, 'dim_in' : dim_in, 'wm' : wm, 'switch' : switch}
get_time('standard_fwd', lambda : x.matmul(w.t()), info)
get_time('standard_gw', lambda : g.t().matmul(x), info)
get_time('standard_gx', lambda : g.matmul(w), info)
get_time('rowwise_fwd', lambda : int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise, None), info)
get_time('rowwise_bwd', lambda : int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise, None), info)
get_time('global_fwd', lambda : int8_matmul_mixed_dequanitze(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('global_bwd', lambda : int8_matmul_mixed_dequanitze(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('x_quantize_rowwise', lambda : quantize_rowwise(x), info)
get_time('g_quantize_rowwise', lambda : quantize_rowwise(g), info)
get_time('w_quantize_rowwise', lambda : quantize_rowwise(w), info)
get_time('w_quantize_colwise_transpose', lambda : quantize_columnwise_and_transpose(w), info)
get_time('w_quantize_global', lambda : quantize_global(w), info)
get_time('w_quantize_global_transpose', lambda : quantize_global_transpose(w), info)
time_standard = info['standard_fwd'] + info['standard_gx'] + info['standard_gw']
time_rowwise = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_colwise_transpose'] + info['w_quantize_rowwise'] + info['standard_gw'] + info['rowwise_fwd'] + info['rowwise_bwd']
time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd']
print('TOTAL STANDARD', time_standard)
print('TOTAL ROWWISE', time_rowwise)
print('TOTAL GLOBAL', time_global)
print('speedup', -100*(time_global - time_standard)/time_standard)
info['time_standard'] = time_standard
info['time_rowwise'] = time_rowwise
info['time_global'] = time_global
info_json = json.dumps(info)
with open("speed_benchmark/info_a100_py2.jsonl", "a") as file:
file.write(info_json + "\n")
import pytest import pytest
import torch import torch
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear from bitsandbytes.nn.triton_based_modules import SwitchBackLinear
from bitsandbytes.nn import Linear8bitLt
@pytest.mark.parametrize("vectorrize", [False, True])
@pytest.mark.parametrize("triton_module", [SwitchBackGlobalLinear, SwitchBackLinear]) def test_switchback(vectorrize):
def test_switchbatch(triton_module):
for dim in [83, 17, 128]: for dim in [83, 17, 128]:
for batch in [13, 128, 256]: for batch in [13, 128, 256]:
standard = torch.nn.Linear(dim, 4 * dim).cuda().half() standard = torch.nn.Linear(dim, 4 * dim).cuda().half()
switchback = triton_module(dim, 4 * dim).cuda().half() print('vectorrize', vectorrize)
switchback = SwitchBackLinear(dim, 4 * dim, vectorize=vectorrize).cuda().half()
baseline = Linear8bitLt(dim, 4 * dim).cuda().half()
switchback.weight.data.copy_(standard.weight) switchback.weight.data.copy_(standard.weight)
switchback.bias.data.copy_(standard.bias) switchback.bias.data.copy_(standard.bias)
baseline.weight.data.copy_(standard.weight)
baseline.bias.data.copy_(standard.bias)
for i in range(100):
x1 = torch.randn(batch, dim).cuda().half().requires_grad_(True) x1 = torch.randn(batch, dim).cuda().half().requires_grad_(True)
x2 = x1.clone().detach().requires_grad_(True) x2 = x1.clone().detach().requires_grad_(True)
print('standard') x3 = x1.clone().detach().requires_grad_(True)
out_standard = standard(x1) out_standard = standard(x1)
print('switchback') (2**10 * out_standard.abs().mean()).backward()
out_sb = switchback(x1)
out_sb = switchback(x2)
(2**10 * out_sb.abs().mean()).backward()
(out_standard.abs().mean()).backward() out_baseline = baseline(x3)
(out_sb.abs().mean()).backward() (2**10 * out_baseline.abs().mean()).backward()
err_sb = (out_standard - out_sb).abs().mean() err_sb = (out_standard - out_sb).abs().mean()
print('OUT', err_sb) err_baseline = (out_standard - out_baseline).abs().mean()
print('OUT', err_sb, err_baseline)
assert err_sb < 2 * err_baseline
err_sb = (standard.bias.grad - switchback.bias.grad).abs().mean() err_sb = (standard.bias.grad - switchback.bias.grad).abs().mean()
err_baseline = (standard.bias.grad - baseline.bias.grad).abs().mean()
print('GW2', err_sb) print('GW2', err_sb, err_baseline)
assert err_sb < 2 * err_baseline
err_sb = (standard.weight.grad - switchback.weight.grad).abs().mean() err_sb = (standard.weight.grad - switchback.weight.grad).abs().mean()
err_baseline = (standard.weight.grad - baseline.weight.grad).abs().mean()
print('GW1', err_sb) print('GW1', err_sb, err_baseline)
assert err_sb < 2 * err_baseline
#err_sb = (x1.grad - x2.grad).abs().mean() err_sb = (x1.grad - x2.grad).abs().mean()
err_baseline = (x1.grad - x3.grad).abs().mean()
#print('GX1', err_sb) print('GX1', err_sb, err_baseline)
assert err_sb < 2 * err_baseline
import torch
import json
from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, StandardLinear
import time
# class AttentionOld(torch.nn.Module):
# def __init__(
# self,
# dim,
# num_heads=8,
# qkv_bias=True,
# scaled_cosine=False,
# scale_heads=False,
# attn_drop=0.,
# proj_drop=0.,
# linear_module=torch.nn.Linear,
# ):
# super().__init__()
# self.scaled_cosine = scaled_cosine
# self.scale_heads = scale_heads
# assert dim % num_heads == 0, 'dim should be divisible by num_heads'
# self.num_heads = num_heads
# self.head_dim = dim // num_heads
# self.scale = self.head_dim ** -0.5
# self.in_proj_linear = linear_module(dim, 3 * dim, bias = qkv_bias)
# self.attn_drop = torch.nn.Dropout(attn_drop)
# if self.scale_heads:
# self.head_scale = torch.nn.Parameter(torch.ones((num_heads, 1, 1)))
# else:
# self.head_scale = None
# self.out_proj = linear_module(dim, dim)
# self.out_drop = torch.nn.Dropout(proj_drop)
# def forward(self, x, attn_mask = None):
# L, N, C = x.shape
# q, k, v = self.in_proj_linear(x).chunk(3, dim=-1)
# q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
# k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
# v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
# q = q * self.scale
# attn = torch.bmm(q, k.transpose(-1, -2))
# if attn_mask is not None:
# if attn_mask.dtype == torch.bool:
# new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
# new_attn_mask.masked_fill_(attn_mask, float("-inf"))
# attn_mask = new_attn_mask
# attn += attn_mask
# attn = attn.softmax(dim=-1)
# attn = self.attn_drop(attn)
# x = torch.bmm(attn, v)
# x = x.transpose(0, 1).reshape(L, N, C)
# x = self.out_proj(x)
# x = self.out_drop(x)
# return x
class Attention(torch.nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=True,
scaled_cosine=False,
scale_heads=False,
attn_drop=0.,
proj_drop=0.,
linear_module=torch.nn.Linear,
):
super().__init__()
self.scaled_cosine = scaled_cosine
self.scale_heads = scale_heads
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.ln = torch.nn.LayerNorm(dim)
self.in_proj_linear = linear_module(dim, 3 * dim, bias = qkv_bias)
self.attn_drop = torch.nn.Dropout(attn_drop)
if self.scale_heads:
self.head_scale = torch.nn.Parameter(torch.ones((num_heads, 1, 1)))
else:
self.head_scale = None
self.out_proj = linear_module(dim, dim)
self.out_drop = torch.nn.Dropout(proj_drop)
def forward(self, x, attn_mask = None):
q, k, v = self.in_proj_linear(self.ln(x)).chunk(3, dim=-1)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask)
x = self.out_proj(x)
return x
if __name__ == '__main__':
for dim in [1024, 1280, 1408, 1664, 2048]:
for batch in [2**14, 2**15, 2**16, 2**17]:
# if dim != 4096 or batch != 2**17:
# continue
x1 = torch.randn( batch // 256, 256, dim ).cuda().requires_grad_(True)
qu = torch.randn( batch // 256, 256, dim ).cuda().requires_grad_(True)
ke = torch.randn( batch // 256, 256, dim ).cuda().requires_grad_(True)
va = torch.randn( batch // 256, 256, dim ).cuda().requires_grad_(True)
standard = Attention(dim).cuda()
my_standard = Attention(dim, linear_module=StandardLinear).cuda()
sb = Attention(dim, linear_module=SwitchBackGlobalLinear).cuda()
standard_compiled = torch.compile(standard)
ln_model = torch.nn.Sequential(
torch.nn.LayerNorm(dim),
torch.nn.LayerNorm(dim),
).cuda()
ln_model_compiled = torch.compile(
ln_model
)
gelu_model = torch.nn.Sequential(
torch.nn.GELU(),
).cuda()
gelu_model_compiled = torch.compile(
gelu_model
)
print('Model part 2')
repeat = 32
info = {'repeat' : repeat, 'batch_size' : batch, 'dim' : dim}
k = 'attn'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out_attn = torch.nn.functional.scaled_dot_product_attention(qu, ke, va)
((2 ** 16) * out_attn).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out_attn = torch.nn.functional.scaled_dot_product_attention(qu, ke, va)
((2 ** 16) * out_attn).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'ln'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out = ln_model(x1)
((2 ** 16) * out).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out = ln_model(x1)
((2 ** 16) * out).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
x1.grad.zero_()
k = 'ln_compiled'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out = ln_model_compiled(x1)
((2 ** 16) * out).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out = ln_model_compiled(x1)
((2 ** 16) * out).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'gelu'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out = gelu_model(x1)
((2 ** 16) * out).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out = gelu_model(x1)
((2 ** 16) * out).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
x1.grad.zero_()
k = 'gelu_compiled'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out = gelu_model_compiled(x1)
((2 ** 16) * out).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out = gelu_model_compiled(x1)
((2 ** 16) * out).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
x1.grad.zero_()
k = 'standard'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out_standard = standard(x1)
((2 ** 16) * out_standard).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out_standard = standard(x1)
((2 ** 16) * out_standard).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
x1.grad.zero_()
k = 'my_standard'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out_my_standard = my_standard(x1)
((2 ** 16) * out_my_standard).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out_my_standard = my_standard(x1)
((2 ** 16) * out_my_standard).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
#
#
x1.grad.zero_()
k = 'standard_compiled'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out_standard_compiled = standard_compiled(x1)
((2 ** 16) * out_standard_compiled).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out_standard_compiled = standard_compiled(x1)
((2 ** 16) * out_standard_compiled).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
x1.grad.zero_()
k = 'sb'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out_sb = sb(x1)
((2 ** 16) * out_sb).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out_sb = sb(x1)
((2 ** 16) * out_sb).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
info_json = json.dumps(info)
with open("tests/triton_tests/attn_info_ln.jsonl", "a") as file:
file.write(info_json + "\n")
#exit()
# err_fused = (out_standard - out_fused).abs().mean()
# err_sb = (out_standard - out_sb).abs().mean()
# print('OUT', err_fused, err_sb)
# err_fused = (standard[d].weight.grad - fused_mlp.linear2.weight.grad).abs().mean()
# err_sb = (standard[d].weight.grad - sb[d].weight.grad).abs().mean()
# print('GW2', err_fused, err_sb)
# err_fused = (standard[0].weight.grad - fused_mlp.linear1.weight.grad).abs().mean()
# err_sb = (standard[0].weight.grad - sb[0].weight.grad).abs().mean()
# print('GW1', err_fused, err_sb)
# err_fused = (x1.grad - x2.grad).abs().mean()
# err_sb = (x1.grad - x3.grad).abs().mean()
# print('GX1', err_fused, err_sb)
# import pdb; pdb.set_trace()
# # NO GELU, ST GRADIENTS, EVERYTHING FINE.
{"repeat": 32, "batch_size": 16384, "dim": 1024, "attn": 2.1414458751678467, "ln": 1.6365647315979004, "ln_compiled": 1.799367368221283, "gelu": 1.0930374264717102, "gelu_compiled": 1.094818115234375, "standard": 4.159651696681976, "my_standard": 4.696495831012726, "standard_compiled": 3.675594925880432, "sb": 4.1465312242507935}
{"repeat": 32, "batch_size": 32768, "dim": 1024, "attn": 4.100345075130463, "ln": 3.1594187021255493, "ln_compiled": 3.437422215938568, "gelu": 2.109348773956299, "gelu_compiled": 2.11450457572937, "standard": 7.706902921199799, "my_standard": 8.799396455287933, "standard_compiled": 6.735652685165405, "sb": 7.66376405954361}
{"repeat": 32, "batch_size": 65536, "dim": 1024, "attn": 7.953710854053497, "ln": 6.236426532268524, "ln_compiled": 6.746955215930939, "gelu": 4.164382815361023, "gelu_compiled": 4.171714186668396, "standard": 14.894917607307434, "my_standard": 17.042435705661774, "standard_compiled": 12.985721230506897, "sb": 14.6140456199646}
{"repeat": 32, "batch_size": 131072, "dim": 1024, "attn": 15.638880431652069, "ln": 12.333884835243225, "ln_compiled": 13.272866606712341, "gelu": 8.228793740272522, "gelu_compiled": 8.243747055530548, "standard": 29.425136744976044, "my_standard": 35.08377820253372, "standard_compiled": 25.69487690925598, "sb": 28.760001063346863}
{"repeat": 32, "batch_size": 16384, "dim": 1280, "attn": 2.627238631248474, "ln": 2.0098239183425903, "ln_compiled": 2.4197474122047424, "gelu": 1.3455823063850403, "gelu_compiled": 1.35069340467453, "standard": 5.554787814617157, "my_standard": 6.2290579080581665, "standard_compiled": 5.132324993610382, "sb": 5.4178386926651}
{"repeat": 32, "batch_size": 32768, "dim": 1280, "attn": 5.0596073269844055, "ln": 3.903590142726898, "ln_compiled": 4.719957709312439, "gelu": 2.6203468441963196, "gelu_compiled": 2.627365291118622, "standard": 10.546617209911346, "my_standard": 11.850126087665558, "standard_compiled": 9.685918688774109, "sb": 10.088451206684113}
{"repeat": 32, "batch_size": 65536, "dim": 1280, "attn": 9.845800697803497, "ln": 7.711298763751984, "ln_compiled": 9.292080998420715, "gelu": 5.172915756702423, "gelu_compiled": 5.180932581424713, "standard": 21.371990442276, "my_standard": 23.921720683574677, "standard_compiled": 19.669152796268463, "sb": 20.267993211746216}
{"repeat": 32, "batch_size": 131072, "dim": 1280, "attn": 19.375711679458618, "ln": 15.333592891693115, "ln_compiled": 18.245264887809753, "gelu": 10.264746844768524, "gelu_compiled": 10.283775627613068, "standard": 41.79700464010239, "my_standard": 45.84744572639465, "standard_compiled": 38.35208714008331, "sb": 38.35364431142807}
{"repeat": 32, "batch_size": 16384, "dim": 1408, "attn": 2.9110386967658997, "ln": 2.1998360753059387, "ln_compiled": 2.581551671028137, "gelu": 1.4731436967849731, "gelu_compiled": 1.478634774684906, "standard": 6.764143705368042, "my_standard": 7.331632077693939, "standard_compiled": 6.24605268239975, "sb": 6.325609982013702}
{"repeat": 32, "batch_size": 32768, "dim": 1408, "attn": 5.542516708374023, "ln": 4.289716482162476, "ln_compiled": 5.065307021141052, "gelu": 2.8742849826812744, "gelu_compiled": 2.882353961467743, "standard": 12.749537825584412, "my_standard": 13.79828155040741, "standard_compiled": 11.728867888450623, "sb": 11.642806231975555}
{"repeat": 32, "batch_size": 65536, "dim": 1408, "attn": 10.80312579870224, "ln": 8.471302688121796, "ln_compiled": 9.96796041727066, "gelu": 5.681410431861877, "gelu_compiled": 5.6905597448349, "standard": 25.19702911376953, "my_standard": 27.226239442825317, "standard_compiled": 23.22910726070404, "sb": 22.682294249534607}
{"repeat": 32, "batch_size": 131072, "dim": 1408, "attn": 21.284908056259155, "ln": 16.85701310634613, "ln_compiled": 19.643358886241913, "gelu": 11.292420327663422, "gelu_compiled": 11.314474046230316, "standard": 50.06787180900574, "my_standard": 54.29378151893616, "standard_compiled": 44.58653926849365, "sb": 45.359253883361816}
{"repeat": 32, "batch_size": 16384, "dim": 1664, "attn": 3.382459282875061, "ln": 2.6206374168395996, "ln_compiled": 2.9666870832443237, "gelu": 1.7263293266296387, "gelu_compiled": 1.7317384481430054, "standard": 8.414775133132935, "my_standard": 9.117811918258667, "standard_compiled": 7.7542513608932495, "sb": 7.70898163318634}
{"repeat": 32, "batch_size": 32768, "dim": 1664, "attn": 6.468378007411957, "ln": 5.125559866428375, "ln_compiled": 5.791269242763519, "gelu": 3.3864825963974, "gelu_compiled": 3.3920034766197205, "standard": 16.016244888305664, "my_standard": 17.25083589553833, "standard_compiled": 14.60808515548706, "sb": 14.347739517688751}
{"repeat": 32, "batch_size": 65536, "dim": 1664, "attn": 12.645229697227478, "ln": 10.13532280921936, "ln_compiled": 11.427387595176697, "gelu": 6.6957250237464905, "gelu_compiled": 6.711684167385101, "standard": 31.792201101779938, "my_standard": 34.31189805269241, "standard_compiled": 29.10037338733673, "sb": 28.3128023147583}
{"repeat": 32, "batch_size": 131072, "dim": 1664, "attn": 24.970605969429016, "ln": 20.182937383651733, "ln_compiled": 22.7489173412323, "gelu": 13.326868414878845, "gelu_compiled": 13.345755636692047, "standard": 63.46555054187775, "my_standard": 70.19880414009094, "standard_compiled": 56.40875548124313, "sb": 56.22846633195877}
{"repeat": 32, "batch_size": 16384, "dim": 2048, "attn": 4.080049693584442, "ln": 3.2655522227287292, "ln_compiled": 3.3329352736473083, "gelu": 2.108432352542877, "gelu_compiled": 2.114713191986084, "standard": 11.370822787284851, "my_standard": 12.234866619110107, "standard_compiled": 10.377615690231323, "sb": 10.209612548351288}
{"repeat": 32, "batch_size": 32768, "dim": 2048, "attn": 7.74645060300827, "ln": 6.418220698833466, "ln_compiled": 6.55733048915863, "gelu": 4.163652658462524, "gelu_compiled": 4.171028733253479, "standard": 21.39316499233246, "my_standard": 23.04024249315262, "standard_compiled": 19.431106746196747, "sb": 18.732361495494843}
{"repeat": 32, "batch_size": 65536, "dim": 2048, "attn": 15.235155820846558, "ln": 12.684382498264313, "ln_compiled": 12.895286083221436, "gelu": 8.228868246078491, "gelu_compiled": 8.242718875408173, "standard": 42.55136102437973, "my_standard": 45.82635313272476, "standard_compiled": 38.663335144519806, "sb": 36.76284849643707}
{"repeat": 32, "batch_size": 131072, "dim": 2048, "attn": 30.24454414844513, "ln": 25.25731921195984, "ln_compiled": 25.67601203918457, "gelu": 16.384944319725037, "gelu_compiled": 16.409948468208313, "standard": 84.26841348409653, "my_standard": 91.10662341117859, "standard_compiled": 76.89539343118668, "sb": 71.73164188861847}
import json
import time
import torch
import torch.nn as nn
import bitsandbytes.nn as bnn
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear, StandardLinear
from bitsandbytes.nn.triton_utils.v0.quantize_rowwise_nogroup import quantize_rowwise_nogroup
from bitsandbytes.nn.triton_utils.v0.quantize_columnwise_nogroup_transpose import quantize_columnwise_nogroup_transpose
from bitsandbytes.nn.triton_utils.v0.int8_matmul_rowwise_dequantize_bias import int8_matmul_rowwise_dequantize_bias
from bitsandbytes.nn.triton_utils.v0.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
from bitsandbytes.nn.triton_utils.v0.quantize_global import quantize_global, quantize_global_transpose
from bitsandbytes.nn.triton_utils.v0.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze, int8_matmul_mixed_dequanitze_bias
# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large.
# not that big of an issue.
def get_time_standard_fwd(k, v):
x = torch.randn(batch_size, dim_in, dtype=torch.float16).cuda()
g = torch.randn(batch_size, dim_out, dtype=torch.float16).cuda()
##### time matmul 1
for _ in range(repeat // 2):
g.t().matmul(x)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
g.t().matmul(x)
torch.cuda.synchronize()
end = time.time()
print(f"time {k}: {(end - start) / repeat * 1000:.3f} ms")
return (end - start) / repeat * 1000
if __name__ == '__main__':
torch.manual_seed(0)
#for (dim, wm) in [(1024, 4), (1280, 4), (1408, 4.3637), (1664, 4.9231), (2048, 4), (4096, 4), (8096, 4)]
for (dim, wm) in [(1408, 4), (1664, 4),]:
for batch_size in [256*32, 256*64, 256*128, 256*256, 256*512]:
#for batch_size in [256*256, 256*512]:
for switch in [False, True]:
# hparams
repeat = 64
batch_size = batch_size
dim_out = dim * wm
dim_in = dim
if switch:
dim_out = dim
dim_in = wm * dim
dim_in = round(dim_in)
dim_out = round(dim_out)
# simulate forward pass
x = torch.randn(batch_size, dim_in, dtype=torch.float16).cuda()
g = torch.randn(batch_size, dim_out, dtype=torch.float16).cuda()
w = torch.randn(dim_out, dim_in, dtype=torch.float16).cuda()
x_int8 = x.clone().to(torch.int8)
g_int8 = g.clone().to(torch.int8)
w_int8 = w.clone().to(torch.int8)
wt_int8 = w.t().contiguous().clone().to(torch.int8)
state_x_rowwise = x.max(dim=1)[0]
state_g_rowwise = g.max(dim=1)[0]
state_w_columnwise = w.max(dim=0)[0]
state_w_rowwise = w.max(dim=1)[0]
state_w_global = w.max()
info = {'repeat' : repeat, 'batch_size' : batch_size, 'dim_out' : dim_out, 'dim_in' : dim_in, 'wm' : wm, 'switch' : switch}
k = 'standard_fwd'
for _ in range(repeat // 2):
x.matmul(w.t())
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
x.matmul(w.t())
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'standard_gw'
for _ in range(repeat // 2):
g.t().matmul(x)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
g.t().matmul(x)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'standard_gx'
for _ in range(repeat // 2):
g.matmul(w)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
g.matmul(w)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'rowwise_fwd'
for _ in range(repeat // 2):
int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'rowwise_bwd'
for _ in range(repeat // 2):
int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'global_fwd'
for _ in range(repeat // 2):
int8_matmul_mixed_dequanitze(x_int8, w_int8.t(), state_x_rowwise, state_w_global)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
int8_matmul_mixed_dequanitze(x_int8, w_int8.t(), state_x_rowwise, state_w_global)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'global_bwd'
for _ in range(repeat // 2):
int8_matmul_mixed_dequanitze(g_int8, wt_int8.t(), state_x_rowwise, state_w_global)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
int8_matmul_mixed_dequanitze(g_int8, wt_int8.t(), state_x_rowwise, state_w_global)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'x_quantize_rowwise'
for _ in range(repeat // 2):
quantize_rowwise_nogroup(x)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
quantize_rowwise_nogroup(x)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'g_quantize_rowwise'
for _ in range(repeat // 2):
quantize_rowwise_nogroup(g)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
quantize_rowwise_nogroup(g)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'w_quantize_rowwise'
for _ in range(repeat // 2):
quantize_rowwise_nogroup(w)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
quantize_rowwise_nogroup(w)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'w_quantize_colwise_transpose'
for _ in range(repeat // 2):
quantize_columnwise_nogroup_transpose(w)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
quantize_columnwise_nogroup_transpose(w)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'w_quantize_global'
for _ in range(repeat // 2):
quantize_global(w)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
quantize_global(w)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'w_quantize_global_transpose'
for _ in range(repeat // 2):
quantize_global_transpose(w)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
quantize_global_transpose(w)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'cast_x'
for _ in range(repeat // 2):
newx = x.to(torch.int8)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
newx = x.to(torch.int8)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'cast_g'
for _ in range(repeat // 2):
newx = g.to(torch.int8)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
newx = g.to(torch.int8)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'cast_w'
for _ in range(repeat // 2):
newx = w.to(torch.int8)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
newx = w.to(torch.int8)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
time_standard = info['standard_fwd'] + info['standard_gx'] + info['standard_gw']
time_rowwise = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_colwise_transpose'] + info['w_quantize_rowwise'] + info['standard_gw'] + info['rowwise_fwd'] + info['rowwise_bwd']
time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd']
print('TOTAL STANDARD', time_standard)
print('TOTAL ROWWISE', time_rowwise)
print('TOTAL GLOBAL', time_global)
print('speedup', -100*(time_global - time_standard)/time_standard)
info['time_standard'] = time_standard
info['time_rowwise'] = time_rowwise
info['time_global'] = time_global
info_json = json.dumps(info)
with open("tests/triton_tests/info.jsonl", "a") as file:
file.write(info_json + "\n")
This diff is collapsed.
{"repeat": 32, "batch_size": 16384, "dim": 1024, "standard": 3.807276487350464, "my_standard": 4.196919500827789, "standard_compiled": 3.771558403968811, "sb": 3.5132691264152527}
{"repeat": 32, "batch_size": 32768, "dim": 1024, "standard": 7.215872406959534, "my_standard": 7.991522550582886, "standard_compiled": 7.241688668727875, "sb": 6.581142544746399}
{"repeat": 32, "batch_size": 65536, "dim": 1024, "standard": 14.26444947719574, "my_standard": 15.685759484767914, "standard_compiled": 14.251746237277985, "sb": 12.735314667224884}
{"repeat": 32, "batch_size": 131072, "dim": 1024, "standard": 28.49559485912323, "my_standard": 31.26966953277588, "standard_compiled": 28.414390981197357, "sb": 25.319166481494904}
{"repeat": 32, "batch_size": 16384, "dim": 1280, "standard": 5.887262523174286, "my_standard": 6.132654845714569, "standard_compiled": 5.902409553527832, "sb": 4.947789013385773}
{"repeat": 32, "batch_size": 32768, "dim": 1280, "standard": 11.14131510257721, "my_standard": 12.859955430030823, "standard_compiled": 11.133037507534027, "sb": 9.303092956542969}
{"repeat": 32, "batch_size": 65536, "dim": 1280, "standard": 22.193141281604767, "my_standard": 25.66336840391159, "standard_compiled": 22.22583442926407, "sb": 18.285617232322693}
{"repeat": 32, "batch_size": 131072, "dim": 1280, "standard": 44.23898458480835, "my_standard": 51.30268633365631, "standard_compiled": 44.08355802297592, "sb": 35.999126732349396}
{"repeat": 32, "batch_size": 16384, "dim": 1408, "standard": 6.938718259334564, "my_standard": 7.269218564033508, "standard_compiled": 6.94604218006134, "sb": 5.764961242675781}
{"repeat": 32, "batch_size": 32768, "dim": 1408, "standard": 13.04878294467926, "my_standard": 13.742901384830475, "standard_compiled": 13.011425733566284, "sb": 10.774023830890656}
{"repeat": 32, "batch_size": 65536, "dim": 1408, "standard": 26.738539338111877, "my_standard": 27.739346027374268, "standard_compiled": 26.75659954547882, "sb": 21.882005035877228}
{"repeat": 32, "batch_size": 131072, "dim": 1408, "standard": 51.905401051044464, "my_standard": 53.98637801408768, "standard_compiled": 51.8316924571991, "sb": 41.67725890874863}
{"repeat": 32, "batch_size": 16384, "dim": 1664, "standard": 9.233824908733368, "my_standard": 9.619377553462982, "standard_compiled": 9.214423596858978, "sb": 7.557623088359833}
{"repeat": 32, "batch_size": 32768, "dim": 1664, "standard": 17.324909567832947, "my_standard": 17.996780574321747, "standard_compiled": 17.29544997215271, "sb": 14.035224914550781}
{"repeat": 32, "batch_size": 65536, "dim": 1664, "standard": 35.51657497882843, "my_standard": 36.674730479717255, "standard_compiled": 35.43049842119217, "sb": 28.38330715894699}
{"repeat": 32, "batch_size": 131072, "dim": 1664, "standard": 69.0087378025055, "my_standard": 71.56594842672348, "standard_compiled": 68.82885098457336, "sb": 54.01633679866791}
{"repeat": 32, "batch_size": 16384, "dim": 2048, "standard": 12.590140104293823, "my_standard": 13.106442987918854, "standard_compiled": 12.606985867023468, "sb": 10.286301374435425}
{"repeat": 32, "batch_size": 32768, "dim": 2048, "standard": 24.830535054206848, "my_standard": 25.563716888427734, "standard_compiled": 24.895809590816498, "sb": 19.559212028980255}
{"repeat": 32, "batch_size": 65536, "dim": 2048, "standard": 49.55078661441803, "my_standard": 51.16480588912964, "standard_compiled": 49.739621579647064, "sb": 38.29141706228256}
{"repeat": 32, "batch_size": 131072, "dim": 2048, "standard": 98.36294502019882, "my_standard": 102.69322991371155, "standard_compiled": 98.76712411642075, "sb": 75.88706165552139}
{"repeat": 32, "batch_size": 16384, "dim": 1024, "standard": 4.91420179605484, "my_standard": 5.577877163887024, "standard_compiled": 4.810944199562073, "sb": 4.512995481491089}
{"repeat": 32, "batch_size": 32768, "dim": 1024, "standard": 8.876129984855652, "my_standard": 10.154612362384796, "standard_compiled": 8.820965886116028, "sb": 8.367843925952911}
{"repeat": 32, "batch_size": 65536, "dim": 1024, "standard": 17.47015118598938, "my_standard": 19.857674837112427, "standard_compiled": 17.338842153549194, "sb": 15.992552042007446}
{"repeat": 32, "batch_size": 131072, "dim": 1024, "standard": 34.824438393116, "my_standard": 39.499424397945404, "standard_compiled": 34.56207364797592, "sb": 31.573951244354248}
{"repeat": 32, "batch_size": 16384, "dim": 1280, "standard": 7.342606782913208, "my_standard": 7.9323723912239075, "standard_compiled": 7.279552519321442, "sb": 6.395488977432251}
{"repeat": 32, "batch_size": 32768, "dim": 1280, "standard": 13.69999349117279, "my_standard": 16.0503089427948, "standard_compiled": 13.603456318378448, "sb": 11.813104152679443}
{"repeat": 32, "batch_size": 65536, "dim": 1280, "standard": 29.557034373283386, "my_standard": 34.2303067445755, "standard_compiled": 29.382556676864624, "sb": 22.882774472236633}
{"repeat": 32, "batch_size": 131072, "dim": 1280, "standard": 53.629085421562195, "my_standard": 63.07622790336609, "standard_compiled": 53.33048850297928, "sb": 44.76426541805267}
{"repeat": 32, "batch_size": 16384, "dim": 1408, "standard": 8.81417840719223, "my_standard": 9.477965533733368, "standard_compiled": 8.73943418264389, "sb": 7.479414343833923}
{"repeat": 32, "batch_size": 32768, "dim": 1408, "standard": 16.242466866970062, "my_standard": 17.616644501686096, "standard_compiled": 16.14125818014145, "sb": 13.665586709976196}
{"repeat": 32, "batch_size": 65536, "dim": 1408, "standard": 32.429613173007965, "my_standard": 34.80646014213562, "standard_compiled": 32.319076359272, "sb": 27.123987674713135}
{"repeat": 32, "batch_size": 131072, "dim": 1408, "standard": 62.85770237445831, "my_standard": 67.55391508340836, "standard_compiled": 62.453076243400574, "sb": 51.53566598892212}
{"repeat": 32, "batch_size": 16384, "dim": 1664, "standard": 11.585861444473267, "my_standard": 12.565858662128448, "standard_compiled": 11.504307389259338, "sb": 9.657211601734161}
{"repeat": 32, "batch_size": 32768, "dim": 1664, "standard": 21.261662244796753, "my_standard": 22.771358489990234, "standard_compiled": 21.12410217523575, "sb": 17.64291524887085}
{"repeat": 32, "batch_size": 65536, "dim": 1664, "standard": 42.85307973623276, "my_standard": 45.70870101451874, "standard_compiled": 42.57970303297043, "sb": 34.918561577796936}
{"repeat": 32, "batch_size": 131072, "dim": 1664, "standard": 83.56057852506638, "my_standard": 89.11971747875214, "standard_compiled": 83.05662125349045, "sb": 66.32210314273834}
{"repeat": 32, "batch_size": 16384, "dim": 2048, "standard": 15.7279372215271, "my_standard": 16.854502260684967, "standard_compiled": 15.655294060707092, "sb": 13.228952884674072}
{"repeat": 32, "batch_size": 32768, "dim": 2048, "standard": 30.42648732662201, "my_standard": 32.26502239704132, "standard_compiled": 30.239209532737732, "sb": 24.354808032512665}
{"repeat": 32, "batch_size": 65536, "dim": 2048, "standard": 60.779355466365814, "my_standard": 64.11923468112946, "standard_compiled": 60.89268624782562, "sb": 46.91776633262634}
{"repeat": 32, "batch_size": 131072, "dim": 2048, "standard": 119.93677169084549, "my_standard": 128.19699943065643, "standard_compiled": 120.20225822925568, "sb": 92.3452153801918}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment