Commit 5a4263f4 authored by Ruff's avatar Ruff Committed by Aarni Koskela
Browse files

Reformat with ruff-format

parent 02e30ca6
......@@ -7,9 +7,7 @@ def get_platform_tag(architecture):
system = platform.system()
if system == "Linux":
tag = (
"manylinux_2_24_x86_64" if architecture == "x86_64" else "manylinux_2_24_aarch64"
)
tag = "manylinux_2_24_x86_64" if architecture == "x86_64" else "manylinux_2_24_aarch64"
elif system == "Darwin":
tag = "macosx_13_1_x86_64" if architecture == "x86_64" else "macosx_13_1_arm64"
elif system == "Windows":
......
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import pandas as pd
cmap=plt.get_cmap('cool')
if __name__ == '__main__':
cmap = plt.get_cmap("cool")
fig = plt.figure(tight_layout=True, figsize=(12,3.5))
if __name__ == "__main__":
fig = plt.figure(tight_layout=True, figsize=(12, 3.5))
gs = gridspec.GridSpec(1, 2)
dims_to_consider = [1024, 1280, 1408, 1664, 2048, 4096]
......@@ -19,25 +17,28 @@ if __name__ == '__main__':
ax = fig.add_subplot(gs[0, 0])
# TODO: change this to what you want.
rdf = pd.read_json('speed_benchmark/info_a100_py2.jsonl', lines=True)
rdf = pd.read_json("speed_benchmark/info_a100_py2.jsonl", lines=True)
df = rdf[rdf.batch_size == batch_size_for_plot1]
# first plot the time occupied by different operations
for k, marker, ls, color, name in [
('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (sum of parts)'),
('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (sum of parts)'),
('standard_fwd', '^', '--', 'C2', 'Matmul XW (standard)'),
('standard_gw', '^', '-.', 'C2', 'Matmul GW (standard)'),
('standard_gx', '^', ':', 'gray', 'Matmul GX (both)'),
('global_fwd', '^', '--', 'C4', 'Int8 Matmul XW (switchback)'),
('global_bwd', '^', '-.', 'C4', 'Int8 Matmul GW (switchback)'),
('x_quantize_rowwise', 'P', '--', 'C4', 'Quantize rowwise X (switchback)'),
('g_quantize_rowwise', 'P', '-.', 'C4', 'Quantize rowwise G (switchback)'),
('w_quantize_global', '.', '--', 'C4', 'Quantize global W (switchback)'),
('w_quantize_global_transpose', '.', '-.', 'C4', 'Quantize global and\ntranspose W (switchback)'),
("standard_gx+standard_gw+standard_fwd", "s", "-", "C2", "Standard fp16 (sum of parts)"),
(
"x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd",
"o",
"-",
"C4",
"SwitchBack int8 (sum of parts)",
),
("standard_fwd", "^", "--", "C2", "Matmul XW (standard)"),
("standard_gw", "^", "-.", "C2", "Matmul GW (standard)"),
("standard_gx", "^", ":", "gray", "Matmul GX (both)"),
("global_fwd", "^", "--", "C4", "Int8 Matmul XW (switchback)"),
("global_bwd", "^", "-.", "C4", "Int8 Matmul GW (switchback)"),
("x_quantize_rowwise", "P", "--", "C4", "Quantize rowwise X (switchback)"),
("g_quantize_rowwise", "P", "-.", "C4", "Quantize rowwise G (switchback)"),
("w_quantize_global", ".", "--", "C4", "Quantize global W (switchback)"),
("w_quantize_global_transpose", ".", "-.", "C4", "Quantize global and\ntranspose W (switchback)"),
]:
xs = []
ys = []
......@@ -47,40 +48,46 @@ if __name__ == '__main__':
df_ = df_[df_.dim_out == embed_dim * 4]
xs.append(embed_dim)
y_ = 0
for k_ in k.split('+'):
for k_ in k.split("+"):
y_ += df_[k_].values[0]
df_ = df[df.dim_in == embed_dim * 4]
df_ = df_[df_.dim_out == embed_dim]
for k_ in k.split('+'):
for k_ in k.split("+"):
y_ += df_[k_].values[0]
ys.append(y_ * 0.5)
ax.plot(
xs,
ys,
color=color,
label=name,
marker=marker,
markersize=5 if marker == "s" else 5,
linestyle=ls,
linewidth=2 if "+" in k else 1.0,
)
ax.plot(xs, ys, color=color, label=name, marker=marker, markersize=5 if marker=='s' else 5, linestyle=ls, linewidth=2 if '+' in k else 1.)
ax.set_xlabel('dim', fontsize=13)
ax.set_ylabel('time (ms)', fontsize=13)
ax.set_xlabel("dim", fontsize=13)
ax.set_ylabel("time (ms)", fontsize=13)
ax.grid()
ax.set_xscale('log')
ax.set_xscale("log")
if logscale_plot1:
ax.set_yscale('log')
ax.set_yscale("log")
ax.tick_params(axis='x', labelsize=11)
ax.tick_params(axis='y', labelsize=11)
ax.tick_params(axis="x", labelsize=11)
ax.tick_params(axis="y", labelsize=11)
ax.set_xticks(dims_to_xtick)
ax.set_xticklabels(dims_to_xtick)
ax.set_xticks([], minor=True)
leg = ax.legend(loc='upper center', bbox_to_anchor=(-0.64, 1.), ncol=1, fontsize=10)
leg.get_texts()[0].set_fontweight('bold')
leg.get_texts()[1].set_fontweight('bold')
leg = ax.legend(loc="upper center", bbox_to_anchor=(-0.64, 1.0), ncol=1, fontsize=10)
leg.get_texts()[0].set_fontweight("bold")
leg.get_texts()[1].set_fontweight("bold")
plt.subplots_adjust(left=0.1)
ax.set_title(' Linear layer, batch * sequence length = 32k', fontsize=10, loc='left', y=1.05, pad=-20)
ax.set_title(" Linear layer, batch * sequence length = 32k", fontsize=10, loc="left", y=1.05, pad=-20)
ax = fig.add_subplot(gs[0, 1])
......@@ -88,10 +95,15 @@ if __name__ == '__main__':
for j, batch_size in enumerate(batch_sizes_for_plot2):
all_xs, all_ys = [], []
for k, marker, ls, color, name in [
('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (total time)'),
('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (total time)'),
("standard_gx+standard_gw+standard_fwd", "s", "-", "C2", "Standard fp16 (total time)"),
(
"x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd",
"o",
"-",
"C4",
"SwitchBack int8 (total time)",
),
]:
xs, ys = [], []
df = rdf[rdf.batch_size == batch_size]
for embed_dim in dims_to_consider:
......@@ -99,11 +111,11 @@ if __name__ == '__main__':
df_ = df_[df_.dim_out == embed_dim * 4]
xs.append(embed_dim)
y_ = 0
for k_ in k.split('+'):
for k_ in k.split("+"):
y_ += df_[k_].values[0]
df_ = df[df.dim_in == embed_dim * 4]
df_ = df_[df_.dim_out == embed_dim]
for k_ in k.split('+'):
for k_ in k.split("+"):
y_ += df_[k_].values[0]
ys.append(y_ * 0.5)
all_xs.append(xs)
......@@ -111,25 +123,29 @@ if __name__ == '__main__':
color = cmap(j * 0.25)
real_ys = [-((all_ys[1][i] - all_ys[0][i]) / all_ys[0][i]) * 100 for i in range(len(all_ys[0]))]
markers = ['^', 'v', 'P', 'o']
ax.plot(all_xs[0], real_ys, color=color, label=f'batch * sequence length = {batch_size}', marker=markers[j], markersize=5 if marker=='s' else 5)
markers = ["^", "v", "P", "o"]
ax.plot(
all_xs[0],
real_ys,
color=color,
label=f"batch * sequence length = {batch_size}",
marker=markers[j],
markersize=5 if marker == "s" else 5,
)
ax.legend()
ax.set_xlabel('dim', fontsize=13)
ax.set_xscale('log')
ax.set_xlabel("dim", fontsize=13)
ax.set_xscale("log")
ax.grid()
ax.set_ylabel(r'% speedup', fontsize=13)
ax.set_ylabel(r"% speedup", fontsize=13)
ax.tick_params(axis='x', labelsize=11)
ax.tick_params(axis='y', labelsize=11)
ax.tick_params(axis="x", labelsize=11)
ax.tick_params(axis="y", labelsize=11)
ax.set_xticks(dims_to_xtick)
ax.set_xticklabels(dims_to_xtick)
ax.set_xticks([], minor=True)
ax.set_title(' Linear layer summary, varying dimensions', fontsize=10, loc='left', y=1.05, pad=-20)
ax.set_title(" Linear layer summary, varying dimensions", fontsize=10, loc="left", y=1.05, pad=-20)
plt.savefig('speed_benchmark/plot_with_info.pdf', bbox_inches='tight')
plt.savefig("speed_benchmark/plot_with_info.pdf", bbox_inches="tight")
......@@ -20,8 +20,8 @@ from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large.
def get_time(k, fn, info_dict):
def get_time(k, fn, info_dict):
for _ in range(repeat // 2):
fn()
......@@ -36,16 +36,15 @@ def get_time(k, fn, info_dict):
print(f"time {k}: {ms:.3f} ms")
info_dict[k] = ms
if __name__ == '__main__':
if __name__ == "__main__":
torch.manual_seed(0)
wm = 4
for dim in [1024, 1280, 1408, 1664, 2048, 4096]:
# note "batch_size" is actually "batch_size * embed_dim", which is why it's large
for batch_size in [256*32, 256*64, 256*128, 256*256, 256*512]:
for batch_size in [256 * 32, 256 * 64, 256 * 128, 256 * 256, 256 * 512]:
# switch switches dim_in and dim_out
for switch in [False, True]:
# hparams
repeat = 64
batch_size = batch_size
......@@ -73,35 +72,86 @@ if __name__ == '__main__':
state_w_rowwise = w.max(dim=1)[0]
state_w_global = w.max()
info = {'repeat' : repeat, 'batch_size' : batch_size, 'dim_out' : dim_out, 'dim_in' : dim_in, 'wm' : wm, 'switch' : switch}
get_time('standard_fwd', lambda : x.matmul(w.t()), info)
get_time('standard_gw', lambda : g.t().matmul(x), info)
get_time('standard_gx', lambda : g.matmul(w), info)
get_time('rowwise_fwd', lambda : int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise, None), info)
get_time('rowwise_bwd', lambda : int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise, None), info)
get_time('global_fwd', lambda : int8_matmul_mixed_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('global_bwd', lambda : int8_matmul_mixed_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('x_quantize_rowwise', lambda : quantize_rowwise(x), info)
get_time('g_quantize_rowwise', lambda : quantize_rowwise(g), info)
get_time('w_quantize_rowwise', lambda : quantize_rowwise(w), info)
get_time('w_quantize_colwise_transpose', lambda : quantize_columnwise_and_transpose(w), info)
get_time('w_quantize_global', lambda : quantize_global(w), info)
get_time('w_quantize_global_transpose', lambda : quantize_global_transpose(w), info)
time_standard = info['standard_fwd'] + info['standard_gx'] + info['standard_gw']
time_rowwise = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_colwise_transpose'] + info['w_quantize_rowwise'] + info['standard_gw'] + info['rowwise_fwd'] + info['rowwise_bwd']
time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd']
print('TOTAL STANDARD', time_standard)
print('TOTAL ROWWISE', time_rowwise)
print('TOTAL GLOBAL', time_global)
print('speedup', -100*(time_global - time_standard)/time_standard)
info['time_standard'] = time_standard
info['time_rowwise'] = time_rowwise
info['time_global'] = time_global
info = {
"repeat": repeat,
"batch_size": batch_size,
"dim_out": dim_out,
"dim_in": dim_in,
"wm": wm,
"switch": switch,
}
get_time("standard_fwd", lambda: x.matmul(w.t()), info)
get_time("standard_gw", lambda: g.t().matmul(x), info)
get_time("standard_gx", lambda: g.matmul(w), info)
get_time(
"rowwise_fwd",
lambda: int8_matmul_rowwise_dequantize(
x_int8,
w_int8.t(),
state_x_rowwise,
state_w_columnwise,
None,
),
info,
)
get_time(
"rowwise_bwd",
lambda: int8_matmul_rowwise_dequantize(
g_int8,
wt_int8.t(),
state_x_rowwise,
state_w_rowwise,
None,
),
info,
)
get_time(
"global_fwd",
lambda: int8_matmul_mixed_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None),
info,
)
get_time(
"global_bwd",
lambda: int8_matmul_mixed_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None),
info,
)
get_time("x_quantize_rowwise", lambda: quantize_rowwise(x), info)
get_time("g_quantize_rowwise", lambda: quantize_rowwise(g), info)
get_time("w_quantize_rowwise", lambda: quantize_rowwise(w), info)
get_time("w_quantize_colwise_transpose", lambda: quantize_columnwise_and_transpose(w), info)
get_time("w_quantize_global", lambda: quantize_global(w), info)
get_time("w_quantize_global_transpose", lambda: quantize_global_transpose(w), info)
time_standard = info["standard_fwd"] + info["standard_gx"] + info["standard_gw"]
time_rowwise = (
info["x_quantize_rowwise"]
+ info["g_quantize_rowwise"]
+ info["w_quantize_colwise_transpose"]
+ info["w_quantize_rowwise"]
+ info["standard_gw"]
+ info["rowwise_fwd"]
+ info["rowwise_bwd"]
)
time_global = (
info["x_quantize_rowwise"]
+ info["g_quantize_rowwise"]
+ info["w_quantize_global"]
+ info["w_quantize_global_transpose"]
+ info["standard_gw"]
+ info["global_fwd"]
+ info["global_bwd"]
)
print("TOTAL STANDARD", time_standard)
print("TOTAL ROWWISE", time_rowwise)
print("TOTAL GLOBAL", time_global)
print("speedup", -100 * (time_global - time_standard) / time_standard)
info["time_standard"] = time_standard
info["time_rowwise"] = time_rowwise
info["time_global"] = time_global
info_json = json.dumps(info)
......
......@@ -14,16 +14,18 @@ import bitsandbytes.functional as F
def prod(iterable):
return reduce(operator.mul, iterable, 1)
# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
"""
This class pools outlier dimensions across layers.
This is particularly important for small models where outlier features
are less systematic and occur with low frequency.
"""
class GlobalOutlierPooler:
_instance = None
......@@ -83,6 +85,7 @@ def get_inverse_transform_indices(
break # if all indices fit in i bytes, stop early
return permuted_tile_indices
def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor:
"""
Undo a tiled permutation such as turing or ampere layout
......@@ -159,20 +162,12 @@ class MatMul8bit(torch.autograd.Function):
)
if not A.is_contiguous():
A = A.contiguous()
qA, S2 = F.vectorwise_quant(
A.view(-1, A.shape[2]), dim=0, quant_type=quant_type
)
qA, S2 = F.vectorwise_quant(A.view(-1, A.shape[2]), dim=0, quant_type=quant_type)
igrad_B = F.igemm(qA.t(), qgrad_output)
grad_B = F.vectorwise_mm_dequant(
igrad_B, S2.t(), S1, grad_output.dtype, quant_type
)
grad_B = F.vectorwise_mm_dequant(igrad_B, S2.t(), S1, grad_output.dtype, quant_type)
else:
qgrad_output, S1 = F.vectorwise_quant(
grad_output, dim=dims, quant_type=quant_type
)
qA, S2 = F.vectorwise_quant(
A, dim=dims, quant_type=quant_type
)
qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type)
igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
grad_B = F.vectorwise_mm_dequant(
igrad_B,
......@@ -201,9 +196,7 @@ class MatMul8bit(torch.autograd.Function):
with torch.no_grad():
grad_A = torch.matmul(grad_output, B.permute(permute_dim))
else:
qgrad_output, S1 = F.vectorwise_quant(
grad_output, dim=dims, quant_type=quant_type
)
qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
grad_A = F.vectorwise_mm_dequant(
......@@ -227,7 +220,7 @@ def supports_igemmlt(device: torch.device) -> bool:
if torch.cuda.get_device_capability(device=device) < (7, 5):
return False
device_name = torch.cuda.get_device_name(device=device)
nvidia16_models = ('GTX 1630', 'GTX 1650', 'GTX 1660') # https://en.wikipedia.org/wiki/GeForce_16_series
nvidia16_models = ("GTX 1630", "GTX 1650", "GTX 1660") # https://en.wikipedia.org/wiki/GeForce_16_series
if any(model_name in device_name for model_name in nvidia16_models):
return False # these devices are technically cuda 7.5-capable, but they lack tensor cores
return True
......@@ -246,6 +239,7 @@ def get_tile_inds(format, device):
with torch.no_grad():
return get_inverse_transform_indices(transform, _get_tile_size(format)).to(device)
@dataclass
class MatmulLtState:
_tile_indices: Optional[torch.Tensor] = None
......@@ -510,7 +504,6 @@ class MatMul4Bit(torch.autograd.Function):
else:
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
# 1. Dequantize
# 2. MatmulnN
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
......@@ -532,7 +525,7 @@ class MatMul4Bit(torch.autograd.Function):
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, _, _, req_gradBias, _= ctx.needs_input_grad
req_gradA, _, _, req_gradBias, _ = ctx.needs_input_grad
A, B = ctx.tensors
grad_A, grad_B, grad_bias = None, None, None
......@@ -542,8 +535,9 @@ class MatMul4Bit(torch.autograd.Function):
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)
# not supported by PyTorch. TODO: create work-around
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t())
# if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
if req_gradA:
grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t())
return grad_A, grad_B, None, grad_bias, None
......@@ -554,7 +548,7 @@ def matmul(
out: Optional[torch.Tensor] = None,
state: Optional[MatmulLtState] = None,
threshold=0.0,
bias=None
bias=None,
):
state = state or MatmulLtState()
if threshold > 0.0:
......@@ -562,11 +556,19 @@ def matmul(
return MatMul8bitLt.apply(A, B, out, bias, state)
def matmul_4bit(A: torch.Tensor, B: torch.Tensor, quant_state: F.QuantState, out: Optional[torch.Tensor] = None, bias=None):
def matmul_4bit(
A: torch.Tensor,
B: torch.Tensor,
quant_state: F.QuantState,
out: Optional[torch.Tensor] = None,
bias=None,
):
assert quant_state is not None
if A.numel() == A.shape[-1] and A.requires_grad == False:
if A.shape[-1] % quant_state.blocksize != 0:
warn(f'Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}')
warn(
f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",
)
return MatMul4Bit.apply(A, B, out, bias, quant_state)
else:
out = F.gemv_4bit(A, B.t(), out, state=quant_state)
......
......@@ -56,7 +56,7 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
"This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n"
"If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n"
"If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n"
"For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<path_to_cuda_dir/lib64\n"
"For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<path_to_cuda_dir/lib64\n",
)
return PACKAGE_DIR / library_name
......@@ -100,7 +100,7 @@ def get_native_library() -> BNBNativeLibrary:
logger.warning(
"The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable."
"8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.",
)
return BNBNativeLibrary(dll)
......@@ -120,5 +120,5 @@ python -m bitsandbytes
Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them
to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes
and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues
"""
""",
)
......@@ -120,7 +120,7 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
The CUDA version for the compile might depend on your conda install, if using conda.
Inspect CUDA version via `conda list | grep cuda`.
"""
""",
)
cuda_major, cuda_minor = cuda_specs.cuda_version_tuple
......@@ -129,7 +129,7 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
"""
WARNING: CUDA versions lower than 11 are currently not supported for LLM.int8().
You will be only to use 8-bit optimizers and quantization routines!
"""
""",
)
print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}")
......@@ -170,7 +170,7 @@ def print_cuda_runtime_diagnostics() -> None:
In the case of a manual override, make sure you set LD_LIBRARY_PATH, e.g.
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2,
"""
""",
)
for pth in cudart_paths:
print(f"* Found CUDA runtime at: {pth}")
......@@ -25,7 +25,7 @@ def sanity_check():
See the documentation for more details if needed.
Trying a simple check anyway, but this will likely fail...
"""
""",
)
from bitsandbytes.optim import Adam
......@@ -71,7 +71,7 @@ def main():
print(
f"WARNING: {__package__} is currently running as CPU-only!\n"
"Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n"
f"If you think that this is so erroneously,\nplease report an issue!"
f"If you think that this is so erroneously,\nplease report an issue!",
)
except Exception:
traceback.print_exc()
......@@ -80,6 +80,6 @@ def main():
Above we output some debug information.
Please provide this info when creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose
WARNING: Please be sure to sanitize sensitive info from the output before posting it.
"""
""",
)
sys.exit(1)
......@@ -21,6 +21,7 @@ from .cextension import lib
def prod(iterable):
return reduce(operator.mul, iterable, 1)
name2qmap = {}
if lib and lib.compiled_with_cuda:
......@@ -127,7 +128,6 @@ class GlobalPageManager:
prefetch_tensor(t, to_cpu)
class CUBLAS_Context:
_instance = None
......@@ -169,6 +169,7 @@ class Cusparse_Context:
cls._instance.initialize()
return cls._instance
dtype2bytes = {}
dtype2bytes[torch.float32] = 4
dtype2bytes[torch.float16] = 2
......@@ -176,10 +177,11 @@ dtype2bytes[torch.bfloat16] = 2
dtype2bytes[torch.uint8] = 1
dtype2bytes[torch.int8] = 1
FIRST_CUDA_DEVICE = torch.device('cuda', index=0)
FIRST_CUDA_DEVICE = torch.device("cuda", index=0)
def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE):
num_bytes = dtype2bytes[dtype]*prod(shape)
num_bytes = dtype2bytes[dtype] * prod(shape)
cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes))
c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int))
new_array = np.ctypeslib.as_array(c_ptr, shape=shape)
......@@ -188,31 +190,35 @@ def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE):
out.page_deviceid = device.index
return out
def prefetch_tensor(A, to_cpu=False):
assert A.is_paged, 'Only paged tensors can be prefetched!'
assert A.is_paged, "Only paged tensors can be prefetched!"
if to_cpu:
deviceid = -1
else:
deviceid = A.page_deviceid
num_bytes = dtype2bytes[A.dtype]*A.numel()
num_bytes = dtype2bytes[A.dtype] * A.numel()
lib.cprefetch(get_ptr(A), ct.c_size_t(num_bytes), ct.c_int32(deviceid))
def elementwise_func(func_name, A, B, value, prefetch=True):
func = None
if A.dtype == torch.float32:
func = getattr(lib, f'c{func_name}_fp32', None)
func = getattr(lib, f"c{func_name}_fp32", None)
cvalue = ct.c_float(value)
elif A.dtype == torch.uint8:
func = getattr(lib, f'c{func_name}_uint8', None)
func = getattr(lib, f"c{func_name}_uint8", None)
cvalue = ct.c_uint8(value)
if func is None: raise NotImplementedError(f'Function not implemented: {func_name}')
if func is None:
raise NotImplementedError(f"Function not implemented: {func_name}")
is_managed = getattr(A, 'is_managed', False)
is_managed = getattr(A, "is_managed", False)
if is_managed and prefetch:
prefetch_tensor(A)
if B is not None: prefetch_tensor(B)
if B is not None:
prefetch_tensor(B)
func(get_ptr(A), get_ptr(B), cvalue, ct.c_int64(A.numel()))
if A.is_paged or B.is_paged:
......@@ -222,28 +228,36 @@ def elementwise_func(func_name, A, B, value, prefetch=True):
# operation occurred. So we synchronize.
torch.cuda.synchronize()
def fill(A, value, device=None, prefetch=True): elementwise_func('fill', A, None, value)
def arange(A, device=None): elementwise_func('arange', A, None, 0)
def _mul(A, B, device=None): elementwise_func('_mul', A, B, 0)
def fill(A, value, device=None, prefetch=True):
elementwise_func("fill", A, None, value)
def arange(A, device=None):
elementwise_func("arange", A, None, 0)
def _mul(A, B, device=None):
elementwise_func("_mul", A, B, 0)
def create_linear_map(signed=True, total_bits=8, add_zero=True):
sign = (-1.0 if signed else 0.0)
sign = -1.0 if signed else 0.0
total_values = 2**total_bits
if add_zero or total_bits < 8:
# add a zero
# since we simulate less bits by having zeros in the data type, we
# we need to center the quantization around zero and as such lose
# a single value
total_values = (2**total_bits if not signed else 2**total_bits-1)
total_values = 2**total_bits if not signed else 2**total_bits - 1
values = torch.linspace(sign, 1.0, total_values)
gap = 256 - values.numel()
if gap == 0:
return values
else:
l = values.numel()//2 # noqa: E741
return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist())
l = values.numel() // 2 # noqa: E741
return torch.Tensor(values[:l].tolist() + [0] * gap + values[l:].tolist())
def create_normal_map(offset=0.9677083, use_extra_value=True):
......@@ -251,18 +265,17 @@ def create_normal_map(offset=0.9677083, use_extra_value=True):
from scipy.stats import norm
except ImportError as ie:
raise ImportError(
"Scipy is required for `create_normal_map`. "
"Install `bitsandbytes` with the `[test]` extra."
"Scipy is required for `create_normal_map`. Install `bitsandbytes` with the `[test]` extra.",
) from ie
if use_extra_value:
# one more positive value, this is an asymmetric type
v1 = norm.ppf(torch.linspace(offset, 0.5, 9)[:-1]).tolist()
v2 = [0]*(256-15) ## we have 15 non-zero values in this data type
v2 = [0] * (256 - 15) ## we have 15 non-zero values in this data type
v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist()
else:
v1 = norm.ppf(torch.linspace(offset, 0.5, 8)[:-1]).tolist()
v2 = [0]*(256-14) ## we have 14 non-zero values in this data type
v2 = [0] * (256 - 14) ## we have 14 non-zero values in this data type
v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist()
v = v1 + v2 + v3
......@@ -275,38 +288,37 @@ def create_normal_map(offset=0.9677083, use_extra_value=True):
return values
def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8):
e = exponent_bits
p = precision_bits
has_sign = 1 if signed else 0
assert e+p == total_bits-has_sign
assert e + p == total_bits - has_sign
# the exponent is biased to 2^(e-1) -1 == 0
evalues = []
pvalues = []
for i, val in enumerate(range(-(2**(exponent_bits-has_sign)), 2**(exponent_bits-has_sign), 1)):
for i, val in enumerate(range(-(2 ** (exponent_bits - has_sign)), 2 ** (exponent_bits - has_sign), 1)):
evalues.append(2**val)
values = []
lst = list(itertools.product([0, 1], repeat=precision_bits))
#for ev in evalues:
bias = 2**(exponent_bits-1)
for evalue in range(2**(exponent_bits)):
# for ev in evalues:
bias = 2 ** (exponent_bits - 1)
for evalue in range(2 ** (exponent_bits)):
for bit_pattern in lst:
value = (1 if evalue != 0 else 0)
value = 1 if evalue != 0 else 0
for i, pval in enumerate(list(bit_pattern)):
value += pval*(2**-(i+1))
value += pval * (2 ** -(i + 1))
if evalue == 0:
# subnormals
value = value*2**-(bias)
value = value * 2**-(bias)
else:
# normals
value = value*2**-(evalue-bias-1)
value = value * 2 ** -(evalue - bias - 1)
values.append(value)
if signed:
values.append(-value)
assert len(values) == 2**total_bits
values.sort()
if total_bits < 8:
......@@ -320,7 +332,6 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
return code
def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
"""
Creates the dynamic quantiztion map.
......@@ -345,7 +356,11 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
non_sign_bits = total_bits - (1 if signed else 1)
additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1
for i in range(max_exponent_bits):
fraction_items = int(2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1)
fraction_items = int(
2 ** (i + non_sign_bits - max_exponent_bits) + 1
if signed
else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1,
)
boundaries = torch.linspace(0.1, 1, fraction_items)
means = (boundaries[:-1] + boundaries[1:]) / 2.0
data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
......@@ -371,8 +386,9 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
data.sort()
return Tensor(data)
def create_quantile_map(A, total_bits=8):
q = estimate_quantiles(A, num_quantiles=2**total_bits-1)
q = estimate_quantiles(A, num_quantiles=2**total_bits - 1)
q = q.tolist()
q.append(0)
......@@ -383,11 +399,13 @@ def create_quantile_map(A, total_bits=8):
q.sort()
q = Tensor(q)
q = q/q.abs().max()
q = q / q.abs().max()
return q
def get_special_format_str():
if not torch.cuda.is_available(): return 'col_turing'
if not torch.cuda.is_available():
return "col_turing"
major, _minor = torch.cuda.get_device_capability()
if major <= 7:
return "col_turing"
......@@ -396,20 +414,24 @@ def get_special_format_str():
return "col_turing"
def is_on_gpu(tensors):
on_gpu = True
gpu_ids = set()
for t in tensors:
if t is None: continue # NULL pointers are fine
is_paged = getattr(t, 'is_paged', False)
on_gpu &= (t.device.type == 'cuda' or is_paged)
if t is None:
continue # NULL pointers are fine
is_paged = getattr(t, "is_paged", False)
on_gpu &= t.device.type == "cuda" or is_paged
if not is_paged:
gpu_ids.add(t.device.index)
if not on_gpu:
raise TypeError(f'All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}')
raise TypeError(
f"All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}",
)
if len(gpu_ids) > 1:
raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}')
raise TypeError(
f"Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}",
)
return on_gpu
......@@ -447,15 +469,13 @@ def get_transform_func(dtype, orderA, orderOut, transpose=False):
if not hasattr(lib, name):
print(name)
raise ValueError(
f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}"
f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}",
)
else:
return getattr(lib, name)
def get_transform_buffer(
shape, dtype, device, to_order, from_order="row", transpose=False
):
def get_transform_buffer(shape, dtype, device, to_order, from_order="row", transpose=False):
# init_func = torch.empty
init_func = torch.zeros
dims = len(shape)
......@@ -508,9 +528,7 @@ def nvidia_transform(
else:
from_order = state[1]
if out is None:
out, new_state = get_transform_buffer(
state[0], A.dtype, A.device, to_order, state[1]
)
out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1])
else:
new_state = (state[1], to_order)
func = get_transform_func(A.dtype, from_order, to_order, transpose)
......@@ -534,8 +552,13 @@ def nvidia_transform(
return out, new_state
def estimate_quantiles(A: Tensor, out: Optional[torch.Tensor] = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor:
'''
def estimate_quantiles(
A: Tensor,
out: Optional[torch.Tensor] = None,
offset: float = 1 / 512,
num_quantiles=256,
) -> Tensor:
"""
Estimates 256 equidistant quantiles on the input tensor eCDF.
Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles
......@@ -562,14 +585,21 @@ def estimate_quantiles(A: Tensor, out: Optional[torch.Tensor] = None, offset: fl
-------
torch.Tensor:
The 256 quantiles in float32 datatype.
'''
if A.numel() < 256: raise NotImplementedError(f'Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.')
if num_quantiles > 256: raise NotImplementedError(f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}")
if num_quantiles < 256 and offset == 1/(512):
"""
if A.numel() < 256:
raise NotImplementedError(
f"Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.",
)
if num_quantiles > 256:
raise NotImplementedError(
f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}",
)
if num_quantiles < 256 and offset == 1 / (512):
# override default arguments
offset = 1/(2*num_quantiles)
offset = 1 / (2 * num_quantiles)
if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device)
if out is None:
out = torch.zeros((256,), dtype=torch.float32, device=A.device)
is_on_gpu([A, out])
device = pre_call(A.device)
if A.dtype == torch.float32:
......@@ -581,7 +611,7 @@ def estimate_quantiles(A: Tensor, out: Optional[torch.Tensor] = None, offset: fl
post_call(device)
if num_quantiles < 256:
step = round(256/num_quantiles)
step = round(256 / num_quantiles)
idx = torch.linspace(0, 255, num_quantiles).long().to(A.device)
out = out[idx]
......@@ -590,12 +620,35 @@ def estimate_quantiles(A: Tensor, out: Optional[torch.Tensor] = None, offset: fl
class QuantState:
"""container for quantization state components to work with Params4bit and similar classes"""
valid_quant_types = ('fp4', 'nf4')
valid_qs_type_keys = [f"bitsandbytes__{x}" for x in valid_quant_types]
valid_qs_keys = ['absmax', 'quant_map', 'nested_absmax', 'nested_quant_map', 'quant_state', 'quant_type',
'blocksize', 'dtype', 'shape', 'nested_blocksize', 'nested_dtype', 'nested_offset']
def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None):
valid_quant_types = ("fp4", "nf4")
valid_qs_type_keys = [f"bitsandbytes__{x}" for x in valid_quant_types]
valid_qs_keys = [
"absmax",
"quant_map",
"nested_absmax",
"nested_quant_map",
"quant_state",
"quant_type",
"blocksize",
"dtype",
"shape",
"nested_blocksize",
"nested_dtype",
"nested_offset",
]
def __init__(
self,
absmax,
shape=None,
code=None,
blocksize=None,
quant_type=None,
dtype=None,
offset=None,
state2=None,
):
self.absmax = absmax
self.shape = shape
self.code = code
......@@ -614,13 +667,20 @@ class QuantState:
state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type]
"""
if self.nested:
list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, [self.offset, self.state2], self.quant_type]
list_repr = [
self.absmax,
self.shape,
self.dtype,
self.blocksize,
[self.offset, self.state2],
self.quant_type,
]
else:
list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, None, self.quant_type]
return list_repr[idx]
@classmethod
def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState':
def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> "QuantState":
"""
unpacks components of state_dict into QuantState
where necessary, convert into strings, torch.dtype, ints, etc.
......@@ -632,37 +692,39 @@ class QuantState:
# unpacking tensor with non-tensor components
qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)]
if not len(qs_key) and 'quant_type' not in qs_dict:
if not len(qs_key) and "quant_type" not in qs_dict:
raise ValueError("Expected packed or unpacked quant_state items, found neither")
elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys:
raise ValueError(f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.")
raise ValueError(
f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.",
)
# unpacking minor and non-tensor quant state items if necessary
if len(qs_key) == 1:
first_qs_key = qs_key[0]
qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(first_qs_key)))
qs_dict = {k.split('.')[-1]: v for k, v in qs_dict.items()} # strip prefixes
qs_dict = {k.split(".")[-1]: v for k, v in qs_dict.items()} # strip prefixes
assert set(qs_dict.keys()).issubset(cls.valid_qs_keys)
if 'nested_absmax' in qs_dict:
offset = torch.tensor(float(qs_dict['nested_offset'])).to(device)
if "nested_absmax" in qs_dict:
offset = torch.tensor(float(qs_dict["nested_offset"])).to(device)
state2 = cls(
absmax=qs_dict['nested_absmax'].to(device),
blocksize=qs_dict['nested_blocksize'],
code=qs_dict['nested_quant_map'].to(device),
dtype=getattr(torch, qs_dict['nested_dtype']),
absmax=qs_dict["nested_absmax"].to(device),
blocksize=qs_dict["nested_blocksize"],
code=qs_dict["nested_quant_map"].to(device),
dtype=getattr(torch, qs_dict["nested_dtype"]),
)
else:
offset, state2 = None, None
quant_state = cls(
quant_type=qs_dict['quant_type'],
absmax=qs_dict['absmax'].to(device),
blocksize=qs_dict['blocksize'],
code=qs_dict['quant_map'].to(device),
dtype=getattr(torch, qs_dict['dtype']),
shape=torch.Size(qs_dict['shape']) if qs_dict['shape'] is not None else None,
quant_type=qs_dict["quant_type"],
absmax=qs_dict["absmax"].to(device),
blocksize=qs_dict["blocksize"],
code=qs_dict["quant_map"].to(device),
dtype=getattr(torch, qs_dict["dtype"]),
shape=torch.Size(qs_dict["shape"]) if qs_dict["shape"] is not None else None,
offset=offset,
state2=state2,
)
......@@ -674,21 +736,23 @@ class QuantState:
param: packed -- returns dict[str, torch.Tensor] for state_dict fit for safetensors saving
"""
qs_dict = {
'quant_type': self.quant_type,
'absmax': self.absmax,
'blocksize': self.blocksize,
'quant_map': self.code,
'dtype': str(self.dtype).strip('torch.'),
'shape': tuple(self.shape),
"quant_type": self.quant_type,
"absmax": self.absmax,
"blocksize": self.blocksize,
"quant_map": self.code,
"dtype": str(self.dtype).strip("torch."),
"shape": tuple(self.shape),
}
if self.nested:
qs_dict.update({
'nested_absmax': self.state2.absmax,
'nested_blocksize': self.state2.blocksize,
'nested_quant_map': self.state2.code.clone(), # un-shared to avoid restoring it after shared tensors are removed by safetensors
'nested_dtype': str(self.state2.dtype).strip('torch.'),
'nested_offset': self.offset.item(),
})
qs_dict.update(
{
"nested_absmax": self.state2.absmax,
"nested_blocksize": self.state2.blocksize,
"nested_quant_map": self.state2.code.clone(), # un-shared to avoid restoring it after shared tensors are removed by safetensors
"nested_dtype": str(self.state2.dtype).strip("torch."),
"nested_offset": self.offset.item(),
},
)
if not packed:
return qs_dict
......@@ -711,14 +775,22 @@ class QuantState:
return False
return (
torch.allclose(self.absmax, other.absmax, atol=1e-6) and
self.shape == other.shape and
torch.allclose(self.code, other.code, atol=1e-6) and
self.dtype == other.dtype and
self.blocksize == other.blocksize and
self.quant_type == other.quant_type and
(self.offset == other.offset if self.offset is not None and other.offset is not None else self.offset is other.offset) and
(self.state2 == other.state2 if self.state2 is not None and other.state2 is not None else self.state2 is other.state2)
torch.allclose(self.absmax, other.absmax, atol=1e-6)
and self.shape == other.shape
and torch.allclose(self.code, other.code, atol=1e-6)
and self.dtype == other.dtype
and self.blocksize == other.blocksize
and self.quant_type == other.quant_type
and (
self.offset == other.offset
if self.offset is not None and other.offset is not None
else self.offset is other.offset
)
and (
self.state2 == other.state2
if self.state2 is not None and other.state2 is not None
else self.state2 is other.state2
)
)
......@@ -756,7 +828,6 @@ def quantize_blockwise(
The quantization state to undo the quantization.
"""
if code is None:
if "dynamic" not in name2qmap:
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
......@@ -771,31 +842,66 @@ def quantize_blockwise(
if out is None:
out = torch.zeros_like(A, dtype=torch.uint8)
if A.device.type != 'cpu':
if A.device.type != "cpu":
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
cblocksize = ct.c_int32(blocksize)
prev_device = pre_call(A.device)
code = code.to(A.device)
is_on_gpu([code, A, out, absmax])
if A.dtype == torch.float32:
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
lib.cquantize_blockwise_fp32(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
cblocksize,
ct.c_int(A.numel()),
)
elif A.dtype == torch.float16:
lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
lib.cquantize_blockwise_fp16(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
cblocksize,
ct.c_int(A.numel()),
)
elif A.dtype == torch.bfloat16:
lib.cquantize_blockwise_bf16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
lib.cquantize_blockwise_bf16(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
cblocksize,
ct.c_int(A.numel()),
)
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
post_call(A.device)
else:
# cpu
code = code.cpu()
lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
lib.cquantize_blockwise_cpu_fp32(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_longlong(blocksize),
ct.c_longlong(A.numel()),
)
if nested:
offset = absmax.mean()
absmax -= offset
qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False)
quant_state = QuantState(absmax=qabsmax, code=code, blocksize=blocksize, dtype=A.dtype, offset=offset, state2=state2)
quant_state = QuantState(
absmax=qabsmax,
code=code,
blocksize=blocksize,
dtype=A.dtype,
offset=offset,
state2=state2,
)
else:
quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=A.dtype)
......@@ -809,7 +915,7 @@ def dequantize_blockwise(
code: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize: int = 4096,
nested=False
nested=False,
) -> Tensor:
"""
Dequantizes blockwise quantized values.
......@@ -849,37 +955,70 @@ def dequantize_blockwise(
if quant_state.nested:
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax += quant_state.offset
if absmax.dtype != torch.float32: absmax = absmax.float()
if absmax.dtype != torch.float32:
absmax = absmax.float()
if out is None:
out = torch.empty(A.shape, dtype=quant_state.dtype, device=A.device)
if A.device.type != 'cpu':
if A.device.type != "cpu":
device = pre_call(A.device)
code = quant_state.code.to(A.device)
if quant_state.blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
raise ValueError(f"The blockwise of {quant_state.blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]")
raise ValueError(
f"The blockwise of {quant_state.blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]",
)
is_on_gpu([A, absmax, out])
if out.dtype == torch.float32:
lib.cdequantize_blockwise_fp32(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel()))
lib.cdequantize_blockwise_fp32(
get_ptr(quant_state.code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(A.numel()),
)
elif out.dtype == torch.float16:
lib.cdequantize_blockwise_fp16(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel()))
lib.cdequantize_blockwise_fp16(
get_ptr(quant_state.code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(A.numel()),
)
elif out.dtype == torch.bfloat16:
lib.cdequantize_blockwise_bf16(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel()))
lib.cdequantize_blockwise_bf16(
get_ptr(quant_state.code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(A.numel()),
)
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
post_call(A.device)
else:
code = quant_state.code.cpu()
lib.cdequantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(quant_state.absmax), get_ptr(out), ct.c_longlong(quant_state.blocksize), ct.c_longlong(A.numel()))
lib.cdequantize_blockwise_cpu_fp32(
get_ptr(code),
get_ptr(A),
get_ptr(quant_state.absmax),
get_ptr(out),
ct.c_longlong(quant_state.blocksize),
ct.c_longlong(A.numel()),
)
return out
def get_4bit_type(typename, device=None, blocksize=64):
if device is None: device = 'cuda'
if device is None:
device = "cuda"
data = None
if typename == 'nf4':
''' Implements the NF4 data type.
if typename == "nf4":
""" Implements the NF4 data type.
Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that
is normalized into the range [-1, 1].
......@@ -888,12 +1027,26 @@ def get_4bit_type(typename, device=None, blocksize=64):
Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
'''
data = [-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635,
-0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725,
0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941,
0.7229568362236023, 1.0]
elif typename == 'fp4':
"""
data = [
-1.0,
-0.6961928009986877,
-0.5250730514526367,
-0.39491748809814453,
-0.28444138169288635,
-0.18477343022823334,
-0.09105003625154495,
0.0,
0.07958029955625534,
0.16093020141124725,
0.24611230194568634,
0.33791524171829224,
0.44070982933044434,
0.5626170039176941,
0.7229568362236023,
1.0,
]
elif typename == "fp4":
# 0b000 = 0
# 0b001 = 0.0625
# 0b010 = 8
......@@ -904,20 +1057,35 @@ def get_4bit_type(typename, device=None, blocksize=64):
# 0b111 = 3
# can also be created with bnb.functional.create_fp8_map(signed=True, exponent_bits=2, precision_bits=1, total_bits=4)
data = [0, 0.0625, 8.0, 12.0, 4.0, 6.0, 2.0, 3.0, -0, -0.0625, -8.0, -12.0, -4.0, -6.0, -2.0, -3.0]
elif typename == 'int4':
elif typename == "int4":
data = [7, 6, 5, 4, 3, 2, 1, 0, -0, -1, -2, -3, -4, -5, -6, -7]
elif typename == 'af4':
elif typename == "af4":
# Taken from: NF4 Isn't Information Theoretically Optimal (and that's Good)
# https://arxiv.org/abs/2306.06965
if blocksize == 64:
data = [-1., -0.69441008, -0.51243739, -0.3736951, -0.25607552, -0.14982478,
-0.04934812, 0., 0.04273164, 0.12934483, 0.21961274, 0.31675666,
0.42563882, 0.55496234, 0.72424863, 1.][::-1]
data = [
-1.0,
-0.69441008,
-0.51243739,
-0.3736951,
-0.25607552,
-0.14982478,
-0.04934812,
0.0,
0.04273164,
0.12934483,
0.21961274,
0.31675666,
0.42563882,
0.55496234,
0.72424863,
1.0,
][::-1]
else:
raise NotImplementedError('4-bit AbnormalFloats currently only support blocksize 64.')
raise NotImplementedError("4-bit AbnormalFloats currently only support blocksize 64.")
if data is None:
raise NotImplementedError(f'Typename {typename} not supported')
raise NotImplementedError(f"Typename {typename} not supported")
data = Tensor(data)
data /= data.abs().max()
......@@ -926,11 +1094,26 @@ def get_4bit_type(typename, device=None, blocksize=64):
return data.to(device)
def quantize_fp4(A: Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8):
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4', quant_storage)
def quantize_fp4(
A: Tensor,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize=64,
compress_statistics=False,
quant_storage=torch.uint8,
):
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage)
def quantize_nf4(A: Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8):
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4', quant_storage)
def quantize_nf4(
A: Tensor,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize=64,
compress_statistics=False,
quant_storage=torch.uint8,
):
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage)
def quantize_4bit(
......@@ -939,7 +1122,7 @@ def quantize_4bit(
out: Optional[torch.Tensor] = None,
blocksize=64,
compress_statistics=False,
quant_type='fp4',
quant_type="fp4",
quant_storage=torch.uint8,
) -> Tuple[Tensor, QuantState]:
"""
......@@ -967,10 +1150,10 @@ def quantize_4bit(
tuple(torch.Tensor, torch.Size, torch.dtype, int):
The quantization state to undo the quantization.
"""
if A.device.type != 'cuda':
raise NotImplementedError(f'Device type not supported for FP4 quantization: {A.device.type}')
if quant_type not in ['fp4', 'nf4']:
raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.')
if A.device.type != "cuda":
raise NotImplementedError(f"Device type not supported for FP4 quantization: {A.device.type}")
if quant_type not in ["fp4", "nf4"]:
raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.")
n = A.numel()
input_shape = A.shape
......@@ -980,10 +1163,9 @@ def quantize_4bit(
blocks += 1 if n % blocksize > 0 else 0
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
if out is None:
mod = dtype2bytes[quant_storage] * 2
out = torch.zeros(((n+1)//mod, 1), dtype=quant_storage, device=A.device)
out = torch.zeros(((n + 1) // mod, 1), dtype=quant_storage, device=A.device)
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
......@@ -991,20 +1173,62 @@ def quantize_4bit(
is_on_gpu([A, out, absmax])
if A.dtype == torch.float32:
if quant_type == 'fp4':
lib.cquantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
if quant_type == "fp4":
lib.cquantize_blockwise_fp32_fp4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int32(blocksize),
ct.c_int(n),
)
else:
lib.cquantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
lib.cquantize_blockwise_fp32_nf4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int32(blocksize),
ct.c_int(n),
)
elif A.dtype == torch.float16:
if quant_type == 'fp4':
lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
if quant_type == "fp4":
lib.cquantize_blockwise_fp16_fp4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int32(blocksize),
ct.c_int(n),
)
else:
lib.cquantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
lib.cquantize_blockwise_fp16_nf4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int32(blocksize),
ct.c_int(n),
)
elif A.dtype == torch.bfloat16:
if quant_type == 'fp4':
lib.cquantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
if quant_type == "fp4":
lib.cquantize_blockwise_bf16_fp4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int32(blocksize),
ct.c_int(n),
)
else:
lib.cquantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
lib.cquantize_blockwise_bf16_nf4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int32(blocksize),
ct.c_int(n),
)
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
post_call(A.device)
......@@ -1016,19 +1240,57 @@ def quantize_4bit(
absmax -= offset
qabsmax, state2 = quantize_blockwise(absmax, blocksize=256)
del absmax
state = QuantState(absmax=qabsmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, offset=offset, state2=state2)
state = QuantState(
absmax=qabsmax,
shape=input_shape,
dtype=A.dtype,
blocksize=blocksize,
code=code,
quant_type=quant_type,
offset=offset,
state2=state2,
)
else:
state = QuantState(absmax=absmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, )
state = QuantState(
absmax=absmax,
shape=input_shape,
dtype=A.dtype,
blocksize=blocksize,
code=code,
quant_type=quant_type,
)
return out, state
def dequantize_fp4(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64) -> Tensor:
return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'fp4')
def dequantize_nf4(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64) -> Tensor:
return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'nf4')
def dequantize_fp4(
A: Tensor,
quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize: int = 64,
) -> Tensor:
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4")
def dequantize_nf4(
A: Tensor,
quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize: int = 64,
) -> Tensor:
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4")
def dequantize_4bit(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64, quant_type='fp4') -> Tensor:
def dequantize_4bit(
A: Tensor,
quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize: int = 64,
quant_type="fp4",
) -> Tensor:
"""
Dequantizes FP4 blockwise quantized values.
......@@ -1056,23 +1318,31 @@ def dequantize_4bit(A: Tensor, quant_state: Optional[QuantState] = None, absmax:
Dequantized tensor.
"""
if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]")
if quant_type not in ['fp4', 'nf4']:
raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.')
raise ValueError(
f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]",
)
if quant_type not in ["fp4", "nf4"]:
raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.")
if quant_state is None:
assert absmax is not None and out is not None
quant_state = QuantState(absmax=absmax, shape=out.shape, dtype=out.dtype, blocksize=blocksize, quant_type=quant_type)
quant_state = QuantState(
absmax=absmax,
shape=out.shape,
dtype=out.dtype,
blocksize=blocksize,
quant_type=quant_type,
)
else:
absmax = quant_state.absmax
if quant_state.nested:
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax += quant_state.offset
if absmax.dtype != torch.float32: absmax = absmax.float()
if absmax.dtype != torch.float32:
absmax = absmax.float()
if out is None:
out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device)
......@@ -1082,27 +1352,71 @@ def dequantize_4bit(A: Tensor, quant_state: Optional[QuantState] = None, absmax:
device = pre_call(A.device)
is_on_gpu([A, absmax, out])
if out.dtype == torch.float32:
if quant_state.quant_type == 'fp4':
lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
if quant_state.quant_type == "fp4":
lib.cdequantize_blockwise_fp32_fp4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
)
else:
lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
lib.cdequantize_blockwise_fp32_nf4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
)
elif out.dtype == torch.float16:
if quant_state.quant_type == 'fp4':
lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
if quant_state.quant_type == "fp4":
lib.cdequantize_blockwise_fp16_fp4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
)
else:
lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
lib.cdequantize_blockwise_fp16_nf4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
)
elif out.dtype == torch.bfloat16:
if quant_state.quant_type == 'fp4':
lib.cdequantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
if quant_state.quant_type == "fp4":
lib.cdequantize_blockwise_bf16_fp4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
)
else:
lib.cdequantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n))
lib.cdequantize_blockwise_bf16_nf4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
)
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
post_call(A.device)
is_transposed = (True if A.shape[0] == 1 else False)
if is_transposed: return out.t()
else: return out
is_transposed = True if A.shape[0] == 1 else False
if is_transposed:
return out.t()
else:
return out
def quantize(
......@@ -1117,7 +1431,8 @@ def quantize(
code = code.to(A.device)
absmax = torch.abs(A).max()
if absmax.dtype != torch.float32: absmax = absmax.float()
if absmax.dtype != torch.float32:
absmax = absmax.float()
inp = A / absmax
out = quantize_no_absmax(inp, code, out)
return out, (absmax, code)
......@@ -1144,7 +1459,7 @@ def dequantize(
def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor:
'''
"""
Quantizes input tensor to 8-bit.
Quantizes the 32-bit input tensor `A` to the 8-bit output tensor
......@@ -1163,9 +1478,10 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = No
-------
torch.Tensor:
Quantized 8-bit tensor.
'''
"""
prev_device = pre_call(A.device)
if out is None: out = torch.zeros_like(A, dtype=torch.uint8)
if out is None:
out = torch.zeros_like(A, dtype=torch.uint8)
is_on_gpu([A, out])
lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
post_call(prev_device)
......@@ -1173,7 +1489,7 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = No
def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor:
'''
"""
Dequantizes the 8-bit tensor to 32-bit.
Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via
......@@ -1192,9 +1508,10 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] =
-------
torch.Tensor:
32-bit output tensor.
'''
"""
prev_device = pre_call(A.device)
if out is None: out = torch.zeros_like(A, dtype=torch.float32)
if out is None:
out = torch.zeros_like(A, dtype=torch.float32)
is_on_gpu([code, A, out])
lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
post_call(prev_device)
......@@ -1261,16 +1578,17 @@ def optimizer_update_32bit(
if max_unorm > 0.0:
param_norm = torch.norm(p.data.float())
optim_func = None
if g.dtype == torch.float32:
optim_func = str2optimizer32bit[optimizer_name][0]
elif g.dtype == torch.float16:
optim_func = str2optimizer32bit[optimizer_name][1]
elif (g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name])==3):
elif g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name]) == 3:
optim_func = str2optimizer32bit[optimizer_name][2]
else:
raise ValueError(f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}")
raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}",
)
is_on_gpu([g, p, state1, state2, unorm_vec])
prev_device = pre_call(g.device)
......@@ -1290,7 +1608,8 @@ def optimizer_update_32bit(
ct.c_float(lr),
ct.c_float(gnorm_scale),
ct.c_bool(skip_zeros),
ct.c_int32(g.numel()))
ct.c_int32(g.numel()),
)
post_call(prev_device)
......@@ -1422,7 +1741,7 @@ def optimizer_update_8bit(
)
else:
raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}",
)
post_call(prev_device)
......@@ -1446,7 +1765,6 @@ def optimizer_update_8bit_blockwise(
gnorm_scale: float = 1.0,
skip_zeros=False,
) -> None:
optim_func = None
prev_device = pre_call(g.device)
is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2])
......@@ -1454,12 +1772,15 @@ def optimizer_update_8bit_blockwise(
optim_func = str2optimizer8bit_blockwise[optimizer_name][0]
elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
optim_func = str2optimizer8bit_blockwise[optimizer_name][1]
elif (g.dtype == torch.bfloat16 and state1.dtype == torch.uint8 and
len(str2optimizer8bit_blockwise[optimizer_name])==3):
elif (
g.dtype == torch.bfloat16
and state1.dtype == torch.uint8
and len(str2optimizer8bit_blockwise[optimizer_name]) == 3
):
optim_func = str2optimizer8bit_blockwise[optimizer_name][2]
else:
raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}",
)
post_call(prev_device)
......@@ -1487,9 +1808,8 @@ def optimizer_update_8bit_blockwise(
)
post_call(prev_device)
def percentile_clipping(
grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5
):
def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5):
"""Applies percentile clipping
grad: torch.Tensor
......@@ -1531,9 +1851,7 @@ def percentile_clipping(
return current_gnorm, clip_value, gnorm_scale
def histogram_scatter_add_2d(
histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor
):
def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor):
assert len(histogram.shape) == 2
assert histogram.dtype == torch.float32
assert source.dtype == torch.float32
......@@ -1550,12 +1868,12 @@ def histogram_scatter_add_2d(
is_on_gpu([histogram, index1, index2, source])
lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n)
def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8):
if not torch.cuda.is_initialized(): torch.cuda.init()
if not torch.cuda.is_initialized():
torch.cuda.init()
if A.dtype != expected_type or B.dtype != expected_type:
raise TypeError(
f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}"
)
raise TypeError(f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}")
sA = A.shape
sB = B.shape
......@@ -1596,12 +1914,7 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
sout = out.shape
# special case common in backprop
if not correct and len(sA) == 3 and len(sB) == 3:
if (
sout[0] == sA[2]
and sout[1] == sB[2]
and sA[0] == sB[0]
and sA[1] == sB[1]
):
if sout[0] == sA[2] and sout[1] == sB[2] and sA[0] == sB[0] and sA[1] == sB[1]:
correct = True
else:
if len(sA) == 2 and len(sB) == 2:
......@@ -1634,26 +1947,29 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
if not correct:
raise ValueError(
f"Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}."
f"Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}.",
)
return sout
def gemv_4bit(
A: Tensor,
B: Tensor,
out: Optional[torch.Tensor] = None,
transposed_A=False,
transposed_B=False,
state=None
state=None,
):
prev_device = pre_call(A.device)
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
# sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
if state is None:
raise ValueError('state cannot None. gem_4bit( ) requires the state from quantize_4bit( )')
raise ValueError("state cannot None. gem_4bit( ) requires the state from quantize_4bit( )")
if A.numel() != A.shape[-1]:
raise ValueError('Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]')
raise ValueError(
'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]',
)
Bshape = state.shape
bout = Bshape[0]
......@@ -1673,7 +1989,7 @@ def gemv_4bit(
k = Bshape[1]
lda = Bshape[0]
ldc = Bshape[0]
ldb = (A.shape[-1]+1)//2
ldb = (A.shape[-1] + 1) // 2
is_on_gpu([B, A, out, absmax, state.code])
m = ct.c_int32(m)
n = ct.c_int32(n)
......@@ -1684,21 +2000,61 @@ def gemv_4bit(
if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]:
if A.dtype == torch.float16:
lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize))
lib.cgemm_4bit_inference_naive_fp16(
m,
n,
k,
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(state.code),
get_ptr(out),
lda,
ldb,
ldc,
ct.c_int32(state.blocksize),
)
elif A.dtype == torch.bfloat16:
lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize))
lib.cgemm_4bit_inference_naive_bf16(
m,
n,
k,
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(state.code),
get_ptr(out),
lda,
ldb,
ldc,
ct.c_int32(state.blocksize),
)
elif A.dtype == torch.float32:
lib.cgemm_4bit_inference_naive_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize))
lib.cgemm_4bit_inference_naive_fp32(
m,
n,
k,
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(state.code),
get_ptr(out),
lda,
ldb,
ldc,
ct.c_int32(state.blocksize),
)
else:
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')
raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}")
else:
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')
raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}")
post_call(prev_device)
return out
def igemm(
A: Tensor,
B: Tensor,
......@@ -1764,7 +2120,7 @@ def igemm(
assert len(sA) == 3
if not (sA[0] == sB[0] and sA[1] == sB[1]):
raise ValueError(
f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}"
f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}",
)
transposed_A = True
......@@ -1783,8 +2139,20 @@ def igemm(
# B^T @ A^T = C^T
# [km, nk -> mn]
is_on_gpu([B, A, out])
lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc))
lib.cigemm(
ptr,
ct.c_bool(transposed_B),
ct.c_bool(transposed_A),
ct.c_int32(m),
ct.c_int32(n),
ct.c_int32(k),
get_ptr(B),
get_ptr(A),
get_ptr(out),
ct.c_int32(lda),
ct.c_int32(ldb),
ct.c_int32(ldc),
)
return out
......@@ -1796,9 +2164,7 @@ def batched_igemm(
transposed_B=False,
):
if not len(A.shape) == 3 or not len(B.shape) == 3:
raise ValueError(
f"Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}"
)
raise ValueError(f"Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}")
sout = check_matmul(A, B, out, transposed_A, transposed_B)
if out is None:
out = torch.zeros(size=sout, dtype=torch.int32, device=A.device)
......@@ -1865,9 +2231,24 @@ def batched_igemm(
ptr = CUBLAS_Context.get_instance().get_context(A.device)
is_on_gpu([B, A, out])
lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc),
ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch))
lib.cbatched_igemm(
ptr,
ct.c_bool(transposed_B),
ct.c_bool(transposed_A),
ct.c_int32(m),
ct.c_int32(n),
ct.c_int32(k),
get_ptr(B),
get_ptr(A),
get_ptr(out),
ct.c_int32(lda),
ct.c_int32(ldb),
ct.c_int32(ldc),
ct.c_long(strideA),
ct.c_long(strideB),
ct.c_long(strideC),
ct.c_uint32(num_batch),
)
return out
......@@ -1876,14 +2257,14 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
shapeB = SB[0]
dimsA = len(shapeA)
dimsB = len(shapeB)
assert dimsB == 2, 'Only two dimensional matrices are supported for argument B'
assert dimsB == 2, "Only two dimensional matrices are supported for argument B"
if dimsA == 2:
m = shapeA[0]
elif dimsA == 3:
m = shapeA[0] * shapeA[1]
rows = n = shapeB[0]
assert prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}'
assert prod(list(shapeA)) > 0, f"Input tensor dimensions need to be > 0: {shapeA}"
# if the tensor is empty, return a transformed empty tensor with the right dimensions
if shapeA[0] == 0 and dimsA == 2:
......@@ -1892,13 +2273,9 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16)
if dimsA == 2 and out is None:
out, Sout = get_transform_buffer(
(shapeA[0], shapeB[0]), dtype, A.device, "col32", "row"
)
out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, "col32", "row")
elif dimsA == 3 and out is None:
out, Sout = get_transform_buffer(
(shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row"
)
out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row")
assert dimsB != 3, "len(B.shape)==3 not supported"
assert A.device.type == "cuda"
......@@ -1940,49 +2317,33 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
has_error = 0
ptrRowScale = get_ptr(None)
is_on_gpu([A, B, out])
if formatB == 'col_turing':
if formatB == "col_turing":
if dtype == torch.int32:
has_error = lib.cigemmlt_turing_32(
ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
)
has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
else:
has_error = lib.cigemmlt_turing_8(
ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
)
has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
elif formatB == "col_ampere":
if dtype == torch.int32:
has_error = lib.cigemmlt_ampere_32(
ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
)
has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
else:
has_error = lib.cigemmlt_ampere_8(
ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
)
has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu`
raise NotImplementedError("igemmlt not available (probably built with NO_CUBLASLT)")
if has_error:
print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}')
raise Exception('cublasLt ran into an error!')
print(f"A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}")
raise Exception("cublasLt ran into an error!")
torch.cuda.set_device(prev_device)
return out, Sout
def mm_dequant(
A,
quant_state,
row_stats,
col_stats,
out=None,
new_row_stats=None,
new_col_stats=None,
bias=None
):
def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None, bias=None):
assert A.dtype == torch.int32
if bias is not None: assert bias.dtype == torch.float16
if bias is not None:
assert bias.dtype == torch.float16
out_shape = quant_state[0]
if len(out_shape) == 3:
out_shape = (out_shape[0] * out_shape[1], out_shape[2])
......@@ -1990,19 +2351,11 @@ def mm_dequant(
if out is None:
out = torch.empty(out_shape, dtype=torch.float16, device=A.device)
if new_row_stats is None:
new_row_stats = torch.empty(
out_shape[0], dtype=torch.float32, device=A.device
)
new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device)
if new_col_stats is None:
new_col_stats = torch.empty(
out_shape[1], dtype=torch.float32, device=A.device
)
assert (
new_row_stats.shape[0] == row_stats.shape[0]
), f"{new_row_stats.shape} vs {row_stats.shape}"
assert (
new_col_stats.shape[0] == col_stats.shape[0]
), f"{new_col_stats.shape} vs {col_stats.shape}"
new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device)
assert new_row_stats.shape[0] == row_stats.shape[0], f"{new_row_stats.shape} vs {row_stats.shape}"
assert new_col_stats.shape[0] == col_stats.shape[0], f"{new_col_stats.shape} vs {col_stats.shape}"
prev_device = pre_call(A.device)
ptrA = get_ptr(A)
......@@ -2016,15 +2369,23 @@ def mm_dequant(
numCols = ct.c_int32(out_shape[1])
is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias])
lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, ptrBias, numRows, numCols)
lib.cdequant_mm_int32_fp16(
ptrA,
ptrRowStats,
ptrColStats,
ptrOut,
ptrNewRowStats,
ptrNewColStats,
ptrBias,
numRows,
numCols,
)
post_call(prev_device)
return out
def get_colrow_absmax(
A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0
):
def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0):
assert A.dtype == torch.float16
device = A.device
......@@ -2037,18 +2398,12 @@ def get_colrow_absmax(
col_tiles = (cols + 255) // 256
tiled_rows = ((rows + 15) // 16) * 16
if row_stats is None:
row_stats = torch.empty(
(rows,), dtype=torch.float32, device=device
).fill_(-50000.0)
row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_(-50000.0)
if col_stats is None:
col_stats = torch.empty(
(cols,), dtype=torch.float32, device=device
).fill_(-50000.0)
col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_(-50000.0)
if nnz_block_ptr is None and threshold > 0.0:
nnz_block_ptr = torch.zeros(
((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device
)
nnz_block_ptr = torch.zeros(((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device)
ptrA = get_ptr(A)
ptrRowStats = get_ptr(row_stats)
......@@ -2122,14 +2477,10 @@ class CSCSparseTensor:
def coo2csr(cooA):
values, counts = torch.unique(cooA.rowidx, return_counts=True)
values.add_(1)
rowptr = torch.zeros(
(cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device
)
rowptr = torch.zeros((cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device)
rowptr.scatter_(index=values.long(), src=counts.int(), dim=0)
rowptr.cumsum_(0)
return CSRSparseTensor(
cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values
)
return CSRSparseTensor(cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values)
def coo2csc(cooA):
......@@ -2138,14 +2489,10 @@ def coo2csc(cooA):
values = cooA.values[col2rowidx]
colvalues, counts = torch.unique(val, return_counts=True)
colvalues.add_(1)
colptr = torch.zeros(
(cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device
)
colptr = torch.zeros((cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device)
colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0)
colptr.cumsum_(0)
return CSCSparseTensor(
cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values
)
return CSCSparseTensor(cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values)
def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
......@@ -2155,9 +2502,7 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values)
def double_quant(
A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0
):
def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0):
device = A.device
assert A.dtype == torch.half
assert device.type == "cuda"
......@@ -2170,9 +2515,7 @@ def double_quant(
rows = A.shape[0]
if row_stats is None or col_stats is None:
row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(
A, threshold=threshold
)
row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold)
if out_col is None:
out_col = torch.zeros(A.shape, device=device, dtype=torch.int8)
......@@ -2190,9 +2533,7 @@ def double_quant(
if threshold > 0.0:
nnz = nnz_row_ptr[-1].item()
if nnz > 0:
coo_tensor = coo_zeros(
A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device
)
coo_tensor = coo_zeros(A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device)
ptrRowIdx = get_ptr(coo_tensor.rowidx)
ptrColIdx = get_ptr(coo_tensor.colidx)
ptrVal = get_ptr(coo_tensor.values)
......@@ -2251,12 +2592,16 @@ def double_quant(
return out_row, out_col, row_stats, col_stats, coo_tensor
def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None):
def transform(A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None):
prev_device = pre_call(A.device)
if state is None: state = (A.shape, from_order)
else: from_order = state[1]
if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose)
else: new_state = (state[0], to_order) # (shape, order)
if state is None:
state = (A.shape, from_order)
else:
from_order = state[1]
if out is None:
out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose)
else:
new_state = (state[0], to_order) # (shape, order)
shape = state[0]
if len(shape) == 2:
......@@ -2267,7 +2612,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
dim2 = ct.c_int32(shape[2])
is_on_gpu([A, out])
if to_order == 'col32':
if to_order == "col32":
if transpose:
lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2)
else:
......@@ -2288,7 +2633,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
elif from_order == "col_ampere":
lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2)
else:
raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}')
raise NotImplementedError(f"Transform function not implemented: From {from_order} to {to_order}")
post_call(prev_device)
......@@ -2297,9 +2642,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
def spmm_coo(cooA, B, out=None):
if out is None:
out = torch.empty(
(cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype
)
out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype)
nnz = cooA.nnz
assert cooA.rowidx.numel() == nnz
assert cooA.colidx.numel() == nnz
......@@ -2326,16 +2669,28 @@ def spmm_coo(cooA, B, out=None):
cldc = ct.c_int32(ldc)
is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out])
lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B))
lib.cspmm_coo(
ptr,
ptrRowidx,
ptrColidx,
ptrValues,
cnnz,
crowsA,
ccolsA,
ccolsB,
cldb,
ptrB,
cldc,
ptrC,
ct.c_bool(transposed_B),
)
return out
def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
if out is None:
out = torch.zeros(
(cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype
)
out = torch.zeros((cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype)
nnz = cooA.nnz
prev_device = pre_call(B.device)
assert cooA.rowidx.numel() == nnz
......@@ -2353,9 +2708,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
max_count, max_idx = torch.sort(counts, descending=True)
max_idx = max_idx.int()
max_count = max_count.int()
assert (
max_count[0] <= 32
), f"Current max count per row is 8 but found {max_count[0]}."
assert max_count[0] <= 32, f"Current max count per row is 8 but found {max_count[0]}."
assert B.dtype in [torch.float16, torch.int8]
ptrOffset = get_ptr(offset)
ptrMaxCount = get_ptr(max_count)
......@@ -2443,9 +2796,7 @@ def vectorwise_quant(x, dim=1, quant_type="vector"):
elif quant_type in ["vector-zeropoint", "row-zeropoint"]:
dtype = x.dtype
x = x.float()
dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin(
x, dim=dim, keepdim=True
)
dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin(x, dim=dim, keepdim=True)
dyna[dyna == 0] = 1
qx = 255.0 / dyna
minx = torch.amin(x, dim=dim, keepdim=True)
......@@ -2553,9 +2904,7 @@ def extract_outliers(A, SA, idx):
assert formatA in ["col_turing", "col_ampere"]
assert A.device.type == "cuda"
out = torch.zeros(
(shapeA[0], idx.numel()), dtype=torch.int8, device=A.device
)
out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device)
idx_size = ct.c_int32(idx.numel())
rows = ct.c_int32(shapeA[0])
......@@ -2565,7 +2914,7 @@ def extract_outliers(A, SA, idx):
ptrOut = get_ptr(out)
prev_device = pre_call(A.device)
if formatA == 'col_turing':
if formatA == "col_turing":
lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
elif formatA == "col_ampere":
lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
......@@ -2573,6 +2922,7 @@ def extract_outliers(A, SA, idx):
return out
def pipeline_test(A, batch_size):
out = torch.zeros_like(A)
lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size))
......
......@@ -44,6 +44,7 @@ class StableEmbedding(torch.nn.Embedding):
reset_parameters(): Reset embedding parameters using Xavier uniform initialization.
forward(input: Tensor) -> Tensor: Forward pass through the stable embedding layer.
"""
def __init__(
self,
num_embeddings: int,
......@@ -89,9 +90,7 @@ class StableEmbedding(torch.nn.Embedding):
dtype,
)
self.norm = torch.nn.LayerNorm(embedding_dim, device=device)
GlobalOptimManager.get_instance().register_module_override(
self, "weight", {"optim_bits": 32}
)
GlobalOptimManager.get_instance().register_module_override(self, "weight", {"optim_bits": 32})
def reset_parameters(self) -> None:
torch.nn.init.xavier_uniform_(self.weight)
......@@ -130,6 +129,7 @@ class Embedding(torch.nn.Embedding):
"""
Embedding class to store and retrieve word embeddings from their indices.
"""
def __init__(
self,
num_embeddings: int,
......@@ -170,11 +170,9 @@ class Embedding(torch.nn.Embedding):
scale_grad_by_freq,
sparse,
_weight,
device=device
)
GlobalOptimManager.get_instance().register_module_override(
self, "weight", {"optim_bits": 32}
device=device,
)
GlobalOptimManager.get_instance().register_module_override(self, "weight", {"optim_bits": 32})
def reset_parameters(self) -> None:
torch.nn.init.xavier_uniform_(self.weight)
......@@ -214,10 +212,10 @@ class Params4bit(torch.nn.Parameter):
quant_state: Optional[QuantState] = None,
blocksize: int = 64,
compress_statistics: bool = True,
quant_type: str = 'fp4',
quant_type: str = "fp4",
quant_storage: torch.dtype = torch.uint8,
module: Optional["Linear4bit"] = None,
bnb_quantized: bool = False
bnb_quantized: bool = False,
) -> "Params4bit":
if data is None:
data = torch.empty(0)
......@@ -250,7 +248,7 @@ class Params4bit(torch.nn.Parameter):
self.bnb_quantized = state["bnb_quantized"]
self.module = state["module"]
def __deepcopy__(self,memo):
def __deepcopy__(self, memo):
new_instance = type(self).__new__(type(self))
state = self.__getstate__()
new_instance.__setstate__(state)
......@@ -265,7 +263,14 @@ class Params4bit(torch.nn.Parameter):
return new_instance
@classmethod
def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], requires_grad: bool = False, device='cuda', **kwargs) -> "Params4bit":
def from_prequantized(
cls,
data: torch.Tensor,
quantized_stats: Dict[str, Any],
requires_grad: bool = False,
device="cuda",
**kwargs,
) -> "Params4bit":
self = torch.Tensor._make_subclass(cls, data.to(device))
self.requires_grad = requires_grad
self.quant_state = QuantState.from_dict(qs_dict=quantized_stats, device=device)
......@@ -292,33 +297,39 @@ class Params4bit(torch.nn.Parameter):
return self
def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
return self.to(device='cuda' if device is None else device, non_blocking=non_blocking)
return self.to(device="cuda" if device is None else device, non_blocking=non_blocking)
@overload
def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ...,) -> T:
...
def to(
self: T,
device: Optional[Union[int, device]] = ...,
dtype: Optional[Union[dtype, str]] = ...,
non_blocking: bool = ...,
) -> T: ...
@overload
def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T:
...
def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: ...
@overload
def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T:
...
def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if (device is not None and device.type == "cuda" and not self.bnb_quantized):
if device is not None and device.type == "cuda" and not self.bnb_quantized:
return self._quantize(device)
else:
if self.quant_state is not None:
self.quant_state.to(device)
new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking),
requires_grad=self.requires_grad, quant_state=self.quant_state,
blocksize=self.blocksize, compress_statistics=self.compress_statistics,
quant_type=self.quant_type)
new_param = Params4bit(
super().to(device=device, dtype=dtype, non_blocking=non_blocking),
requires_grad=self.requires_grad,
quant_state=self.quant_state,
blocksize=self.blocksize,
compress_statistics=self.compress_statistics,
quant_type=self.quant_type,
)
return new_param
......@@ -355,7 +366,18 @@ class Linear4bit(nn.Linear):
quantized_model = quantized_model.to(0) # Quantization happens here
```
"""
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', quant_storage=torch.uint8, device=None):
def __init__(
self,
input_features,
output_features,
bias=True,
compute_dtype=None,
compress_statistics=True,
quant_type="fp4",
quant_storage=torch.uint8,
device=None,
):
"""
Initialize Linear4bit class.
......@@ -368,7 +390,14 @@ class Linear4bit(nn.Linear):
Whether the linear class uses the bias term as well.
"""
super().__init__(input_features, output_features, bias, device)
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type, quant_storage=quant_storage, module=self)
self.weight = Params4bit(
self.weight.data,
requires_grad=False,
compress_statistics=compress_statistics,
quant_type=quant_type,
quant_storage=quant_storage,
module=self,
)
# self.persistent_buffers = [] # TODO consider as way to save quant state
self.compute_dtype = compute_dtype
self.compute_type_is_set = False
......@@ -385,11 +414,15 @@ class Linear4bit(nn.Linear):
if self.compute_dtype == torch.float32 and (x.numel() == x.shape[-1]):
# single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast
# warn the user about this
warnings.warn('Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.')
warnings.filterwarnings('ignore', message='.*inference.')
warnings.warn(
"Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.",
)
warnings.filterwarnings("ignore", message=".*inference.")
if self.compute_dtype == torch.float32 and (x.numel() != x.shape[-1]):
warnings.warn('Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.')
warnings.filterwarnings('ignore', message='.*inference or training')
warnings.warn(
"Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.",
)
warnings.filterwarnings("ignore", message=".*inference or training")
def _save_to_state_dict(self, destination, prefix, keep_vars):
"""
......@@ -407,8 +440,8 @@ class Linear4bit(nn.Linear):
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
if getattr(self.weight, 'quant_state', None) is None:
if getattr(self, 'quant_state', None) is not None:
if getattr(self.weight, "quant_state", None) is None:
if getattr(self, "quant_state", None) is not None:
# the quant state got lost when the parameter got converted. This happens for example for fsdp
# since we registered the module, we can recover the state here
assert self.weight.shape[1] == 1
......@@ -416,7 +449,9 @@ class Linear4bit(nn.Linear):
self.weight = Params4bit(self.weight, quant_storage=self.quant_storage)
self.weight.quant_state = self.quant_state
else:
print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
print(
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.",
)
if not self.compute_type_is_set:
self.set_compute_type(x)
self.compute_type_is_set = True
......@@ -437,7 +472,17 @@ class LinearFP4(Linear4bit):
"""
Implements the FP4 data type.
"""
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None):
def __init__(
self,
input_features,
output_features,
bias=True,
compute_dtype=None,
compress_statistics=True,
quant_storage=torch.uint8,
device=None,
):
"""
Args:
input_features (`str`):
......@@ -447,11 +492,20 @@ class LinearFP4(Linear4bit):
bias (`bool`, defaults to `True`):
Whether the linear class uses the bias term as well.
"""
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', quant_storage, device)
super().__init__(
input_features,
output_features,
bias,
compute_dtype,
compress_statistics,
"fp4",
quant_storage,
device,
)
class LinearNF4(Linear4bit):
''' Implements the NF4 data type.
"""Implements the NF4 data type.
Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that
is normalized into the range [-1, 1].
......@@ -460,8 +514,18 @@ class LinearNF4(Linear4bit):
Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
'''
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None):
"""
def __init__(
self,
input_features,
output_features,
bias=True,
compute_dtype=None,
compress_statistics=True,
quant_storage=torch.uint8,
device=None,
):
"""
Args:
input_features (`str`):
......@@ -471,7 +535,16 @@ class LinearNF4(Linear4bit):
bias (`bool`, defaults to `True`):
Whether the linear class uses the bias term as well.
"""
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', quant_storage, device)
super().__init__(
input_features,
output_features,
bias,
compute_dtype,
compress_statistics,
"nf4",
quant_storage,
device,
)
class Int8Params(torch.nn.Parameter):
......@@ -514,33 +587,22 @@ class Int8Params(torch.nn.Parameter):
device: Optional[Union[int, device]] = ...,
dtype: Optional[Union[dtype, str]] = ...,
non_blocking: bool = ...,
) -> T:
...
) -> T: ...
@overload
def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T:
...
def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: ...
@overload
def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T:
...
def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
*args, **kwargs
)
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if (
device is not None
and device.type == "cuda"
and self.data.device.type == "cpu"
):
if device is not None and device.type == "cuda" and self.data.device.type == "cpu":
return self.cuda(device)
else:
new_param = Int8Params(
super().to(
device=device, dtype=dtype, non_blocking=non_blocking
),
super().to(device=device, dtype=dtype, non_blocking=non_blocking),
requires_grad=self.requires_grad,
has_fp16_weights=self.has_fp16_weights,
)
......@@ -593,8 +655,18 @@ class Linear8bitLt(nn.Linear):
int8_model = int8_model.to(0) # Quantization happens here
```
"""
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
memory_efficient_backward=False, threshold=0.0, index=None, device=None):
def __init__(
self,
input_features,
output_features,
bias=True,
has_fp16_weights=True,
memory_efficient_backward=False,
threshold=0.0,
index=None,
device=None,
):
"""
Initialize Linear8bitLt class.
......@@ -647,19 +719,36 @@ class Linear8bitLt(nn.Linear):
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
destination[format_name] = self.state.formatB
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs)
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
unexpected_copy = list(unexpected_keys)
for key in unexpected_copy:
input_name = key[len(prefix):]
input_name = key[len(prefix) :]
if input_name == "SCB":
if self.weight.SCB is None:
# buffers not yet initialized, can't access them directly without quantizing first
raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is "
"not supported. Please call module.cuda() before module.load_state_dict()")
raise RuntimeError(
"Loading a quantized checkpoint into non-quantized Linear8bitLt is "
"not supported. Please call module.cuda() before module.load_state_dict()",
)
input_param = state_dict[key]
self.weight.SCB.copy_(input_param)
......@@ -702,18 +791,18 @@ class OutlierAwareLinear(nn.Linear):
self.is_quantized = False
def forward_with_outliers(self, x, outlier_idx):
raise NotImplementedError('Please override the `forward_with_outliers(self, x, outlier_idx)` function')
raise NotImplementedError("Please override the `forward_with_outliers(self, x, outlier_idx)` function")
def quantize_weight(self, w, outlier_idx):
raise NotImplementedError('Please override the `quantize_weights(self, w, outlier_idx)` function')
raise NotImplementedError("Please override the `quantize_weights(self, w, outlier_idx)` function")
def forward(self, x):
if self.outlier_dim is None:
tracer = OutlierTracer.get_instance()
if not tracer.is_initialized():
print('Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer')
print("Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer")
outlier_idx = tracer.get_outliers(self.weight)
#print(outlier_idx, tracer.get_hvalue(self.weight))
# print(outlier_idx, tracer.get_hvalue(self.weight))
self.outlier_dim = outlier_idx
if not self.is_quantized:
......@@ -721,6 +810,7 @@ class OutlierAwareLinear(nn.Linear):
self.weight.data.copy_(w)
self.is_quantized = True
class SwitchBackLinearBnb(nn.Linear):
def __init__(
self,
......@@ -731,11 +821,9 @@ class SwitchBackLinearBnb(nn.Linear):
memory_efficient_backward=False,
threshold=0.0,
index=None,
device=None
device=None,
):
super().__init__(
input_features, output_features, bias, device
)
super().__init__(input_features, output_features, bias, device)
self.state = bnb.MatmulLtState()
self.index = index
......@@ -745,9 +833,7 @@ class SwitchBackLinearBnb(nn.Linear):
if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True
self.weight = Int8Params(
self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights
)
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
def init_8bit_state(self):
self.state.CB = self.weight.CB
......
......@@ -22,7 +22,6 @@ from bitsandbytes.triton.triton_utils import is_triton_available
class _switchback_global(torch.autograd.Function):
@staticmethod
def forward(ctx, X_3D, W, bias):
# reshape input to [N * L, D]
......@@ -37,9 +36,7 @@ class _switchback_global(torch.autograd.Function):
# matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized
return int8_matmul_mixed_dequantize(
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D.size()[:-1], -1)
return int8_matmul_mixed_dequantize(X_int8, W_int8.t(), state_X, state_W, bias).view(*X_3D.size()[:-1], -1)
@staticmethod
def backward(ctx, G_3D):
......@@ -56,7 +53,8 @@ class _switchback_global(torch.autograd.Function):
G_int8, state_G = quantize_rowwise(G)
W_int8, state_W = quantize_global_transpose(W)
grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D.size()[:-1], -1
*G_3D.size()[:-1],
-1,
)
if ctx.needs_input_grad[1]:
# backward pass uses standard weight grad
......@@ -66,8 +64,8 @@ class _switchback_global(torch.autograd.Function):
return grad_X, grad_W, grad_bias
class _switchback_vectorrize(torch.autograd.Function):
class _switchback_vectorrize(torch.autograd.Function):
@staticmethod
def forward(ctx, X_3D, W, bias):
# reshape input to [N * L, D]
......@@ -81,9 +79,7 @@ class _switchback_vectorrize(torch.autograd.Function):
# matmult, fused dequant and add bias
# call kernel which expects rowwise quantized X and W
return int8_matmul_rowwise_dequantize(
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D.size()[:-1], -1)
return int8_matmul_rowwise_dequantize(X_int8, W_int8.t(), state_X, state_W, bias).view(*X_3D.size()[:-1], -1)
@staticmethod
def backward(ctx, G_3D):
......@@ -99,7 +95,8 @@ class _switchback_vectorrize(torch.autograd.Function):
G_int8, state_G = quantize_rowwise(G)
W_int8, state_W = quantize_columnwise_and_transpose(W)
grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D.size()[:-1], -1
*G_3D.size()[:-1],
-1,
)
if ctx.needs_input_grad[1]:
# backward pass uses standard weight grad
......@@ -109,8 +106,8 @@ class _switchback_vectorrize(torch.autograd.Function):
return grad_X, grad_W, grad_bias
class _switchback_global_mem_efficient(torch.autograd.Function):
class _switchback_global_mem_efficient(torch.autograd.Function):
@staticmethod
def forward(ctx, X_3D, W, bias):
# reshape input to [N * L, D]
......@@ -127,9 +124,7 @@ class _switchback_global_mem_efficient(torch.autograd.Function):
# matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized
return int8_matmul_mixed_dequantize(
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D_sz[:-1], -1)
return int8_matmul_mixed_dequantize(X_int8, W_int8.t(), state_X, state_W, bias).view(*X_3D_sz[:-1], -1)
@staticmethod
def backward(ctx, G_3D):
......@@ -151,12 +146,11 @@ class _switchback_global_mem_efficient(torch.autograd.Function):
G_int8, state_G = quantize_rowwise(G)
del G
W_int8 = W_int8.t().contiguous()
grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D_sz[:-1], -1
)
grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(*G_3D_sz[:-1], -1)
return grad_X, grad_W, grad_bias
class SwitchBackLinear(nn.Linear):
def __init__(
self,
......@@ -166,20 +160,20 @@ class SwitchBackLinear(nn.Linear):
device=None,
dtype=None,
vector_wise_quantization: bool = False,
mem_efficient : bool = False,
mem_efficient: bool = False,
):
super().__init__(in_features, out_features, bias, device, dtype)
if not is_triton_available():
raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear.
Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''')
raise ImportError("""Could not import triton. Please install triton to use SwitchBackLinear.
Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower""")
# By default, we use the global quantization.
self.vector_wise_quantization = vector_wise_quantization
if self.vector_wise_quantization:
self._fn = _switchback_vectorrize
if mem_efficient:
print('mem efficient is not supported for vector-wise quantization.')
print("mem efficient is not supported for vector-wise quantization.")
exit(1)
else:
if mem_efficient:
......@@ -195,7 +189,7 @@ class SwitchBackLinear(nn.Linear):
# if hasattr(m, "prepare_for_eval"):
# m.prepare_for_eval()
# model.apply(cond_prepare)
print('=> preparing for eval.')
print("=> preparing for eval.")
if self.vector_wise_quantization:
W_int8, state_W = quantize_rowwise(self.weight)
else:
......@@ -219,18 +213,22 @@ class SwitchBackLinear(nn.Linear):
X_int8, state_X = quantize_rowwise(X)
if self.vector_wise_quantization:
return int8_matmul_rowwise_dequantize(
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
).view(*x.size()[:-1], -1)
return int8_matmul_rowwise_dequantize(X_int8, self.W_int8.t(), state_X, self.state_W, self.bias).view(
*x.size()[:-1],
-1,
)
else:
return int8_matmul_mixed_dequantize(
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
).view(*x.size()[:-1], -1)
return int8_matmul_mixed_dequantize(X_int8, self.W_int8.t(), state_X, self.state_W, self.bias).view(
*x.size()[:-1],
-1,
)
SwitchBackLinearGlobal = partial(SwitchBackLinear, vector_wise_quantization=False)
SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vector_wise_quantization=False, mem_efficient=True)
SwitchBackLinearVectorwise = partial(SwitchBackLinear, vector_wise_quantization=True)
# This is just the standard linear function.
class StandardLinearFunction(torch.autograd.Function):
@staticmethod
......@@ -260,7 +258,7 @@ class StandardLinearFunction(torch.autograd.Function):
return grad_input, grad_weight, grad_bias
class StandardLinear(nn.Linear):
class StandardLinear(nn.Linear):
def forward(self, x):
return StandardLinearFunction.apply(x, self.weight, self.bias)
......@@ -50,9 +50,7 @@ class Adagrad(Optimizer1State):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay:
raise ValueError(
f"Invalid weight_decay value: {weight_decay}"
)
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if initial_accumulator_value != 0.0:
......@@ -119,9 +117,7 @@ class Adagrad8bit(Optimizer1State):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay:
raise ValueError(
f"Invalid weight_decay value: {weight_decay}"
)
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if initial_accumulator_value != 0.0:
......@@ -189,9 +185,7 @@ class Adagrad32bit(Optimizer1State):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay:
raise ValueError(
f"Invalid weight_decay value: {weight_decay}"
)
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if initial_accumulator_value != 0.0:
......
......@@ -14,8 +14,21 @@ from bitsandbytes.optim.optimizer import Optimizer2State
class Adam(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
"""
Base Adam optimizer.
......@@ -45,11 +58,38 @@ class Adam(Optimizer2State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class Adam8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
"""
8-bit Adam optimizer.
......@@ -79,11 +119,38 @@ class Adam8bit(Optimizer2State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class Adam32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
"""
32-bit Adam optimizer.
......@@ -113,11 +180,38 @@ class Adam32bit(Optimizer2State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class PagedAdam(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
"""
Paged Adam optimizer.
......@@ -147,11 +241,38 @@ class PagedAdam(Optimizer2State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
class PagedAdam8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
"""
8-bit paged Adam optimizer.
......@@ -181,11 +302,38 @@ class PagedAdam8bit(Optimizer2State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
class PagedAdam32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
"""
Paged 32-bit Adam optimizer.
......@@ -215,7 +363,21 @@ class PagedAdam32bit(Optimizer2State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
class AnalysisAdam(torch.optim.Optimizer):
"""Adam that performs 8-bit vs 32-bit error analysis.
......@@ -293,9 +455,7 @@ class AnalysisAdam(torch.optim.Optimizer):
if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float()
if grad.is_sparse:
raise RuntimeError(
"Adam does not support sparse gradients, please consider SparseAdam instead"
)
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
amsgrad = group.get("amsgrad", False)
assert not amsgrad
......@@ -312,15 +472,9 @@ class AnalysisAdam(torch.optim.Optimizer):
state["exp_avg"] = torch.zeros_like(p_data_fp32)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
state["abserrors"] = torch.zeros(
(256, 256), device=p_data_fp32.device
)
state["relerrors"] = torch.zeros(
(256, 256), device=p_data_fp32.device
)
state["counts"] = torch.zeros(
(256, 256), device=p_data_fp32.device
)
state["abserrors"] = torch.zeros((256, 256), device=p_data_fp32.device)
state["relerrors"] = torch.zeros((256, 256), device=p_data_fp32.device)
state["counts"] = torch.zeros((256, 256), device=p_data_fp32.device)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32)
......@@ -328,25 +482,19 @@ class AnalysisAdam(torch.optim.Optimizer):
state["exp_avg"] = state["exp_avg"].to(p_data_fp32)
state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32)
if amsgrad:
state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to(
p_data_fp32
)
state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to(p_data_fp32)
state["step"] += 1
beta1, beta2 = group["betas"]
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
step_size = (
group["lr"] * math.sqrt(bias_correction2) / bias_correction1
)
step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
e = state["abserrors"]
rele = state["relerrors"]
counts = state["counts"]
if group["weight_decay"] != 0:
p_data_fp32.add_(
p_data_fp32, alpha=-group["weight_decay"] * group["lr"]
)
p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"])
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
if amsgrad:
......@@ -359,10 +507,7 @@ class AnalysisAdam(torch.optim.Optimizer):
denom = exp_avg_sq.sqrt().add_(group["eps"])
update_fp32 = exp_avg / denom
if (
p_data_fp32.numel() <= 8192
or p_data_fp32.numel() > 50000 * 1000
):
if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000 * 1000:
# embedding layer or too small
p_data_fp32 += -step_size * update_fp32
else:
......@@ -401,9 +546,7 @@ class AnalysisAdam(torch.optim.Optimizer):
# 3. dequantize
# Error will be calculated automatically!
else:
raise ValueError(
f"Invalid analysis value: {self.analysis}!"
)
raise ValueError(f"Invalid analysis value: {self.analysis}!")
denom = state2.sqrt().add_(group["eps"])
update_8bit = state1 / denom
......@@ -415,9 +558,7 @@ class AnalysisAdam(torch.optim.Optimizer):
F.histogram_scatter_add_2d(e, C1.int(), C2.int(), abserr)
F.histogram_scatter_add_2d(rele, C1.int(), C2.int(), relerr)
F.histogram_scatter_add_2d(
counts, C1.int(), C2.int(), torch.ones_like(abserr)
)
F.histogram_scatter_add_2d(counts, C1.int(), C2.int(), torch.ones_like(abserr))
p_data_fp32 += -step_size * update_fp32
......@@ -425,18 +566,10 @@ class AnalysisAdam(torch.optim.Optimizer):
if self.savedir != "" and state["step"] % 100 == 0:
if not os.path.exists(self.savedir):
os.makedirs(self.savedir)
shapestr = "_".join(
[str(dim) for dim in p_data_fp32.shape]
)
pathe = os.path.join(
self.savedir, f"{p_id}_{shapestr}_abserr.pkl"
)
pathrele = os.path.join(
self.savedir, f"{p_id}_{shapestr}_relerr.pkl"
)
pathcounts = os.path.join(
self.savedir, f"{p_id}_{shapestr}_counts.pkl"
)
shapestr = "_".join([str(dim) for dim in p_data_fp32.shape])
pathe = os.path.join(self.savedir, f"{p_id}_{shapestr}_abserr.pkl")
pathrele = os.path.join(self.savedir, f"{p_id}_{shapestr}_relerr.pkl")
pathcounts = os.path.join(self.savedir, f"{p_id}_{shapestr}_counts.pkl")
torch.save(e, pathe)
torch.save(rele, pathrele)
torch.save(counts, pathcounts)
......
......@@ -6,8 +6,21 @@ from bitsandbytes.optim.optimizer import Optimizer2State
class AdamW(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
"""
Base AdamW optimizer.
......@@ -37,11 +50,38 @@ class AdamW(Optimizer2State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged )
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class AdamW8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
"""
8-bit AdamW optimizer.
......@@ -71,11 +111,38 @@ class AdamW8bit(Optimizer2State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged )
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class AdamW32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
"""
32-bit AdamW optimizer.
......@@ -105,12 +172,37 @@ class AdamW32bit(Optimizer2State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class PagedAdamW(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
"""
Paged AdamW optimizer.
......@@ -140,11 +232,37 @@ class PagedAdamW(Optimizer2State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
class PagedAdamW8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
"""
Paged 8-bit AdamW optimizer.
......@@ -174,11 +292,37 @@ class PagedAdamW8bit(Optimizer2State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
class PagedAdamW32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
"""
Paged 32-bit AdamW optimizer.
......@@ -208,4 +352,17 @@ class PagedAdamW32bit(Optimizer2State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
......@@ -51,9 +51,7 @@ class LARS(Optimizer1State):
The maximum gradient norm.
"""
if momentum == 0:
raise NotImplementedError(
"LARS without momentum is not supported!"
)
raise NotImplementedError("LARS without momentum is not supported!")
super().__init__(
"lars",
params,
......@@ -110,9 +108,7 @@ class LARS8bit(Optimizer1State):
The maximum gradient norm.
"""
if momentum == 0:
raise NotImplementedError(
"LARS without momentum is not supported!"
)
raise NotImplementedError("LARS without momentum is not supported!")
super().__init__(
"lars",
params,
......@@ -169,9 +165,7 @@ class LARS32bit(Optimizer1State):
The maximum gradient norm.
"""
if momentum == 0:
raise NotImplementedError(
"LARS without momentum is not supported!"
)
raise NotImplementedError("LARS without momentum is not supported!")
super().__init__(
"lars",
params,
......@@ -204,9 +198,7 @@ class PytorchLARS(Optimizer):
if momentum < 0.0:
raise ValueError(f"Invalid momentum value: {momentum}")
if weight_decay < 0.0:
raise ValueError(
f"Invalid weight_decay value: {weight_decay}"
)
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(
lr=lr,
......@@ -217,9 +209,7 @@ class PytorchLARS(Optimizer):
max_unorm=max_unorm,
)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError(
"Nesterov momentum requires a momentum and zero dampening"
)
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super().__init__(params, defaults)
def __setstate__(self, state):
......
......@@ -6,7 +6,19 @@ from bitsandbytes.optim.optimizer import Optimizer1State
class Lion(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
"""
Base Lion optimizer.
......@@ -32,10 +44,35 @@ class Lion(Optimizer1State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
super().__init__(
"lion",
params,
lr,
betas,
0.0,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class Lion8bit(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
"""
8-bit Lion optimizer.
......@@ -59,10 +96,35 @@ class Lion8bit(Optimizer1State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
super().__init__(
"lion",
params,
lr,
betas,
0.0,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class Lion32bit(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
"""
32-bit Lion optimizer.
......@@ -86,11 +148,35 @@ class Lion32bit(Optimizer1State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
super().__init__(
"lion",
params,
lr,
betas,
0.0,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class PagedLion(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
"""
Paged Lion optimizer.
......@@ -114,10 +200,34 @@ class PagedLion(Optimizer1State):
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
super().__init__(
"lion",
params,
lr,
betas,
0.0,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
class PagedLion8bit(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
"""
Paged 8-bit Lion optimizer.
......@@ -141,10 +251,34 @@ class PagedLion8bit(Optimizer1State):
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
super().__init__(
"lion",
params,
lr,
betas,
0.0,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
class PagedLion32bit(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
"""
Paged 32-bit Lion optimizer.
......@@ -168,4 +302,17 @@ class PagedLion32bit(Optimizer1State):
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
super().__init__(
"lion",
params,
lr,
betas,
0.0,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
......@@ -21,6 +21,7 @@ class GlobalOptimManager:
"""
A global optimizer manager for enabling custom optimizer configs.
"""
_instance = None
def __init__(self):
......@@ -48,13 +49,9 @@ class GlobalOptimManager:
for group_index, group in enumerate(param_groups):
for p_index, p in enumerate(group["params"]):
if id(p) in self.pid2config:
self.index2config[(group_index, p_index)] = self.pid2config[
id(p)
]
self.index2config[(group_index, p_index)] = self.pid2config[id(p)]
def override_config(
self, parameters, key=None, value=None, key_value_dict=None
):
def override_config(self, parameters, key=None, value=None, key_value_dict=None):
"""
Override initial optimizer config with specific hyperparameters.
......@@ -170,16 +167,12 @@ class Optimizer8bit(torch.optim.Optimizer):
saved_groups = state_dict["param_groups"]
if len(groups) != len(saved_groups):
raise ValueError(
"loaded state dict has a different number of "
"parameter groups"
)
raise ValueError("loaded state dict has a different number of parameter groups")
param_lens = (len(g["params"]) for g in groups)
saved_lens = (len(g["params"]) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError(
"loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group"
"loaded state dict contains a parameter group that doesn't match the size of optimizer's group",
)
# Update the state
......@@ -228,9 +221,7 @@ class Optimizer8bit(torch.optim.Optimizer):
new_group["params"] = group["params"]
return new_group
param_groups = [
update_group(g, ng) for g, ng in zip(groups, saved_groups)
]
param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({"state": state, "param_groups": param_groups})
def to_gpu(self):
......@@ -240,7 +231,7 @@ class Optimizer8bit(torch.optim.Optimizer):
values = self.state[p]
for k, v in values.items():
if isinstance(v, torch.Tensor):
is_paged = getattr(v, 'is_paged', False)
is_paged = getattr(v, "is_paged", False)
if not is_paged:
self.state[p][k] = v.to(p.device)
......@@ -248,9 +239,7 @@ class Optimizer8bit(torch.optim.Optimizer):
for module, attr, config in self.mng.module_weight_config_triple:
pmodule = getattr(module, attr)
assert pmodule is not None
assert isinstance(pmodule, torch.Tensor) or isinstance(
pmodule, torch.Parameter
)
assert isinstance(pmodule, torch.Tensor) or isinstance(pmodule, torch.Parameter)
found = False
for gindex, group in enumerate(self.param_groups):
if found:
......@@ -262,9 +251,7 @@ class Optimizer8bit(torch.optim.Optimizer):
# found the matching parameter
# init override
self.mng.pid2config[id(p)] = config
self.mng.index2config[
(gindex, pindex)
] = self.mng.pid2config[id(p)]
self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[id(p)]
found = True
@torch.no_grad()
......@@ -287,7 +274,7 @@ class Optimizer8bit(torch.optim.Optimizer):
self.to_gpu() # needed for fairseq pure fp16 training
self.initialized = True
#if self.is_paged: self.page_mng.prefetch_all()
# if self.is_paged: self.page_mng.prefetch_all()
for gindex, group in enumerate(self.param_groups):
for pindex, p in enumerate(group["params"]):
if p.grad is None:
......@@ -304,7 +291,6 @@ class Optimizer8bit(torch.optim.Optimizer):
# to sync to make sure all tensors are in the right state
torch.cuda.synchronize()
return loss
def get_config(self, gindex, pindex, group):
......@@ -328,9 +314,7 @@ class Optimizer8bit(torch.optim.Optimizer):
raise NotImplementedError("init_state method needs to be overridden")
def update_step(self, group, p, gindex, pindex):
raise NotImplementedError(
"The update_step method needs to be overridden"
)
raise NotImplementedError("The update_step method needs to be overridden")
def get_state_buffer(self, p, dtype=torch.float32):
if not self.is_paged or p.numel() < 1e5:
......@@ -345,12 +329,12 @@ class Optimizer8bit(torch.optim.Optimizer):
def prefetch_state(self, p):
if self.is_paged:
state = self.state[p]
s1 = state['state1']
is_paged = getattr(s1, 'is_paged', False)
s1 = state["state1"]
is_paged = getattr(s1, "is_paged", False)
if is_paged:
F.prefetch_tensor(state['state1'])
if 'state2' in state:
F.prefetch_tensor(state['state2'])
F.prefetch_tensor(state["state1"])
if "state2" in state:
F.prefetch_tensor(state["state2"])
class Optimizer2State(Optimizer8bit):
......@@ -369,7 +353,7 @@ class Optimizer2State(Optimizer8bit):
block_wise=True,
max_unorm=0.0,
skip_zeros=False,
is_paged=False
is_paged=False,
):
"""
Base 2-state update optimizer class.
......@@ -414,13 +398,9 @@ class Optimizer2State(Optimizer8bit):
betas = [float(b) for b in betas]
for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0:
raise ValueError(
f"Invalid beta parameter at index {i}: {betas[i]}"
)
raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
if not 0.0 <= weight_decay:
raise ValueError(
f"Invalid weight_decay value: {weight_decay}"
)
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults, optim_bits, is_paged)
......@@ -449,9 +429,7 @@ class Optimizer2State(Optimizer8bit):
elif config["optim_bits"] == 8:
dtype = torch.uint8
else:
raise NotImplementedError(
f'Amount of optimizer bits not supported: {config["optim_bits"]}'
)
raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')
if p.numel() < config["min_8bit_size"]:
dtype = torch.float32
......@@ -459,21 +437,15 @@ class Optimizer2State(Optimizer8bit):
state = self.state[p]
state["step"] = 0
if dtype == torch.float32 or (
dtype == torch.uint8 and p.numel() < 4096
):
if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
state["state1"] = self.get_state_buffer(p, dtype=torch.float32)
state["state2"] = self.get_state_buffer(p, dtype=torch.float32)
elif dtype == torch.uint8:
if state["step"] == 0:
if "dynamic" not in self.name2qmap:
self.fill_qmap()
self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(
p.device
)
self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(
p.device
)
self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device)
self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(p.device)
state["state1"] = self.get_state_buffer(p, dtype=torch.uint8)
state["qmap1"] = self.name2qmap["dynamic"]
......@@ -486,25 +458,13 @@ class Optimizer2State(Optimizer8bit):
blocks = n // 2048
blocks += 1 if n % 2048 > 0 else 0
state["absmax1"] = torch.zeros(
(blocks,), dtype=torch.float32, device=p.device
)
state["absmax2"] = torch.zeros(
(blocks,), dtype=torch.float32, device=p.device
)
state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
else:
state["max1"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
state["new_max1"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
state["max2"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
state["new_max2"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
state["new_max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
state["max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
state["new_max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
if config["percentile_clipping"] < 100:
state["gnorm_vec"] = torch.zeros((100,), device=p.device)
......@@ -524,7 +484,10 @@ class Optimizer2State(Optimizer8bit):
if config["percentile_clipping"] < 100:
current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(
grad, state["gnorm_vec"], step, config["percentile_clipping"]
grad,
state["gnorm_vec"],
step,
config["percentile_clipping"],
)
else:
gnorm_scale = 1.0
......@@ -568,9 +531,7 @@ class Optimizer2State(Optimizer8bit):
state["new_max2"],
config["weight_decay"],
gnorm_scale=gnorm_scale,
unorm_vec=state["unorm_vec"]
if config["max_unorm"] > 0.0
else None,
unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
max_unorm=config["max_unorm"],
)
......@@ -615,7 +576,7 @@ class Optimizer1State(Optimizer8bit):
block_wise=True,
max_unorm=0.0,
skip_zeros=False,
is_paged=False
is_paged=False,
):
"""
Base 1-state update optimizer class.
......@@ -656,13 +617,9 @@ class Optimizer1State(Optimizer8bit):
raise ValueError(f"Invalid epsilon value: {eps}")
for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0:
raise ValueError(
f"Invalid beta parameter at index {i}: {betas[i]}"
)
raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
if not 0.0 <= weight_decay:
raise ValueError(
f"Invalid weight_decay value: {weight_decay}"
)
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults, optim_bits, is_paged)
......@@ -691,9 +648,7 @@ class Optimizer1State(Optimizer8bit):
elif config["optim_bits"] == 8:
dtype = torch.uint8
else:
raise NotImplementedError(
f'Amount of optimizer bits not supported: {config["optim_bits"]}'
)
raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')
if p.numel() < config["min_8bit_size"]:
dtype = torch.float32
......@@ -701,17 +656,13 @@ class Optimizer1State(Optimizer8bit):
state = self.state[p]
state["step"] = 0
if dtype == torch.float32 or (
dtype == torch.uint8 and p.numel() < 4096
):
if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
state["state1"] = self.get_state_buffer(p, dtype=torch.float32)
elif dtype == torch.uint8:
if state["step"] == 0:
if "dynamic" not in self.name2qmap:
self.fill_qmap()
self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(
p.device
)
self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device)
state["state1"] = self.get_state_buffer(p, dtype=torch.uint8)
state["qmap1"] = self.name2qmap["dynamic"]
......@@ -721,16 +672,10 @@ class Optimizer1State(Optimizer8bit):
blocks = n // 2048
blocks += 1 if n % 2048 > 0 else 0
state["absmax1"] = torch.zeros(
(blocks,), dtype=torch.float32, device=p.device
)
state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
else:
state["max1"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
state["new_max1"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
state["new_max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
if config["percentile_clipping"] < 100:
state["gnorm_vec"] = torch.zeros((100,), device=p.device)
......@@ -750,7 +695,10 @@ class Optimizer1State(Optimizer8bit):
if config["percentile_clipping"] < 100:
current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(
grad, state["gnorm_vec"], step, config["percentile_clipping"]
grad,
state["gnorm_vec"],
step,
config["percentile_clipping"],
)
else:
gnorm_scale = 1.0
......@@ -766,7 +714,7 @@ class Optimizer1State(Optimizer8bit):
step,
config["lr"],
None,
config['betas'][1],
config["betas"][1],
config["weight_decay"],
gnorm_scale,
state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
......
......@@ -51,9 +51,7 @@ class RMSprop(Optimizer1State):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
if alpha == 0:
raise NotImplementedError(
"RMSprop with alpha==0.0 is not supported!"
)
raise NotImplementedError("RMSprop with alpha==0.0 is not supported!")
if centered:
raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__(
......@@ -116,9 +114,7 @@ class RMSprop8bit(Optimizer1State):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
if alpha == 0:
raise NotImplementedError(
"RMSprop with alpha==0.0 is not supported!"
)
raise NotImplementedError("RMSprop with alpha==0.0 is not supported!")
if centered:
raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__(
......@@ -182,9 +178,7 @@ class RMSprop32bit(Optimizer1State):
"""
if alpha == 0:
raise NotImplementedError(
"RMSprop with alpha==0.0 is not supported!"
)
raise NotImplementedError("RMSprop with alpha==0.0 is not supported!")
if centered:
raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__(
......
......@@ -195,9 +195,9 @@ class SwitchBackBnb(torch.autograd.Function):
ctx.B = B
ctx.bias = bias
if A.shape[-1] == B.shape[0]:
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device)
return torch.empty(A.shape[:-1] + B.shape[1:], dtype=A.dtype, device=A.device)
else:
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device)
return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=A.device)
# 1. Quantize A
# 2. Quantize B
......@@ -216,9 +216,7 @@ class SwitchBackBnb(torch.autograd.Function):
# 1. Quantize A
if len(A.shape) == 3:
A = A.view(-1, A.shape[-1]).contiguous()
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(
A.to(torch.float16), threshold=state.threshold
)
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold)
if state.threshold > 0.0 and coo_tensorA is not None:
if state.has_fp16_weights:
......@@ -234,14 +232,14 @@ class SwitchBackBnb(torch.autograd.Function):
# we also need to convert it to the turing/ampere format
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
else:
#print('A shape', A.shape)
# print('A shape', A.shape)
if not state.has_fp16_weights and state.CxB is None:
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
subA = None
# 2. Quantize B
if state.has_fp16_weights:
#print('B shape', B.shape)
# print('B shape', B.shape)
has_grad = True if (getattr(B, "grad", None) is not None) else False
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
if is_transposed:
......@@ -272,12 +270,7 @@ class SwitchBackBnb(torch.autograd.Function):
# else:
# state.idx = outlier_idx
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
state.subB = (
(outliers * state.SCB.view(-1, 1) / 127.0)
.t()
.contiguous()
.to(A.dtype)
)
state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype)
CA[:, state.idx.long()] = 0
CAt[:, state.idx.long()] = 0
subA = A[:, state.idx.long()]
......@@ -320,14 +313,13 @@ class SwitchBackBnb(torch.autograd.Function):
ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None)
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
clone_func = torch.clone if len(output_shape) == 3 else lambda x: x
return clone_func(output.view(output_shape))
@staticmethod
def backward(ctx, grad_output):
if ctx.is_empty:
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
CAt, subA, A = ctx.tensors
......@@ -342,9 +334,7 @@ class SwitchBackBnb(torch.autograd.Function):
# Cast grad_output to fp16
if len(grad_output.shape) == 3:
grad_output = grad_output.reshape(
-1, grad_output.shape[-1]
).contiguous()
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
......@@ -357,25 +347,24 @@ class SwitchBackBnb(torch.autograd.Function):
if state.CBt is not None:
C32grad, Sgrad = F.transform(Cgrad, "col32")
if state.CxBt is None:
state.CxBt, state.SBt = F.transform(
state.CBt, to_order=formatB, transpose=True
)
state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
# print('back B shape', state.CxBt.shape)
# print('back grad shape', C32grad.shape)
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
elif state.CB is not None:
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0))
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
else:
raise Exception('State must contain either CBt or CB matrix for backward')
raise Exception("State must contain either CBt or CB matrix for backward")
return grad_A, grad_B, None, grad_bias, None
def get_block_sizes(input_matrix, weight_matrix):
input_features = input_matrix.shape[-1]
output_features = (weight_matrix.shape[0] if weight_matrix.shape[1] == input_features else weight_matrix.shape[1])
output_features = weight_matrix.shape[0] if weight_matrix.shape[1] == input_features else weight_matrix.shape[1]
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
bsz, bsz2 = 1024, 1024
for i, k in enumerate(array):
......@@ -399,7 +388,8 @@ def matmul_fp8_global(
bsz: int = -1,
bsz2: int = -1,
):
if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B)
if bsz == -1 or bsz2 == -1:
bsz, bsz2 = get_block_sizes(A, B)
return MatMulFP8Global.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
......@@ -412,7 +402,8 @@ def matmul_fp8_mixed(
bsz: int = -1,
bsz2: int = -1,
):
if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B)
if bsz == -1 or bsz2 == -1:
bsz, bsz2 = get_block_sizes(A, B)
return MatMulFP8Mixed.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
......@@ -422,7 +413,7 @@ def switchback_bnb(
out: Optional[torch.Tensor] = None,
state: Optional[MatmulLtState] = None,
threshold=0.0,
bias=None
bias=None,
):
state = state or MatmulLtState()
if threshold > 0.0:
......
......@@ -28,12 +28,20 @@ class LinearFP8Mixed(nn.Linear):
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
out = bnb.research.matmul_fp8_mixed(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2)
out = bnb.research.matmul_fp8_mixed(
x,
self.weight.t(),
fw_code=self.fw_code,
bw_code=self.bw_code,
bsz=self.bsz,
bsz2=self.bsz2,
)
if self.bias is not None:
out += self.bias
return out
class LinearFP8Global(nn.Linear):
def __init__(self, input_features, output_features, bias=True):
super().__init__(input_features, output_features, bias)
......@@ -54,7 +62,14 @@ class LinearFP8Global(nn.Linear):
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
out = bnb.matmul_fp8_global(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2)
out = bnb.matmul_fp8_global(
x,
self.weight.t(),
fw_code=self.fw_code,
bw_code=self.bw_code,
bsz=self.bsz,
bsz2=self.bsz2,
)
if self.bias is not None:
out += self.bias
......
......@@ -5,9 +5,10 @@ import torch
from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available():
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): return None
else:
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
return None
else:
import triton
import triton.language as tl
......@@ -29,7 +30,7 @@ else:
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=['n_elements']
key=["n_elements"],
)
@triton.jit
def _dequantize_rowwise(
......@@ -51,7 +52,6 @@ else:
output = max_val * x * inv_127
tl.store(output_ptr + offsets, output, mask=row_mask)
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
output = torch.empty(*x.shape, device=x.device, dtype=torch.float16)
......@@ -60,5 +60,5 @@ else:
assert x.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (x.shape[0],)
_dequantize_rowwise[grid](x, state_x, output, 1./127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
_dequantize_rowwise[grid](x, state_x, output, 1.0 / 127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment