Commit 5a4263f4 authored by Ruff's avatar Ruff Committed by Aarni Koskela
Browse files

Reformat with ruff-format

parent 02e30ca6
...@@ -7,9 +7,7 @@ def get_platform_tag(architecture): ...@@ -7,9 +7,7 @@ def get_platform_tag(architecture):
system = platform.system() system = platform.system()
if system == "Linux": if system == "Linux":
tag = ( tag = "manylinux_2_24_x86_64" if architecture == "x86_64" else "manylinux_2_24_aarch64"
"manylinux_2_24_x86_64" if architecture == "x86_64" else "manylinux_2_24_aarch64"
)
elif system == "Darwin": elif system == "Darwin":
tag = "macosx_13_1_x86_64" if architecture == "x86_64" else "macosx_13_1_arm64" tag = "macosx_13_1_x86_64" if architecture == "x86_64" else "macosx_13_1_arm64"
elif system == "Windows": elif system == "Windows":
......
import matplotlib.gridspec as gridspec import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
cmap=plt.get_cmap('cool') cmap = plt.get_cmap("cool")
if __name__ == '__main__':
fig = plt.figure(tight_layout=True, figsize=(12,3.5)) if __name__ == "__main__":
fig = plt.figure(tight_layout=True, figsize=(12, 3.5))
gs = gridspec.GridSpec(1, 2) gs = gridspec.GridSpec(1, 2)
dims_to_consider = [1024, 1280, 1408, 1664, 2048, 4096] dims_to_consider = [1024, 1280, 1408, 1664, 2048, 4096]
...@@ -19,25 +17,28 @@ if __name__ == '__main__': ...@@ -19,25 +17,28 @@ if __name__ == '__main__':
ax = fig.add_subplot(gs[0, 0]) ax = fig.add_subplot(gs[0, 0])
# TODO: change this to what you want. # TODO: change this to what you want.
rdf = pd.read_json('speed_benchmark/info_a100_py2.jsonl', lines=True) rdf = pd.read_json("speed_benchmark/info_a100_py2.jsonl", lines=True)
df = rdf[rdf.batch_size == batch_size_for_plot1] df = rdf[rdf.batch_size == batch_size_for_plot1]
# first plot the time occupied by different operations # first plot the time occupied by different operations
for k, marker, ls, color, name in [ for k, marker, ls, color, name in [
('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (sum of parts)'), ("standard_gx+standard_gw+standard_fwd", "s", "-", "C2", "Standard fp16 (sum of parts)"),
('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (sum of parts)'), (
"x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd",
('standard_fwd', '^', '--', 'C2', 'Matmul XW (standard)'), "o",
('standard_gw', '^', '-.', 'C2', 'Matmul GW (standard)'), "-",
('standard_gx', '^', ':', 'gray', 'Matmul GX (both)'), "C4",
"SwitchBack int8 (sum of parts)",
('global_fwd', '^', '--', 'C4', 'Int8 Matmul XW (switchback)'), ),
('global_bwd', '^', '-.', 'C4', 'Int8 Matmul GW (switchback)'), ("standard_fwd", "^", "--", "C2", "Matmul XW (standard)"),
("standard_gw", "^", "-.", "C2", "Matmul GW (standard)"),
('x_quantize_rowwise', 'P', '--', 'C4', 'Quantize rowwise X (switchback)'), ("standard_gx", "^", ":", "gray", "Matmul GX (both)"),
('g_quantize_rowwise', 'P', '-.', 'C4', 'Quantize rowwise G (switchback)'), ("global_fwd", "^", "--", "C4", "Int8 Matmul XW (switchback)"),
('w_quantize_global', '.', '--', 'C4', 'Quantize global W (switchback)'), ("global_bwd", "^", "-.", "C4", "Int8 Matmul GW (switchback)"),
('w_quantize_global_transpose', '.', '-.', 'C4', 'Quantize global and\ntranspose W (switchback)'), ("x_quantize_rowwise", "P", "--", "C4", "Quantize rowwise X (switchback)"),
("g_quantize_rowwise", "P", "-.", "C4", "Quantize rowwise G (switchback)"),
("w_quantize_global", ".", "--", "C4", "Quantize global W (switchback)"),
("w_quantize_global_transpose", ".", "-.", "C4", "Quantize global and\ntranspose W (switchback)"),
]: ]:
xs = [] xs = []
ys = [] ys = []
...@@ -47,40 +48,46 @@ if __name__ == '__main__': ...@@ -47,40 +48,46 @@ if __name__ == '__main__':
df_ = df_[df_.dim_out == embed_dim * 4] df_ = df_[df_.dim_out == embed_dim * 4]
xs.append(embed_dim) xs.append(embed_dim)
y_ = 0 y_ = 0
for k_ in k.split('+'): for k_ in k.split("+"):
y_ += df_[k_].values[0] y_ += df_[k_].values[0]
df_ = df[df.dim_in == embed_dim * 4] df_ = df[df.dim_in == embed_dim * 4]
df_ = df_[df_.dim_out == embed_dim] df_ = df_[df_.dim_out == embed_dim]
for k_ in k.split('+'): for k_ in k.split("+"):
y_ += df_[k_].values[0] y_ += df_[k_].values[0]
ys.append(y_ * 0.5) ys.append(y_ * 0.5)
ax.plot(
xs,
ys,
color=color,
label=name,
marker=marker,
markersize=5 if marker == "s" else 5,
linestyle=ls,
linewidth=2 if "+" in k else 1.0,
)
ax.plot(xs, ys, color=color, label=name, marker=marker, markersize=5 if marker=='s' else 5, linestyle=ls, linewidth=2 if '+' in k else 1.) ax.set_xlabel("dim", fontsize=13)
ax.set_ylabel("time (ms)", fontsize=13)
ax.set_xlabel('dim', fontsize=13)
ax.set_ylabel('time (ms)', fontsize=13)
ax.grid() ax.grid()
ax.set_xscale('log') ax.set_xscale("log")
if logscale_plot1: if logscale_plot1:
ax.set_yscale('log') ax.set_yscale("log")
ax.tick_params(axis='x', labelsize=11) ax.tick_params(axis="x", labelsize=11)
ax.tick_params(axis='y', labelsize=11) ax.tick_params(axis="y", labelsize=11)
ax.set_xticks(dims_to_xtick) ax.set_xticks(dims_to_xtick)
ax.set_xticklabels(dims_to_xtick) ax.set_xticklabels(dims_to_xtick)
ax.set_xticks([], minor=True) ax.set_xticks([], minor=True)
leg = ax.legend(loc='upper center', bbox_to_anchor=(-0.64, 1.), ncol=1, fontsize=10) leg = ax.legend(loc="upper center", bbox_to_anchor=(-0.64, 1.0), ncol=1, fontsize=10)
leg.get_texts()[0].set_fontweight('bold') leg.get_texts()[0].set_fontweight("bold")
leg.get_texts()[1].set_fontweight('bold') leg.get_texts()[1].set_fontweight("bold")
plt.subplots_adjust(left=0.1) plt.subplots_adjust(left=0.1)
ax.set_title(' Linear layer, batch * sequence length = 32k', fontsize=10, loc='left', y=1.05, pad=-20) ax.set_title(" Linear layer, batch * sequence length = 32k", fontsize=10, loc="left", y=1.05, pad=-20)
ax = fig.add_subplot(gs[0, 1]) ax = fig.add_subplot(gs[0, 1])
...@@ -88,10 +95,15 @@ if __name__ == '__main__': ...@@ -88,10 +95,15 @@ if __name__ == '__main__':
for j, batch_size in enumerate(batch_sizes_for_plot2): for j, batch_size in enumerate(batch_sizes_for_plot2):
all_xs, all_ys = [], [] all_xs, all_ys = [], []
for k, marker, ls, color, name in [ for k, marker, ls, color, name in [
('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (total time)'), ("standard_gx+standard_gw+standard_fwd", "s", "-", "C2", "Standard fp16 (total time)"),
('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (total time)'), (
"x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd",
"o",
"-",
"C4",
"SwitchBack int8 (total time)",
),
]: ]:
xs, ys = [], [] xs, ys = [], []
df = rdf[rdf.batch_size == batch_size] df = rdf[rdf.batch_size == batch_size]
for embed_dim in dims_to_consider: for embed_dim in dims_to_consider:
...@@ -99,11 +111,11 @@ if __name__ == '__main__': ...@@ -99,11 +111,11 @@ if __name__ == '__main__':
df_ = df_[df_.dim_out == embed_dim * 4] df_ = df_[df_.dim_out == embed_dim * 4]
xs.append(embed_dim) xs.append(embed_dim)
y_ = 0 y_ = 0
for k_ in k.split('+'): for k_ in k.split("+"):
y_ += df_[k_].values[0] y_ += df_[k_].values[0]
df_ = df[df.dim_in == embed_dim * 4] df_ = df[df.dim_in == embed_dim * 4]
df_ = df_[df_.dim_out == embed_dim] df_ = df_[df_.dim_out == embed_dim]
for k_ in k.split('+'): for k_ in k.split("+"):
y_ += df_[k_].values[0] y_ += df_[k_].values[0]
ys.append(y_ * 0.5) ys.append(y_ * 0.5)
all_xs.append(xs) all_xs.append(xs)
...@@ -111,25 +123,29 @@ if __name__ == '__main__': ...@@ -111,25 +123,29 @@ if __name__ == '__main__':
color = cmap(j * 0.25) color = cmap(j * 0.25)
real_ys = [-((all_ys[1][i] - all_ys[0][i]) / all_ys[0][i]) * 100 for i in range(len(all_ys[0]))] real_ys = [-((all_ys[1][i] - all_ys[0][i]) / all_ys[0][i]) * 100 for i in range(len(all_ys[0]))]
markers = ['^', 'v', 'P', 'o'] markers = ["^", "v", "P", "o"]
ax.plot(all_xs[0], real_ys, color=color, label=f'batch * sequence length = {batch_size}', marker=markers[j], markersize=5 if marker=='s' else 5) ax.plot(
all_xs[0],
real_ys,
color=color,
label=f"batch * sequence length = {batch_size}",
marker=markers[j],
markersize=5 if marker == "s" else 5,
)
ax.legend() ax.legend()
ax.set_xlabel('dim', fontsize=13) ax.set_xlabel("dim", fontsize=13)
ax.set_xscale('log') ax.set_xscale("log")
ax.grid() ax.grid()
ax.set_ylabel(r'% speedup', fontsize=13) ax.set_ylabel(r"% speedup", fontsize=13)
ax.tick_params(axis="x", labelsize=11)
ax.tick_params(axis='x', labelsize=11) ax.tick_params(axis="y", labelsize=11)
ax.tick_params(axis='y', labelsize=11)
ax.set_xticks(dims_to_xtick) ax.set_xticks(dims_to_xtick)
ax.set_xticklabels(dims_to_xtick) ax.set_xticklabels(dims_to_xtick)
ax.set_xticks([], minor=True) ax.set_xticks([], minor=True)
ax.set_title(' Linear layer summary, varying dimensions', fontsize=10, loc='left', y=1.05, pad=-20) ax.set_title(" Linear layer summary, varying dimensions", fontsize=10, loc="left", y=1.05, pad=-20)
plt.savefig('speed_benchmark/plot_with_info.pdf', bbox_inches='tight') plt.savefig("speed_benchmark/plot_with_info.pdf", bbox_inches="tight")
...@@ -20,8 +20,8 @@ from bitsandbytes.triton.quantize_rowwise import quantize_rowwise ...@@ -20,8 +20,8 @@ from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large. # KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large.
def get_time(k, fn, info_dict):
def get_time(k, fn, info_dict):
for _ in range(repeat // 2): for _ in range(repeat // 2):
fn() fn()
...@@ -36,16 +36,15 @@ def get_time(k, fn, info_dict): ...@@ -36,16 +36,15 @@ def get_time(k, fn, info_dict):
print(f"time {k}: {ms:.3f} ms") print(f"time {k}: {ms:.3f} ms")
info_dict[k] = ms info_dict[k] = ms
if __name__ == '__main__':
if __name__ == "__main__":
torch.manual_seed(0) torch.manual_seed(0)
wm = 4 wm = 4
for dim in [1024, 1280, 1408, 1664, 2048, 4096]: for dim in [1024, 1280, 1408, 1664, 2048, 4096]:
# note "batch_size" is actually "batch_size * embed_dim", which is why it's large # note "batch_size" is actually "batch_size * embed_dim", which is why it's large
for batch_size in [256*32, 256*64, 256*128, 256*256, 256*512]: for batch_size in [256 * 32, 256 * 64, 256 * 128, 256 * 256, 256 * 512]:
# switch switches dim_in and dim_out # switch switches dim_in and dim_out
for switch in [False, True]: for switch in [False, True]:
# hparams # hparams
repeat = 64 repeat = 64
batch_size = batch_size batch_size = batch_size
...@@ -73,35 +72,86 @@ if __name__ == '__main__': ...@@ -73,35 +72,86 @@ if __name__ == '__main__':
state_w_rowwise = w.max(dim=1)[0] state_w_rowwise = w.max(dim=1)[0]
state_w_global = w.max() state_w_global = w.max()
info = {'repeat' : repeat, 'batch_size' : batch_size, 'dim_out' : dim_out, 'dim_in' : dim_in, 'wm' : wm, 'switch' : switch} info = {
"repeat": repeat,
get_time('standard_fwd', lambda : x.matmul(w.t()), info) "batch_size": batch_size,
get_time('standard_gw', lambda : g.t().matmul(x), info) "dim_out": dim_out,
get_time('standard_gx', lambda : g.matmul(w), info) "dim_in": dim_in,
get_time('rowwise_fwd', lambda : int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise, None), info) "wm": wm,
get_time('rowwise_bwd', lambda : int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise, None), info) "switch": switch,
get_time('global_fwd', lambda : int8_matmul_mixed_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info) }
get_time('global_bwd', lambda : int8_matmul_mixed_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('x_quantize_rowwise', lambda : quantize_rowwise(x), info) get_time("standard_fwd", lambda: x.matmul(w.t()), info)
get_time('g_quantize_rowwise', lambda : quantize_rowwise(g), info) get_time("standard_gw", lambda: g.t().matmul(x), info)
get_time('w_quantize_rowwise', lambda : quantize_rowwise(w), info) get_time("standard_gx", lambda: g.matmul(w), info)
get_time('w_quantize_colwise_transpose', lambda : quantize_columnwise_and_transpose(w), info) get_time(
get_time('w_quantize_global', lambda : quantize_global(w), info) "rowwise_fwd",
get_time('w_quantize_global_transpose', lambda : quantize_global_transpose(w), info) lambda: int8_matmul_rowwise_dequantize(
x_int8,
time_standard = info['standard_fwd'] + info['standard_gx'] + info['standard_gw'] w_int8.t(),
time_rowwise = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_colwise_transpose'] + info['w_quantize_rowwise'] + info['standard_gw'] + info['rowwise_fwd'] + info['rowwise_bwd'] state_x_rowwise,
time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd'] state_w_columnwise,
None,
print('TOTAL STANDARD', time_standard) ),
print('TOTAL ROWWISE', time_rowwise) info,
print('TOTAL GLOBAL', time_global) )
get_time(
print('speedup', -100*(time_global - time_standard)/time_standard) "rowwise_bwd",
lambda: int8_matmul_rowwise_dequantize(
info['time_standard'] = time_standard g_int8,
info['time_rowwise'] = time_rowwise wt_int8.t(),
info['time_global'] = time_global state_x_rowwise,
state_w_rowwise,
None,
),
info,
)
get_time(
"global_fwd",
lambda: int8_matmul_mixed_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None),
info,
)
get_time(
"global_bwd",
lambda: int8_matmul_mixed_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None),
info,
)
get_time("x_quantize_rowwise", lambda: quantize_rowwise(x), info)
get_time("g_quantize_rowwise", lambda: quantize_rowwise(g), info)
get_time("w_quantize_rowwise", lambda: quantize_rowwise(w), info)
get_time("w_quantize_colwise_transpose", lambda: quantize_columnwise_and_transpose(w), info)
get_time("w_quantize_global", lambda: quantize_global(w), info)
get_time("w_quantize_global_transpose", lambda: quantize_global_transpose(w), info)
time_standard = info["standard_fwd"] + info["standard_gx"] + info["standard_gw"]
time_rowwise = (
info["x_quantize_rowwise"]
+ info["g_quantize_rowwise"]
+ info["w_quantize_colwise_transpose"]
+ info["w_quantize_rowwise"]
+ info["standard_gw"]
+ info["rowwise_fwd"]
+ info["rowwise_bwd"]
)
time_global = (
info["x_quantize_rowwise"]
+ info["g_quantize_rowwise"]
+ info["w_quantize_global"]
+ info["w_quantize_global_transpose"]
+ info["standard_gw"]
+ info["global_fwd"]
+ info["global_bwd"]
)
print("TOTAL STANDARD", time_standard)
print("TOTAL ROWWISE", time_rowwise)
print("TOTAL GLOBAL", time_global)
print("speedup", -100 * (time_global - time_standard) / time_standard)
info["time_standard"] = time_standard
info["time_rowwise"] = time_rowwise
info["time_global"] = time_global
info_json = json.dumps(info) info_json = json.dumps(info)
......
...@@ -14,16 +14,18 @@ import bitsandbytes.functional as F ...@@ -14,16 +14,18 @@ import bitsandbytes.functional as F
def prod(iterable): def prod(iterable):
return reduce(operator.mul, iterable, 1) return reduce(operator.mul, iterable, 1)
# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov: # The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py # https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
""" """
This class pools outlier dimensions across layers. This class pools outlier dimensions across layers.
This is particularly important for small models where outlier features This is particularly important for small models where outlier features
are less systematic and occur with low frequency. are less systematic and occur with low frequency.
""" """
class GlobalOutlierPooler: class GlobalOutlierPooler:
_instance = None _instance = None
...@@ -83,6 +85,7 @@ def get_inverse_transform_indices( ...@@ -83,6 +85,7 @@ def get_inverse_transform_indices(
break # if all indices fit in i bytes, stop early break # if all indices fit in i bytes, stop early
return permuted_tile_indices return permuted_tile_indices
def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor: def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor:
""" """
Undo a tiled permutation such as turing or ampere layout Undo a tiled permutation such as turing or ampere layout
...@@ -159,20 +162,12 @@ class MatMul8bit(torch.autograd.Function): ...@@ -159,20 +162,12 @@ class MatMul8bit(torch.autograd.Function):
) )
if not A.is_contiguous(): if not A.is_contiguous():
A = A.contiguous() A = A.contiguous()
qA, S2 = F.vectorwise_quant( qA, S2 = F.vectorwise_quant(A.view(-1, A.shape[2]), dim=0, quant_type=quant_type)
A.view(-1, A.shape[2]), dim=0, quant_type=quant_type
)
igrad_B = F.igemm(qA.t(), qgrad_output) igrad_B = F.igemm(qA.t(), qgrad_output)
grad_B = F.vectorwise_mm_dequant( grad_B = F.vectorwise_mm_dequant(igrad_B, S2.t(), S1, grad_output.dtype, quant_type)
igrad_B, S2.t(), S1, grad_output.dtype, quant_type
)
else: else:
qgrad_output, S1 = F.vectorwise_quant( qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
grad_output, dim=dims, quant_type=quant_type qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type)
)
qA, S2 = F.vectorwise_quant(
A, dim=dims, quant_type=quant_type
)
igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output) igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
grad_B = F.vectorwise_mm_dequant( grad_B = F.vectorwise_mm_dequant(
igrad_B, igrad_B,
...@@ -201,9 +196,7 @@ class MatMul8bit(torch.autograd.Function): ...@@ -201,9 +196,7 @@ class MatMul8bit(torch.autograd.Function):
with torch.no_grad(): with torch.no_grad():
grad_A = torch.matmul(grad_output, B.permute(permute_dim)) grad_A = torch.matmul(grad_output, B.permute(permute_dim))
else: else:
qgrad_output, S1 = F.vectorwise_quant( qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
grad_output, dim=dims, quant_type=quant_type
)
qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type) qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim)) igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
grad_A = F.vectorwise_mm_dequant( grad_A = F.vectorwise_mm_dequant(
...@@ -227,7 +220,7 @@ def supports_igemmlt(device: torch.device) -> bool: ...@@ -227,7 +220,7 @@ def supports_igemmlt(device: torch.device) -> bool:
if torch.cuda.get_device_capability(device=device) < (7, 5): if torch.cuda.get_device_capability(device=device) < (7, 5):
return False return False
device_name = torch.cuda.get_device_name(device=device) device_name = torch.cuda.get_device_name(device=device)
nvidia16_models = ('GTX 1630', 'GTX 1650', 'GTX 1660') # https://en.wikipedia.org/wiki/GeForce_16_series nvidia16_models = ("GTX 1630", "GTX 1650", "GTX 1660") # https://en.wikipedia.org/wiki/GeForce_16_series
if any(model_name in device_name for model_name in nvidia16_models): if any(model_name in device_name for model_name in nvidia16_models):
return False # these devices are technically cuda 7.5-capable, but they lack tensor cores return False # these devices are technically cuda 7.5-capable, but they lack tensor cores
return True return True
...@@ -246,6 +239,7 @@ def get_tile_inds(format, device): ...@@ -246,6 +239,7 @@ def get_tile_inds(format, device):
with torch.no_grad(): with torch.no_grad():
return get_inverse_transform_indices(transform, _get_tile_size(format)).to(device) return get_inverse_transform_indices(transform, _get_tile_size(format)).to(device)
@dataclass @dataclass
class MatmulLtState: class MatmulLtState:
_tile_indices: Optional[torch.Tensor] = None _tile_indices: Optional[torch.Tensor] = None
...@@ -510,7 +504,6 @@ class MatMul4Bit(torch.autograd.Function): ...@@ -510,7 +504,6 @@ class MatMul4Bit(torch.autograd.Function):
else: else:
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device) return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
# 1. Dequantize # 1. Dequantize
# 2. MatmulnN # 2. MatmulnN
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias) output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
...@@ -532,7 +525,7 @@ class MatMul4Bit(torch.autograd.Function): ...@@ -532,7 +525,7 @@ class MatMul4Bit(torch.autograd.Function):
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, _, _, req_gradBias, _= ctx.needs_input_grad req_gradA, _, _, req_gradBias, _ = ctx.needs_input_grad
A, B = ctx.tensors A, B = ctx.tensors
grad_A, grad_B, grad_bias = None, None, None grad_A, grad_B, grad_bias = None, None, None
...@@ -542,8 +535,9 @@ class MatMul4Bit(torch.autograd.Function): ...@@ -542,8 +535,9 @@ class MatMul4Bit(torch.autograd.Function):
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias) grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)
# not supported by PyTorch. TODO: create work-around # not supported by PyTorch. TODO: create work-around
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A) # if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t()) if req_gradA:
grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t())
return grad_A, grad_B, None, grad_bias, None return grad_A, grad_B, None, grad_bias, None
...@@ -554,7 +548,7 @@ def matmul( ...@@ -554,7 +548,7 @@ def matmul(
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
state: Optional[MatmulLtState] = None, state: Optional[MatmulLtState] = None,
threshold=0.0, threshold=0.0,
bias=None bias=None,
): ):
state = state or MatmulLtState() state = state or MatmulLtState()
if threshold > 0.0: if threshold > 0.0:
...@@ -562,11 +556,19 @@ def matmul( ...@@ -562,11 +556,19 @@ def matmul(
return MatMul8bitLt.apply(A, B, out, bias, state) return MatMul8bitLt.apply(A, B, out, bias, state)
def matmul_4bit(A: torch.Tensor, B: torch.Tensor, quant_state: F.QuantState, out: Optional[torch.Tensor] = None, bias=None): def matmul_4bit(
A: torch.Tensor,
B: torch.Tensor,
quant_state: F.QuantState,
out: Optional[torch.Tensor] = None,
bias=None,
):
assert quant_state is not None assert quant_state is not None
if A.numel() == A.shape[-1] and A.requires_grad == False: if A.numel() == A.shape[-1] and A.requires_grad == False:
if A.shape[-1] % quant_state.blocksize != 0: if A.shape[-1] % quant_state.blocksize != 0:
warn(f'Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}') warn(
f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",
)
return MatMul4Bit.apply(A, B, out, bias, quant_state) return MatMul4Bit.apply(A, B, out, bias, quant_state)
else: else:
out = F.gemv_4bit(A, B.t(), out, state=quant_state) out = F.gemv_4bit(A, B.t(), out, state=quant_state)
......
...@@ -56,7 +56,7 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: ...@@ -56,7 +56,7 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
"This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n" "This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n"
"If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n" "If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n"
"If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n" "If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n"
"For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<path_to_cuda_dir/lib64\n" "For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<path_to_cuda_dir/lib64\n",
) )
return PACKAGE_DIR / library_name return PACKAGE_DIR / library_name
...@@ -100,7 +100,7 @@ def get_native_library() -> BNBNativeLibrary: ...@@ -100,7 +100,7 @@ def get_native_library() -> BNBNativeLibrary:
logger.warning( logger.warning(
"The installed version of bitsandbytes was compiled without GPU support. " "The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable." "8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.",
) )
return BNBNativeLibrary(dll) return BNBNativeLibrary(dll)
...@@ -120,5 +120,5 @@ python -m bitsandbytes ...@@ -120,5 +120,5 @@ python -m bitsandbytes
Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them
to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes
and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues
""" """,
) )
...@@ -120,7 +120,7 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: ...@@ -120,7 +120,7 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
The CUDA version for the compile might depend on your conda install, if using conda. The CUDA version for the compile might depend on your conda install, if using conda.
Inspect CUDA version via `conda list | grep cuda`. Inspect CUDA version via `conda list | grep cuda`.
""" """,
) )
cuda_major, cuda_minor = cuda_specs.cuda_version_tuple cuda_major, cuda_minor = cuda_specs.cuda_version_tuple
...@@ -129,7 +129,7 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: ...@@ -129,7 +129,7 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
""" """
WARNING: CUDA versions lower than 11 are currently not supported for LLM.int8(). WARNING: CUDA versions lower than 11 are currently not supported for LLM.int8().
You will be only to use 8-bit optimizers and quantization routines! You will be only to use 8-bit optimizers and quantization routines!
""" """,
) )
print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}") print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}")
...@@ -170,7 +170,7 @@ def print_cuda_runtime_diagnostics() -> None: ...@@ -170,7 +170,7 @@ def print_cuda_runtime_diagnostics() -> None:
In the case of a manual override, make sure you set LD_LIBRARY_PATH, e.g. In the case of a manual override, make sure you set LD_LIBRARY_PATH, e.g.
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2, export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2,
""" """,
) )
for pth in cudart_paths: for pth in cudart_paths:
print(f"* Found CUDA runtime at: {pth}") print(f"* Found CUDA runtime at: {pth}")
...@@ -25,7 +25,7 @@ def sanity_check(): ...@@ -25,7 +25,7 @@ def sanity_check():
See the documentation for more details if needed. See the documentation for more details if needed.
Trying a simple check anyway, but this will likely fail... Trying a simple check anyway, but this will likely fail...
""" """,
) )
from bitsandbytes.optim import Adam from bitsandbytes.optim import Adam
...@@ -71,7 +71,7 @@ def main(): ...@@ -71,7 +71,7 @@ def main():
print( print(
f"WARNING: {__package__} is currently running as CPU-only!\n" f"WARNING: {__package__} is currently running as CPU-only!\n"
"Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n"
f"If you think that this is so erroneously,\nplease report an issue!" f"If you think that this is so erroneously,\nplease report an issue!",
) )
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
...@@ -80,6 +80,6 @@ def main(): ...@@ -80,6 +80,6 @@ def main():
Above we output some debug information. Above we output some debug information.
Please provide this info when creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose Please provide this info when creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose
WARNING: Please be sure to sanitize sensitive info from the output before posting it. WARNING: Please be sure to sanitize sensitive info from the output before posting it.
""" """,
) )
sys.exit(1) sys.exit(1)
...@@ -21,6 +21,7 @@ from .cextension import lib ...@@ -21,6 +21,7 @@ from .cextension import lib
def prod(iterable): def prod(iterable):
return reduce(operator.mul, iterable, 1) return reduce(operator.mul, iterable, 1)
name2qmap = {} name2qmap = {}
if lib and lib.compiled_with_cuda: if lib and lib.compiled_with_cuda:
...@@ -127,7 +128,6 @@ class GlobalPageManager: ...@@ -127,7 +128,6 @@ class GlobalPageManager:
prefetch_tensor(t, to_cpu) prefetch_tensor(t, to_cpu)
class CUBLAS_Context: class CUBLAS_Context:
_instance = None _instance = None
...@@ -169,6 +169,7 @@ class Cusparse_Context: ...@@ -169,6 +169,7 @@ class Cusparse_Context:
cls._instance.initialize() cls._instance.initialize()
return cls._instance return cls._instance
dtype2bytes = {} dtype2bytes = {}
dtype2bytes[torch.float32] = 4 dtype2bytes[torch.float32] = 4
dtype2bytes[torch.float16] = 2 dtype2bytes[torch.float16] = 2
...@@ -176,10 +177,11 @@ dtype2bytes[torch.bfloat16] = 2 ...@@ -176,10 +177,11 @@ dtype2bytes[torch.bfloat16] = 2
dtype2bytes[torch.uint8] = 1 dtype2bytes[torch.uint8] = 1
dtype2bytes[torch.int8] = 1 dtype2bytes[torch.int8] = 1
FIRST_CUDA_DEVICE = torch.device('cuda', index=0) FIRST_CUDA_DEVICE = torch.device("cuda", index=0)
def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE): def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE):
num_bytes = dtype2bytes[dtype]*prod(shape) num_bytes = dtype2bytes[dtype] * prod(shape)
cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes)) cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes))
c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int)) c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int))
new_array = np.ctypeslib.as_array(c_ptr, shape=shape) new_array = np.ctypeslib.as_array(c_ptr, shape=shape)
...@@ -188,31 +190,35 @@ def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE): ...@@ -188,31 +190,35 @@ def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE):
out.page_deviceid = device.index out.page_deviceid = device.index
return out return out
def prefetch_tensor(A, to_cpu=False): def prefetch_tensor(A, to_cpu=False):
assert A.is_paged, 'Only paged tensors can be prefetched!' assert A.is_paged, "Only paged tensors can be prefetched!"
if to_cpu: if to_cpu:
deviceid = -1 deviceid = -1
else: else:
deviceid = A.page_deviceid deviceid = A.page_deviceid
num_bytes = dtype2bytes[A.dtype]*A.numel() num_bytes = dtype2bytes[A.dtype] * A.numel()
lib.cprefetch(get_ptr(A), ct.c_size_t(num_bytes), ct.c_int32(deviceid)) lib.cprefetch(get_ptr(A), ct.c_size_t(num_bytes), ct.c_int32(deviceid))
def elementwise_func(func_name, A, B, value, prefetch=True): def elementwise_func(func_name, A, B, value, prefetch=True):
func = None func = None
if A.dtype == torch.float32: if A.dtype == torch.float32:
func = getattr(lib, f'c{func_name}_fp32', None) func = getattr(lib, f"c{func_name}_fp32", None)
cvalue = ct.c_float(value) cvalue = ct.c_float(value)
elif A.dtype == torch.uint8: elif A.dtype == torch.uint8:
func = getattr(lib, f'c{func_name}_uint8', None) func = getattr(lib, f"c{func_name}_uint8", None)
cvalue = ct.c_uint8(value) cvalue = ct.c_uint8(value)
if func is None: raise NotImplementedError(f'Function not implemented: {func_name}') if func is None:
raise NotImplementedError(f"Function not implemented: {func_name}")
is_managed = getattr(A, 'is_managed', False) is_managed = getattr(A, "is_managed", False)
if is_managed and prefetch: if is_managed and prefetch:
prefetch_tensor(A) prefetch_tensor(A)
if B is not None: prefetch_tensor(B) if B is not None:
prefetch_tensor(B)
func(get_ptr(A), get_ptr(B), cvalue, ct.c_int64(A.numel())) func(get_ptr(A), get_ptr(B), cvalue, ct.c_int64(A.numel()))
if A.is_paged or B.is_paged: if A.is_paged or B.is_paged:
...@@ -222,28 +228,36 @@ def elementwise_func(func_name, A, B, value, prefetch=True): ...@@ -222,28 +228,36 @@ def elementwise_func(func_name, A, B, value, prefetch=True):
# operation occurred. So we synchronize. # operation occurred. So we synchronize.
torch.cuda.synchronize() torch.cuda.synchronize()
def fill(A, value, device=None, prefetch=True): elementwise_func('fill', A, None, value)
def arange(A, device=None): elementwise_func('arange', A, None, 0) def fill(A, value, device=None, prefetch=True):
def _mul(A, B, device=None): elementwise_func('_mul', A, B, 0) elementwise_func("fill", A, None, value)
def arange(A, device=None):
elementwise_func("arange", A, None, 0)
def _mul(A, B, device=None):
elementwise_func("_mul", A, B, 0)
def create_linear_map(signed=True, total_bits=8, add_zero=True): def create_linear_map(signed=True, total_bits=8, add_zero=True):
sign = (-1.0 if signed else 0.0) sign = -1.0 if signed else 0.0
total_values = 2**total_bits total_values = 2**total_bits
if add_zero or total_bits < 8: if add_zero or total_bits < 8:
# add a zero # add a zero
# since we simulate less bits by having zeros in the data type, we # since we simulate less bits by having zeros in the data type, we
# we need to center the quantization around zero and as such lose # we need to center the quantization around zero and as such lose
# a single value # a single value
total_values = (2**total_bits if not signed else 2**total_bits-1) total_values = 2**total_bits if not signed else 2**total_bits - 1
values = torch.linspace(sign, 1.0, total_values) values = torch.linspace(sign, 1.0, total_values)
gap = 256 - values.numel() gap = 256 - values.numel()
if gap == 0: if gap == 0:
return values return values
else: else:
l = values.numel()//2 # noqa: E741 l = values.numel() // 2 # noqa: E741
return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist()) return torch.Tensor(values[:l].tolist() + [0] * gap + values[l:].tolist())
def create_normal_map(offset=0.9677083, use_extra_value=True): def create_normal_map(offset=0.9677083, use_extra_value=True):
...@@ -251,18 +265,17 @@ def create_normal_map(offset=0.9677083, use_extra_value=True): ...@@ -251,18 +265,17 @@ def create_normal_map(offset=0.9677083, use_extra_value=True):
from scipy.stats import norm from scipy.stats import norm
except ImportError as ie: except ImportError as ie:
raise ImportError( raise ImportError(
"Scipy is required for `create_normal_map`. " "Scipy is required for `create_normal_map`. Install `bitsandbytes` with the `[test]` extra.",
"Install `bitsandbytes` with the `[test]` extra."
) from ie ) from ie
if use_extra_value: if use_extra_value:
# one more positive value, this is an asymmetric type # one more positive value, this is an asymmetric type
v1 = norm.ppf(torch.linspace(offset, 0.5, 9)[:-1]).tolist() v1 = norm.ppf(torch.linspace(offset, 0.5, 9)[:-1]).tolist()
v2 = [0]*(256-15) ## we have 15 non-zero values in this data type v2 = [0] * (256 - 15) ## we have 15 non-zero values in this data type
v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist() v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist()
else: else:
v1 = norm.ppf(torch.linspace(offset, 0.5, 8)[:-1]).tolist() v1 = norm.ppf(torch.linspace(offset, 0.5, 8)[:-1]).tolist()
v2 = [0]*(256-14) ## we have 14 non-zero values in this data type v2 = [0] * (256 - 14) ## we have 14 non-zero values in this data type
v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist() v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist()
v = v1 + v2 + v3 v = v1 + v2 + v3
...@@ -275,38 +288,37 @@ def create_normal_map(offset=0.9677083, use_extra_value=True): ...@@ -275,38 +288,37 @@ def create_normal_map(offset=0.9677083, use_extra_value=True):
return values return values
def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8): def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8):
e = exponent_bits e = exponent_bits
p = precision_bits p = precision_bits
has_sign = 1 if signed else 0 has_sign = 1 if signed else 0
assert e+p == total_bits-has_sign assert e + p == total_bits - has_sign
# the exponent is biased to 2^(e-1) -1 == 0 # the exponent is biased to 2^(e-1) -1 == 0
evalues = [] evalues = []
pvalues = [] pvalues = []
for i, val in enumerate(range(-(2**(exponent_bits-has_sign)), 2**(exponent_bits-has_sign), 1)): for i, val in enumerate(range(-(2 ** (exponent_bits - has_sign)), 2 ** (exponent_bits - has_sign), 1)):
evalues.append(2**val) evalues.append(2**val)
values = [] values = []
lst = list(itertools.product([0, 1], repeat=precision_bits)) lst = list(itertools.product([0, 1], repeat=precision_bits))
#for ev in evalues: # for ev in evalues:
bias = 2**(exponent_bits-1) bias = 2 ** (exponent_bits - 1)
for evalue in range(2**(exponent_bits)): for evalue in range(2 ** (exponent_bits)):
for bit_pattern in lst: for bit_pattern in lst:
value = (1 if evalue != 0 else 0) value = 1 if evalue != 0 else 0
for i, pval in enumerate(list(bit_pattern)): for i, pval in enumerate(list(bit_pattern)):
value += pval*(2**-(i+1)) value += pval * (2 ** -(i + 1))
if evalue == 0: if evalue == 0:
# subnormals # subnormals
value = value*2**-(bias) value = value * 2**-(bias)
else: else:
# normals # normals
value = value*2**-(evalue-bias-1) value = value * 2 ** -(evalue - bias - 1)
values.append(value) values.append(value)
if signed: if signed:
values.append(-value) values.append(-value)
assert len(values) == 2**total_bits assert len(values) == 2**total_bits
values.sort() values.sort()
if total_bits < 8: if total_bits < 8:
...@@ -320,7 +332,6 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) ...@@ -320,7 +332,6 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
return code return code
def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
""" """
Creates the dynamic quantiztion map. Creates the dynamic quantiztion map.
...@@ -345,7 +356,11 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): ...@@ -345,7 +356,11 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
non_sign_bits = total_bits - (1 if signed else 1) non_sign_bits = total_bits - (1 if signed else 1)
additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1 additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1
for i in range(max_exponent_bits): for i in range(max_exponent_bits):
fraction_items = int(2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1) fraction_items = int(
2 ** (i + non_sign_bits - max_exponent_bits) + 1
if signed
else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1,
)
boundaries = torch.linspace(0.1, 1, fraction_items) boundaries = torch.linspace(0.1, 1, fraction_items)
means = (boundaries[:-1] + boundaries[1:]) / 2.0 means = (boundaries[:-1] + boundaries[1:]) / 2.0
data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
...@@ -371,8 +386,9 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): ...@@ -371,8 +386,9 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
data.sort() data.sort()
return Tensor(data) return Tensor(data)
def create_quantile_map(A, total_bits=8): def create_quantile_map(A, total_bits=8):
q = estimate_quantiles(A, num_quantiles=2**total_bits-1) q = estimate_quantiles(A, num_quantiles=2**total_bits - 1)
q = q.tolist() q = q.tolist()
q.append(0) q.append(0)
...@@ -383,11 +399,13 @@ def create_quantile_map(A, total_bits=8): ...@@ -383,11 +399,13 @@ def create_quantile_map(A, total_bits=8):
q.sort() q.sort()
q = Tensor(q) q = Tensor(q)
q = q/q.abs().max() q = q / q.abs().max()
return q return q
def get_special_format_str(): def get_special_format_str():
if not torch.cuda.is_available(): return 'col_turing' if not torch.cuda.is_available():
return "col_turing"
major, _minor = torch.cuda.get_device_capability() major, _minor = torch.cuda.get_device_capability()
if major <= 7: if major <= 7:
return "col_turing" return "col_turing"
...@@ -396,20 +414,24 @@ def get_special_format_str(): ...@@ -396,20 +414,24 @@ def get_special_format_str():
return "col_turing" return "col_turing"
def is_on_gpu(tensors): def is_on_gpu(tensors):
on_gpu = True on_gpu = True
gpu_ids = set() gpu_ids = set()
for t in tensors: for t in tensors:
if t is None: continue # NULL pointers are fine if t is None:
is_paged = getattr(t, 'is_paged', False) continue # NULL pointers are fine
on_gpu &= (t.device.type == 'cuda' or is_paged) is_paged = getattr(t, "is_paged", False)
on_gpu &= t.device.type == "cuda" or is_paged
if not is_paged: if not is_paged:
gpu_ids.add(t.device.index) gpu_ids.add(t.device.index)
if not on_gpu: if not on_gpu:
raise TypeError(f'All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}') raise TypeError(
f"All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}",
)
if len(gpu_ids) > 1: if len(gpu_ids) > 1:
raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}') raise TypeError(
f"Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}",
)
return on_gpu return on_gpu
...@@ -447,15 +469,13 @@ def get_transform_func(dtype, orderA, orderOut, transpose=False): ...@@ -447,15 +469,13 @@ def get_transform_func(dtype, orderA, orderOut, transpose=False):
if not hasattr(lib, name): if not hasattr(lib, name):
print(name) print(name)
raise ValueError( raise ValueError(
f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}" f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}",
) )
else: else:
return getattr(lib, name) return getattr(lib, name)
def get_transform_buffer( def get_transform_buffer(shape, dtype, device, to_order, from_order="row", transpose=False):
shape, dtype, device, to_order, from_order="row", transpose=False
):
# init_func = torch.empty # init_func = torch.empty
init_func = torch.zeros init_func = torch.zeros
dims = len(shape) dims = len(shape)
...@@ -508,9 +528,7 @@ def nvidia_transform( ...@@ -508,9 +528,7 @@ def nvidia_transform(
else: else:
from_order = state[1] from_order = state[1]
if out is None: if out is None:
out, new_state = get_transform_buffer( out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1])
state[0], A.dtype, A.device, to_order, state[1]
)
else: else:
new_state = (state[1], to_order) new_state = (state[1], to_order)
func = get_transform_func(A.dtype, from_order, to_order, transpose) func = get_transform_func(A.dtype, from_order, to_order, transpose)
...@@ -534,8 +552,13 @@ def nvidia_transform( ...@@ -534,8 +552,13 @@ def nvidia_transform(
return out, new_state return out, new_state
def estimate_quantiles(A: Tensor, out: Optional[torch.Tensor] = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor: def estimate_quantiles(
''' A: Tensor,
out: Optional[torch.Tensor] = None,
offset: float = 1 / 512,
num_quantiles=256,
) -> Tensor:
"""
Estimates 256 equidistant quantiles on the input tensor eCDF. Estimates 256 equidistant quantiles on the input tensor eCDF.
Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles
...@@ -562,14 +585,21 @@ def estimate_quantiles(A: Tensor, out: Optional[torch.Tensor] = None, offset: fl ...@@ -562,14 +585,21 @@ def estimate_quantiles(A: Tensor, out: Optional[torch.Tensor] = None, offset: fl
------- -------
torch.Tensor: torch.Tensor:
The 256 quantiles in float32 datatype. The 256 quantiles in float32 datatype.
''' """
if A.numel() < 256: raise NotImplementedError(f'Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.') if A.numel() < 256:
if num_quantiles > 256: raise NotImplementedError(f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}") raise NotImplementedError(
if num_quantiles < 256 and offset == 1/(512): f"Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.",
)
if num_quantiles > 256:
raise NotImplementedError(
f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}",
)
if num_quantiles < 256 and offset == 1 / (512):
# override default arguments # override default arguments
offset = 1/(2*num_quantiles) offset = 1 / (2 * num_quantiles)
if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device) if out is None:
out = torch.zeros((256,), dtype=torch.float32, device=A.device)
is_on_gpu([A, out]) is_on_gpu([A, out])
device = pre_call(A.device) device = pre_call(A.device)
if A.dtype == torch.float32: if A.dtype == torch.float32:
...@@ -581,7 +611,7 @@ def estimate_quantiles(A: Tensor, out: Optional[torch.Tensor] = None, offset: fl ...@@ -581,7 +611,7 @@ def estimate_quantiles(A: Tensor, out: Optional[torch.Tensor] = None, offset: fl
post_call(device) post_call(device)
if num_quantiles < 256: if num_quantiles < 256:
step = round(256/num_quantiles) step = round(256 / num_quantiles)
idx = torch.linspace(0, 255, num_quantiles).long().to(A.device) idx = torch.linspace(0, 255, num_quantiles).long().to(A.device)
out = out[idx] out = out[idx]
...@@ -590,12 +620,35 @@ def estimate_quantiles(A: Tensor, out: Optional[torch.Tensor] = None, offset: fl ...@@ -590,12 +620,35 @@ def estimate_quantiles(A: Tensor, out: Optional[torch.Tensor] = None, offset: fl
class QuantState: class QuantState:
"""container for quantization state components to work with Params4bit and similar classes""" """container for quantization state components to work with Params4bit and similar classes"""
valid_quant_types = ('fp4', 'nf4')
valid_qs_type_keys = [f"bitsandbytes__{x}" for x in valid_quant_types]
valid_qs_keys = ['absmax', 'quant_map', 'nested_absmax', 'nested_quant_map', 'quant_state', 'quant_type',
'blocksize', 'dtype', 'shape', 'nested_blocksize', 'nested_dtype', 'nested_offset']
def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None): valid_quant_types = ("fp4", "nf4")
valid_qs_type_keys = [f"bitsandbytes__{x}" for x in valid_quant_types]
valid_qs_keys = [
"absmax",
"quant_map",
"nested_absmax",
"nested_quant_map",
"quant_state",
"quant_type",
"blocksize",
"dtype",
"shape",
"nested_blocksize",
"nested_dtype",
"nested_offset",
]
def __init__(
self,
absmax,
shape=None,
code=None,
blocksize=None,
quant_type=None,
dtype=None,
offset=None,
state2=None,
):
self.absmax = absmax self.absmax = absmax
self.shape = shape self.shape = shape
self.code = code self.code = code
...@@ -614,13 +667,20 @@ class QuantState: ...@@ -614,13 +667,20 @@ class QuantState:
state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type] state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type]
""" """
if self.nested: if self.nested:
list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, [self.offset, self.state2], self.quant_type] list_repr = [
self.absmax,
self.shape,
self.dtype,
self.blocksize,
[self.offset, self.state2],
self.quant_type,
]
else: else:
list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, None, self.quant_type] list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, None, self.quant_type]
return list_repr[idx] return list_repr[idx]
@classmethod @classmethod
def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState': def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> "QuantState":
""" """
unpacks components of state_dict into QuantState unpacks components of state_dict into QuantState
where necessary, convert into strings, torch.dtype, ints, etc. where necessary, convert into strings, torch.dtype, ints, etc.
...@@ -632,37 +692,39 @@ class QuantState: ...@@ -632,37 +692,39 @@ class QuantState:
# unpacking tensor with non-tensor components # unpacking tensor with non-tensor components
qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)] qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)]
if not len(qs_key) and 'quant_type' not in qs_dict: if not len(qs_key) and "quant_type" not in qs_dict:
raise ValueError("Expected packed or unpacked quant_state items, found neither") raise ValueError("Expected packed or unpacked quant_state items, found neither")
elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys: elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys:
raise ValueError(f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.") raise ValueError(
f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.",
)
# unpacking minor and non-tensor quant state items if necessary # unpacking minor and non-tensor quant state items if necessary
if len(qs_key) == 1: if len(qs_key) == 1:
first_qs_key = qs_key[0] first_qs_key = qs_key[0]
qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(first_qs_key))) qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(first_qs_key)))
qs_dict = {k.split('.')[-1]: v for k, v in qs_dict.items()} # strip prefixes qs_dict = {k.split(".")[-1]: v for k, v in qs_dict.items()} # strip prefixes
assert set(qs_dict.keys()).issubset(cls.valid_qs_keys) assert set(qs_dict.keys()).issubset(cls.valid_qs_keys)
if 'nested_absmax' in qs_dict: if "nested_absmax" in qs_dict:
offset = torch.tensor(float(qs_dict['nested_offset'])).to(device) offset = torch.tensor(float(qs_dict["nested_offset"])).to(device)
state2 = cls( state2 = cls(
absmax=qs_dict['nested_absmax'].to(device), absmax=qs_dict["nested_absmax"].to(device),
blocksize=qs_dict['nested_blocksize'], blocksize=qs_dict["nested_blocksize"],
code=qs_dict['nested_quant_map'].to(device), code=qs_dict["nested_quant_map"].to(device),
dtype=getattr(torch, qs_dict['nested_dtype']), dtype=getattr(torch, qs_dict["nested_dtype"]),
) )
else: else:
offset, state2 = None, None offset, state2 = None, None
quant_state = cls( quant_state = cls(
quant_type=qs_dict['quant_type'], quant_type=qs_dict["quant_type"],
absmax=qs_dict['absmax'].to(device), absmax=qs_dict["absmax"].to(device),
blocksize=qs_dict['blocksize'], blocksize=qs_dict["blocksize"],
code=qs_dict['quant_map'].to(device), code=qs_dict["quant_map"].to(device),
dtype=getattr(torch, qs_dict['dtype']), dtype=getattr(torch, qs_dict["dtype"]),
shape=torch.Size(qs_dict['shape']) if qs_dict['shape'] is not None else None, shape=torch.Size(qs_dict["shape"]) if qs_dict["shape"] is not None else None,
offset=offset, offset=offset,
state2=state2, state2=state2,
) )
...@@ -674,21 +736,23 @@ class QuantState: ...@@ -674,21 +736,23 @@ class QuantState:
param: packed -- returns dict[str, torch.Tensor] for state_dict fit for safetensors saving param: packed -- returns dict[str, torch.Tensor] for state_dict fit for safetensors saving
""" """
qs_dict = { qs_dict = {
'quant_type': self.quant_type, "quant_type": self.quant_type,
'absmax': self.absmax, "absmax": self.absmax,
'blocksize': self.blocksize, "blocksize": self.blocksize,
'quant_map': self.code, "quant_map": self.code,
'dtype': str(self.dtype).strip('torch.'), "dtype": str(self.dtype).strip("torch."),
'shape': tuple(self.shape), "shape": tuple(self.shape),
} }
if self.nested: if self.nested:
qs_dict.update({ qs_dict.update(
'nested_absmax': self.state2.absmax, {
'nested_blocksize': self.state2.blocksize, "nested_absmax": self.state2.absmax,
'nested_quant_map': self.state2.code.clone(), # un-shared to avoid restoring it after shared tensors are removed by safetensors "nested_blocksize": self.state2.blocksize,
'nested_dtype': str(self.state2.dtype).strip('torch.'), "nested_quant_map": self.state2.code.clone(), # un-shared to avoid restoring it after shared tensors are removed by safetensors
'nested_offset': self.offset.item(), "nested_dtype": str(self.state2.dtype).strip("torch."),
}) "nested_offset": self.offset.item(),
},
)
if not packed: if not packed:
return qs_dict return qs_dict
...@@ -711,14 +775,22 @@ class QuantState: ...@@ -711,14 +775,22 @@ class QuantState:
return False return False
return ( return (
torch.allclose(self.absmax, other.absmax, atol=1e-6) and torch.allclose(self.absmax, other.absmax, atol=1e-6)
self.shape == other.shape and and self.shape == other.shape
torch.allclose(self.code, other.code, atol=1e-6) and and torch.allclose(self.code, other.code, atol=1e-6)
self.dtype == other.dtype and and self.dtype == other.dtype
self.blocksize == other.blocksize and and self.blocksize == other.blocksize
self.quant_type == other.quant_type and and self.quant_type == other.quant_type
(self.offset == other.offset if self.offset is not None and other.offset is not None else self.offset is other.offset) and and (
(self.state2 == other.state2 if self.state2 is not None and other.state2 is not None else self.state2 is other.state2) self.offset == other.offset
if self.offset is not None and other.offset is not None
else self.offset is other.offset
)
and (
self.state2 == other.state2
if self.state2 is not None and other.state2 is not None
else self.state2 is other.state2
)
) )
...@@ -756,7 +828,6 @@ def quantize_blockwise( ...@@ -756,7 +828,6 @@ def quantize_blockwise(
The quantization state to undo the quantization. The quantization state to undo the quantization.
""" """
if code is None: if code is None:
if "dynamic" not in name2qmap: if "dynamic" not in name2qmap:
name2qmap["dynamic"] = create_dynamic_map().to(A.device) name2qmap["dynamic"] = create_dynamic_map().to(A.device)
...@@ -771,31 +842,66 @@ def quantize_blockwise( ...@@ -771,31 +842,66 @@ def quantize_blockwise(
if out is None: if out is None:
out = torch.zeros_like(A, dtype=torch.uint8) out = torch.zeros_like(A, dtype=torch.uint8)
if A.device.type != 'cpu': if A.device.type != "cpu":
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
cblocksize = ct.c_int32(blocksize) cblocksize = ct.c_int32(blocksize)
prev_device = pre_call(A.device) prev_device = pre_call(A.device)
code = code.to(A.device) code = code.to(A.device)
is_on_gpu([code, A, out, absmax]) is_on_gpu([code, A, out, absmax])
if A.dtype == torch.float32: if A.dtype == torch.float32:
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) lib.cquantize_blockwise_fp32(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
cblocksize,
ct.c_int(A.numel()),
)
elif A.dtype == torch.float16: elif A.dtype == torch.float16:
lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) lib.cquantize_blockwise_fp16(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
cblocksize,
ct.c_int(A.numel()),
)
elif A.dtype == torch.bfloat16: elif A.dtype == torch.bfloat16:
lib.cquantize_blockwise_bf16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) lib.cquantize_blockwise_bf16(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
cblocksize,
ct.c_int(A.numel()),
)
else: else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
post_call(A.device) post_call(A.device)
else: else:
# cpu # cpu
code = code.cpu() code = code.cpu()
lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) lib.cquantize_blockwise_cpu_fp32(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_longlong(blocksize),
ct.c_longlong(A.numel()),
)
if nested: if nested:
offset = absmax.mean() offset = absmax.mean()
absmax -= offset absmax -= offset
qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False) qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False)
quant_state = QuantState(absmax=qabsmax, code=code, blocksize=blocksize, dtype=A.dtype, offset=offset, state2=state2) quant_state = QuantState(
absmax=qabsmax,
code=code,
blocksize=blocksize,
dtype=A.dtype,
offset=offset,
state2=state2,
)
else: else:
quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=A.dtype) quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=A.dtype)
...@@ -809,7 +915,7 @@ def dequantize_blockwise( ...@@ -809,7 +915,7 @@ def dequantize_blockwise(
code: Optional[torch.Tensor] = None, code: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
blocksize: int = 4096, blocksize: int = 4096,
nested=False nested=False,
) -> Tensor: ) -> Tensor:
""" """
Dequantizes blockwise quantized values. Dequantizes blockwise quantized values.
...@@ -849,37 +955,70 @@ def dequantize_blockwise( ...@@ -849,37 +955,70 @@ def dequantize_blockwise(
if quant_state.nested: if quant_state.nested:
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax += quant_state.offset absmax += quant_state.offset
if absmax.dtype != torch.float32: absmax = absmax.float() if absmax.dtype != torch.float32:
absmax = absmax.float()
if out is None: if out is None:
out = torch.empty(A.shape, dtype=quant_state.dtype, device=A.device) out = torch.empty(A.shape, dtype=quant_state.dtype, device=A.device)
if A.device.type != 'cpu': if A.device.type != "cpu":
device = pre_call(A.device) device = pre_call(A.device)
code = quant_state.code.to(A.device) code = quant_state.code.to(A.device)
if quant_state.blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: if quant_state.blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
raise ValueError(f"The blockwise of {quant_state.blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") raise ValueError(
f"The blockwise of {quant_state.blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]",
)
is_on_gpu([A, absmax, out]) is_on_gpu([A, absmax, out])
if out.dtype == torch.float32: if out.dtype == torch.float32:
lib.cdequantize_blockwise_fp32(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel())) lib.cdequantize_blockwise_fp32(
get_ptr(quant_state.code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(A.numel()),
)
elif out.dtype == torch.float16: elif out.dtype == torch.float16:
lib.cdequantize_blockwise_fp16(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel())) lib.cdequantize_blockwise_fp16(
get_ptr(quant_state.code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(A.numel()),
)
elif out.dtype == torch.bfloat16: elif out.dtype == torch.bfloat16:
lib.cdequantize_blockwise_bf16(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel())) lib.cdequantize_blockwise_bf16(
get_ptr(quant_state.code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(A.numel()),
)
else: else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
post_call(A.device) post_call(A.device)
else: else:
code = quant_state.code.cpu() code = quant_state.code.cpu()
lib.cdequantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(quant_state.absmax), get_ptr(out), ct.c_longlong(quant_state.blocksize), ct.c_longlong(A.numel())) lib.cdequantize_blockwise_cpu_fp32(
get_ptr(code),
get_ptr(A),
get_ptr(quant_state.absmax),
get_ptr(out),
ct.c_longlong(quant_state.blocksize),
ct.c_longlong(A.numel()),
)
return out return out
def get_4bit_type(typename, device=None, blocksize=64): def get_4bit_type(typename, device=None, blocksize=64):
if device is None: device = 'cuda' if device is None:
device = "cuda"
data = None data = None
if typename == 'nf4': if typename == "nf4":
''' Implements the NF4 data type. """ Implements the NF4 data type.
Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that
is normalized into the range [-1, 1]. is normalized into the range [-1, 1].
...@@ -888,12 +1027,26 @@ def get_4bit_type(typename, device=None, blocksize=64): ...@@ -888,12 +1027,26 @@ def get_4bit_type(typename, device=None, blocksize=64):
Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236. the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
''' """
data = [-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, data = [
-0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, -1.0,
0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, -0.6961928009986877,
0.7229568362236023, 1.0] -0.5250730514526367,
elif typename == 'fp4': -0.39491748809814453,
-0.28444138169288635,
-0.18477343022823334,
-0.09105003625154495,
0.0,
0.07958029955625534,
0.16093020141124725,
0.24611230194568634,
0.33791524171829224,
0.44070982933044434,
0.5626170039176941,
0.7229568362236023,
1.0,
]
elif typename == "fp4":
# 0b000 = 0 # 0b000 = 0
# 0b001 = 0.0625 # 0b001 = 0.0625
# 0b010 = 8 # 0b010 = 8
...@@ -904,20 +1057,35 @@ def get_4bit_type(typename, device=None, blocksize=64): ...@@ -904,20 +1057,35 @@ def get_4bit_type(typename, device=None, blocksize=64):
# 0b111 = 3 # 0b111 = 3
# can also be created with bnb.functional.create_fp8_map(signed=True, exponent_bits=2, precision_bits=1, total_bits=4) # can also be created with bnb.functional.create_fp8_map(signed=True, exponent_bits=2, precision_bits=1, total_bits=4)
data = [0, 0.0625, 8.0, 12.0, 4.0, 6.0, 2.0, 3.0, -0, -0.0625, -8.0, -12.0, -4.0, -6.0, -2.0, -3.0] data = [0, 0.0625, 8.0, 12.0, 4.0, 6.0, 2.0, 3.0, -0, -0.0625, -8.0, -12.0, -4.0, -6.0, -2.0, -3.0]
elif typename == 'int4': elif typename == "int4":
data = [7, 6, 5, 4, 3, 2, 1, 0, -0, -1, -2, -3, -4, -5, -6, -7] data = [7, 6, 5, 4, 3, 2, 1, 0, -0, -1, -2, -3, -4, -5, -6, -7]
elif typename == 'af4': elif typename == "af4":
# Taken from: NF4 Isn't Information Theoretically Optimal (and that's Good) # Taken from: NF4 Isn't Information Theoretically Optimal (and that's Good)
# https://arxiv.org/abs/2306.06965 # https://arxiv.org/abs/2306.06965
if blocksize == 64: if blocksize == 64:
data = [-1., -0.69441008, -0.51243739, -0.3736951, -0.25607552, -0.14982478, data = [
-0.04934812, 0., 0.04273164, 0.12934483, 0.21961274, 0.31675666, -1.0,
0.42563882, 0.55496234, 0.72424863, 1.][::-1] -0.69441008,
-0.51243739,
-0.3736951,
-0.25607552,
-0.14982478,
-0.04934812,
0.0,
0.04273164,
0.12934483,
0.21961274,
0.31675666,
0.42563882,
0.55496234,
0.72424863,
1.0,
][::-1]
else: else:
raise NotImplementedError('4-bit AbnormalFloats currently only support blocksize 64.') raise NotImplementedError("4-bit AbnormalFloats currently only support blocksize 64.")
if data is None: if data is None:
raise NotImplementedError(f'Typename {typename} not supported') raise NotImplementedError(f"Typename {typename} not supported")
data = Tensor(data) data = Tensor(data)
data /= data.abs().max() data /= data.abs().max()
...@@ -926,11 +1094,26 @@ def get_4bit_type(typename, device=None, blocksize=64): ...@@ -926,11 +1094,26 @@ def get_4bit_type(typename, device=None, blocksize=64):
return data.to(device) return data.to(device)
def quantize_fp4(A: Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8): def quantize_fp4(
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4', quant_storage) A: Tensor,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize=64,
compress_statistics=False,
quant_storage=torch.uint8,
):
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage)
def quantize_nf4(A: Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8): def quantize_nf4(
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4', quant_storage) A: Tensor,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize=64,
compress_statistics=False,
quant_storage=torch.uint8,
):
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage)
def quantize_4bit( def quantize_4bit(
...@@ -939,7 +1122,7 @@ def quantize_4bit( ...@@ -939,7 +1122,7 @@ def quantize_4bit(
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
blocksize=64, blocksize=64,
compress_statistics=False, compress_statistics=False,
quant_type='fp4', quant_type="fp4",
quant_storage=torch.uint8, quant_storage=torch.uint8,
) -> Tuple[Tensor, QuantState]: ) -> Tuple[Tensor, QuantState]:
""" """
...@@ -967,10 +1150,10 @@ def quantize_4bit( ...@@ -967,10 +1150,10 @@ def quantize_4bit(
tuple(torch.Tensor, torch.Size, torch.dtype, int): tuple(torch.Tensor, torch.Size, torch.dtype, int):
The quantization state to undo the quantization. The quantization state to undo the quantization.
""" """
if A.device.type != 'cuda': if A.device.type != "cuda":
raise NotImplementedError(f'Device type not supported for FP4 quantization: {A.device.type}') raise NotImplementedError(f"Device type not supported for FP4 quantization: {A.device.type}")
if quant_type not in ['fp4', 'nf4']: if quant_type not in ["fp4", "nf4"]:
raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.")
n = A.numel() n = A.numel()
input_shape = A.shape input_shape = A.shape
...@@ -980,10 +1163,9 @@ def quantize_4bit( ...@@ -980,10 +1163,9 @@ def quantize_4bit(
blocks += 1 if n % blocksize > 0 else 0 blocks += 1 if n % blocksize > 0 else 0
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
if out is None: if out is None:
mod = dtype2bytes[quant_storage] * 2 mod = dtype2bytes[quant_storage] * 2
out = torch.zeros(((n+1)//mod, 1), dtype=quant_storage, device=A.device) out = torch.zeros(((n + 1) // mod, 1), dtype=quant_storage, device=A.device)
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
...@@ -991,20 +1173,62 @@ def quantize_4bit( ...@@ -991,20 +1173,62 @@ def quantize_4bit(
is_on_gpu([A, out, absmax]) is_on_gpu([A, out, absmax])
if A.dtype == torch.float32: if A.dtype == torch.float32:
if quant_type == 'fp4': if quant_type == "fp4":
lib.cquantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) lib.cquantize_blockwise_fp32_fp4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int32(blocksize),
ct.c_int(n),
)
else: else:
lib.cquantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) lib.cquantize_blockwise_fp32_nf4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int32(blocksize),
ct.c_int(n),
)
elif A.dtype == torch.float16: elif A.dtype == torch.float16:
if quant_type == 'fp4': if quant_type == "fp4":
lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) lib.cquantize_blockwise_fp16_fp4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int32(blocksize),
ct.c_int(n),
)
else: else:
lib.cquantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) lib.cquantize_blockwise_fp16_nf4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int32(blocksize),
ct.c_int(n),
)
elif A.dtype == torch.bfloat16: elif A.dtype == torch.bfloat16:
if quant_type == 'fp4': if quant_type == "fp4":
lib.cquantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) lib.cquantize_blockwise_bf16_fp4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int32(blocksize),
ct.c_int(n),
)
else: else:
lib.cquantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) lib.cquantize_blockwise_bf16_nf4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int32(blocksize),
ct.c_int(n),
)
else: else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
post_call(A.device) post_call(A.device)
...@@ -1016,19 +1240,57 @@ def quantize_4bit( ...@@ -1016,19 +1240,57 @@ def quantize_4bit(
absmax -= offset absmax -= offset
qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) qabsmax, state2 = quantize_blockwise(absmax, blocksize=256)
del absmax del absmax
state = QuantState(absmax=qabsmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, offset=offset, state2=state2) state = QuantState(
absmax=qabsmax,
shape=input_shape,
dtype=A.dtype,
blocksize=blocksize,
code=code,
quant_type=quant_type,
offset=offset,
state2=state2,
)
else: else:
state = QuantState(absmax=absmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, ) state = QuantState(
absmax=absmax,
shape=input_shape,
dtype=A.dtype,
blocksize=blocksize,
code=code,
quant_type=quant_type,
)
return out, state return out, state
def dequantize_fp4(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64) -> Tensor:
return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'fp4')
def dequantize_nf4(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64) -> Tensor: def dequantize_fp4(
return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'nf4') A: Tensor,
quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize: int = 64,
) -> Tensor:
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4")
def dequantize_nf4(
A: Tensor,
quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize: int = 64,
) -> Tensor:
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4")
def dequantize_4bit(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64, quant_type='fp4') -> Tensor: def dequantize_4bit(
A: Tensor,
quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize: int = 64,
quant_type="fp4",
) -> Tensor:
""" """
Dequantizes FP4 blockwise quantized values. Dequantizes FP4 blockwise quantized values.
...@@ -1056,23 +1318,31 @@ def dequantize_4bit(A: Tensor, quant_state: Optional[QuantState] = None, absmax: ...@@ -1056,23 +1318,31 @@ def dequantize_4bit(A: Tensor, quant_state: Optional[QuantState] = None, absmax:
Dequantized tensor. Dequantized tensor.
""" """
if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") raise ValueError(
if quant_type not in ['fp4', 'nf4']: f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]",
raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') )
if quant_type not in ["fp4", "nf4"]:
raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.")
if quant_state is None: if quant_state is None:
assert absmax is not None and out is not None assert absmax is not None and out is not None
quant_state = QuantState(absmax=absmax, shape=out.shape, dtype=out.dtype, blocksize=blocksize, quant_type=quant_type) quant_state = QuantState(
absmax=absmax,
shape=out.shape,
dtype=out.dtype,
blocksize=blocksize,
quant_type=quant_type,
)
else: else:
absmax = quant_state.absmax absmax = quant_state.absmax
if quant_state.nested: if quant_state.nested:
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax += quant_state.offset absmax += quant_state.offset
if absmax.dtype != torch.float32: absmax = absmax.float() if absmax.dtype != torch.float32:
absmax = absmax.float()
if out is None: if out is None:
out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device) out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device)
...@@ -1082,27 +1352,71 @@ def dequantize_4bit(A: Tensor, quant_state: Optional[QuantState] = None, absmax: ...@@ -1082,27 +1352,71 @@ def dequantize_4bit(A: Tensor, quant_state: Optional[QuantState] = None, absmax:
device = pre_call(A.device) device = pre_call(A.device)
is_on_gpu([A, absmax, out]) is_on_gpu([A, absmax, out])
if out.dtype == torch.float32: if out.dtype == torch.float32:
if quant_state.quant_type == 'fp4': if quant_state.quant_type == "fp4":
lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) lib.cdequantize_blockwise_fp32_fp4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
)
else: else:
lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) lib.cdequantize_blockwise_fp32_nf4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
)
elif out.dtype == torch.float16: elif out.dtype == torch.float16:
if quant_state.quant_type == 'fp4': if quant_state.quant_type == "fp4":
lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) lib.cdequantize_blockwise_fp16_fp4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
)
else: else:
lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) lib.cdequantize_blockwise_fp16_nf4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
)
elif out.dtype == torch.bfloat16: elif out.dtype == torch.bfloat16:
if quant_state.quant_type == 'fp4': if quant_state.quant_type == "fp4":
lib.cdequantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) lib.cdequantize_blockwise_bf16_fp4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
)
else: else:
lib.cdequantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) lib.cdequantize_blockwise_bf16_nf4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
)
else: else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
post_call(A.device) post_call(A.device)
is_transposed = (True if A.shape[0] == 1 else False) is_transposed = True if A.shape[0] == 1 else False
if is_transposed: return out.t() if is_transposed:
else: return out return out.t()
else:
return out
def quantize( def quantize(
...@@ -1117,7 +1431,8 @@ def quantize( ...@@ -1117,7 +1431,8 @@ def quantize(
code = code.to(A.device) code = code.to(A.device)
absmax = torch.abs(A).max() absmax = torch.abs(A).max()
if absmax.dtype != torch.float32: absmax = absmax.float() if absmax.dtype != torch.float32:
absmax = absmax.float()
inp = A / absmax inp = A / absmax
out = quantize_no_absmax(inp, code, out) out = quantize_no_absmax(inp, code, out)
return out, (absmax, code) return out, (absmax, code)
...@@ -1144,7 +1459,7 @@ def dequantize( ...@@ -1144,7 +1459,7 @@ def dequantize(
def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor: def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor:
''' """
Quantizes input tensor to 8-bit. Quantizes input tensor to 8-bit.
Quantizes the 32-bit input tensor `A` to the 8-bit output tensor Quantizes the 32-bit input tensor `A` to the 8-bit output tensor
...@@ -1163,9 +1478,10 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = No ...@@ -1163,9 +1478,10 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = No
------- -------
torch.Tensor: torch.Tensor:
Quantized 8-bit tensor. Quantized 8-bit tensor.
''' """
prev_device = pre_call(A.device) prev_device = pre_call(A.device)
if out is None: out = torch.zeros_like(A, dtype=torch.uint8) if out is None:
out = torch.zeros_like(A, dtype=torch.uint8)
is_on_gpu([A, out]) is_on_gpu([A, out])
lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
post_call(prev_device) post_call(prev_device)
...@@ -1173,7 +1489,7 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = No ...@@ -1173,7 +1489,7 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = No
def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor: def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor:
''' """
Dequantizes the 8-bit tensor to 32-bit. Dequantizes the 8-bit tensor to 32-bit.
Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via
...@@ -1192,9 +1508,10 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = ...@@ -1192,9 +1508,10 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] =
------- -------
torch.Tensor: torch.Tensor:
32-bit output tensor. 32-bit output tensor.
''' """
prev_device = pre_call(A.device) prev_device = pre_call(A.device)
if out is None: out = torch.zeros_like(A, dtype=torch.float32) if out is None:
out = torch.zeros_like(A, dtype=torch.float32)
is_on_gpu([code, A, out]) is_on_gpu([code, A, out])
lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
post_call(prev_device) post_call(prev_device)
...@@ -1261,16 +1578,17 @@ def optimizer_update_32bit( ...@@ -1261,16 +1578,17 @@ def optimizer_update_32bit(
if max_unorm > 0.0: if max_unorm > 0.0:
param_norm = torch.norm(p.data.float()) param_norm = torch.norm(p.data.float())
optim_func = None optim_func = None
if g.dtype == torch.float32: if g.dtype == torch.float32:
optim_func = str2optimizer32bit[optimizer_name][0] optim_func = str2optimizer32bit[optimizer_name][0]
elif g.dtype == torch.float16: elif g.dtype == torch.float16:
optim_func = str2optimizer32bit[optimizer_name][1] optim_func = str2optimizer32bit[optimizer_name][1]
elif (g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name])==3): elif g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name]) == 3:
optim_func = str2optimizer32bit[optimizer_name][2] optim_func = str2optimizer32bit[optimizer_name][2]
else: else:
raise ValueError(f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}") raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}",
)
is_on_gpu([g, p, state1, state2, unorm_vec]) is_on_gpu([g, p, state1, state2, unorm_vec])
prev_device = pre_call(g.device) prev_device = pre_call(g.device)
...@@ -1290,7 +1608,8 @@ def optimizer_update_32bit( ...@@ -1290,7 +1608,8 @@ def optimizer_update_32bit(
ct.c_float(lr), ct.c_float(lr),
ct.c_float(gnorm_scale), ct.c_float(gnorm_scale),
ct.c_bool(skip_zeros), ct.c_bool(skip_zeros),
ct.c_int32(g.numel())) ct.c_int32(g.numel()),
)
post_call(prev_device) post_call(prev_device)
...@@ -1422,7 +1741,7 @@ def optimizer_update_8bit( ...@@ -1422,7 +1741,7 @@ def optimizer_update_8bit(
) )
else: else:
raise ValueError( raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}",
) )
post_call(prev_device) post_call(prev_device)
...@@ -1446,7 +1765,6 @@ def optimizer_update_8bit_blockwise( ...@@ -1446,7 +1765,6 @@ def optimizer_update_8bit_blockwise(
gnorm_scale: float = 1.0, gnorm_scale: float = 1.0,
skip_zeros=False, skip_zeros=False,
) -> None: ) -> None:
optim_func = None optim_func = None
prev_device = pre_call(g.device) prev_device = pre_call(g.device)
is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2]) is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2])
...@@ -1454,12 +1772,15 @@ def optimizer_update_8bit_blockwise( ...@@ -1454,12 +1772,15 @@ def optimizer_update_8bit_blockwise(
optim_func = str2optimizer8bit_blockwise[optimizer_name][0] optim_func = str2optimizer8bit_blockwise[optimizer_name][0]
elif g.dtype == torch.float16 and state1.dtype == torch.uint8: elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
optim_func = str2optimizer8bit_blockwise[optimizer_name][1] optim_func = str2optimizer8bit_blockwise[optimizer_name][1]
elif (g.dtype == torch.bfloat16 and state1.dtype == torch.uint8 and elif (
len(str2optimizer8bit_blockwise[optimizer_name])==3): g.dtype == torch.bfloat16
and state1.dtype == torch.uint8
and len(str2optimizer8bit_blockwise[optimizer_name]) == 3
):
optim_func = str2optimizer8bit_blockwise[optimizer_name][2] optim_func = str2optimizer8bit_blockwise[optimizer_name][2]
else: else:
raise ValueError( raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}",
) )
post_call(prev_device) post_call(prev_device)
...@@ -1487,9 +1808,8 @@ def optimizer_update_8bit_blockwise( ...@@ -1487,9 +1808,8 @@ def optimizer_update_8bit_blockwise(
) )
post_call(prev_device) post_call(prev_device)
def percentile_clipping(
grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5 def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5):
):
"""Applies percentile clipping """Applies percentile clipping
grad: torch.Tensor grad: torch.Tensor
...@@ -1531,9 +1851,7 @@ def percentile_clipping( ...@@ -1531,9 +1851,7 @@ def percentile_clipping(
return current_gnorm, clip_value, gnorm_scale return current_gnorm, clip_value, gnorm_scale
def histogram_scatter_add_2d( def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor):
histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor
):
assert len(histogram.shape) == 2 assert len(histogram.shape) == 2
assert histogram.dtype == torch.float32 assert histogram.dtype == torch.float32
assert source.dtype == torch.float32 assert source.dtype == torch.float32
...@@ -1550,12 +1868,12 @@ def histogram_scatter_add_2d( ...@@ -1550,12 +1868,12 @@ def histogram_scatter_add_2d(
is_on_gpu([histogram, index1, index2, source]) is_on_gpu([histogram, index1, index2, source])
lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n) lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n)
def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8): def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8):
if not torch.cuda.is_initialized(): torch.cuda.init() if not torch.cuda.is_initialized():
torch.cuda.init()
if A.dtype != expected_type or B.dtype != expected_type: if A.dtype != expected_type or B.dtype != expected_type:
raise TypeError( raise TypeError(f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}")
f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}"
)
sA = A.shape sA = A.shape
sB = B.shape sB = B.shape
...@@ -1596,12 +1914,7 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8 ...@@ -1596,12 +1914,7 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
sout = out.shape sout = out.shape
# special case common in backprop # special case common in backprop
if not correct and len(sA) == 3 and len(sB) == 3: if not correct and len(sA) == 3 and len(sB) == 3:
if ( if sout[0] == sA[2] and sout[1] == sB[2] and sA[0] == sB[0] and sA[1] == sB[1]:
sout[0] == sA[2]
and sout[1] == sB[2]
and sA[0] == sB[0]
and sA[1] == sB[1]
):
correct = True correct = True
else: else:
if len(sA) == 2 and len(sB) == 2: if len(sA) == 2 and len(sB) == 2:
...@@ -1634,26 +1947,29 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8 ...@@ -1634,26 +1947,29 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
if not correct: if not correct:
raise ValueError( raise ValueError(
f"Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}." f"Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}.",
) )
return sout return sout
def gemv_4bit( def gemv_4bit(
A: Tensor, A: Tensor,
B: Tensor, B: Tensor,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
transposed_A=False, transposed_A=False,
transposed_B=False, transposed_B=False,
state=None state=None,
): ):
prev_device = pre_call(A.device) prev_device = pre_call(A.device)
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) # sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
if state is None: if state is None:
raise ValueError('state cannot None. gem_4bit( ) requires the state from quantize_4bit( )') raise ValueError("state cannot None. gem_4bit( ) requires the state from quantize_4bit( )")
if A.numel() != A.shape[-1]: if A.numel() != A.shape[-1]:
raise ValueError('Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]') raise ValueError(
'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]',
)
Bshape = state.shape Bshape = state.shape
bout = Bshape[0] bout = Bshape[0]
...@@ -1673,7 +1989,7 @@ def gemv_4bit( ...@@ -1673,7 +1989,7 @@ def gemv_4bit(
k = Bshape[1] k = Bshape[1]
lda = Bshape[0] lda = Bshape[0]
ldc = Bshape[0] ldc = Bshape[0]
ldb = (A.shape[-1]+1)//2 ldb = (A.shape[-1] + 1) // 2
is_on_gpu([B, A, out, absmax, state.code]) is_on_gpu([B, A, out, absmax, state.code])
m = ct.c_int32(m) m = ct.c_int32(m)
n = ct.c_int32(n) n = ct.c_int32(n)
...@@ -1684,21 +2000,61 @@ def gemv_4bit( ...@@ -1684,21 +2000,61 @@ def gemv_4bit(
if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]: if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]:
if A.dtype == torch.float16: if A.dtype == torch.float16:
lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize)) lib.cgemm_4bit_inference_naive_fp16(
m,
n,
k,
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(state.code),
get_ptr(out),
lda,
ldb,
ldc,
ct.c_int32(state.blocksize),
)
elif A.dtype == torch.bfloat16: elif A.dtype == torch.bfloat16:
lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize)) lib.cgemm_4bit_inference_naive_bf16(
m,
n,
k,
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(state.code),
get_ptr(out),
lda,
ldb,
ldc,
ct.c_int32(state.blocksize),
)
elif A.dtype == torch.float32: elif A.dtype == torch.float32:
lib.cgemm_4bit_inference_naive_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize)) lib.cgemm_4bit_inference_naive_fp32(
m,
n,
k,
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(state.code),
get_ptr(out),
lda,
ldb,
ldc,
ct.c_int32(state.blocksize),
)
else: else:
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}")
else: else:
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}")
post_call(prev_device) post_call(prev_device)
return out return out
def igemm( def igemm(
A: Tensor, A: Tensor,
B: Tensor, B: Tensor,
...@@ -1764,7 +2120,7 @@ def igemm( ...@@ -1764,7 +2120,7 @@ def igemm(
assert len(sA) == 3 assert len(sA) == 3
if not (sA[0] == sB[0] and sA[1] == sB[1]): if not (sA[0] == sB[0] and sA[1] == sB[1]):
raise ValueError( raise ValueError(
f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}" f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}",
) )
transposed_A = True transposed_A = True
...@@ -1783,8 +2139,20 @@ def igemm( ...@@ -1783,8 +2139,20 @@ def igemm(
# B^T @ A^T = C^T # B^T @ A^T = C^T
# [km, nk -> mn] # [km, nk -> mn]
is_on_gpu([B, A, out]) is_on_gpu([B, A, out])
lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), lib.cigemm(
get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc)) ptr,
ct.c_bool(transposed_B),
ct.c_bool(transposed_A),
ct.c_int32(m),
ct.c_int32(n),
ct.c_int32(k),
get_ptr(B),
get_ptr(A),
get_ptr(out),
ct.c_int32(lda),
ct.c_int32(ldb),
ct.c_int32(ldc),
)
return out return out
...@@ -1796,9 +2164,7 @@ def batched_igemm( ...@@ -1796,9 +2164,7 @@ def batched_igemm(
transposed_B=False, transposed_B=False,
): ):
if not len(A.shape) == 3 or not len(B.shape) == 3: if not len(A.shape) == 3 or not len(B.shape) == 3:
raise ValueError( raise ValueError(f"Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}")
f"Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}"
)
sout = check_matmul(A, B, out, transposed_A, transposed_B) sout = check_matmul(A, B, out, transposed_A, transposed_B)
if out is None: if out is None:
out = torch.zeros(size=sout, dtype=torch.int32, device=A.device) out = torch.zeros(size=sout, dtype=torch.int32, device=A.device)
...@@ -1865,9 +2231,24 @@ def batched_igemm( ...@@ -1865,9 +2231,24 @@ def batched_igemm(
ptr = CUBLAS_Context.get_instance().get_context(A.device) ptr = CUBLAS_Context.get_instance().get_context(A.device)
is_on_gpu([B, A, out]) is_on_gpu([B, A, out])
lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), lib.cbatched_igemm(
get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc), ptr,
ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch)) ct.c_bool(transposed_B),
ct.c_bool(transposed_A),
ct.c_int32(m),
ct.c_int32(n),
ct.c_int32(k),
get_ptr(B),
get_ptr(A),
get_ptr(out),
ct.c_int32(lda),
ct.c_int32(ldb),
ct.c_int32(ldc),
ct.c_long(strideA),
ct.c_long(strideB),
ct.c_long(strideC),
ct.c_uint32(num_batch),
)
return out return out
...@@ -1876,14 +2257,14 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): ...@@ -1876,14 +2257,14 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
shapeB = SB[0] shapeB = SB[0]
dimsA = len(shapeA) dimsA = len(shapeA)
dimsB = len(shapeB) dimsB = len(shapeB)
assert dimsB == 2, 'Only two dimensional matrices are supported for argument B' assert dimsB == 2, "Only two dimensional matrices are supported for argument B"
if dimsA == 2: if dimsA == 2:
m = shapeA[0] m = shapeA[0]
elif dimsA == 3: elif dimsA == 3:
m = shapeA[0] * shapeA[1] m = shapeA[0] * shapeA[1]
rows = n = shapeB[0] rows = n = shapeB[0]
assert prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}' assert prod(list(shapeA)) > 0, f"Input tensor dimensions need to be > 0: {shapeA}"
# if the tensor is empty, return a transformed empty tensor with the right dimensions # if the tensor is empty, return a transformed empty tensor with the right dimensions
if shapeA[0] == 0 and dimsA == 2: if shapeA[0] == 0 and dimsA == 2:
...@@ -1892,13 +2273,9 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): ...@@ -1892,13 +2273,9 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16) return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16)
if dimsA == 2 and out is None: if dimsA == 2 and out is None:
out, Sout = get_transform_buffer( out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, "col32", "row")
(shapeA[0], shapeB[0]), dtype, A.device, "col32", "row"
)
elif dimsA == 3 and out is None: elif dimsA == 3 and out is None:
out, Sout = get_transform_buffer( out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row")
(shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row"
)
assert dimsB != 3, "len(B.shape)==3 not supported" assert dimsB != 3, "len(B.shape)==3 not supported"
assert A.device.type == "cuda" assert A.device.type == "cuda"
...@@ -1940,49 +2317,33 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): ...@@ -1940,49 +2317,33 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
has_error = 0 has_error = 0
ptrRowScale = get_ptr(None) ptrRowScale = get_ptr(None)
is_on_gpu([A, B, out]) is_on_gpu([A, B, out])
if formatB == 'col_turing': if formatB == "col_turing":
if dtype == torch.int32: if dtype == torch.int32:
has_error = lib.cigemmlt_turing_32( has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
)
else: else:
has_error = lib.cigemmlt_turing_8( has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
)
elif formatB == "col_ampere": elif formatB == "col_ampere":
if dtype == torch.int32: if dtype == torch.int32:
has_error = lib.cigemmlt_ampere_32( has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
)
else: else:
has_error = lib.cigemmlt_ampere_8( has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
)
if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu`
raise NotImplementedError("igemmlt not available (probably built with NO_CUBLASLT)") raise NotImplementedError("igemmlt not available (probably built with NO_CUBLASLT)")
if has_error: if has_error:
print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}') print(f"A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}")
raise Exception('cublasLt ran into an error!') raise Exception("cublasLt ran into an error!")
torch.cuda.set_device(prev_device) torch.cuda.set_device(prev_device)
return out, Sout return out, Sout
def mm_dequant( def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None, bias=None):
A,
quant_state,
row_stats,
col_stats,
out=None,
new_row_stats=None,
new_col_stats=None,
bias=None
):
assert A.dtype == torch.int32 assert A.dtype == torch.int32
if bias is not None: assert bias.dtype == torch.float16 if bias is not None:
assert bias.dtype == torch.float16
out_shape = quant_state[0] out_shape = quant_state[0]
if len(out_shape) == 3: if len(out_shape) == 3:
out_shape = (out_shape[0] * out_shape[1], out_shape[2]) out_shape = (out_shape[0] * out_shape[1], out_shape[2])
...@@ -1990,19 +2351,11 @@ def mm_dequant( ...@@ -1990,19 +2351,11 @@ def mm_dequant(
if out is None: if out is None:
out = torch.empty(out_shape, dtype=torch.float16, device=A.device) out = torch.empty(out_shape, dtype=torch.float16, device=A.device)
if new_row_stats is None: if new_row_stats is None:
new_row_stats = torch.empty( new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device)
out_shape[0], dtype=torch.float32, device=A.device
)
if new_col_stats is None: if new_col_stats is None:
new_col_stats = torch.empty( new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device)
out_shape[1], dtype=torch.float32, device=A.device assert new_row_stats.shape[0] == row_stats.shape[0], f"{new_row_stats.shape} vs {row_stats.shape}"
) assert new_col_stats.shape[0] == col_stats.shape[0], f"{new_col_stats.shape} vs {col_stats.shape}"
assert (
new_row_stats.shape[0] == row_stats.shape[0]
), f"{new_row_stats.shape} vs {row_stats.shape}"
assert (
new_col_stats.shape[0] == col_stats.shape[0]
), f"{new_col_stats.shape} vs {col_stats.shape}"
prev_device = pre_call(A.device) prev_device = pre_call(A.device)
ptrA = get_ptr(A) ptrA = get_ptr(A)
...@@ -2016,15 +2369,23 @@ def mm_dequant( ...@@ -2016,15 +2369,23 @@ def mm_dequant(
numCols = ct.c_int32(out_shape[1]) numCols = ct.c_int32(out_shape[1])
is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias]) is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias])
lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, ptrBias, numRows, numCols) lib.cdequant_mm_int32_fp16(
ptrA,
ptrRowStats,
ptrColStats,
ptrOut,
ptrNewRowStats,
ptrNewColStats,
ptrBias,
numRows,
numCols,
)
post_call(prev_device) post_call(prev_device)
return out return out
def get_colrow_absmax( def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0):
A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0
):
assert A.dtype == torch.float16 assert A.dtype == torch.float16
device = A.device device = A.device
...@@ -2037,18 +2398,12 @@ def get_colrow_absmax( ...@@ -2037,18 +2398,12 @@ def get_colrow_absmax(
col_tiles = (cols + 255) // 256 col_tiles = (cols + 255) // 256
tiled_rows = ((rows + 15) // 16) * 16 tiled_rows = ((rows + 15) // 16) * 16
if row_stats is None: if row_stats is None:
row_stats = torch.empty( row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_(-50000.0)
(rows,), dtype=torch.float32, device=device
).fill_(-50000.0)
if col_stats is None: if col_stats is None:
col_stats = torch.empty( col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_(-50000.0)
(cols,), dtype=torch.float32, device=device
).fill_(-50000.0)
if nnz_block_ptr is None and threshold > 0.0: if nnz_block_ptr is None and threshold > 0.0:
nnz_block_ptr = torch.zeros( nnz_block_ptr = torch.zeros(((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device)
((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device
)
ptrA = get_ptr(A) ptrA = get_ptr(A)
ptrRowStats = get_ptr(row_stats) ptrRowStats = get_ptr(row_stats)
...@@ -2122,14 +2477,10 @@ class CSCSparseTensor: ...@@ -2122,14 +2477,10 @@ class CSCSparseTensor:
def coo2csr(cooA): def coo2csr(cooA):
values, counts = torch.unique(cooA.rowidx, return_counts=True) values, counts = torch.unique(cooA.rowidx, return_counts=True)
values.add_(1) values.add_(1)
rowptr = torch.zeros( rowptr = torch.zeros((cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device)
(cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device
)
rowptr.scatter_(index=values.long(), src=counts.int(), dim=0) rowptr.scatter_(index=values.long(), src=counts.int(), dim=0)
rowptr.cumsum_(0) rowptr.cumsum_(0)
return CSRSparseTensor( return CSRSparseTensor(cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values)
cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values
)
def coo2csc(cooA): def coo2csc(cooA):
...@@ -2138,14 +2489,10 @@ def coo2csc(cooA): ...@@ -2138,14 +2489,10 @@ def coo2csc(cooA):
values = cooA.values[col2rowidx] values = cooA.values[col2rowidx]
colvalues, counts = torch.unique(val, return_counts=True) colvalues, counts = torch.unique(val, return_counts=True)
colvalues.add_(1) colvalues.add_(1)
colptr = torch.zeros( colptr = torch.zeros((cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device)
(cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device
)
colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0) colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0)
colptr.cumsum_(0) colptr.cumsum_(0)
return CSCSparseTensor( return CSCSparseTensor(cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values)
cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values
)
def coo_zeros(rows, cols, nnz, device, dtype=torch.half): def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
...@@ -2155,9 +2502,7 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half): ...@@ -2155,9 +2502,7 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values)
def double_quant( def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0):
A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0
):
device = A.device device = A.device
assert A.dtype == torch.half assert A.dtype == torch.half
assert device.type == "cuda" assert device.type == "cuda"
...@@ -2170,9 +2515,7 @@ def double_quant( ...@@ -2170,9 +2515,7 @@ def double_quant(
rows = A.shape[0] rows = A.shape[0]
if row_stats is None or col_stats is None: if row_stats is None or col_stats is None:
row_stats, col_stats, nnz_row_ptr = get_colrow_absmax( row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold)
A, threshold=threshold
)
if out_col is None: if out_col is None:
out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) out_col = torch.zeros(A.shape, device=device, dtype=torch.int8)
...@@ -2190,9 +2533,7 @@ def double_quant( ...@@ -2190,9 +2533,7 @@ def double_quant(
if threshold > 0.0: if threshold > 0.0:
nnz = nnz_row_ptr[-1].item() nnz = nnz_row_ptr[-1].item()
if nnz > 0: if nnz > 0:
coo_tensor = coo_zeros( coo_tensor = coo_zeros(A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device)
A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device
)
ptrRowIdx = get_ptr(coo_tensor.rowidx) ptrRowIdx = get_ptr(coo_tensor.rowidx)
ptrColIdx = get_ptr(coo_tensor.colidx) ptrColIdx = get_ptr(coo_tensor.colidx)
ptrVal = get_ptr(coo_tensor.values) ptrVal = get_ptr(coo_tensor.values)
...@@ -2251,12 +2592,16 @@ def double_quant( ...@@ -2251,12 +2592,16 @@ def double_quant(
return out_row, out_col, row_stats, col_stats, coo_tensor return out_row, out_col, row_stats, col_stats, coo_tensor
def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): def transform(A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None):
prev_device = pre_call(A.device) prev_device = pre_call(A.device)
if state is None: state = (A.shape, from_order) if state is None:
else: from_order = state[1] state = (A.shape, from_order)
if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) else:
else: new_state = (state[0], to_order) # (shape, order) from_order = state[1]
if out is None:
out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose)
else:
new_state = (state[0], to_order) # (shape, order)
shape = state[0] shape = state[0]
if len(shape) == 2: if len(shape) == 2:
...@@ -2267,7 +2612,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No ...@@ -2267,7 +2612,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
dim2 = ct.c_int32(shape[2]) dim2 = ct.c_int32(shape[2])
is_on_gpu([A, out]) is_on_gpu([A, out])
if to_order == 'col32': if to_order == "col32":
if transpose: if transpose:
lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2)
else: else:
...@@ -2288,7 +2633,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No ...@@ -2288,7 +2633,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
elif from_order == "col_ampere": elif from_order == "col_ampere":
lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2)
else: else:
raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}') raise NotImplementedError(f"Transform function not implemented: From {from_order} to {to_order}")
post_call(prev_device) post_call(prev_device)
...@@ -2297,9 +2642,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No ...@@ -2297,9 +2642,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
def spmm_coo(cooA, B, out=None): def spmm_coo(cooA, B, out=None):
if out is None: if out is None:
out = torch.empty( out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype)
(cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype
)
nnz = cooA.nnz nnz = cooA.nnz
assert cooA.rowidx.numel() == nnz assert cooA.rowidx.numel() == nnz
assert cooA.colidx.numel() == nnz assert cooA.colidx.numel() == nnz
...@@ -2326,16 +2669,28 @@ def spmm_coo(cooA, B, out=None): ...@@ -2326,16 +2669,28 @@ def spmm_coo(cooA, B, out=None):
cldc = ct.c_int32(ldc) cldc = ct.c_int32(ldc)
is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out]) is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out])
lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B)) lib.cspmm_coo(
ptr,
ptrRowidx,
ptrColidx,
ptrValues,
cnnz,
crowsA,
ccolsA,
ccolsB,
cldb,
ptrB,
cldc,
ptrC,
ct.c_bool(transposed_B),
)
return out return out
def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
if out is None: if out is None:
out = torch.zeros( out = torch.zeros((cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype)
(cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype
)
nnz = cooA.nnz nnz = cooA.nnz
prev_device = pre_call(B.device) prev_device = pre_call(B.device)
assert cooA.rowidx.numel() == nnz assert cooA.rowidx.numel() == nnz
...@@ -2353,9 +2708,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): ...@@ -2353,9 +2708,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
max_count, max_idx = torch.sort(counts, descending=True) max_count, max_idx = torch.sort(counts, descending=True)
max_idx = max_idx.int() max_idx = max_idx.int()
max_count = max_count.int() max_count = max_count.int()
assert ( assert max_count[0] <= 32, f"Current max count per row is 8 but found {max_count[0]}."
max_count[0] <= 32
), f"Current max count per row is 8 but found {max_count[0]}."
assert B.dtype in [torch.float16, torch.int8] assert B.dtype in [torch.float16, torch.int8]
ptrOffset = get_ptr(offset) ptrOffset = get_ptr(offset)
ptrMaxCount = get_ptr(max_count) ptrMaxCount = get_ptr(max_count)
...@@ -2443,9 +2796,7 @@ def vectorwise_quant(x, dim=1, quant_type="vector"): ...@@ -2443,9 +2796,7 @@ def vectorwise_quant(x, dim=1, quant_type="vector"):
elif quant_type in ["vector-zeropoint", "row-zeropoint"]: elif quant_type in ["vector-zeropoint", "row-zeropoint"]:
dtype = x.dtype dtype = x.dtype
x = x.float() x = x.float()
dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin( dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin(x, dim=dim, keepdim=True)
x, dim=dim, keepdim=True
)
dyna[dyna == 0] = 1 dyna[dyna == 0] = 1
qx = 255.0 / dyna qx = 255.0 / dyna
minx = torch.amin(x, dim=dim, keepdim=True) minx = torch.amin(x, dim=dim, keepdim=True)
...@@ -2553,9 +2904,7 @@ def extract_outliers(A, SA, idx): ...@@ -2553,9 +2904,7 @@ def extract_outliers(A, SA, idx):
assert formatA in ["col_turing", "col_ampere"] assert formatA in ["col_turing", "col_ampere"]
assert A.device.type == "cuda" assert A.device.type == "cuda"
out = torch.zeros( out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device)
(shapeA[0], idx.numel()), dtype=torch.int8, device=A.device
)
idx_size = ct.c_int32(idx.numel()) idx_size = ct.c_int32(idx.numel())
rows = ct.c_int32(shapeA[0]) rows = ct.c_int32(shapeA[0])
...@@ -2565,7 +2914,7 @@ def extract_outliers(A, SA, idx): ...@@ -2565,7 +2914,7 @@ def extract_outliers(A, SA, idx):
ptrOut = get_ptr(out) ptrOut = get_ptr(out)
prev_device = pre_call(A.device) prev_device = pre_call(A.device)
if formatA == 'col_turing': if formatA == "col_turing":
lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
elif formatA == "col_ampere": elif formatA == "col_ampere":
lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
...@@ -2573,6 +2922,7 @@ def extract_outliers(A, SA, idx): ...@@ -2573,6 +2922,7 @@ def extract_outliers(A, SA, idx):
return out return out
def pipeline_test(A, batch_size): def pipeline_test(A, batch_size):
out = torch.zeros_like(A) out = torch.zeros_like(A)
lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size)) lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size))
......
...@@ -44,6 +44,7 @@ class StableEmbedding(torch.nn.Embedding): ...@@ -44,6 +44,7 @@ class StableEmbedding(torch.nn.Embedding):
reset_parameters(): Reset embedding parameters using Xavier uniform initialization. reset_parameters(): Reset embedding parameters using Xavier uniform initialization.
forward(input: Tensor) -> Tensor: Forward pass through the stable embedding layer. forward(input: Tensor) -> Tensor: Forward pass through the stable embedding layer.
""" """
def __init__( def __init__(
self, self,
num_embeddings: int, num_embeddings: int,
...@@ -89,9 +90,7 @@ class StableEmbedding(torch.nn.Embedding): ...@@ -89,9 +90,7 @@ class StableEmbedding(torch.nn.Embedding):
dtype, dtype,
) )
self.norm = torch.nn.LayerNorm(embedding_dim, device=device) self.norm = torch.nn.LayerNorm(embedding_dim, device=device)
GlobalOptimManager.get_instance().register_module_override( GlobalOptimManager.get_instance().register_module_override(self, "weight", {"optim_bits": 32})
self, "weight", {"optim_bits": 32}
)
def reset_parameters(self) -> None: def reset_parameters(self) -> None:
torch.nn.init.xavier_uniform_(self.weight) torch.nn.init.xavier_uniform_(self.weight)
...@@ -130,6 +129,7 @@ class Embedding(torch.nn.Embedding): ...@@ -130,6 +129,7 @@ class Embedding(torch.nn.Embedding):
""" """
Embedding class to store and retrieve word embeddings from their indices. Embedding class to store and retrieve word embeddings from their indices.
""" """
def __init__( def __init__(
self, self,
num_embeddings: int, num_embeddings: int,
...@@ -170,11 +170,9 @@ class Embedding(torch.nn.Embedding): ...@@ -170,11 +170,9 @@ class Embedding(torch.nn.Embedding):
scale_grad_by_freq, scale_grad_by_freq,
sparse, sparse,
_weight, _weight,
device=device device=device,
)
GlobalOptimManager.get_instance().register_module_override(
self, "weight", {"optim_bits": 32}
) )
GlobalOptimManager.get_instance().register_module_override(self, "weight", {"optim_bits": 32})
def reset_parameters(self) -> None: def reset_parameters(self) -> None:
torch.nn.init.xavier_uniform_(self.weight) torch.nn.init.xavier_uniform_(self.weight)
...@@ -214,10 +212,10 @@ class Params4bit(torch.nn.Parameter): ...@@ -214,10 +212,10 @@ class Params4bit(torch.nn.Parameter):
quant_state: Optional[QuantState] = None, quant_state: Optional[QuantState] = None,
blocksize: int = 64, blocksize: int = 64,
compress_statistics: bool = True, compress_statistics: bool = True,
quant_type: str = 'fp4', quant_type: str = "fp4",
quant_storage: torch.dtype = torch.uint8, quant_storage: torch.dtype = torch.uint8,
module: Optional["Linear4bit"] = None, module: Optional["Linear4bit"] = None,
bnb_quantized: bool = False bnb_quantized: bool = False,
) -> "Params4bit": ) -> "Params4bit":
if data is None: if data is None:
data = torch.empty(0) data = torch.empty(0)
...@@ -250,7 +248,7 @@ class Params4bit(torch.nn.Parameter): ...@@ -250,7 +248,7 @@ class Params4bit(torch.nn.Parameter):
self.bnb_quantized = state["bnb_quantized"] self.bnb_quantized = state["bnb_quantized"]
self.module = state["module"] self.module = state["module"]
def __deepcopy__(self,memo): def __deepcopy__(self, memo):
new_instance = type(self).__new__(type(self)) new_instance = type(self).__new__(type(self))
state = self.__getstate__() state = self.__getstate__()
new_instance.__setstate__(state) new_instance.__setstate__(state)
...@@ -265,7 +263,14 @@ class Params4bit(torch.nn.Parameter): ...@@ -265,7 +263,14 @@ class Params4bit(torch.nn.Parameter):
return new_instance return new_instance
@classmethod @classmethod
def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], requires_grad: bool = False, device='cuda', **kwargs) -> "Params4bit": def from_prequantized(
cls,
data: torch.Tensor,
quantized_stats: Dict[str, Any],
requires_grad: bool = False,
device="cuda",
**kwargs,
) -> "Params4bit":
self = torch.Tensor._make_subclass(cls, data.to(device)) self = torch.Tensor._make_subclass(cls, data.to(device))
self.requires_grad = requires_grad self.requires_grad = requires_grad
self.quant_state = QuantState.from_dict(qs_dict=quantized_stats, device=device) self.quant_state = QuantState.from_dict(qs_dict=quantized_stats, device=device)
...@@ -292,33 +297,39 @@ class Params4bit(torch.nn.Parameter): ...@@ -292,33 +297,39 @@ class Params4bit(torch.nn.Parameter):
return self return self
def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False): def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
return self.to(device='cuda' if device is None else device, non_blocking=non_blocking) return self.to(device="cuda" if device is None else device, non_blocking=non_blocking)
@overload @overload
def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ...,) -> T: def to(
... self: T,
device: Optional[Union[int, device]] = ...,
dtype: Optional[Union[dtype, str]] = ...,
non_blocking: bool = ...,
) -> T: ...
@overload @overload
def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: ...
...
@overload @overload
def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
...
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if (device is not None and device.type == "cuda" and not self.bnb_quantized): if device is not None and device.type == "cuda" and not self.bnb_quantized:
return self._quantize(device) return self._quantize(device)
else: else:
if self.quant_state is not None: if self.quant_state is not None:
self.quant_state.to(device) self.quant_state.to(device)
new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking), new_param = Params4bit(
requires_grad=self.requires_grad, quant_state=self.quant_state, super().to(device=device, dtype=dtype, non_blocking=non_blocking),
blocksize=self.blocksize, compress_statistics=self.compress_statistics, requires_grad=self.requires_grad,
quant_type=self.quant_type) quant_state=self.quant_state,
blocksize=self.blocksize,
compress_statistics=self.compress_statistics,
quant_type=self.quant_type,
)
return new_param return new_param
...@@ -355,7 +366,18 @@ class Linear4bit(nn.Linear): ...@@ -355,7 +366,18 @@ class Linear4bit(nn.Linear):
quantized_model = quantized_model.to(0) # Quantization happens here quantized_model = quantized_model.to(0) # Quantization happens here
``` ```
""" """
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', quant_storage=torch.uint8, device=None):
def __init__(
self,
input_features,
output_features,
bias=True,
compute_dtype=None,
compress_statistics=True,
quant_type="fp4",
quant_storage=torch.uint8,
device=None,
):
""" """
Initialize Linear4bit class. Initialize Linear4bit class.
...@@ -368,7 +390,14 @@ class Linear4bit(nn.Linear): ...@@ -368,7 +390,14 @@ class Linear4bit(nn.Linear):
Whether the linear class uses the bias term as well. Whether the linear class uses the bias term as well.
""" """
super().__init__(input_features, output_features, bias, device) super().__init__(input_features, output_features, bias, device)
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type, quant_storage=quant_storage, module=self) self.weight = Params4bit(
self.weight.data,
requires_grad=False,
compress_statistics=compress_statistics,
quant_type=quant_type,
quant_storage=quant_storage,
module=self,
)
# self.persistent_buffers = [] # TODO consider as way to save quant state # self.persistent_buffers = [] # TODO consider as way to save quant state
self.compute_dtype = compute_dtype self.compute_dtype = compute_dtype
self.compute_type_is_set = False self.compute_type_is_set = False
...@@ -385,11 +414,15 @@ class Linear4bit(nn.Linear): ...@@ -385,11 +414,15 @@ class Linear4bit(nn.Linear):
if self.compute_dtype == torch.float32 and (x.numel() == x.shape[-1]): if self.compute_dtype == torch.float32 and (x.numel() == x.shape[-1]):
# single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast # single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast
# warn the user about this # warn the user about this
warnings.warn('Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.') warnings.warn(
warnings.filterwarnings('ignore', message='.*inference.') "Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.",
)
warnings.filterwarnings("ignore", message=".*inference.")
if self.compute_dtype == torch.float32 and (x.numel() != x.shape[-1]): if self.compute_dtype == torch.float32 and (x.numel() != x.shape[-1]):
warnings.warn('Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.') warnings.warn(
warnings.filterwarnings('ignore', message='.*inference or training') "Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.",
)
warnings.filterwarnings("ignore", message=".*inference or training")
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_state_dict(self, destination, prefix, keep_vars):
""" """
...@@ -407,8 +440,8 @@ class Linear4bit(nn.Linear): ...@@ -407,8 +440,8 @@ class Linear4bit(nn.Linear):
if self.bias is not None and self.bias.dtype != x.dtype: if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype) self.bias.data = self.bias.data.to(x.dtype)
if getattr(self.weight, 'quant_state', None) is None: if getattr(self.weight, "quant_state", None) is None:
if getattr(self, 'quant_state', None) is not None: if getattr(self, "quant_state", None) is not None:
# the quant state got lost when the parameter got converted. This happens for example for fsdp # the quant state got lost when the parameter got converted. This happens for example for fsdp
# since we registered the module, we can recover the state here # since we registered the module, we can recover the state here
assert self.weight.shape[1] == 1 assert self.weight.shape[1] == 1
...@@ -416,7 +449,9 @@ class Linear4bit(nn.Linear): ...@@ -416,7 +449,9 @@ class Linear4bit(nn.Linear):
self.weight = Params4bit(self.weight, quant_storage=self.quant_storage) self.weight = Params4bit(self.weight, quant_storage=self.quant_storage)
self.weight.quant_state = self.quant_state self.weight.quant_state = self.quant_state
else: else:
print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') print(
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.",
)
if not self.compute_type_is_set: if not self.compute_type_is_set:
self.set_compute_type(x) self.set_compute_type(x)
self.compute_type_is_set = True self.compute_type_is_set = True
...@@ -437,7 +472,17 @@ class LinearFP4(Linear4bit): ...@@ -437,7 +472,17 @@ class LinearFP4(Linear4bit):
""" """
Implements the FP4 data type. Implements the FP4 data type.
""" """
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None):
def __init__(
self,
input_features,
output_features,
bias=True,
compute_dtype=None,
compress_statistics=True,
quant_storage=torch.uint8,
device=None,
):
""" """
Args: Args:
input_features (`str`): input_features (`str`):
...@@ -447,11 +492,20 @@ class LinearFP4(Linear4bit): ...@@ -447,11 +492,20 @@ class LinearFP4(Linear4bit):
bias (`bool`, defaults to `True`): bias (`bool`, defaults to `True`):
Whether the linear class uses the bias term as well. Whether the linear class uses the bias term as well.
""" """
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', quant_storage, device) super().__init__(
input_features,
output_features,
bias,
compute_dtype,
compress_statistics,
"fp4",
quant_storage,
device,
)
class LinearNF4(Linear4bit): class LinearNF4(Linear4bit):
''' Implements the NF4 data type. """Implements the NF4 data type.
Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that
is normalized into the range [-1, 1]. is normalized into the range [-1, 1].
...@@ -460,8 +514,18 @@ class LinearNF4(Linear4bit): ...@@ -460,8 +514,18 @@ class LinearNF4(Linear4bit):
Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236. the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
''' """
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None):
def __init__(
self,
input_features,
output_features,
bias=True,
compute_dtype=None,
compress_statistics=True,
quant_storage=torch.uint8,
device=None,
):
""" """
Args: Args:
input_features (`str`): input_features (`str`):
...@@ -471,7 +535,16 @@ class LinearNF4(Linear4bit): ...@@ -471,7 +535,16 @@ class LinearNF4(Linear4bit):
bias (`bool`, defaults to `True`): bias (`bool`, defaults to `True`):
Whether the linear class uses the bias term as well. Whether the linear class uses the bias term as well.
""" """
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', quant_storage, device) super().__init__(
input_features,
output_features,
bias,
compute_dtype,
compress_statistics,
"nf4",
quant_storage,
device,
)
class Int8Params(torch.nn.Parameter): class Int8Params(torch.nn.Parameter):
...@@ -514,33 +587,22 @@ class Int8Params(torch.nn.Parameter): ...@@ -514,33 +587,22 @@ class Int8Params(torch.nn.Parameter):
device: Optional[Union[int, device]] = ..., device: Optional[Union[int, device]] = ...,
dtype: Optional[Union[dtype, str]] = ..., dtype: Optional[Union[dtype, str]] = ...,
non_blocking: bool = ..., non_blocking: bool = ...,
) -> T: ) -> T: ...
...
@overload @overload
def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: ...
...
@overload @overload
def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
...
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
*args, **kwargs
)
if ( if device is not None and device.type == "cuda" and self.data.device.type == "cpu":
device is not None
and device.type == "cuda"
and self.data.device.type == "cpu"
):
return self.cuda(device) return self.cuda(device)
else: else:
new_param = Int8Params( new_param = Int8Params(
super().to( super().to(device=device, dtype=dtype, non_blocking=non_blocking),
device=device, dtype=dtype, non_blocking=non_blocking
),
requires_grad=self.requires_grad, requires_grad=self.requires_grad,
has_fp16_weights=self.has_fp16_weights, has_fp16_weights=self.has_fp16_weights,
) )
...@@ -593,8 +655,18 @@ class Linear8bitLt(nn.Linear): ...@@ -593,8 +655,18 @@ class Linear8bitLt(nn.Linear):
int8_model = int8_model.to(0) # Quantization happens here int8_model = int8_model.to(0) # Quantization happens here
``` ```
""" """
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
memory_efficient_backward=False, threshold=0.0, index=None, device=None): def __init__(
self,
input_features,
output_features,
bias=True,
has_fp16_weights=True,
memory_efficient_backward=False,
threshold=0.0,
index=None,
device=None,
):
""" """
Initialize Linear8bitLt class. Initialize Linear8bitLt class.
...@@ -647,19 +719,36 @@ class Linear8bitLt(nn.Linear): ...@@ -647,19 +719,36 @@ class Linear8bitLt(nn.Linear):
destination[key_name] = param_from_state if keep_vars else param_from_state.detach() destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
destination[format_name] = self.state.formatB destination[format_name] = self.state.formatB
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, def _load_from_state_dict(
missing_keys, unexpected_keys, error_msgs): self,
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, state_dict,
error_msgs) prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
unexpected_copy = list(unexpected_keys) unexpected_copy = list(unexpected_keys)
for key in unexpected_copy: for key in unexpected_copy:
input_name = key[len(prefix):] input_name = key[len(prefix) :]
if input_name == "SCB": if input_name == "SCB":
if self.weight.SCB is None: if self.weight.SCB is None:
# buffers not yet initialized, can't access them directly without quantizing first # buffers not yet initialized, can't access them directly without quantizing first
raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is " raise RuntimeError(
"not supported. Please call module.cuda() before module.load_state_dict()") "Loading a quantized checkpoint into non-quantized Linear8bitLt is "
"not supported. Please call module.cuda() before module.load_state_dict()",
)
input_param = state_dict[key] input_param = state_dict[key]
self.weight.SCB.copy_(input_param) self.weight.SCB.copy_(input_param)
...@@ -702,18 +791,18 @@ class OutlierAwareLinear(nn.Linear): ...@@ -702,18 +791,18 @@ class OutlierAwareLinear(nn.Linear):
self.is_quantized = False self.is_quantized = False
def forward_with_outliers(self, x, outlier_idx): def forward_with_outliers(self, x, outlier_idx):
raise NotImplementedError('Please override the `forward_with_outliers(self, x, outlier_idx)` function') raise NotImplementedError("Please override the `forward_with_outliers(self, x, outlier_idx)` function")
def quantize_weight(self, w, outlier_idx): def quantize_weight(self, w, outlier_idx):
raise NotImplementedError('Please override the `quantize_weights(self, w, outlier_idx)` function') raise NotImplementedError("Please override the `quantize_weights(self, w, outlier_idx)` function")
def forward(self, x): def forward(self, x):
if self.outlier_dim is None: if self.outlier_dim is None:
tracer = OutlierTracer.get_instance() tracer = OutlierTracer.get_instance()
if not tracer.is_initialized(): if not tracer.is_initialized():
print('Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer') print("Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer")
outlier_idx = tracer.get_outliers(self.weight) outlier_idx = tracer.get_outliers(self.weight)
#print(outlier_idx, tracer.get_hvalue(self.weight)) # print(outlier_idx, tracer.get_hvalue(self.weight))
self.outlier_dim = outlier_idx self.outlier_dim = outlier_idx
if not self.is_quantized: if not self.is_quantized:
...@@ -721,6 +810,7 @@ class OutlierAwareLinear(nn.Linear): ...@@ -721,6 +810,7 @@ class OutlierAwareLinear(nn.Linear):
self.weight.data.copy_(w) self.weight.data.copy_(w)
self.is_quantized = True self.is_quantized = True
class SwitchBackLinearBnb(nn.Linear): class SwitchBackLinearBnb(nn.Linear):
def __init__( def __init__(
self, self,
...@@ -731,11 +821,9 @@ class SwitchBackLinearBnb(nn.Linear): ...@@ -731,11 +821,9 @@ class SwitchBackLinearBnb(nn.Linear):
memory_efficient_backward=False, memory_efficient_backward=False,
threshold=0.0, threshold=0.0,
index=None, index=None,
device=None device=None,
): ):
super().__init__( super().__init__(input_features, output_features, bias, device)
input_features, output_features, bias, device
)
self.state = bnb.MatmulLtState() self.state = bnb.MatmulLtState()
self.index = index self.index = index
...@@ -745,9 +833,7 @@ class SwitchBackLinearBnb(nn.Linear): ...@@ -745,9 +833,7 @@ class SwitchBackLinearBnb(nn.Linear):
if threshold > 0.0 and not has_fp16_weights: if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True self.state.use_pool = True
self.weight = Int8Params( self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights
)
def init_8bit_state(self): def init_8bit_state(self):
self.state.CB = self.weight.CB self.state.CB = self.weight.CB
......
...@@ -22,7 +22,6 @@ from bitsandbytes.triton.triton_utils import is_triton_available ...@@ -22,7 +22,6 @@ from bitsandbytes.triton.triton_utils import is_triton_available
class _switchback_global(torch.autograd.Function): class _switchback_global(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, X_3D, W, bias): def forward(ctx, X_3D, W, bias):
# reshape input to [N * L, D] # reshape input to [N * L, D]
...@@ -37,9 +36,7 @@ class _switchback_global(torch.autograd.Function): ...@@ -37,9 +36,7 @@ class _switchback_global(torch.autograd.Function):
# matmult, fused dequant and add bias # matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized # call "mixed" because we are mixing rowwise quantized and global quantized
return int8_matmul_mixed_dequantize( return int8_matmul_mixed_dequantize(X_int8, W_int8.t(), state_X, state_W, bias).view(*X_3D.size()[:-1], -1)
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D.size()[:-1], -1)
@staticmethod @staticmethod
def backward(ctx, G_3D): def backward(ctx, G_3D):
...@@ -56,7 +53,8 @@ class _switchback_global(torch.autograd.Function): ...@@ -56,7 +53,8 @@ class _switchback_global(torch.autograd.Function):
G_int8, state_G = quantize_rowwise(G) G_int8, state_G = quantize_rowwise(G)
W_int8, state_W = quantize_global_transpose(W) W_int8, state_W = quantize_global_transpose(W)
grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view( grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D.size()[:-1], -1 *G_3D.size()[:-1],
-1,
) )
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
# backward pass uses standard weight grad # backward pass uses standard weight grad
...@@ -66,8 +64,8 @@ class _switchback_global(torch.autograd.Function): ...@@ -66,8 +64,8 @@ class _switchback_global(torch.autograd.Function):
return grad_X, grad_W, grad_bias return grad_X, grad_W, grad_bias
class _switchback_vectorrize(torch.autograd.Function):
class _switchback_vectorrize(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, X_3D, W, bias): def forward(ctx, X_3D, W, bias):
# reshape input to [N * L, D] # reshape input to [N * L, D]
...@@ -81,9 +79,7 @@ class _switchback_vectorrize(torch.autograd.Function): ...@@ -81,9 +79,7 @@ class _switchback_vectorrize(torch.autograd.Function):
# matmult, fused dequant and add bias # matmult, fused dequant and add bias
# call kernel which expects rowwise quantized X and W # call kernel which expects rowwise quantized X and W
return int8_matmul_rowwise_dequantize( return int8_matmul_rowwise_dequantize(X_int8, W_int8.t(), state_X, state_W, bias).view(*X_3D.size()[:-1], -1)
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D.size()[:-1], -1)
@staticmethod @staticmethod
def backward(ctx, G_3D): def backward(ctx, G_3D):
...@@ -99,7 +95,8 @@ class _switchback_vectorrize(torch.autograd.Function): ...@@ -99,7 +95,8 @@ class _switchback_vectorrize(torch.autograd.Function):
G_int8, state_G = quantize_rowwise(G) G_int8, state_G = quantize_rowwise(G)
W_int8, state_W = quantize_columnwise_and_transpose(W) W_int8, state_W = quantize_columnwise_and_transpose(W)
grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view( grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D.size()[:-1], -1 *G_3D.size()[:-1],
-1,
) )
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
# backward pass uses standard weight grad # backward pass uses standard weight grad
...@@ -109,8 +106,8 @@ class _switchback_vectorrize(torch.autograd.Function): ...@@ -109,8 +106,8 @@ class _switchback_vectorrize(torch.autograd.Function):
return grad_X, grad_W, grad_bias return grad_X, grad_W, grad_bias
class _switchback_global_mem_efficient(torch.autograd.Function):
class _switchback_global_mem_efficient(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, X_3D, W, bias): def forward(ctx, X_3D, W, bias):
# reshape input to [N * L, D] # reshape input to [N * L, D]
...@@ -127,9 +124,7 @@ class _switchback_global_mem_efficient(torch.autograd.Function): ...@@ -127,9 +124,7 @@ class _switchback_global_mem_efficient(torch.autograd.Function):
# matmult, fused dequant and add bias # matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized # call "mixed" because we are mixing rowwise quantized and global quantized
return int8_matmul_mixed_dequantize( return int8_matmul_mixed_dequantize(X_int8, W_int8.t(), state_X, state_W, bias).view(*X_3D_sz[:-1], -1)
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D_sz[:-1], -1)
@staticmethod @staticmethod
def backward(ctx, G_3D): def backward(ctx, G_3D):
...@@ -151,12 +146,11 @@ class _switchback_global_mem_efficient(torch.autograd.Function): ...@@ -151,12 +146,11 @@ class _switchback_global_mem_efficient(torch.autograd.Function):
G_int8, state_G = quantize_rowwise(G) G_int8, state_G = quantize_rowwise(G)
del G del G
W_int8 = W_int8.t().contiguous() W_int8 = W_int8.t().contiguous()
grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view( grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(*G_3D_sz[:-1], -1)
*G_3D_sz[:-1], -1
)
return grad_X, grad_W, grad_bias return grad_X, grad_W, grad_bias
class SwitchBackLinear(nn.Linear): class SwitchBackLinear(nn.Linear):
def __init__( def __init__(
self, self,
...@@ -166,20 +160,20 @@ class SwitchBackLinear(nn.Linear): ...@@ -166,20 +160,20 @@ class SwitchBackLinear(nn.Linear):
device=None, device=None,
dtype=None, dtype=None,
vector_wise_quantization: bool = False, vector_wise_quantization: bool = False,
mem_efficient : bool = False, mem_efficient: bool = False,
): ):
super().__init__(in_features, out_features, bias, device, dtype) super().__init__(in_features, out_features, bias, device, dtype)
if not is_triton_available(): if not is_triton_available():
raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear. raise ImportError("""Could not import triton. Please install triton to use SwitchBackLinear.
Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''') Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower""")
# By default, we use the global quantization. # By default, we use the global quantization.
self.vector_wise_quantization = vector_wise_quantization self.vector_wise_quantization = vector_wise_quantization
if self.vector_wise_quantization: if self.vector_wise_quantization:
self._fn = _switchback_vectorrize self._fn = _switchback_vectorrize
if mem_efficient: if mem_efficient:
print('mem efficient is not supported for vector-wise quantization.') print("mem efficient is not supported for vector-wise quantization.")
exit(1) exit(1)
else: else:
if mem_efficient: if mem_efficient:
...@@ -195,7 +189,7 @@ class SwitchBackLinear(nn.Linear): ...@@ -195,7 +189,7 @@ class SwitchBackLinear(nn.Linear):
# if hasattr(m, "prepare_for_eval"): # if hasattr(m, "prepare_for_eval"):
# m.prepare_for_eval() # m.prepare_for_eval()
# model.apply(cond_prepare) # model.apply(cond_prepare)
print('=> preparing for eval.') print("=> preparing for eval.")
if self.vector_wise_quantization: if self.vector_wise_quantization:
W_int8, state_W = quantize_rowwise(self.weight) W_int8, state_W = quantize_rowwise(self.weight)
else: else:
...@@ -219,18 +213,22 @@ class SwitchBackLinear(nn.Linear): ...@@ -219,18 +213,22 @@ class SwitchBackLinear(nn.Linear):
X_int8, state_X = quantize_rowwise(X) X_int8, state_X = quantize_rowwise(X)
if self.vector_wise_quantization: if self.vector_wise_quantization:
return int8_matmul_rowwise_dequantize( return int8_matmul_rowwise_dequantize(X_int8, self.W_int8.t(), state_X, self.state_W, self.bias).view(
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias *x.size()[:-1],
).view(*x.size()[:-1], -1) -1,
)
else: else:
return int8_matmul_mixed_dequantize( return int8_matmul_mixed_dequantize(X_int8, self.W_int8.t(), state_X, self.state_W, self.bias).view(
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias *x.size()[:-1],
).view(*x.size()[:-1], -1) -1,
)
SwitchBackLinearGlobal = partial(SwitchBackLinear, vector_wise_quantization=False) SwitchBackLinearGlobal = partial(SwitchBackLinear, vector_wise_quantization=False)
SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vector_wise_quantization=False, mem_efficient=True) SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vector_wise_quantization=False, mem_efficient=True)
SwitchBackLinearVectorwise = partial(SwitchBackLinear, vector_wise_quantization=True) SwitchBackLinearVectorwise = partial(SwitchBackLinear, vector_wise_quantization=True)
# This is just the standard linear function. # This is just the standard linear function.
class StandardLinearFunction(torch.autograd.Function): class StandardLinearFunction(torch.autograd.Function):
@staticmethod @staticmethod
...@@ -260,7 +258,7 @@ class StandardLinearFunction(torch.autograd.Function): ...@@ -260,7 +258,7 @@ class StandardLinearFunction(torch.autograd.Function):
return grad_input, grad_weight, grad_bias return grad_input, grad_weight, grad_bias
class StandardLinear(nn.Linear):
class StandardLinear(nn.Linear):
def forward(self, x): def forward(self, x):
return StandardLinearFunction.apply(x, self.weight, self.bias) return StandardLinearFunction.apply(x, self.weight, self.bias)
...@@ -50,9 +50,7 @@ class Adagrad(Optimizer1State): ...@@ -50,9 +50,7 @@ class Adagrad(Optimizer1State):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}") raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError( raise ValueError(f"Invalid weight_decay value: {weight_decay}")
f"Invalid weight_decay value: {weight_decay}"
)
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}") raise ValueError(f"Invalid epsilon value: {eps}")
if initial_accumulator_value != 0.0: if initial_accumulator_value != 0.0:
...@@ -119,9 +117,7 @@ class Adagrad8bit(Optimizer1State): ...@@ -119,9 +117,7 @@ class Adagrad8bit(Optimizer1State):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}") raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError( raise ValueError(f"Invalid weight_decay value: {weight_decay}")
f"Invalid weight_decay value: {weight_decay}"
)
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}") raise ValueError(f"Invalid epsilon value: {eps}")
if initial_accumulator_value != 0.0: if initial_accumulator_value != 0.0:
...@@ -189,9 +185,7 @@ class Adagrad32bit(Optimizer1State): ...@@ -189,9 +185,7 @@ class Adagrad32bit(Optimizer1State):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}") raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError( raise ValueError(f"Invalid weight_decay value: {weight_decay}")
f"Invalid weight_decay value: {weight_decay}"
)
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}") raise ValueError(f"Invalid epsilon value: {eps}")
if initial_accumulator_value != 0.0: if initial_accumulator_value != 0.0:
......
...@@ -14,8 +14,21 @@ from bitsandbytes.optim.optimizer import Optimizer2State ...@@ -14,8 +14,21 @@ from bitsandbytes.optim.optimizer import Optimizer2State
class Adam(Optimizer2State): class Adam(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, def __init__(
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
""" """
Base Adam optimizer. Base Adam optimizer.
...@@ -45,11 +58,38 @@ class Adam(Optimizer2State): ...@@ -45,11 +58,38 @@ class Adam(Optimizer2State):
is_paged (`bool`, defaults to `False`): is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not. Whether the optimizer is a paged optimizer or not.
""" """
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class Adam8bit(Optimizer2State): class Adam8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, def __init__(
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
""" """
8-bit Adam optimizer. 8-bit Adam optimizer.
...@@ -79,11 +119,38 @@ class Adam8bit(Optimizer2State): ...@@ -79,11 +119,38 @@ class Adam8bit(Optimizer2State):
is_paged (`bool`, defaults to `False`): is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not. Whether the optimizer is a paged optimizer or not.
""" """
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class Adam32bit(Optimizer2State): class Adam32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, def __init__(
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
""" """
32-bit Adam optimizer. 32-bit Adam optimizer.
...@@ -113,11 +180,38 @@ class Adam32bit(Optimizer2State): ...@@ -113,11 +180,38 @@ class Adam32bit(Optimizer2State):
is_paged (`bool`, defaults to `False`): is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not. Whether the optimizer is a paged optimizer or not.
""" """
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class PagedAdam(Optimizer2State): class PagedAdam(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, def __init__(
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
""" """
Paged Adam optimizer. Paged Adam optimizer.
...@@ -147,11 +241,38 @@ class PagedAdam(Optimizer2State): ...@@ -147,11 +241,38 @@ class PagedAdam(Optimizer2State):
is_paged (`bool`, defaults to `False`): is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not. Whether the optimizer is a paged optimizer or not.
""" """
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
class PagedAdam8bit(Optimizer2State): class PagedAdam8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, def __init__(
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
""" """
8-bit paged Adam optimizer. 8-bit paged Adam optimizer.
...@@ -181,11 +302,38 @@ class PagedAdam8bit(Optimizer2State): ...@@ -181,11 +302,38 @@ class PagedAdam8bit(Optimizer2State):
is_paged (`bool`, defaults to `False`): is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not. Whether the optimizer is a paged optimizer or not.
""" """
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
class PagedAdam32bit(Optimizer2State): class PagedAdam32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, def __init__(
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
""" """
Paged 32-bit Adam optimizer. Paged 32-bit Adam optimizer.
...@@ -215,7 +363,21 @@ class PagedAdam32bit(Optimizer2State): ...@@ -215,7 +363,21 @@ class PagedAdam32bit(Optimizer2State):
is_paged (`bool`, defaults to `False`): is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not. Whether the optimizer is a paged optimizer or not.
""" """
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
class AnalysisAdam(torch.optim.Optimizer): class AnalysisAdam(torch.optim.Optimizer):
"""Adam that performs 8-bit vs 32-bit error analysis. """Adam that performs 8-bit vs 32-bit error analysis.
...@@ -293,9 +455,7 @@ class AnalysisAdam(torch.optim.Optimizer): ...@@ -293,9 +455,7 @@ class AnalysisAdam(torch.optim.Optimizer):
if grad.dtype in {torch.float16, torch.bfloat16}: if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float() grad = grad.float()
if grad.is_sparse: if grad.is_sparse:
raise RuntimeError( raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
"Adam does not support sparse gradients, please consider SparseAdam instead"
)
amsgrad = group.get("amsgrad", False) amsgrad = group.get("amsgrad", False)
assert not amsgrad assert not amsgrad
...@@ -312,15 +472,9 @@ class AnalysisAdam(torch.optim.Optimizer): ...@@ -312,15 +472,9 @@ class AnalysisAdam(torch.optim.Optimizer):
state["exp_avg"] = torch.zeros_like(p_data_fp32) state["exp_avg"] = torch.zeros_like(p_data_fp32)
# Exponential moving average of squared gradient values # Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
state["abserrors"] = torch.zeros( state["abserrors"] = torch.zeros((256, 256), device=p_data_fp32.device)
(256, 256), device=p_data_fp32.device state["relerrors"] = torch.zeros((256, 256), device=p_data_fp32.device)
) state["counts"] = torch.zeros((256, 256), device=p_data_fp32.device)
state["relerrors"] = torch.zeros(
(256, 256), device=p_data_fp32.device
)
state["counts"] = torch.zeros(
(256, 256), device=p_data_fp32.device
)
if amsgrad: if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values # Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32) state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32)
...@@ -328,25 +482,19 @@ class AnalysisAdam(torch.optim.Optimizer): ...@@ -328,25 +482,19 @@ class AnalysisAdam(torch.optim.Optimizer):
state["exp_avg"] = state["exp_avg"].to(p_data_fp32) state["exp_avg"] = state["exp_avg"].to(p_data_fp32)
state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32) state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32)
if amsgrad: if amsgrad:
state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to( state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to(p_data_fp32)
p_data_fp32
)
state["step"] += 1 state["step"] += 1
beta1, beta2 = group["betas"] beta1, beta2 = group["betas"]
bias_correction1 = 1 - beta1 ** state["step"] bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"]
step_size = ( step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
group["lr"] * math.sqrt(bias_correction2) / bias_correction1
)
e = state["abserrors"] e = state["abserrors"]
rele = state["relerrors"] rele = state["relerrors"]
counts = state["counts"] counts = state["counts"]
if group["weight_decay"] != 0: if group["weight_decay"] != 0:
p_data_fp32.add_( p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"])
p_data_fp32, alpha=-group["weight_decay"] * group["lr"]
)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
if amsgrad: if amsgrad:
...@@ -359,10 +507,7 @@ class AnalysisAdam(torch.optim.Optimizer): ...@@ -359,10 +507,7 @@ class AnalysisAdam(torch.optim.Optimizer):
denom = exp_avg_sq.sqrt().add_(group["eps"]) denom = exp_avg_sq.sqrt().add_(group["eps"])
update_fp32 = exp_avg / denom update_fp32 = exp_avg / denom
if ( if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000 * 1000:
p_data_fp32.numel() <= 8192
or p_data_fp32.numel() > 50000 * 1000
):
# embedding layer or too small # embedding layer or too small
p_data_fp32 += -step_size * update_fp32 p_data_fp32 += -step_size * update_fp32
else: else:
...@@ -401,9 +546,7 @@ class AnalysisAdam(torch.optim.Optimizer): ...@@ -401,9 +546,7 @@ class AnalysisAdam(torch.optim.Optimizer):
# 3. dequantize # 3. dequantize
# Error will be calculated automatically! # Error will be calculated automatically!
else: else:
raise ValueError( raise ValueError(f"Invalid analysis value: {self.analysis}!")
f"Invalid analysis value: {self.analysis}!"
)
denom = state2.sqrt().add_(group["eps"]) denom = state2.sqrt().add_(group["eps"])
update_8bit = state1 / denom update_8bit = state1 / denom
...@@ -415,9 +558,7 @@ class AnalysisAdam(torch.optim.Optimizer): ...@@ -415,9 +558,7 @@ class AnalysisAdam(torch.optim.Optimizer):
F.histogram_scatter_add_2d(e, C1.int(), C2.int(), abserr) F.histogram_scatter_add_2d(e, C1.int(), C2.int(), abserr)
F.histogram_scatter_add_2d(rele, C1.int(), C2.int(), relerr) F.histogram_scatter_add_2d(rele, C1.int(), C2.int(), relerr)
F.histogram_scatter_add_2d( F.histogram_scatter_add_2d(counts, C1.int(), C2.int(), torch.ones_like(abserr))
counts, C1.int(), C2.int(), torch.ones_like(abserr)
)
p_data_fp32 += -step_size * update_fp32 p_data_fp32 += -step_size * update_fp32
...@@ -425,18 +566,10 @@ class AnalysisAdam(torch.optim.Optimizer): ...@@ -425,18 +566,10 @@ class AnalysisAdam(torch.optim.Optimizer):
if self.savedir != "" and state["step"] % 100 == 0: if self.savedir != "" and state["step"] % 100 == 0:
if not os.path.exists(self.savedir): if not os.path.exists(self.savedir):
os.makedirs(self.savedir) os.makedirs(self.savedir)
shapestr = "_".join( shapestr = "_".join([str(dim) for dim in p_data_fp32.shape])
[str(dim) for dim in p_data_fp32.shape] pathe = os.path.join(self.savedir, f"{p_id}_{shapestr}_abserr.pkl")
) pathrele = os.path.join(self.savedir, f"{p_id}_{shapestr}_relerr.pkl")
pathe = os.path.join( pathcounts = os.path.join(self.savedir, f"{p_id}_{shapestr}_counts.pkl")
self.savedir, f"{p_id}_{shapestr}_abserr.pkl"
)
pathrele = os.path.join(
self.savedir, f"{p_id}_{shapestr}_relerr.pkl"
)
pathcounts = os.path.join(
self.savedir, f"{p_id}_{shapestr}_counts.pkl"
)
torch.save(e, pathe) torch.save(e, pathe)
torch.save(rele, pathrele) torch.save(rele, pathrele)
torch.save(counts, pathcounts) torch.save(counts, pathcounts)
......
...@@ -6,8 +6,21 @@ from bitsandbytes.optim.optimizer import Optimizer2State ...@@ -6,8 +6,21 @@ from bitsandbytes.optim.optimizer import Optimizer2State
class AdamW(Optimizer2State): class AdamW(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, def __init__(
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
""" """
Base AdamW optimizer. Base AdamW optimizer.
...@@ -37,11 +50,38 @@ class AdamW(Optimizer2State): ...@@ -37,11 +50,38 @@ class AdamW(Optimizer2State):
is_paged (`bool`, defaults to `False`): is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not. Whether the optimizer is a paged optimizer or not.
""" """
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged ) super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class AdamW8bit(Optimizer2State): class AdamW8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, def __init__(
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
""" """
8-bit AdamW optimizer. 8-bit AdamW optimizer.
...@@ -71,11 +111,38 @@ class AdamW8bit(Optimizer2State): ...@@ -71,11 +111,38 @@ class AdamW8bit(Optimizer2State):
is_paged (`bool`, defaults to `False`): is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not. Whether the optimizer is a paged optimizer or not.
""" """
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged ) super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class AdamW32bit(Optimizer2State): class AdamW32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, def __init__(
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
""" """
32-bit AdamW optimizer. 32-bit AdamW optimizer.
...@@ -105,12 +172,37 @@ class AdamW32bit(Optimizer2State): ...@@ -105,12 +172,37 @@ class AdamW32bit(Optimizer2State):
is_paged (`bool`, defaults to `False`): is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not. Whether the optimizer is a paged optimizer or not.
""" """
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class PagedAdamW(Optimizer2State): class PagedAdamW(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, def __init__(
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
""" """
Paged AdamW optimizer. Paged AdamW optimizer.
...@@ -140,11 +232,37 @@ class PagedAdamW(Optimizer2State): ...@@ -140,11 +232,37 @@ class PagedAdamW(Optimizer2State):
is_paged (`bool`, defaults to `False`): is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not. Whether the optimizer is a paged optimizer or not.
""" """
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
class PagedAdamW8bit(Optimizer2State): class PagedAdamW8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, def __init__(
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
""" """
Paged 8-bit AdamW optimizer. Paged 8-bit AdamW optimizer.
...@@ -174,11 +292,37 @@ class PagedAdamW8bit(Optimizer2State): ...@@ -174,11 +292,37 @@ class PagedAdamW8bit(Optimizer2State):
is_paged (`bool`, defaults to `False`): is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not. Whether the optimizer is a paged optimizer or not.
""" """
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
class PagedAdamW32bit(Optimizer2State): class PagedAdamW32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, def __init__(
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
""" """
Paged 32-bit AdamW optimizer. Paged 32-bit AdamW optimizer.
...@@ -208,4 +352,17 @@ class PagedAdamW32bit(Optimizer2State): ...@@ -208,4 +352,17 @@ class PagedAdamW32bit(Optimizer2State):
is_paged (`bool`, defaults to `False`): is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not. Whether the optimizer is a paged optimizer or not.
""" """
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
...@@ -51,9 +51,7 @@ class LARS(Optimizer1State): ...@@ -51,9 +51,7 @@ class LARS(Optimizer1State):
The maximum gradient norm. The maximum gradient norm.
""" """
if momentum == 0: if momentum == 0:
raise NotImplementedError( raise NotImplementedError("LARS without momentum is not supported!")
"LARS without momentum is not supported!"
)
super().__init__( super().__init__(
"lars", "lars",
params, params,
...@@ -110,9 +108,7 @@ class LARS8bit(Optimizer1State): ...@@ -110,9 +108,7 @@ class LARS8bit(Optimizer1State):
The maximum gradient norm. The maximum gradient norm.
""" """
if momentum == 0: if momentum == 0:
raise NotImplementedError( raise NotImplementedError("LARS without momentum is not supported!")
"LARS without momentum is not supported!"
)
super().__init__( super().__init__(
"lars", "lars",
params, params,
...@@ -169,9 +165,7 @@ class LARS32bit(Optimizer1State): ...@@ -169,9 +165,7 @@ class LARS32bit(Optimizer1State):
The maximum gradient norm. The maximum gradient norm.
""" """
if momentum == 0: if momentum == 0:
raise NotImplementedError( raise NotImplementedError("LARS without momentum is not supported!")
"LARS without momentum is not supported!"
)
super().__init__( super().__init__(
"lars", "lars",
params, params,
...@@ -204,9 +198,7 @@ class PytorchLARS(Optimizer): ...@@ -204,9 +198,7 @@ class PytorchLARS(Optimizer):
if momentum < 0.0: if momentum < 0.0:
raise ValueError(f"Invalid momentum value: {momentum}") raise ValueError(f"Invalid momentum value: {momentum}")
if weight_decay < 0.0: if weight_decay < 0.0:
raise ValueError( raise ValueError(f"Invalid weight_decay value: {weight_decay}")
f"Invalid weight_decay value: {weight_decay}"
)
defaults = dict( defaults = dict(
lr=lr, lr=lr,
...@@ -217,9 +209,7 @@ class PytorchLARS(Optimizer): ...@@ -217,9 +209,7 @@ class PytorchLARS(Optimizer):
max_unorm=max_unorm, max_unorm=max_unorm,
) )
if nesterov and (momentum <= 0 or dampening != 0): if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError( raise ValueError("Nesterov momentum requires a momentum and zero dampening")
"Nesterov momentum requires a momentum and zero dampening"
)
super().__init__(params, defaults) super().__init__(params, defaults)
def __setstate__(self, state): def __setstate__(self, state):
......
...@@ -6,7 +6,19 @@ from bitsandbytes.optim.optimizer import Optimizer1State ...@@ -6,7 +6,19 @@ from bitsandbytes.optim.optimizer import Optimizer1State
class Lion(Optimizer1State): class Lion(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
""" """
Base Lion optimizer. Base Lion optimizer.
...@@ -32,10 +44,35 @@ class Lion(Optimizer1State): ...@@ -32,10 +44,35 @@ class Lion(Optimizer1State):
is_paged (`bool`, defaults to `False`): is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not. Whether the optimizer is a paged optimizer or not.
""" """
super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) super().__init__(
"lion",
params,
lr,
betas,
0.0,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class Lion8bit(Optimizer1State): class Lion8bit(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
""" """
8-bit Lion optimizer. 8-bit Lion optimizer.
...@@ -59,10 +96,35 @@ class Lion8bit(Optimizer1State): ...@@ -59,10 +96,35 @@ class Lion8bit(Optimizer1State):
is_paged (`bool`, defaults to `False`): is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not. Whether the optimizer is a paged optimizer or not.
""" """
super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) super().__init__(
"lion",
params,
lr,
betas,
0.0,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class Lion32bit(Optimizer1State): class Lion32bit(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
""" """
32-bit Lion optimizer. 32-bit Lion optimizer.
...@@ -86,11 +148,35 @@ class Lion32bit(Optimizer1State): ...@@ -86,11 +148,35 @@ class Lion32bit(Optimizer1State):
is_paged (`bool`, defaults to `False`): is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not. Whether the optimizer is a paged optimizer or not.
""" """
super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) super().__init__(
"lion",
params,
lr,
betas,
0.0,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class PagedLion(Optimizer1State): class PagedLion(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
""" """
Paged Lion optimizer. Paged Lion optimizer.
...@@ -114,10 +200,34 @@ class PagedLion(Optimizer1State): ...@@ -114,10 +200,34 @@ class PagedLion(Optimizer1State):
block_wise (`bool`, defaults to `True`): block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
""" """
super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) super().__init__(
"lion",
params,
lr,
betas,
0.0,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
class PagedLion8bit(Optimizer1State): class PagedLion8bit(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
""" """
Paged 8-bit Lion optimizer. Paged 8-bit Lion optimizer.
...@@ -141,10 +251,34 @@ class PagedLion8bit(Optimizer1State): ...@@ -141,10 +251,34 @@ class PagedLion8bit(Optimizer1State):
block_wise (`bool`, defaults to `True`): block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
""" """
super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) super().__init__(
"lion",
params,
lr,
betas,
0.0,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
class PagedLion32bit(Optimizer1State): class PagedLion32bit(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
""" """
Paged 32-bit Lion optimizer. Paged 32-bit Lion optimizer.
...@@ -168,4 +302,17 @@ class PagedLion32bit(Optimizer1State): ...@@ -168,4 +302,17 @@ class PagedLion32bit(Optimizer1State):
block_wise (`bool`, defaults to `True`): block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
""" """
super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) super().__init__(
"lion",
params,
lr,
betas,
0.0,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
...@@ -21,6 +21,7 @@ class GlobalOptimManager: ...@@ -21,6 +21,7 @@ class GlobalOptimManager:
""" """
A global optimizer manager for enabling custom optimizer configs. A global optimizer manager for enabling custom optimizer configs.
""" """
_instance = None _instance = None
def __init__(self): def __init__(self):
...@@ -48,13 +49,9 @@ class GlobalOptimManager: ...@@ -48,13 +49,9 @@ class GlobalOptimManager:
for group_index, group in enumerate(param_groups): for group_index, group in enumerate(param_groups):
for p_index, p in enumerate(group["params"]): for p_index, p in enumerate(group["params"]):
if id(p) in self.pid2config: if id(p) in self.pid2config:
self.index2config[(group_index, p_index)] = self.pid2config[ self.index2config[(group_index, p_index)] = self.pid2config[id(p)]
id(p)
]
def override_config( def override_config(self, parameters, key=None, value=None, key_value_dict=None):
self, parameters, key=None, value=None, key_value_dict=None
):
""" """
Override initial optimizer config with specific hyperparameters. Override initial optimizer config with specific hyperparameters.
...@@ -170,16 +167,12 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -170,16 +167,12 @@ class Optimizer8bit(torch.optim.Optimizer):
saved_groups = state_dict["param_groups"] saved_groups = state_dict["param_groups"]
if len(groups) != len(saved_groups): if len(groups) != len(saved_groups):
raise ValueError( raise ValueError("loaded state dict has a different number of parameter groups")
"loaded state dict has a different number of "
"parameter groups"
)
param_lens = (len(g["params"]) for g in groups) param_lens = (len(g["params"]) for g in groups)
saved_lens = (len(g["params"]) for g in saved_groups) saved_lens = (len(g["params"]) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError( raise ValueError(
"loaded state dict contains a parameter group " "loaded state dict contains a parameter group that doesn't match the size of optimizer's group",
"that doesn't match the size of optimizer's group"
) )
# Update the state # Update the state
...@@ -228,9 +221,7 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -228,9 +221,7 @@ class Optimizer8bit(torch.optim.Optimizer):
new_group["params"] = group["params"] new_group["params"] = group["params"]
return new_group return new_group
param_groups = [ param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
update_group(g, ng) for g, ng in zip(groups, saved_groups)
]
self.__setstate__({"state": state, "param_groups": param_groups}) self.__setstate__({"state": state, "param_groups": param_groups})
def to_gpu(self): def to_gpu(self):
...@@ -240,7 +231,7 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -240,7 +231,7 @@ class Optimizer8bit(torch.optim.Optimizer):
values = self.state[p] values = self.state[p]
for k, v in values.items(): for k, v in values.items():
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
is_paged = getattr(v, 'is_paged', False) is_paged = getattr(v, "is_paged", False)
if not is_paged: if not is_paged:
self.state[p][k] = v.to(p.device) self.state[p][k] = v.to(p.device)
...@@ -248,9 +239,7 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -248,9 +239,7 @@ class Optimizer8bit(torch.optim.Optimizer):
for module, attr, config in self.mng.module_weight_config_triple: for module, attr, config in self.mng.module_weight_config_triple:
pmodule = getattr(module, attr) pmodule = getattr(module, attr)
assert pmodule is not None assert pmodule is not None
assert isinstance(pmodule, torch.Tensor) or isinstance( assert isinstance(pmodule, torch.Tensor) or isinstance(pmodule, torch.Parameter)
pmodule, torch.Parameter
)
found = False found = False
for gindex, group in enumerate(self.param_groups): for gindex, group in enumerate(self.param_groups):
if found: if found:
...@@ -262,9 +251,7 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -262,9 +251,7 @@ class Optimizer8bit(torch.optim.Optimizer):
# found the matching parameter # found the matching parameter
# init override # init override
self.mng.pid2config[id(p)] = config self.mng.pid2config[id(p)] = config
self.mng.index2config[ self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[id(p)]
(gindex, pindex)
] = self.mng.pid2config[id(p)]
found = True found = True
@torch.no_grad() @torch.no_grad()
...@@ -287,7 +274,7 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -287,7 +274,7 @@ class Optimizer8bit(torch.optim.Optimizer):
self.to_gpu() # needed for fairseq pure fp16 training self.to_gpu() # needed for fairseq pure fp16 training
self.initialized = True self.initialized = True
#if self.is_paged: self.page_mng.prefetch_all() # if self.is_paged: self.page_mng.prefetch_all()
for gindex, group in enumerate(self.param_groups): for gindex, group in enumerate(self.param_groups):
for pindex, p in enumerate(group["params"]): for pindex, p in enumerate(group["params"]):
if p.grad is None: if p.grad is None:
...@@ -304,7 +291,6 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -304,7 +291,6 @@ class Optimizer8bit(torch.optim.Optimizer):
# to sync to make sure all tensors are in the right state # to sync to make sure all tensors are in the right state
torch.cuda.synchronize() torch.cuda.synchronize()
return loss return loss
def get_config(self, gindex, pindex, group): def get_config(self, gindex, pindex, group):
...@@ -328,9 +314,7 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -328,9 +314,7 @@ class Optimizer8bit(torch.optim.Optimizer):
raise NotImplementedError("init_state method needs to be overridden") raise NotImplementedError("init_state method needs to be overridden")
def update_step(self, group, p, gindex, pindex): def update_step(self, group, p, gindex, pindex):
raise NotImplementedError( raise NotImplementedError("The update_step method needs to be overridden")
"The update_step method needs to be overridden"
)
def get_state_buffer(self, p, dtype=torch.float32): def get_state_buffer(self, p, dtype=torch.float32):
if not self.is_paged or p.numel() < 1e5: if not self.is_paged or p.numel() < 1e5:
...@@ -345,12 +329,12 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -345,12 +329,12 @@ class Optimizer8bit(torch.optim.Optimizer):
def prefetch_state(self, p): def prefetch_state(self, p):
if self.is_paged: if self.is_paged:
state = self.state[p] state = self.state[p]
s1 = state['state1'] s1 = state["state1"]
is_paged = getattr(s1, 'is_paged', False) is_paged = getattr(s1, "is_paged", False)
if is_paged: if is_paged:
F.prefetch_tensor(state['state1']) F.prefetch_tensor(state["state1"])
if 'state2' in state: if "state2" in state:
F.prefetch_tensor(state['state2']) F.prefetch_tensor(state["state2"])
class Optimizer2State(Optimizer8bit): class Optimizer2State(Optimizer8bit):
...@@ -369,7 +353,7 @@ class Optimizer2State(Optimizer8bit): ...@@ -369,7 +353,7 @@ class Optimizer2State(Optimizer8bit):
block_wise=True, block_wise=True,
max_unorm=0.0, max_unorm=0.0,
skip_zeros=False, skip_zeros=False,
is_paged=False is_paged=False,
): ):
""" """
Base 2-state update optimizer class. Base 2-state update optimizer class.
...@@ -414,13 +398,9 @@ class Optimizer2State(Optimizer8bit): ...@@ -414,13 +398,9 @@ class Optimizer2State(Optimizer8bit):
betas = [float(b) for b in betas] betas = [float(b) for b in betas]
for i in range(len(betas)): for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0: if not 0.0 <= betas[i] < 1.0:
raise ValueError( raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
f"Invalid beta parameter at index {i}: {betas[i]}"
)
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError( raise ValueError(f"Invalid weight_decay value: {weight_decay}")
f"Invalid weight_decay value: {weight_decay}"
)
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults, optim_bits, is_paged) super().__init__(params, defaults, optim_bits, is_paged)
...@@ -449,9 +429,7 @@ class Optimizer2State(Optimizer8bit): ...@@ -449,9 +429,7 @@ class Optimizer2State(Optimizer8bit):
elif config["optim_bits"] == 8: elif config["optim_bits"] == 8:
dtype = torch.uint8 dtype = torch.uint8
else: else:
raise NotImplementedError( raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')
f'Amount of optimizer bits not supported: {config["optim_bits"]}'
)
if p.numel() < config["min_8bit_size"]: if p.numel() < config["min_8bit_size"]:
dtype = torch.float32 dtype = torch.float32
...@@ -459,21 +437,15 @@ class Optimizer2State(Optimizer8bit): ...@@ -459,21 +437,15 @@ class Optimizer2State(Optimizer8bit):
state = self.state[p] state = self.state[p]
state["step"] = 0 state["step"] = 0
if dtype == torch.float32 or ( if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
dtype == torch.uint8 and p.numel() < 4096
):
state["state1"] = self.get_state_buffer(p, dtype=torch.float32) state["state1"] = self.get_state_buffer(p, dtype=torch.float32)
state["state2"] = self.get_state_buffer(p, dtype=torch.float32) state["state2"] = self.get_state_buffer(p, dtype=torch.float32)
elif dtype == torch.uint8: elif dtype == torch.uint8:
if state["step"] == 0: if state["step"] == 0:
if "dynamic" not in self.name2qmap: if "dynamic" not in self.name2qmap:
self.fill_qmap() self.fill_qmap()
self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to( self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device)
p.device self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(p.device)
)
self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(
p.device
)
state["state1"] = self.get_state_buffer(p, dtype=torch.uint8) state["state1"] = self.get_state_buffer(p, dtype=torch.uint8)
state["qmap1"] = self.name2qmap["dynamic"] state["qmap1"] = self.name2qmap["dynamic"]
...@@ -486,25 +458,13 @@ class Optimizer2State(Optimizer8bit): ...@@ -486,25 +458,13 @@ class Optimizer2State(Optimizer8bit):
blocks = n // 2048 blocks = n // 2048
blocks += 1 if n % 2048 > 0 else 0 blocks += 1 if n % 2048 > 0 else 0
state["absmax1"] = torch.zeros( state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
(blocks,), dtype=torch.float32, device=p.device state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
)
state["absmax2"] = torch.zeros(
(blocks,), dtype=torch.float32, device=p.device
)
else: else:
state["max1"] = torch.zeros( state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
(1,), dtype=torch.float32, device=p.device state["new_max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
) state["max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
state["new_max1"] = torch.zeros( state["new_max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
(1,), dtype=torch.float32, device=p.device
)
state["max2"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
state["new_max2"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
if config["percentile_clipping"] < 100: if config["percentile_clipping"] < 100:
state["gnorm_vec"] = torch.zeros((100,), device=p.device) state["gnorm_vec"] = torch.zeros((100,), device=p.device)
...@@ -524,7 +484,10 @@ class Optimizer2State(Optimizer8bit): ...@@ -524,7 +484,10 @@ class Optimizer2State(Optimizer8bit):
if config["percentile_clipping"] < 100: if config["percentile_clipping"] < 100:
current_gnorm, clip_value, gnorm_scale = F.percentile_clipping( current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(
grad, state["gnorm_vec"], step, config["percentile_clipping"] grad,
state["gnorm_vec"],
step,
config["percentile_clipping"],
) )
else: else:
gnorm_scale = 1.0 gnorm_scale = 1.0
...@@ -568,9 +531,7 @@ class Optimizer2State(Optimizer8bit): ...@@ -568,9 +531,7 @@ class Optimizer2State(Optimizer8bit):
state["new_max2"], state["new_max2"],
config["weight_decay"], config["weight_decay"],
gnorm_scale=gnorm_scale, gnorm_scale=gnorm_scale,
unorm_vec=state["unorm_vec"] unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
if config["max_unorm"] > 0.0
else None,
max_unorm=config["max_unorm"], max_unorm=config["max_unorm"],
) )
...@@ -615,7 +576,7 @@ class Optimizer1State(Optimizer8bit): ...@@ -615,7 +576,7 @@ class Optimizer1State(Optimizer8bit):
block_wise=True, block_wise=True,
max_unorm=0.0, max_unorm=0.0,
skip_zeros=False, skip_zeros=False,
is_paged=False is_paged=False,
): ):
""" """
Base 1-state update optimizer class. Base 1-state update optimizer class.
...@@ -656,13 +617,9 @@ class Optimizer1State(Optimizer8bit): ...@@ -656,13 +617,9 @@ class Optimizer1State(Optimizer8bit):
raise ValueError(f"Invalid epsilon value: {eps}") raise ValueError(f"Invalid epsilon value: {eps}")
for i in range(len(betas)): for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0: if not 0.0 <= betas[i] < 1.0:
raise ValueError( raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
f"Invalid beta parameter at index {i}: {betas[i]}"
)
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError( raise ValueError(f"Invalid weight_decay value: {weight_decay}")
f"Invalid weight_decay value: {weight_decay}"
)
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults, optim_bits, is_paged) super().__init__(params, defaults, optim_bits, is_paged)
...@@ -691,9 +648,7 @@ class Optimizer1State(Optimizer8bit): ...@@ -691,9 +648,7 @@ class Optimizer1State(Optimizer8bit):
elif config["optim_bits"] == 8: elif config["optim_bits"] == 8:
dtype = torch.uint8 dtype = torch.uint8
else: else:
raise NotImplementedError( raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')
f'Amount of optimizer bits not supported: {config["optim_bits"]}'
)
if p.numel() < config["min_8bit_size"]: if p.numel() < config["min_8bit_size"]:
dtype = torch.float32 dtype = torch.float32
...@@ -701,17 +656,13 @@ class Optimizer1State(Optimizer8bit): ...@@ -701,17 +656,13 @@ class Optimizer1State(Optimizer8bit):
state = self.state[p] state = self.state[p]
state["step"] = 0 state["step"] = 0
if dtype == torch.float32 or ( if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
dtype == torch.uint8 and p.numel() < 4096
):
state["state1"] = self.get_state_buffer(p, dtype=torch.float32) state["state1"] = self.get_state_buffer(p, dtype=torch.float32)
elif dtype == torch.uint8: elif dtype == torch.uint8:
if state["step"] == 0: if state["step"] == 0:
if "dynamic" not in self.name2qmap: if "dynamic" not in self.name2qmap:
self.fill_qmap() self.fill_qmap()
self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to( self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device)
p.device
)
state["state1"] = self.get_state_buffer(p, dtype=torch.uint8) state["state1"] = self.get_state_buffer(p, dtype=torch.uint8)
state["qmap1"] = self.name2qmap["dynamic"] state["qmap1"] = self.name2qmap["dynamic"]
...@@ -721,16 +672,10 @@ class Optimizer1State(Optimizer8bit): ...@@ -721,16 +672,10 @@ class Optimizer1State(Optimizer8bit):
blocks = n // 2048 blocks = n // 2048
blocks += 1 if n % 2048 > 0 else 0 blocks += 1 if n % 2048 > 0 else 0
state["absmax1"] = torch.zeros( state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
(blocks,), dtype=torch.float32, device=p.device
)
else: else:
state["max1"] = torch.zeros( state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
(1,), dtype=torch.float32, device=p.device state["new_max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
)
state["new_max1"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
if config["percentile_clipping"] < 100: if config["percentile_clipping"] < 100:
state["gnorm_vec"] = torch.zeros((100,), device=p.device) state["gnorm_vec"] = torch.zeros((100,), device=p.device)
...@@ -750,7 +695,10 @@ class Optimizer1State(Optimizer8bit): ...@@ -750,7 +695,10 @@ class Optimizer1State(Optimizer8bit):
if config["percentile_clipping"] < 100: if config["percentile_clipping"] < 100:
current_gnorm, clip_value, gnorm_scale = F.percentile_clipping( current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(
grad, state["gnorm_vec"], step, config["percentile_clipping"] grad,
state["gnorm_vec"],
step,
config["percentile_clipping"],
) )
else: else:
gnorm_scale = 1.0 gnorm_scale = 1.0
...@@ -766,7 +714,7 @@ class Optimizer1State(Optimizer8bit): ...@@ -766,7 +714,7 @@ class Optimizer1State(Optimizer8bit):
step, step,
config["lr"], config["lr"],
None, None,
config['betas'][1], config["betas"][1],
config["weight_decay"], config["weight_decay"],
gnorm_scale, gnorm_scale,
state["unorm_vec"] if config["max_unorm"] > 0.0 else None, state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
......
...@@ -51,9 +51,7 @@ class RMSprop(Optimizer1State): ...@@ -51,9 +51,7 @@ class RMSprop(Optimizer1State):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
""" """
if alpha == 0: if alpha == 0:
raise NotImplementedError( raise NotImplementedError("RMSprop with alpha==0.0 is not supported!")
"RMSprop with alpha==0.0 is not supported!"
)
if centered: if centered:
raise NotImplementedError("Centered RMSprop is not supported!") raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__( super().__init__(
...@@ -116,9 +114,7 @@ class RMSprop8bit(Optimizer1State): ...@@ -116,9 +114,7 @@ class RMSprop8bit(Optimizer1State):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
""" """
if alpha == 0: if alpha == 0:
raise NotImplementedError( raise NotImplementedError("RMSprop with alpha==0.0 is not supported!")
"RMSprop with alpha==0.0 is not supported!"
)
if centered: if centered:
raise NotImplementedError("Centered RMSprop is not supported!") raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__( super().__init__(
...@@ -182,9 +178,7 @@ class RMSprop32bit(Optimizer1State): ...@@ -182,9 +178,7 @@ class RMSprop32bit(Optimizer1State):
""" """
if alpha == 0: if alpha == 0:
raise NotImplementedError( raise NotImplementedError("RMSprop with alpha==0.0 is not supported!")
"RMSprop with alpha==0.0 is not supported!"
)
if centered: if centered:
raise NotImplementedError("Centered RMSprop is not supported!") raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__( super().__init__(
......
...@@ -195,9 +195,9 @@ class SwitchBackBnb(torch.autograd.Function): ...@@ -195,9 +195,9 @@ class SwitchBackBnb(torch.autograd.Function):
ctx.B = B ctx.B = B
ctx.bias = bias ctx.bias = bias
if A.shape[-1] == B.shape[0]: if A.shape[-1] == B.shape[0]:
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device) return torch.empty(A.shape[:-1] + B.shape[1:], dtype=A.dtype, device=A.device)
else: else:
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device) return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=A.device)
# 1. Quantize A # 1. Quantize A
# 2. Quantize B # 2. Quantize B
...@@ -216,9 +216,7 @@ class SwitchBackBnb(torch.autograd.Function): ...@@ -216,9 +216,7 @@ class SwitchBackBnb(torch.autograd.Function):
# 1. Quantize A # 1. Quantize A
if len(A.shape) == 3: if len(A.shape) == 3:
A = A.view(-1, A.shape[-1]).contiguous() A = A.view(-1, A.shape[-1]).contiguous()
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant( CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold)
A.to(torch.float16), threshold=state.threshold
)
if state.threshold > 0.0 and coo_tensorA is not None: if state.threshold > 0.0 and coo_tensorA is not None:
if state.has_fp16_weights: if state.has_fp16_weights:
...@@ -234,14 +232,14 @@ class SwitchBackBnb(torch.autograd.Function): ...@@ -234,14 +232,14 @@ class SwitchBackBnb(torch.autograd.Function):
# we also need to convert it to the turing/ampere format # we also need to convert it to the turing/ampere format
state.CxB, state.SB = F.transform(state.CB, to_order=formatB) state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
else: else:
#print('A shape', A.shape) # print('A shape', A.shape)
if not state.has_fp16_weights and state.CxB is None: if not state.has_fp16_weights and state.CxB is None:
state.CxB, state.SB = F.transform(state.CB, to_order=formatB) state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
subA = None subA = None
# 2. Quantize B # 2. Quantize B
if state.has_fp16_weights: if state.has_fp16_weights:
#print('B shape', B.shape) # print('B shape', B.shape)
has_grad = True if (getattr(B, "grad", None) is not None) else False has_grad = True if (getattr(B, "grad", None) is not None) else False
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1) is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
if is_transposed: if is_transposed:
...@@ -272,12 +270,7 @@ class SwitchBackBnb(torch.autograd.Function): ...@@ -272,12 +270,7 @@ class SwitchBackBnb(torch.autograd.Function):
# else: # else:
# state.idx = outlier_idx # state.idx = outlier_idx
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
state.subB = ( state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype)
(outliers * state.SCB.view(-1, 1) / 127.0)
.t()
.contiguous()
.to(A.dtype)
)
CA[:, state.idx.long()] = 0 CA[:, state.idx.long()] = 0
CAt[:, state.idx.long()] = 0 CAt[:, state.idx.long()] = 0
subA = A[:, state.idx.long()] subA = A[:, state.idx.long()]
...@@ -320,14 +313,13 @@ class SwitchBackBnb(torch.autograd.Function): ...@@ -320,14 +313,13 @@ class SwitchBackBnb(torch.autograd.Function):
ctx.tensor_states = (None, None) ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None) ctx.save_for_backward(None, None)
clone_func = torch.clone if len(output_shape) == 3 else lambda x: x
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
return clone_func(output.view(output_shape)) return clone_func(output.view(output_shape))
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
if ctx.is_empty: if ctx.is_empty:
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias)) bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
CAt, subA, A = ctx.tensors CAt, subA, A = ctx.tensors
...@@ -342,9 +334,7 @@ class SwitchBackBnb(torch.autograd.Function): ...@@ -342,9 +334,7 @@ class SwitchBackBnb(torch.autograd.Function):
# Cast grad_output to fp16 # Cast grad_output to fp16
if len(grad_output.shape) == 3: if len(grad_output.shape) == 3:
grad_output = grad_output.reshape( grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
-1, grad_output.shape[-1]
).contiguous()
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
...@@ -357,25 +347,24 @@ class SwitchBackBnb(torch.autograd.Function): ...@@ -357,25 +347,24 @@ class SwitchBackBnb(torch.autograd.Function):
if state.CBt is not None: if state.CBt is not None:
C32grad, Sgrad = F.transform(Cgrad, "col32") C32grad, Sgrad = F.transform(Cgrad, "col32")
if state.CxBt is None: if state.CxBt is None:
state.CxBt, state.SBt = F.transform( state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
state.CBt, to_order=formatB, transpose=True
)
# print('back B shape', state.CxBt.shape) # print('back B shape', state.CxBt.shape)
# print('back grad shape', C32grad.shape) # print('back grad shape', C32grad.shape)
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
elif state.CB is not None: elif state.CB is not None:
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0)) CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
else: else:
raise Exception('State must contain either CBt or CB matrix for backward') raise Exception("State must contain either CBt or CB matrix for backward")
return grad_A, grad_B, None, grad_bias, None return grad_A, grad_B, None, grad_bias, None
def get_block_sizes(input_matrix, weight_matrix): def get_block_sizes(input_matrix, weight_matrix):
input_features = input_matrix.shape[-1] input_features = input_matrix.shape[-1]
output_features = (weight_matrix.shape[0] if weight_matrix.shape[1] == input_features else weight_matrix.shape[1]) output_features = weight_matrix.shape[0] if weight_matrix.shape[1] == input_features else weight_matrix.shape[1]
array = [4096, 2048, 1024, 512, 256, 128, 64, 0] array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
bsz, bsz2 = 1024, 1024 bsz, bsz2 = 1024, 1024
for i, k in enumerate(array): for i, k in enumerate(array):
...@@ -399,7 +388,8 @@ def matmul_fp8_global( ...@@ -399,7 +388,8 @@ def matmul_fp8_global(
bsz: int = -1, bsz: int = -1,
bsz2: int = -1, bsz2: int = -1,
): ):
if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B) if bsz == -1 or bsz2 == -1:
bsz, bsz2 = get_block_sizes(A, B)
return MatMulFP8Global.apply(A, B, out, fw_code, bw_code, bsz, bsz2) return MatMulFP8Global.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
...@@ -412,7 +402,8 @@ def matmul_fp8_mixed( ...@@ -412,7 +402,8 @@ def matmul_fp8_mixed(
bsz: int = -1, bsz: int = -1,
bsz2: int = -1, bsz2: int = -1,
): ):
if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B) if bsz == -1 or bsz2 == -1:
bsz, bsz2 = get_block_sizes(A, B)
return MatMulFP8Mixed.apply(A, B, out, fw_code, bw_code, bsz, bsz2) return MatMulFP8Mixed.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
...@@ -422,7 +413,7 @@ def switchback_bnb( ...@@ -422,7 +413,7 @@ def switchback_bnb(
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
state: Optional[MatmulLtState] = None, state: Optional[MatmulLtState] = None,
threshold=0.0, threshold=0.0,
bias=None bias=None,
): ):
state = state or MatmulLtState() state = state or MatmulLtState()
if threshold > 0.0: if threshold > 0.0:
......
...@@ -28,12 +28,20 @@ class LinearFP8Mixed(nn.Linear): ...@@ -28,12 +28,20 @@ class LinearFP8Mixed(nn.Linear):
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device) self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device) self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
out = bnb.research.matmul_fp8_mixed(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2) out = bnb.research.matmul_fp8_mixed(
x,
self.weight.t(),
fw_code=self.fw_code,
bw_code=self.bw_code,
bsz=self.bsz,
bsz2=self.bsz2,
)
if self.bias is not None: if self.bias is not None:
out += self.bias out += self.bias
return out return out
class LinearFP8Global(nn.Linear): class LinearFP8Global(nn.Linear):
def __init__(self, input_features, output_features, bias=True): def __init__(self, input_features, output_features, bias=True):
super().__init__(input_features, output_features, bias) super().__init__(input_features, output_features, bias)
...@@ -54,7 +62,14 @@ class LinearFP8Global(nn.Linear): ...@@ -54,7 +62,14 @@ class LinearFP8Global(nn.Linear):
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device) self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device) self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
out = bnb.matmul_fp8_global(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2) out = bnb.matmul_fp8_global(
x,
self.weight.t(),
fw_code=self.fw_code,
bw_code=self.bw_code,
bsz=self.bsz,
bsz2=self.bsz2,
)
if self.bias is not None: if self.bias is not None:
out += self.bias out += self.bias
......
...@@ -5,9 +5,10 @@ import torch ...@@ -5,9 +5,10 @@ import torch
from bitsandbytes.triton.triton_utils import is_triton_available from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available(): if not is_triton_available():
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): return None
else:
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
return None
else:
import triton import triton
import triton.language as tl import triton.language as tl
...@@ -29,7 +30,7 @@ else: ...@@ -29,7 +30,7 @@ else:
triton.Config({}, num_warps=4), triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8), triton.Config({}, num_warps=8),
], ],
key=['n_elements'] key=["n_elements"],
) )
@triton.jit @triton.jit
def _dequantize_rowwise( def _dequantize_rowwise(
...@@ -51,7 +52,6 @@ else: ...@@ -51,7 +52,6 @@ else:
output = max_val * x * inv_127 output = max_val * x * inv_127
tl.store(output_ptr + offsets, output, mask=row_mask) tl.store(output_ptr + offsets, output, mask=row_mask)
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
output = torch.empty(*x.shape, device=x.device, dtype=torch.float16) output = torch.empty(*x.shape, device=x.device, dtype=torch.float16)
...@@ -60,5 +60,5 @@ else: ...@@ -60,5 +60,5 @@ else:
assert x.is_cuda and output.is_cuda assert x.is_cuda and output.is_cuda
n_elements = output.numel() n_elements = output.numel()
grid = lambda meta: (x.shape[0],) grid = lambda meta: (x.shape[0],)
_dequantize_rowwise[grid](x, state_x, output, 1./127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) _dequantize_rowwise[grid](x, state_x, output, 1.0 / 127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
return output return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment