Commit a24aae30 authored by Jeongseok Kang's avatar Jeongseok Kang
Browse files

Merge branch 'main' into fix/libcuda-to-torch

parents 2b4cc256 4395d68c
from . import nn
from .autograd._functions import (
switchback_bnb,
matmul_fp8_global,
matmul_fp8_mixed,
)
import operator
import warnings
from dataclasses import dataclass
from functools import reduce # Required in Python 3
import torch
import bitsandbytes.functional as F
from bitsandbytes.autograd._functions import MatmulLtState, GlobalOutlierPooler
# math.prod not compatible with python < 3.8
def prod(iterable):
return reduce(operator.mul, iterable, 1)
tensor = torch.Tensor
class MatMulFP8Mixed(torch.autograd.Function):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@staticmethod
def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024):
# default of pytorch behavior if inputs are empty
ctx.is_empty = False
if prod(A.shape) == 0:
ctx.is_empty = True
ctx.A = A
ctx.B = B
B_shape = B.shape
if A.shape[-1] == B_shape[0]:
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
else:
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
# 1. Dequantize
# 2. MatmulnN
cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=bsz)
fp8A = F.dequantize_blockwise(cA, state, blocksize=bsz).to(A.dtype)
cB, state = F.quantize(B.float(), code=fw_code)
fp8B = F.dequantize(cB, state).to(B.dtype)
output = torch.matmul(fp8A, fp8B)
# output is half
# 3. Save state
ctx.fw_code = fw_code
ctx.bw_code = bw_code
ctx.bsz = bsz
ctx.bsz2 = bsz2
ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype
if any(ctx.needs_input_grad[:2]):
# NOTE: we send back A, and re-quant.
ctx.tensors = (A, fp8B)
else:
ctx.tensors = (None, None)
return output
@staticmethod
def backward(ctx, grad_output):
if ctx.is_empty:
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None, None
req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad
A, B = ctx.tensors
grad_A, grad_B = None, None
# TODO: Fix blocksize to be output_dim
cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=ctx.bsz2)
fp8out = F.dequantize_blockwise(cgrad_out, state, blocksize=ctx.bsz2).to(grad_output.dtype)
# cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code)
# fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype)
# grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
# fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector')
# fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose
# fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2])
# not supported by PyTorch. TODO: create work-around
if req_gradA:
grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype)
if req_gradB:
if len(A.shape) == 3:
At = A.transpose(2, 1).contiguous()
else:
At = A.transpose(1, 0).contiguous()
# cA, state = F.quantize(At.float(), code=ctx.fw_code)
# fp8At = F.dequantize(cA, state).to(A.dtype)
grad_B = torch.matmul(At.to(grad_output.dtype), grad_output).to(B.dtype)
return grad_A, grad_B, None, None, None, None, None
class MatMulFP8Global(torch.autograd.Function):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@staticmethod
def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024):
# default of pytorch behavior if inputs are empty
ctx.is_empty = False
if prod(A.shape) == 0:
ctx.is_empty = True
ctx.A = A
ctx.B = B
B_shape = B.shape
if A.shape[-1] == B_shape[0]:
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
else:
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
# 1. Dequantize
# 2. MatmulnN
cA, state = F.quantize(A.float(), code=fw_code)
fp8A = F.dequantize(cA, state).to(A.dtype)
cB, state = F.quantize(B.float(), code=fw_code)
fp8B = F.dequantize(cB, state).to(B.dtype)
output = torch.matmul(fp8A, fp8B)
# output is half
# 3. Save state
ctx.fw_code = fw_code
ctx.bw_code = bw_code
ctx.bsz = bsz
ctx.bsz2 = bsz2
ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype
if any(ctx.needs_input_grad[:2]):
# NOTE: we send back A, and re-quant.
ctx.tensors = (A, fp8B)
else:
ctx.tensors = (None, None)
return output
@staticmethod
def backward(ctx, grad_output):
if ctx.is_empty:
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None, None
req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad
A, B = ctx.tensors
grad_A, grad_B = None, None
# TODO: Fix blocksize to be output_dim
cgrad_out, state = F.quantize(grad_output.float(), code=ctx.bw_code)
fp8out = F.dequantize(cgrad_out, state).to(grad_output.dtype)
# cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code)
# fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype)
# grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
# fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector')
# fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose
# fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2])
# not supported by PyTorch. TODO: create work-around
if req_gradA:
grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype)
if req_gradB:
if len(A.shape) == 3:
At = A.transpose(2, 1).contiguous()
else:
At = A.transpose(1, 0).contiguous()
cA, state = F.quantize(At.float(), code=ctx.fw_code)
fp8At = F.dequantize(cA, state).to(A.dtype)
grad_B = torch.matmul(fp8At.to(fp8out.dtype), fp8out).to(B.dtype)
return grad_A, grad_B, None, None, None, None, None
class SwitchBackBnb(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
# default to pytorch behavior if inputs are empty
ctx.is_empty = False
if prod(A.shape) == 0:
ctx.is_empty = True
ctx.A = A
ctx.B = B
ctx.bias = bias
if A.shape[-1] == B.shape[0]:
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device)
else:
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device)
# 1. Quantize A
# 2. Quantize B
# 3. Matmul
# 4. Mixed-precision decomposition matmul
# 5. Save state
formatB = state.formatB
input_shape = A.shape
if state.outlier_pool is None:
state.outlier_pool = GlobalOutlierPooler.get_instance()
# Cast A to fp16
if A.dtype != torch.float16:
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
# 1. Quantize A
if len(A.shape) == 3:
A = A.view(-1, A.shape[-1]).contiguous()
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(
A.to(torch.float16), threshold=state.threshold
)
if state.threshold > 0.0 and coo_tensorA is not None:
if state.has_fp16_weights:
idx = torch.unique(coo_tensorA.colidx).long()
CA[:, idx] = 0
CAt[:, idx] = 0
subA = A[:, idx]
state.subB = B[:, idx].t().contiguous()
state.idx = idx
else:
if state.CxB is None:
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
else:
#print('A shape', A.shape)
if not state.has_fp16_weights and state.CxB is None:
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
subA = None
# 2. Quantize B
if state.has_fp16_weights:
#print('B shape', B.shape)
has_grad = True if (getattr(B, "grad", None) is not None) else False
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
if is_transposed:
B = B.contiguous()
if (state.is_training and not has_grad) or state.CxB is None:
state.reset_grads()
(
CB,
state.CBt,
state.SCB,
state.SCBt,
coo_tensorB,
) = F.double_quant(B.to(torch.float16))
state.CxB, state.SB = F.transform(CB, to_order=formatB)
else:
has_grad = False
if coo_tensorA is not None and not state.has_fp16_weights:
# extract outliers
outlier_idx = torch.unique(coo_tensorA.colidx)
state.idx = outlier_idx
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# # do not use pool for 2nd FFN layer
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
# else:
# state.idx = outlier_idx
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
state.subB = (
(outliers * state.SCB.view(-1, 1) / 127.0)
.t()
.contiguous()
.to(A.dtype)
)
CA[:, state.idx.long()] = 0
CAt[:, state.idx.long()] = 0
subA = A[:, state.idx.long()]
shapeB = state.SB[0]
if len(input_shape) == 3:
output_shape = (input_shape[0], input_shape[1], shapeB[0])
else:
output_shape = (input_shape[0], shapeB[0])
# 3. Matmul
C32A, SA = F.transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
# we apply the fused bias here
if bias is None or bias.dtype == torch.float16:
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
output = output.to(A.dtype)
else: # apply bias separately
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
output = output.to(A.dtype).add_(bias)
# 4. Mixed-precision decomposition matmul
if coo_tensorA is not None and subA is not None:
output += torch.matmul(subA, state.subB)
# 5. Save state
ctx.state = state
ctx.formatB = formatB
ctx.grad_shape = input_shape
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
if any(ctx.needs_input_grad[:2]):
ctx.tensors = (CAt, subA, A)
ctx.tensor_states = (SCAt, state.idx)
else:
ctx.tensors = [None, None, None]
ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None)
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
return clone_func(output.view(output_shape))
@staticmethod
def backward(ctx, grad_output):
if ctx.is_empty:
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
CAt, subA, A = ctx.tensors
SCAt, idx = ctx.tensor_states
formatB = ctx.formatB
state = ctx.state
grad_A = grad_B = grad_bias = None
if req_gradBias:
# compute grad_bias first before changing grad_output dtype
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)
# Cast grad_output to fp16
if len(grad_output.shape) == 3:
grad_output = grad_output.reshape(
-1, grad_output.shape[-1]
).contiguous()
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
if req_gradB:
# print('back A shape', A.shape)
# print('grad output t shape', grad_output.t().shape)
grad_B = torch.matmul(grad_output.t(), A)
if req_gradA:
if state.CBt is not None:
C32grad, Sgrad = F.transform(Cgrad, "col32")
if state.CxBt is None:
state.CxBt, state.SBt = F.transform(
state.CBt, to_order=formatB, transpose=True
)
# print('back B shape', state.CxBt.shape)
# print('back grad shape', C32grad.shape)
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
elif state.CB is not None:
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0))
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
else:
raise Exception('State must contain either CBt or CB matrix for backward')
return grad_A, grad_B, None, grad_bias, None
def get_block_sizes(input_matrix, weight_matrix):
input_features = input_matrix.shape[-1]
output_features = (weight_matrix.shape[0] if weight_matrix.shape[1] == input_features else weight_matrix.shape[1])
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
bsz, bsz2 = 1024, 1024
for i, k in enumerate(array):
if input_features > array[i + 1]:
bsz = k
break
for i, k in enumerate(array):
if output_features > array[i + 1]:
bsz2 = k
break
return bsz, bsz2
def matmul_fp8_global(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1):
if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B)
return MatMulFP8Global.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
def matmul_fp8_mixed(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1):
if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B)
return MatMulFP8Mixed.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
def switchback_bnb(
A: tensor,
B: tensor,
out: tensor = None,
state: MatmulLtState = None,
threshold=0.0,
bias=None
):
state = state or MatmulLtState()
if threshold > 0.0:
state.threshold = threshold
return SwitchBackBnb.apply(A, B, out, bias, state)
from .modules import LinearFP8Mixed, LinearFP8Global
from typing import Optional, TypeVar, Union, overload
import torch
import torch.nn.functional as F
from torch import Tensor, device, dtype, nn
import bitsandbytes as bnb
from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import OutlierTracer, find_outlier_dims
T = TypeVar("T", bound="torch.nn.Module")
class LinearFP8Mixed(nn.Linear):
def __init__(self, input_features, output_features, bias=True):
super().__init__(input_features, output_features, bias)
self.bw_code = None
self.fw_code = None
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
for i, k in enumerate(array):
if input_features > array[i + 1]:
self.bsz = k
break
for i, k in enumerate(array):
if output_features > array[i + 1]:
self.bsz2 = k
break
def forward(self, x: torch.Tensor):
if self.fw_code is None:
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
out = bnb.research.matmul_fp8_mixed(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2)
if self.bias is not None:
out += self.bias
return out
class LinearFP8Global(nn.Linear):
def __init__(self, input_features, output_features, bias=True):
super().__init__(input_features, output_features, bias)
self.bw_code = None
self.fw_code = None
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
for i, k in enumerate(array):
if input_features > array[i + 1]:
self.bsz = k
break
for i, k in enumerate(array):
if output_features > array[i + 1]:
self.bsz2 = k
break
def forward(self, x: torch.Tensor):
if self.fw_code is None:
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
out = bnb.matmul_fp8_global(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2)
if self.bias is not None:
out += self.bias
return out
import math
import torch
import time
from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available():
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): return None
else:
import triton
import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# rowwise quantize
# TODO: autotune this better.
@triton.autotune(
configs=[
triton.Config({}, num_stages=1, num_warps=8),
triton.Config({}, num_stages=2, num_warps=8),
triton.Config({}, num_stages=4, num_warps=8),
triton.Config({}, num_stages=8, num_warps=8),
triton.Config({}, num_stages=1),
triton.Config({}, num_stages=2),
triton.Config({}, num_stages=4),
triton.Config({}, num_stages=8),
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=['n_elements']
)
@triton.jit
def _dequantize_rowwise(
x_ptr,
state_x,
output_ptr,
inv_127,
n_elements,
BLOCK_SIZE: tl.constexpr,
P2: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
arange = tl.arange(0, P2)
offsets = block_start + arange
row_mask = arange < BLOCK_SIZE
x = tl.load(x_ptr + offsets, mask=row_mask)
max_val = tl.load(state_x + pid)
output = max_val * x * inv_127
tl.store(output_ptr + offsets, output, mask=row_mask)
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
output = torch.empty(*x.shape, device=x.device, dtype=torch.float16)
P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
assert x.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (x.shape[0],)
_dequantize_rowwise[grid](x, state_x, output, 1./127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
return output
import torch
from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available():
def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias): return None
else:
import triton
import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# This is a matmul kernel based on triton.ops.matmul
# It is modified to support rowwise quantized input and global quantized weight
# It's purpose is fused matmul then dequantize
# It does support bias.
def init_to_zero(name):
return lambda nargs: nargs[name].zero_()
def get_configs_io_bound():
configs = []
for num_stages in [2, 3, 4, 5, 6]:
for block_m in [16, 32]:
for block_k in [32, 64]:
for block_n in [32, 64, 128, 256]:
num_warps = 2 if block_n <= 64 else 4
configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
num_stages=num_stages, num_warps=num_warps))
# split_k
for split_k in [2, 4, 8, 16]:
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
return configs
@triton.autotune(
configs=[
# basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
# good for int8
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
] + get_configs_io_bound(),
key=['M', 'N', 'K'],
prune_configs_by={
'early_config_prune': early_config_prune,
'perf_model': estimate_matmul_time,
'top_k': 10
},
)
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.jit
def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr
):
# matrix multiplication
pid = tl.program_id(0)
pid_z = tl.program_id(1)
grid_m = tl.cdiv(M, BLOCK_M)
grid_n = tl.cdiv(N, BLOCK_N)
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
# pointers
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
w_factor = tl.load(state_w_ptr)
x_factor = tl.load(state_x_ptr + ram)[:, None]
# acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
k_remaining = K - k * (BLOCK_K * SPLIT_K)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
acc += tl.dot(a, b)
A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk
acc = (w_factor * (x_factor * (acc * divfactor)))
acc = acc.to(C.dtype.element_ty)
# conditionally add bias
if has_bias:
bias = tl.load(bias + rn).to(C.dtype.element_ty)
acc = acc + bias[None, :]
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(C, acc, mask=mask)
else:
tl.atomic_add(C, acc, mask=mask)
def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias):
device = a.device
divfactor = 1. / (127. * 127.)
has_bias = 0 if bias is None else 1
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
M, K = a.shape
_, N = b.shape
# allocates output
c = torch.empty((M, N), device=device, dtype=torch.float16)
# accumulator types
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch int8_matmul_mixed_dequantize kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_int8_matmul_mixed_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
GROUP_M=8, ACC_TYPE=ACC_TYPE)
return c
import torch
from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available():
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): return None
else:
import triton
import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# This is a matmul kernel based on triton.ops.matmul
# It is modified to support rowwise quantized input and columnwise quantized weight
# It's purpose is fused matmul then dequantize
# It does support bias.
def init_to_zero(name):
return lambda nargs: nargs[name].zero_()
def get_configs_io_bound():
configs = []
for num_stages in [2, 3, 4, 5, 6]:
for block_m in [16, 32]:
for block_k in [32, 64]:
for block_n in [32, 64, 128, 256]:
num_warps = 2 if block_n <= 64 else 4
configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
num_stages=num_stages, num_warps=num_warps))
# split_k
for split_k in [2, 4, 8, 16]:
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
return configs
@triton.autotune(
configs=[
# basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
# good for int8
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
] + get_configs_io_bound(),
key=['M', 'N', 'K'],
prune_configs_by={
'early_config_prune': early_config_prune,
'perf_model': estimate_matmul_time,
'top_k': 10
},
)
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.jit
def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr
):
# matrix multiplication
pid = tl.program_id(0)
pid_z = tl.program_id(1)
grid_m = tl.cdiv(M, BLOCK_M)
grid_n = tl.cdiv(N, BLOCK_N)
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
# pointers
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
w_factor = tl.load(state_w_ptr + rbn)[None, :]
x_factor = tl.load(state_x_ptr + ram)[:, None]
# acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
k_remaining = K - k * (BLOCK_K * SPLIT_K)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
acc += tl.dot(a, b)
A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk
acc = (w_factor * (x_factor * (acc * divfactor)))
acc = acc.to(C.dtype.element_ty)
if has_bias:
bias = tl.load(bias + rn).to(C.dtype.element_ty)
acc = acc + bias[None, :]
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(C, acc, mask=mask)
else:
tl.atomic_add(C, acc, mask=mask)
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):
divfactor = 1. / (127. * 127.)
has_bias = 0 if bias is None else 1
device = a.device
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
M, K = a.shape
_, N = b.shape
# allocates output
c = torch.empty((M, N), device=device, dtype=torch.float16)
# accumulator types
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch int8_matmul_rowwise_dequantize kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_int8_matmul_rowwise_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
GROUP_M=8, ACC_TYPE=ACC_TYPE)
return c
import math
import torch
import time
from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available():
def quantize_columnwise_and_transpose(x: torch.Tensor): return None
else:
import triton
import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# This kernel does fused columnwise quantization and transpose.
# TODO: autotune this better.
@triton.autotune(
configs=[
triton.Config({}, num_stages=1),
triton.Config({}, num_stages=2),
triton.Config({}, num_stages=4),
triton.Config({}, num_stages=8),
triton.Config({}, num_stages=16),
triton.Config({}, num_stages=1, num_warps=8),
triton.Config({}, num_stages=2, num_warps=8),
triton.Config({}, num_stages=4, num_warps=8),
triton.Config({}, num_stages=8, num_warps=8),
triton.Config({}, num_stages=16, num_warps=8),
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=['n_elements']
)
@triton.jit
def _quantize_columnwise_and_transpose(
x_ptr,
output_ptr,
output_maxs,
n_elements,
M : tl.constexpr, N : tl.constexpr,
BLOCK_SIZE: tl.constexpr,
P2: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid
p2_arange = tl.arange(0, P2)
p2_arange_mask = p2_arange < M
arange = p2_arange * N
offsets = block_start + arange
x = tl.load(x_ptr + offsets, mask=p2_arange_mask)
abs_x = tl.abs(x)
max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0)
output = tl.libdevice.llrint(127. * (x / max_val))
new_start = pid * M
new_offsets = new_start + p2_arange
tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask)
tl.store(output_maxs + pid, max_val)
def quantize_columnwise_and_transpose(x: torch.Tensor):
M, N = x.shape
output = torch.empty(N, M, device=x.device, dtype=torch.int8)
output_maxs = torch.empty(x.shape[1], device=x.device, dtype=torch.float16)
P2 = int(2 ** (math.ceil(math.log2(M))))
assert x.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2)
return output, output_maxs
import math
import torch
import time
from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available():
def quantize_global_transpose(input): return None
def quantize_global(x: torch.Tensor): return None
else:
import triton
import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# global quantize
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4),
triton.Config({'BLOCK_SIZE': 2048,}, num_stages=1),
],
key=['n_elements']
)
@triton.jit
def _quantize_global(
x_ptr,
absmax_inv_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
absmax_inv = tl.load(absmax_inv_ptr)
output = tl.libdevice.llrint(127. * (x * absmax_inv))
tl.store(output_ptr + offsets, output, mask=mask)
def quantize_global(x: torch.Tensor):
absmax = x.abs().max().unsqueeze(0)
absmax_inv = 1./ absmax
output = torch.empty(*x.shape, device='cuda', dtype=torch.int8)
assert x.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_quantize_global[grid](x, absmax_inv, output, n_elements)
return output, absmax
# global quantize and transpose
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),
# ...
],
key=['M', 'N']
)
@triton.jit
def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, stride_bn, stride_bm, M, N,
BLOCK_M : tl.constexpr,
BLOCK_N : tl.constexpr,
GROUP_M : tl.constexpr):
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // group_size
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
A = A + (rm[:, None] * stride_am + rn[None, :] * stride_an)
mask = (rm < M)[:, None] & (rn < N)[None, :]
a = tl.load(A, mask=mask)
absmax_inv = tl.load(absmax_inv_ptr)
# rematerialize to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
output = tl.libdevice.llrint(127. * (a * absmax_inv))
tl.store(B, output, mask=mask)
def quantize_global_transpose(input):
absmax = input.abs().max().unsqueeze(0)
absmax_inv = 1./ absmax
M, N = input.shape
out = torch.empty(N, M, device='cuda', dtype=torch.int8)
assert out.size(0) == N and out.size(1) == M
assert input.stride(0) == 1 or input.stride(1) == 1
assert out.stride(0) == 1 or out.stride(1) == 1
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
_quantize_global_transpose[grid](input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N)
return out, absmax
import math
import torch
import time
from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available():
def quantize_rowwise(x: torch.Tensor): return None
else:
import triton
import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# rowwise quantize
# TODO: autotune this better.
@triton.autotune(
configs=[
triton.Config({}, num_stages=1, num_warps=8),
triton.Config({}, num_stages=2, num_warps=8),
triton.Config({}, num_stages=4, num_warps=8),
triton.Config({}, num_stages=8, num_warps=8),
triton.Config({}, num_stages=1),
triton.Config({}, num_stages=2),
triton.Config({}, num_stages=4),
triton.Config({}, num_stages=8),
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=['n_elements']
)
@triton.jit
def _quantize_rowwise(
x_ptr,
output_ptr,
output_maxs,
n_elements,
BLOCK_SIZE: tl.constexpr,
P2: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
arange = tl.arange(0, P2)
offsets = block_start + arange
row_mask = arange < BLOCK_SIZE
x = tl.load(x_ptr + offsets, mask=row_mask)
abs_x = tl.abs(x)
max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0)
output = tl.libdevice.llrint(127. * (x / max_val))
tl.store(output_ptr + offsets, output, mask=row_mask)
tl.store(output_maxs + pid, max_val)
def quantize_rowwise(x: torch.Tensor):
output = torch.empty(*x.shape, device=x.device, dtype=torch.int8)
output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16)
P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
assert x.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (x.shape[0],)
_quantize_rowwise[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
return output, output_maxs
import importlib
def is_triton_available():
return importlib.util.find_spec("triton") is not None
import shlex import shlex
import subprocess import subprocess
import torch
from typing import Tuple from typing import Tuple
def outlier_hook(module, input):
assert isinstance(module, torch.nn.Linear)
tracer = OutlierTracer.get_instance()
hvalue = tracer.get_hvalue(module.weight)
if hvalue not in tracer.hvalue2outlier_idx:
outlier_idx = find_outlier_dims(module.weight)
tracer.outliers.append(outlier_idx)
tracer.hvalues.append(hvalue)
if len(tracer.outliers) > 1:
# assign the current layer the outlier idx found from the weight
# of the previous linear layer
if tracer.outliers[-1].numel() > 0:
assert tracer.outliers[-1].max() < module.weight.shape[1]
tracer.hvalue2outlier_idx[hvalue] = tracer.outliers[-1]
else:
# first layer, we cannot use the weight for outlier detection
# we follow a mixed approach:
# (1) zscore test of std of hidden dimension
# (2) magnitude > 6 test
merged = input[0].view(-1, input[0].shape[-1])
# (1) zscore test of std of hidden dimension
outlier_idx = find_outlier_dims(merged, reduction_dim=1, zscore=3)
# (2) magnitude > 6 test
dims = (torch.abs(input[0])> 6).sum(dim=list(range(len(input[0].shape)-1)))
outlier_idx2 = torch.where(dims > 0)[0]
outlier_idx = torch.cat([outlier_idx, outlier_idx2]).unique()
tracer.hvalue2outlier_idx[hvalue] = outlier_idx
else:
for hook in tracer.hooks:
hook.remove()
class OutlierTracer(object):
_instance = None
def __init__(self):
raise RuntimeError("Call get_instance() instead")
def initialize(self, model):
self.last_w = None
self.current_outlier_dims = None
self.hvalues = []
self.outliers = []
self.hvalue2outlier_idx = {}
self.initialized = True
self.hooks = []
for n, m in model.named_modules():
if isinstance(m, torch.nn.Linear):
self.hooks.append(m.register_forward_pre_hook(outlier_hook))
def is_initialized(self):
return getattr(self, 'initialized', False)
def get_hvalue(self, weight):
return weight.data.storage().data_ptr()
def get_outliers(self, weight):
if not self.is_initialized():
print('Outlier tracer is not initialized...')
return None
hvalue = self.get_hvalue(weight)
if hvalue in self.hvalue2outlier_idx:
return self.hvalue2outlier_idx[hvalue]
else:
return None
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls.__new__(cls)
return cls._instance
def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False):
if rdm:
return torch.randint(0, weight.shape[1], size=(topk,), device=weight.device).long()
m = weight.mean(reduction_dim)
mm = m.mean()
mstd = m.std()
zm = (m-mm)/mstd
std = weight.std(reduction_dim)
stdm = std.mean()
stdstd = std.std()
zstd = (std-stdm)/stdstd
if topk is not None:
val, idx = torch.topk(std.abs(), k=topk, dim=0)
else:
idx = torch.where(zstd > zscore)[0]
return idx
def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_weights=False, post_processing_function=None):
"""
Replace linear modules with a new Linear module.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
linear_replacement (`torch.nn.Module`):
The linear module that replaces the old one. Only expects standard arguments.
If other arguments need to be passed, use a lambda.
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
List of modules names not to convert. Defaults to `lm_head`.
copy_weights (`bool`):
Copy the weights from the old linear module to the new one
post_processing_fun_name (`str`):
A function name of the replacement linear class that is called
after processing.
"""
for name, module in model.named_children():
if len(list(module.children())) > 0:
replace_linear(module, linear_replacement, skip_modules, copy_weights, post_processing_function)
if isinstance(module, torch.nn.Linear) and name not in skip_modules:
old_module = model._modules[name]
model._modules[name] = linear_replacement(
module.in_features,
module.out_features,
module.bias is not None,
)
if copy_weights:
model._modules[name].weight = old_module.weight
model._modules[name].bias = old_module.bias
if post_processing_function is not None:
func = getattr(module, post_processing_function, None)
if func is not None: func(module)
return model
def execute_and_return(command_string: str) -> Tuple[str, str]: def execute_and_return(command_string: str) -> Tuple[str, str]:
def _decode(subprocess_err_out_tuple): def _decode(subprocess_err_out_tuple):
...@@ -21,3 +157,43 @@ def execute_and_return(command_string: str) -> Tuple[str, str]: ...@@ -21,3 +157,43 @@ def execute_and_return(command_string: str) -> Tuple[str, str]:
std_out, std_err = execute_and_return_decoded_std_streams(command_string) std_out, std_err = execute_and_return_decoded_std_streams(command_string)
return std_out, std_err return std_out, std_err
def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_weights=False, post_processing_function=None):
"""
Replace linear modules with a new Linear module.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
linear_replacement (`torch.nn.Module`):
The linear module that replaces the old one. Only expects standard arguments.
If other arguments need to be passed, use a lambda.
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
List of modules names not to convert. Defaults to `lm_head`.
copy_weights (`bool`):
Copy the weights from the old linear module to the new one
post_processing_fun_name (`str`):
A function name of the replacement linear class that is called
after processing.
"""
for name, module in model.named_children():
if len(list(module.children())) > 0:
replace_linear(module, linear_replacement, skip_modules, copy_weights, post_processing_function)
if isinstance(module, torch.nn.Linear) and name not in skip_modules:
old_module = model._modules[name]
model._modules[name] = linear_replacement(
module.in_features,
module.out_features,
module.bias is not None,
)
if copy_weights:
model._modules[name].weight = old_module.weight
model._modules[name].bias = old_module.bias
if post_processing_function is not None:
func = getattr(module, post_processing_function, None)
if func is not None: func(module)
return model
...@@ -33,3 +33,8 @@ You can set `CUDA_HOME` to `/usr/local/cuda-11.7`. For example, you might be abl ...@@ -33,3 +33,8 @@ You can set `CUDA_HOME` to `/usr/local/cuda-11.7`. For example, you might be abl
If you have problems compiling the library with these instructions from source, please open an issue. If you have problems compiling the library with these instructions from source, please open an issue.
## Compilation with Kepler
Since 0.39.1 bitsandbytes installed via pip no longer provides Kepler binaries and these need to be compiled from source. Follow the steps above and instead of `cuda11x_nomatmul` etc use `cuda11x_nomatmul_kepler`
This diff is collapsed.
...@@ -9,13 +9,15 @@ ...@@ -9,13 +9,15 @@
#ifndef kernels #ifndef kernels
#define kernels #define kernels
//template <int QUANT_TYPE, typename INP_TYPE, typename COMP_TYPE, typename OUT_TYPE>__global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB);
template<typename T>__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n); template<typename T>__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n);
__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n); __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n);
__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n); __global__ void kDequantize(float *code, unsigned char *A, float *out, const int n);
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC> __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TYPE> __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH> __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n); template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE> __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n);
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS> template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPreconditionOptimizer32bit2State(T* g, T* p, __global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
...@@ -120,4 +122,9 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T ...@@ -120,4 +122,9 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc);
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
template <typename T, int FUNC> __global__ void kfunc(T *A, T *B, T value, long n);
#endif #endif
...@@ -50,54 +50,53 @@ void dequantize(float *code, unsigned char *A, float *out, int n) ...@@ -50,54 +50,53 @@ void dequantize(float *code, unsigned char *A, float *out, int n)
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n) template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n)
{ {
int num_blocks = n/blocksize; int num_blocks = n/blocksize;
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
if(STOCHASTIC == 1)
assert(blocksize == 4096);
if(blocksize == 4096) if(blocksize == 4096)
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n); kQuantizeBlockwise<T, 4096, 4, STOCHASTIC, 0><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 2048) else if(blocksize == 2048)
kQuantizeBlockwise<T, 2048, 4, 0><<<num_blocks, 512>>>(code, A, absmax, out, rand, rand_offset, n); kQuantizeBlockwise<T, 2048, 4, 0, DATA_TYPE><<<num_blocks, 512>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 1024) else if(blocksize == 1024)
kQuantizeBlockwise<T, 1024, 4, 0><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n); kQuantizeBlockwise<T, 1024, 4, 0, DATA_TYPE><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 512) else if(blocksize == 512)
kQuantizeBlockwise<T, 512, 2, 0><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n); kQuantizeBlockwise<T, 512, 2, 0, DATA_TYPE><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 256) else if(blocksize == 256)
kQuantizeBlockwise<T, 256, 2, 0><<<num_blocks, 128>>>(code, A, absmax, out, rand, rand_offset, n); kQuantizeBlockwise<T, 256, 2, 0, DATA_TYPE><<<num_blocks, 128>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 128) else if(blocksize == 128)
kQuantizeBlockwise<T, 128, 2, 0><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n); kQuantizeBlockwise<T, 128, 2, 0, DATA_TYPE><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 64) else if(blocksize == 64)
kQuantizeBlockwise<T, 64, 1, 0><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n); kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n)
{ {
int num_blocks = n/blocksize; int num_blocks = n/blocksize;
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
if(blocksize == 4096) int tile_size = (DATA_TYPE > 0) ? 1024 : 512;
kDequantizeBlockwise<T, 4096, 1024, 4><<<num_blocks, 4096/4>>>(code, A, absmax, out, n);
else if(blocksize == 2048) if(DATA_TYPE > 0)
kDequantizeBlockwise<T, 2048, 512, 4><<<num_blocks, 2048/4>>>(code, A, absmax, out, n); kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE><<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize/2, n);
else if(blocksize == 1024) else
kDequantizeBlockwise<T, 1024, 256, 4><<<num_blocks, 1024/4>>>(code, A, absmax, out, n); kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE><<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize, n);
else if(blocksize == 512)
kDequantizeBlockwise<T, 512, 256, 2><<<num_blocks, 512/2>>>(code, A, absmax, out, n);
else if(blocksize == 256)
kDequantizeBlockwise<T, 256, 128, 2><<<num_blocks, 256/2>>>(code, A, absmax, out, n);
else if(blocksize == 128)
kDequantizeBlockwise<T, 128, 64, 2><<<num_blocks, 128/2>>>(code, A, absmax, out, n);
else if(blocksize == 64)
kDequantizeBlockwise<T, 64, 64, 1><<<num_blocks, 64/1>>>(code, A, absmax, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
//void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB)
//{
// int num_blocks = (colsB+32-1)/32;
// kMatmul_inference_4bit<NF4, half, half, half><<<num_blocks, 256>>>(A, B, out, lda, ldb, rowsA, colsA, colsB);
// CUDA_CHECK_RETURN(cudaPeekAtLastError());
//}
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay, const float beta1, const float beta2, const float eps, const float weight_decay,
...@@ -683,10 +682,73 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id ...@@ -683,10 +682,73 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits)
{
int num_blocks = (m+31)/32;
//cout << num_blocks << endl;
//cout << lda << endl;
//cout << ldb << endl;
//cout << ldc << endl;
//cout << m << endl;
//cout << n << endl;
//cout << k << endl;
//if(bits == 32)
//gemm_device<T, 32, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
//gemm_device<T, 32, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
if(bits == 16)
//gemm_device<T, 16, 256><<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
gemm_device<T, 16, 160><<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
//gemm_device<T, 16, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
//gemm_device<T, 16, 96><<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
//gemm_device<T, 16, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
//gemm_device<T, 16, 64><<< num_blocks, 64, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
}
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
{
int num_blocks = (m+31)/32;
//cout << num_blocks << endl;
//cout << lda << endl;
//cout << ldb << endl;
//cout << ldc << endl;
//cout << m << endl;
//cout << n << endl;
//cout << k << endl;
kgemm_4bit_inference<T, 160><<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
//kgemm_4bit_inference<T, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
}
template <typename T, int FUNC> void func(T *A, T *B, T value, long n)
{
int threads = 512;
int blocks = n/threads;
blocks = n % threads == 0 ? blocks : blocks + 1;
blocks = blocks > 65535 ? 65535 : blocks;
kfunc<T, FUNC><<<blocks, 512>>>(A, B, value, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
//============================================================== //==============================================================
// TEMPLATE DEFINITIONS // TEMPLATE DEFINITIONS
//============================================================== //==============================================================
template void func<float, FILL>(float *A, float *B, float value, long n);
template void func<unsigned char, FILL>(unsigned char *A, unsigned char *B, unsigned char value, long n);
template void func<float, ARANGE>(float *A, float *B, float value, long n);
template void func<float, _MUL>(float *A, float *B, float value, long n);
template void gemm_4bit_inference<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits);
template void extractOutliers<COL_TURING>(char * A, int *idx, char *out, int idx_size, int rows, int cols); template void extractOutliers<COL_TURING>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
template void extractOutliers<COL_AMPERE>(char * A, int *idx, char *out, int idx_size, int rows, int cols); template void extractOutliers<COL_AMPERE>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
...@@ -710,12 +772,20 @@ template void transformRowToFormat<COL_AMPERE, 1>(char * A, char *out, int rows, ...@@ -710,12 +772,20 @@ template void transformRowToFormat<COL_AMPERE, 1>(char * A, char *out, int rows,
template void estimateQuantiles(half *A, float *code, float offset, int n); template void estimateQuantiles(half *A, float *code, float offset, int n);
template void estimateQuantiles(float *A, float *code, float offset, int n); template void estimateQuantiles(float *A, float *code, float offset, int n);
template void quantizeBlockwise<half, 0>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void quantizeBlockwise<half, 1, General8bit>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<float, 0>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void quantizeBlockwise<float, 1, General8bit>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<half, 1>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void quantizeBlockwise<half, 0, General8bit>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<float, 1>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void quantizeBlockwise<float, 0, General8bit>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void dequantizeBlockwise<half>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); template void quantizeBlockwise<half, 0, FP4>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void dequantizeBlockwise<float>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); template void quantizeBlockwise<float, 0, FP4>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<half, 0, NF4>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<float, 0, NF4>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void dequantizeBlockwise<half, General8bit>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
template void dequantizeBlockwise<float, General8bit>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
template void dequantizeBlockwise<half, FP4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
template void dequantizeBlockwise<float, FP4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
template void dequantizeBlockwise<half, NF4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
template void dequantizeBlockwise<float, NF4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
#define MAKE_optimizer32bit(name, gtype) \ #define MAKE_optimizer32bit(name, gtype) \
template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \ template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
...@@ -725,12 +795,14 @@ template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \ ...@@ -725,12 +795,14 @@ template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
MAKE_optimizer32bit(ADAM, half) MAKE_optimizer32bit(ADAM, half)
MAKE_optimizer32bit(ADAM, float) MAKE_optimizer32bit(ADAM, float)
MAKE_optimizer32bit(ADAM, __nv_bfloat16)
MAKE_optimizer32bit(MOMENTUM, half) MAKE_optimizer32bit(MOMENTUM, half)
MAKE_optimizer32bit(MOMENTUM, float) MAKE_optimizer32bit(MOMENTUM, float)
MAKE_optimizer32bit(RMSPROP, half) MAKE_optimizer32bit(RMSPROP, half)
MAKE_optimizer32bit(RMSPROP, float) MAKE_optimizer32bit(RMSPROP, float)
MAKE_optimizer32bit(LION, half) MAKE_optimizer32bit(LION, half)
MAKE_optimizer32bit(LION, float) MAKE_optimizer32bit(LION, float)
MAKE_optimizer32bit(LION, __nv_bfloat16)
MAKE_optimizer32bit(ADAGRAD, half) MAKE_optimizer32bit(ADAGRAD, half)
MAKE_optimizer32bit(ADAGRAD, float) MAKE_optimizer32bit(ADAGRAD, float)
...@@ -766,8 +838,11 @@ MAKE_optimizerStatic8bitBlockwise(half, RMSPROP); ...@@ -766,8 +838,11 @@ MAKE_optimizerStatic8bitBlockwise(half, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); MAKE_optimizerStatic8bitBlockwise(float, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(half, LION); MAKE_optimizerStatic8bitBlockwise(half, LION);
MAKE_optimizerStatic8bitBlockwise(float, LION); MAKE_optimizerStatic8bitBlockwise(float, LION);
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, LION);
MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD); MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD);
MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD);
template void percentileClipping(float * g, float *gnorm_vec, int step, const int n); template void percentileClipping(float * g, float *gnorm_vec, int step, const int n);
template void percentileClipping(half * g, float *gnorm_vec, int step, const int n); template void percentileClipping(half * g, float *gnorm_vec, int step, const int n);
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAM);
...@@ -20,6 +20,11 @@ ...@@ -20,6 +20,11 @@
#include <vector> #include <vector>
#include <functional> #include <functional>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#define CUDA_CHECK_RETURN(value) { \ #define CUDA_CHECK_RETURN(value) { \
cudaError_t _m_cudaStat = value; \ cudaError_t _m_cudaStat = value; \
if (_m_cudaStat != cudaSuccess) { \ if (_m_cudaStat != cudaSuccess) { \
...@@ -82,6 +87,20 @@ typedef enum Transform_t ...@@ -82,6 +87,20 @@ typedef enum Transform_t
COL_AMPERE = 4, COL_AMPERE = 4,
} Transform_t; } Transform_t;
typedef enum DataType_t
{
General8bit = 0,
FP4 = 1,
NF4 = 2,
} DataType_t;
typedef enum Funcs_t
{
FILL = 0,
ARANGE = 1,
_MUL = 2,
} Funcs_t;
class Context class Context
{ {
public: public:
...@@ -129,8 +148,8 @@ template <typename T> void estimateQuantiles(T *A, float *code, float offset, in ...@@ -129,8 +148,8 @@ template <typename T> void estimateQuantiles(T *A, float *code, float offset, in
void quantize(float *code, float *A, unsigned char *out, int n); void quantize(float *code, float *A, unsigned char *out, int n);
void dequantize(float *code, unsigned char *A, float *out, int n); void dequantize(float *code, unsigned char *A, float *out, int n);
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n); template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n);
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
...@@ -177,4 +196,11 @@ template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, ...@@ -177,4 +196,11 @@ template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count,
template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols);
void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB);
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits);
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
template <typename T, int FUNC> void func(T *A, T *B, T value, long n);
#endif #endif
...@@ -20,8 +20,25 @@ void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimat ...@@ -20,8 +20,25 @@ void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimat
void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles<half>(A, code, offset, n); } void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles<half>(A, code, offset, n); }
//void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc)
//{ gemm_host<float>(M, N, K, A, B, out, lda, ldb, ldc, 32); }
void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc)
{ gemm_host<half>(M, N, K, A, B, out, lda, ldb, ldc, 16); }
void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
{ gemm_4bit_inference<half>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func<ctype, FUNC>(A, B, value, n); } \
MAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL)
MAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL)
MAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE)
MAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL)
#define MAKE_FUNC32(fname, oname, gtype, gbits) \ #define MAKE_FUNC32(fname, oname, gtype, gbits) \
void fname##32bit_g##gbits(gtype *g, gtype *p, \ void fname##32bit_grad_##gbits(gtype *g, gtype *p, \
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
const float beta1, const float beta2, const float eps, const float weight_decay, \ const float beta1, const float beta2, const float eps, const float weight_decay, \
const int step, const float lr, float gnorm_scale, bool skip_zeros, const int n) \ const int step, const float lr, float gnorm_scale, bool skip_zeros, const int n) \
...@@ -29,17 +46,19 @@ void fname##32bit_g##gbits(gtype *g, gtype *p, \ ...@@ -29,17 +46,19 @@ void fname##32bit_g##gbits(gtype *g, gtype *p, \
MAKE_FUNC32(momentum, MOMENTUM, float, 32) MAKE_FUNC32(momentum, MOMENTUM, float, 32)
MAKE_FUNC32(momentum, MOMENTUM, half, 16) MAKE_FUNC32(momentum, MOMENTUM, half, 16)
MAKE_FUNC32(adam, ADAM, float, 32) MAKE_FUNC32(adam, ADAM, float, fp32)
MAKE_FUNC32(adam, ADAM, half, 16) MAKE_FUNC32(adam, ADAM, half, fp16)
MAKE_FUNC32(adam, ADAM, __nv_bfloat16, bf16)
MAKE_FUNC32(rmsprop, RMSPROP, float, 32) MAKE_FUNC32(rmsprop, RMSPROP, float, 32)
MAKE_FUNC32(rmsprop, RMSPROP, half, 16) MAKE_FUNC32(rmsprop, RMSPROP, half, 16)
MAKE_FUNC32(lion, LION, float, 32) MAKE_FUNC32(lion, LION, float, fp32)
MAKE_FUNC32(lion, LION, half, 16) MAKE_FUNC32(lion, LION, half, fp16)
MAKE_FUNC32(lion, LION, __nv_bfloat16, bf16)
MAKE_FUNC32(adagrad, ADAGRAD, float, 32) MAKE_FUNC32(adagrad, ADAGRAD, float, 32)
MAKE_FUNC32(adagrad, ADAGRAD, half, 16) MAKE_FUNC32(adagrad, ADAGRAD, half, 16)
#define MAKE_FUNC8(fname, oname, gtype, gbits) \ #define MAKE_FUNC8(fname, oname, gtype, gbits) \
void fname##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ void fname##_static_8bit_grad_##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
float *unorm, float max_unorm, float param_norm, \ float *unorm, float max_unorm, float param_norm, \
float beta1, float beta2, \ float beta1, float beta2, \
float eps, int step, float lr, \ float eps, int step, float lr, \
...@@ -61,33 +80,42 @@ MAKE_FUNC8(lion, LION, float, 32) ...@@ -61,33 +80,42 @@ MAKE_FUNC8(lion, LION, float, 32)
MAKE_FUNC8(lion, LION, half, 16) MAKE_FUNC8(lion, LION, half, 16)
#define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \ #define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \
void fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \ void fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)\ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)\
{ optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\ { optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\
MAKE_BLOCKWISE8(adam, ADAM, half, 16) MAKE_BLOCKWISE8(adam, ADAM, half, fp16)
MAKE_BLOCKWISE8(adam, ADAM, float, 32) MAKE_BLOCKWISE8(adam, ADAM, float, fp32)
MAKE_BLOCKWISE8(momentum, MOMENTUM, half, 16) MAKE_BLOCKWISE8(momentum, MOMENTUM, half, fp16)
MAKE_BLOCKWISE8(momentum, MOMENTUM, float, 32) MAKE_BLOCKWISE8(momentum, MOMENTUM, float, fp32)
MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, 16) MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, fp16)
MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, 32) MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, fp32)
MAKE_BLOCKWISE8(lion, LION, half, 16) MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, fp16)
MAKE_BLOCKWISE8(lion, LION, float, 32) MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, fp32)
MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, 16) MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32) MAKE_BLOCKWISE8(lion, LION, half, fp16)
MAKE_BLOCKWISE8(lion, LION, float, fp32)
MAKE_BLOCKWISE8(lion, LION, __nv_bfloat16, bf16)
void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); } void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); }
void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping<half>(g, gnorm_vec, step, n); } void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping<half>(g, gnorm_vec, step, n); }
void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0>(code, A, absmax, out, NULL, 0, blocksize, n); } void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0>(code, A, absmax, out, NULL, 0, blocksize, n); } void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<half, 1>(code, A, absmax, out, rand, rand_offset, 4096, n); } void quantizeBlockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<float, 1>(code, A, absmax, out, rand, rand_offset, 4096, n); } void quantizeBlockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half, General8bit>(code, A, absmax, out, blocksize, n); } \
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, General8bit>(code, A, absmax, out, blocksize, n); }
void dequantizeBlockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half, FP4>(NULL, A, absmax, out, blocksize, n); } \
void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, FP4>(NULL, A, absmax, out, blocksize, n); }
void dequantizeBlockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half, NF4>(NULL, A, absmax, out, blocksize, n); } \
void dequantizeBlockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, NF4>(NULL, A, absmax, out, blocksize, n); }
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half>(code, A, absmax, out, blocksize, n); } \
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float>(code, A, absmax, out, blocksize, n); }
#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ #define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \ void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \
...@@ -148,32 +176,41 @@ extern "C" ...@@ -148,32 +176,41 @@ extern "C"
void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); } void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); }
void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
void cquantize_blockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp16(code, A, absmax, out, rand, rand_offset, n); }
void cquantize_blockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp32(code, A, absmax, out, rand, rand_offset, n); }
void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
void cquantize_blockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); }
void cquantize_blockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); }
void cdequantize_blockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); }
void cdequantize_blockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); }
void cquantize_blockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); }
void cquantize_blockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); }
void cdequantize_blockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); }
void cdequantize_blockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); }
#define MAKE_CFUNC32(name, gtype, gbits) \ #define MAKE_CFUNC32(name, gtype, gbits) \
void c##name##32bit_g##gbits(gtype *g, gtype *p, \ void c##name##32bit_grad_##gbits(gtype *g, gtype *p, \
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
const float beta1, const float beta2, const float eps, const float weight_decay, \ const float beta1, const float beta2, const float eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \ const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \
{ name##32bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ { name##32bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
MAKE_CFUNC32(adam, float, 32) MAKE_CFUNC32(adam, float, fp32)
MAKE_CFUNC32(adam, half, 16) MAKE_CFUNC32(adam, half, fp16)
MAKE_CFUNC32(adam, __nv_bfloat16, bf16)
MAKE_CFUNC32(momentum, float, 32) MAKE_CFUNC32(momentum, float, 32)
MAKE_CFUNC32(momentum, half, 16) MAKE_CFUNC32(momentum, half, 16)
MAKE_CFUNC32(rmsprop, float, 32) MAKE_CFUNC32(rmsprop, float, 32)
MAKE_CFUNC32(rmsprop, half, 16) MAKE_CFUNC32(rmsprop, half, 16)
MAKE_CFUNC32(lion, float, 32) MAKE_CFUNC32(lion, float, fp32)
MAKE_CFUNC32(lion, half, 16) MAKE_CFUNC32(lion, half, fp16)
MAKE_CFUNC32(lion, __nv_bfloat16, bf16)
MAKE_CFUNC32(adagrad, float, 32) MAKE_CFUNC32(adagrad, float, 32)
MAKE_CFUNC32(adagrad, half, 16) MAKE_CFUNC32(adagrad, half, 16)
#define MAKE_CFUNC8(name, gtype, gbits) \ #define MAKE_CFUNC8(name, gtype, gbits) \
void c##name##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ void c##name##_static_8bit_grad_##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
float *unorm, float max_unorm, float param_norm, \ float *unorm, float max_unorm, float param_norm, \
float beta1, float beta2, \ float beta1, float beta2, \
float eps, int step, float lr, \ float eps, int step, float lr, \
...@@ -181,7 +218,7 @@ extern "C" ...@@ -181,7 +218,7 @@ extern "C"
float* max1, float* max2, float* new_max1, float* new_max2, \ float* max1, float* max2, float* new_max1, float* new_max2, \
float weight_decay, float gnorm_scale, int n) \ float weight_decay, float gnorm_scale, int n) \
{ \ { \
name##_static_8bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \ name##_static_8bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \ quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \
} \ } \
...@@ -195,22 +232,23 @@ extern "C" ...@@ -195,22 +232,23 @@ extern "C"
MAKE_CFUNC8(lion, half, 16) MAKE_CFUNC8(lion, half, 16)
#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \ #define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \
void c##fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \ void c##fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \
{ fname##_8bit_blockwise_fp##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \ { fname##_8bit_blockwise_grad_##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \
MAKE_CBLOCKWISE8(adam, ADAM, half, 16) MAKE_CBLOCKWISE8(adam, ADAM, half, fp16)
MAKE_CBLOCKWISE8(adam, ADAM, float, 32) MAKE_CBLOCKWISE8(adam, ADAM, float, fp32)
MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, 16) MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, fp16)
MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, 32) MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, fp32)
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, 16) MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, fp16)
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, 32) MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, fp32)
MAKE_CBLOCKWISE8(lion, LION, half, 16) MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, fp16)
MAKE_CBLOCKWISE8(lion, LION, float, 32) MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, fp32)
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, 16) MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, 32) MAKE_CBLOCKWISE8(lion, LION, half, fp16)
MAKE_CBLOCKWISE8(lion, LION, float, fp32)
MAKE_CBLOCKWISE8(lion, LION, __nv_bfloat16, bf16)
void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); } void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); }
void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); } void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); }
...@@ -298,6 +336,38 @@ extern "C" ...@@ -298,6 +336,38 @@ extern "C"
void cextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_turing(A, idx, out, idx_size, rows, cols); } void cextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_turing(A, idx, out, idx_size, rows, cols); }
void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); } void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); }
//void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc)
//{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); }
void cgemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc)
{ gemm_host_fp16(M, N, K, A, B, out, lda, ldb, ldc); }
void cgemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
{ gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
void *cget_managed_ptr(size_t bytes)
{
void *ptr;
CUDA_CHECK_RETURN(cudaMallocManaged(&ptr, bytes, cudaMemAttachHost));
CUDA_CHECK_RETURN(cudaPeekAtLastError());
return ptr;
}
void cprefetch(void *ptr, size_t bytes, int device)
{
CUDA_CHECK_RETURN(cudaMemPrefetchAsync(ptr, bytes, device, 0));
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
#define CMAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
void c##fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ fname##_##type_name(A, B, value, n); } \
CMAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL)
CMAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL)
CMAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE)
CMAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL)
#endif #endif
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); }
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment