Commit 5a4263f4 authored by Ruff's avatar Ruff Committed by Aarni Koskela
Browse files

Reformat with ruff-format

parent 02e30ca6
...@@ -7,9 +7,7 @@ def get_platform_tag(architecture): ...@@ -7,9 +7,7 @@ def get_platform_tag(architecture):
system = platform.system() system = platform.system()
if system == "Linux": if system == "Linux":
tag = ( tag = "manylinux_2_24_x86_64" if architecture == "x86_64" else "manylinux_2_24_aarch64"
"manylinux_2_24_x86_64" if architecture == "x86_64" else "manylinux_2_24_aarch64"
)
elif system == "Darwin": elif system == "Darwin":
tag = "macosx_13_1_x86_64" if architecture == "x86_64" else "macosx_13_1_arm64" tag = "macosx_13_1_x86_64" if architecture == "x86_64" else "macosx_13_1_arm64"
elif system == "Windows": elif system == "Windows":
......
import matplotlib.gridspec as gridspec import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
cmap=plt.get_cmap('cool') cmap = plt.get_cmap("cool")
if __name__ == '__main__':
fig = plt.figure(tight_layout=True, figsize=(12,3.5)) if __name__ == "__main__":
fig = plt.figure(tight_layout=True, figsize=(12, 3.5))
gs = gridspec.GridSpec(1, 2) gs = gridspec.GridSpec(1, 2)
dims_to_consider = [1024, 1280, 1408, 1664, 2048, 4096] dims_to_consider = [1024, 1280, 1408, 1664, 2048, 4096]
...@@ -19,25 +17,28 @@ if __name__ == '__main__': ...@@ -19,25 +17,28 @@ if __name__ == '__main__':
ax = fig.add_subplot(gs[0, 0]) ax = fig.add_subplot(gs[0, 0])
# TODO: change this to what you want. # TODO: change this to what you want.
rdf = pd.read_json('speed_benchmark/info_a100_py2.jsonl', lines=True) rdf = pd.read_json("speed_benchmark/info_a100_py2.jsonl", lines=True)
df = rdf[rdf.batch_size == batch_size_for_plot1] df = rdf[rdf.batch_size == batch_size_for_plot1]
# first plot the time occupied by different operations # first plot the time occupied by different operations
for k, marker, ls, color, name in [ for k, marker, ls, color, name in [
('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (sum of parts)'), ("standard_gx+standard_gw+standard_fwd", "s", "-", "C2", "Standard fp16 (sum of parts)"),
('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (sum of parts)'), (
"x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd",
('standard_fwd', '^', '--', 'C2', 'Matmul XW (standard)'), "o",
('standard_gw', '^', '-.', 'C2', 'Matmul GW (standard)'), "-",
('standard_gx', '^', ':', 'gray', 'Matmul GX (both)'), "C4",
"SwitchBack int8 (sum of parts)",
('global_fwd', '^', '--', 'C4', 'Int8 Matmul XW (switchback)'), ),
('global_bwd', '^', '-.', 'C4', 'Int8 Matmul GW (switchback)'), ("standard_fwd", "^", "--", "C2", "Matmul XW (standard)"),
("standard_gw", "^", "-.", "C2", "Matmul GW (standard)"),
('x_quantize_rowwise', 'P', '--', 'C4', 'Quantize rowwise X (switchback)'), ("standard_gx", "^", ":", "gray", "Matmul GX (both)"),
('g_quantize_rowwise', 'P', '-.', 'C4', 'Quantize rowwise G (switchback)'), ("global_fwd", "^", "--", "C4", "Int8 Matmul XW (switchback)"),
('w_quantize_global', '.', '--', 'C4', 'Quantize global W (switchback)'), ("global_bwd", "^", "-.", "C4", "Int8 Matmul GW (switchback)"),
('w_quantize_global_transpose', '.', '-.', 'C4', 'Quantize global and\ntranspose W (switchback)'), ("x_quantize_rowwise", "P", "--", "C4", "Quantize rowwise X (switchback)"),
("g_quantize_rowwise", "P", "-.", "C4", "Quantize rowwise G (switchback)"),
("w_quantize_global", ".", "--", "C4", "Quantize global W (switchback)"),
("w_quantize_global_transpose", ".", "-.", "C4", "Quantize global and\ntranspose W (switchback)"),
]: ]:
xs = [] xs = []
ys = [] ys = []
...@@ -47,40 +48,46 @@ if __name__ == '__main__': ...@@ -47,40 +48,46 @@ if __name__ == '__main__':
df_ = df_[df_.dim_out == embed_dim * 4] df_ = df_[df_.dim_out == embed_dim * 4]
xs.append(embed_dim) xs.append(embed_dim)
y_ = 0 y_ = 0
for k_ in k.split('+'): for k_ in k.split("+"):
y_ += df_[k_].values[0] y_ += df_[k_].values[0]
df_ = df[df.dim_in == embed_dim * 4] df_ = df[df.dim_in == embed_dim * 4]
df_ = df_[df_.dim_out == embed_dim] df_ = df_[df_.dim_out == embed_dim]
for k_ in k.split('+'): for k_ in k.split("+"):
y_ += df_[k_].values[0] y_ += df_[k_].values[0]
ys.append(y_ * 0.5) ys.append(y_ * 0.5)
ax.plot(
xs,
ys,
color=color,
label=name,
marker=marker,
markersize=5 if marker == "s" else 5,
linestyle=ls,
linewidth=2 if "+" in k else 1.0,
)
ax.plot(xs, ys, color=color, label=name, marker=marker, markersize=5 if marker=='s' else 5, linestyle=ls, linewidth=2 if '+' in k else 1.) ax.set_xlabel("dim", fontsize=13)
ax.set_ylabel("time (ms)", fontsize=13)
ax.set_xlabel('dim', fontsize=13)
ax.set_ylabel('time (ms)', fontsize=13)
ax.grid() ax.grid()
ax.set_xscale('log') ax.set_xscale("log")
if logscale_plot1: if logscale_plot1:
ax.set_yscale('log') ax.set_yscale("log")
ax.tick_params(axis='x', labelsize=11) ax.tick_params(axis="x", labelsize=11)
ax.tick_params(axis='y', labelsize=11) ax.tick_params(axis="y", labelsize=11)
ax.set_xticks(dims_to_xtick) ax.set_xticks(dims_to_xtick)
ax.set_xticklabels(dims_to_xtick) ax.set_xticklabels(dims_to_xtick)
ax.set_xticks([], minor=True) ax.set_xticks([], minor=True)
leg = ax.legend(loc='upper center', bbox_to_anchor=(-0.64, 1.), ncol=1, fontsize=10) leg = ax.legend(loc="upper center", bbox_to_anchor=(-0.64, 1.0), ncol=1, fontsize=10)
leg.get_texts()[0].set_fontweight('bold') leg.get_texts()[0].set_fontweight("bold")
leg.get_texts()[1].set_fontweight('bold') leg.get_texts()[1].set_fontweight("bold")
plt.subplots_adjust(left=0.1) plt.subplots_adjust(left=0.1)
ax.set_title(' Linear layer, batch * sequence length = 32k', fontsize=10, loc='left', y=1.05, pad=-20) ax.set_title(" Linear layer, batch * sequence length = 32k", fontsize=10, loc="left", y=1.05, pad=-20)
ax = fig.add_subplot(gs[0, 1]) ax = fig.add_subplot(gs[0, 1])
...@@ -88,10 +95,15 @@ if __name__ == '__main__': ...@@ -88,10 +95,15 @@ if __name__ == '__main__':
for j, batch_size in enumerate(batch_sizes_for_plot2): for j, batch_size in enumerate(batch_sizes_for_plot2):
all_xs, all_ys = [], [] all_xs, all_ys = [], []
for k, marker, ls, color, name in [ for k, marker, ls, color, name in [
('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (total time)'), ("standard_gx+standard_gw+standard_fwd", "s", "-", "C2", "Standard fp16 (total time)"),
('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (total time)'), (
"x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd",
"o",
"-",
"C4",
"SwitchBack int8 (total time)",
),
]: ]:
xs, ys = [], [] xs, ys = [], []
df = rdf[rdf.batch_size == batch_size] df = rdf[rdf.batch_size == batch_size]
for embed_dim in dims_to_consider: for embed_dim in dims_to_consider:
...@@ -99,11 +111,11 @@ if __name__ == '__main__': ...@@ -99,11 +111,11 @@ if __name__ == '__main__':
df_ = df_[df_.dim_out == embed_dim * 4] df_ = df_[df_.dim_out == embed_dim * 4]
xs.append(embed_dim) xs.append(embed_dim)
y_ = 0 y_ = 0
for k_ in k.split('+'): for k_ in k.split("+"):
y_ += df_[k_].values[0] y_ += df_[k_].values[0]
df_ = df[df.dim_in == embed_dim * 4] df_ = df[df.dim_in == embed_dim * 4]
df_ = df_[df_.dim_out == embed_dim] df_ = df_[df_.dim_out == embed_dim]
for k_ in k.split('+'): for k_ in k.split("+"):
y_ += df_[k_].values[0] y_ += df_[k_].values[0]
ys.append(y_ * 0.5) ys.append(y_ * 0.5)
all_xs.append(xs) all_xs.append(xs)
...@@ -111,25 +123,29 @@ if __name__ == '__main__': ...@@ -111,25 +123,29 @@ if __name__ == '__main__':
color = cmap(j * 0.25) color = cmap(j * 0.25)
real_ys = [-((all_ys[1][i] - all_ys[0][i]) / all_ys[0][i]) * 100 for i in range(len(all_ys[0]))] real_ys = [-((all_ys[1][i] - all_ys[0][i]) / all_ys[0][i]) * 100 for i in range(len(all_ys[0]))]
markers = ['^', 'v', 'P', 'o'] markers = ["^", "v", "P", "o"]
ax.plot(all_xs[0], real_ys, color=color, label=f'batch * sequence length = {batch_size}', marker=markers[j], markersize=5 if marker=='s' else 5) ax.plot(
all_xs[0],
real_ys,
color=color,
label=f"batch * sequence length = {batch_size}",
marker=markers[j],
markersize=5 if marker == "s" else 5,
)
ax.legend() ax.legend()
ax.set_xlabel('dim', fontsize=13) ax.set_xlabel("dim", fontsize=13)
ax.set_xscale('log') ax.set_xscale("log")
ax.grid() ax.grid()
ax.set_ylabel(r'% speedup', fontsize=13) ax.set_ylabel(r"% speedup", fontsize=13)
ax.tick_params(axis="x", labelsize=11)
ax.tick_params(axis='x', labelsize=11) ax.tick_params(axis="y", labelsize=11)
ax.tick_params(axis='y', labelsize=11)
ax.set_xticks(dims_to_xtick) ax.set_xticks(dims_to_xtick)
ax.set_xticklabels(dims_to_xtick) ax.set_xticklabels(dims_to_xtick)
ax.set_xticks([], minor=True) ax.set_xticks([], minor=True)
ax.set_title(' Linear layer summary, varying dimensions', fontsize=10, loc='left', y=1.05, pad=-20) ax.set_title(" Linear layer summary, varying dimensions", fontsize=10, loc="left", y=1.05, pad=-20)
plt.savefig('speed_benchmark/plot_with_info.pdf', bbox_inches='tight') plt.savefig("speed_benchmark/plot_with_info.pdf", bbox_inches="tight")
...@@ -20,8 +20,8 @@ from bitsandbytes.triton.quantize_rowwise import quantize_rowwise ...@@ -20,8 +20,8 @@ from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large. # KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large.
def get_time(k, fn, info_dict):
def get_time(k, fn, info_dict):
for _ in range(repeat // 2): for _ in range(repeat // 2):
fn() fn()
...@@ -36,16 +36,15 @@ def get_time(k, fn, info_dict): ...@@ -36,16 +36,15 @@ def get_time(k, fn, info_dict):
print(f"time {k}: {ms:.3f} ms") print(f"time {k}: {ms:.3f} ms")
info_dict[k] = ms info_dict[k] = ms
if __name__ == '__main__':
if __name__ == "__main__":
torch.manual_seed(0) torch.manual_seed(0)
wm = 4 wm = 4
for dim in [1024, 1280, 1408, 1664, 2048, 4096]: for dim in [1024, 1280, 1408, 1664, 2048, 4096]:
# note "batch_size" is actually "batch_size * embed_dim", which is why it's large # note "batch_size" is actually "batch_size * embed_dim", which is why it's large
for batch_size in [256*32, 256*64, 256*128, 256*256, 256*512]: for batch_size in [256 * 32, 256 * 64, 256 * 128, 256 * 256, 256 * 512]:
# switch switches dim_in and dim_out # switch switches dim_in and dim_out
for switch in [False, True]: for switch in [False, True]:
# hparams # hparams
repeat = 64 repeat = 64
batch_size = batch_size batch_size = batch_size
...@@ -73,35 +72,86 @@ if __name__ == '__main__': ...@@ -73,35 +72,86 @@ if __name__ == '__main__':
state_w_rowwise = w.max(dim=1)[0] state_w_rowwise = w.max(dim=1)[0]
state_w_global = w.max() state_w_global = w.max()
info = {'repeat' : repeat, 'batch_size' : batch_size, 'dim_out' : dim_out, 'dim_in' : dim_in, 'wm' : wm, 'switch' : switch} info = {
"repeat": repeat,
get_time('standard_fwd', lambda : x.matmul(w.t()), info) "batch_size": batch_size,
get_time('standard_gw', lambda : g.t().matmul(x), info) "dim_out": dim_out,
get_time('standard_gx', lambda : g.matmul(w), info) "dim_in": dim_in,
get_time('rowwise_fwd', lambda : int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise, None), info) "wm": wm,
get_time('rowwise_bwd', lambda : int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise, None), info) "switch": switch,
get_time('global_fwd', lambda : int8_matmul_mixed_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info) }
get_time('global_bwd', lambda : int8_matmul_mixed_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('x_quantize_rowwise', lambda : quantize_rowwise(x), info) get_time("standard_fwd", lambda: x.matmul(w.t()), info)
get_time('g_quantize_rowwise', lambda : quantize_rowwise(g), info) get_time("standard_gw", lambda: g.t().matmul(x), info)
get_time('w_quantize_rowwise', lambda : quantize_rowwise(w), info) get_time("standard_gx", lambda: g.matmul(w), info)
get_time('w_quantize_colwise_transpose', lambda : quantize_columnwise_and_transpose(w), info) get_time(
get_time('w_quantize_global', lambda : quantize_global(w), info) "rowwise_fwd",
get_time('w_quantize_global_transpose', lambda : quantize_global_transpose(w), info) lambda: int8_matmul_rowwise_dequantize(
x_int8,
time_standard = info['standard_fwd'] + info['standard_gx'] + info['standard_gw'] w_int8.t(),
time_rowwise = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_colwise_transpose'] + info['w_quantize_rowwise'] + info['standard_gw'] + info['rowwise_fwd'] + info['rowwise_bwd'] state_x_rowwise,
time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd'] state_w_columnwise,
None,
print('TOTAL STANDARD', time_standard) ),
print('TOTAL ROWWISE', time_rowwise) info,
print('TOTAL GLOBAL', time_global) )
get_time(
print('speedup', -100*(time_global - time_standard)/time_standard) "rowwise_bwd",
lambda: int8_matmul_rowwise_dequantize(
info['time_standard'] = time_standard g_int8,
info['time_rowwise'] = time_rowwise wt_int8.t(),
info['time_global'] = time_global state_x_rowwise,
state_w_rowwise,
None,
),
info,
)
get_time(
"global_fwd",
lambda: int8_matmul_mixed_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None),
info,
)
get_time(
"global_bwd",
lambda: int8_matmul_mixed_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None),
info,
)
get_time("x_quantize_rowwise", lambda: quantize_rowwise(x), info)
get_time("g_quantize_rowwise", lambda: quantize_rowwise(g), info)
get_time("w_quantize_rowwise", lambda: quantize_rowwise(w), info)
get_time("w_quantize_colwise_transpose", lambda: quantize_columnwise_and_transpose(w), info)
get_time("w_quantize_global", lambda: quantize_global(w), info)
get_time("w_quantize_global_transpose", lambda: quantize_global_transpose(w), info)
time_standard = info["standard_fwd"] + info["standard_gx"] + info["standard_gw"]
time_rowwise = (
info["x_quantize_rowwise"]
+ info["g_quantize_rowwise"]
+ info["w_quantize_colwise_transpose"]
+ info["w_quantize_rowwise"]
+ info["standard_gw"]
+ info["rowwise_fwd"]
+ info["rowwise_bwd"]
)
time_global = (
info["x_quantize_rowwise"]
+ info["g_quantize_rowwise"]
+ info["w_quantize_global"]
+ info["w_quantize_global_transpose"]
+ info["standard_gw"]
+ info["global_fwd"]
+ info["global_bwd"]
)
print("TOTAL STANDARD", time_standard)
print("TOTAL ROWWISE", time_rowwise)
print("TOTAL GLOBAL", time_global)
print("speedup", -100 * (time_global - time_standard) / time_standard)
info["time_standard"] = time_standard
info["time_rowwise"] = time_rowwise
info["time_global"] = time_global
info_json = json.dumps(info) info_json = json.dumps(info)
......
...@@ -14,16 +14,18 @@ import bitsandbytes.functional as F ...@@ -14,16 +14,18 @@ import bitsandbytes.functional as F
def prod(iterable): def prod(iterable):
return reduce(operator.mul, iterable, 1) return reduce(operator.mul, iterable, 1)
# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov: # The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py # https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
""" """
This class pools outlier dimensions across layers. This class pools outlier dimensions across layers.
This is particularly important for small models where outlier features This is particularly important for small models where outlier features
are less systematic and occur with low frequency. are less systematic and occur with low frequency.
""" """
class GlobalOutlierPooler: class GlobalOutlierPooler:
_instance = None _instance = None
...@@ -83,6 +85,7 @@ def get_inverse_transform_indices( ...@@ -83,6 +85,7 @@ def get_inverse_transform_indices(
break # if all indices fit in i bytes, stop early break # if all indices fit in i bytes, stop early
return permuted_tile_indices return permuted_tile_indices
def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor: def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor:
""" """
Undo a tiled permutation such as turing or ampere layout Undo a tiled permutation such as turing or ampere layout
...@@ -159,20 +162,12 @@ class MatMul8bit(torch.autograd.Function): ...@@ -159,20 +162,12 @@ class MatMul8bit(torch.autograd.Function):
) )
if not A.is_contiguous(): if not A.is_contiguous():
A = A.contiguous() A = A.contiguous()
qA, S2 = F.vectorwise_quant( qA, S2 = F.vectorwise_quant(A.view(-1, A.shape[2]), dim=0, quant_type=quant_type)
A.view(-1, A.shape[2]), dim=0, quant_type=quant_type
)
igrad_B = F.igemm(qA.t(), qgrad_output) igrad_B = F.igemm(qA.t(), qgrad_output)
grad_B = F.vectorwise_mm_dequant( grad_B = F.vectorwise_mm_dequant(igrad_B, S2.t(), S1, grad_output.dtype, quant_type)
igrad_B, S2.t(), S1, grad_output.dtype, quant_type
)
else: else:
qgrad_output, S1 = F.vectorwise_quant( qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
grad_output, dim=dims, quant_type=quant_type qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type)
)
qA, S2 = F.vectorwise_quant(
A, dim=dims, quant_type=quant_type
)
igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output) igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
grad_B = F.vectorwise_mm_dequant( grad_B = F.vectorwise_mm_dequant(
igrad_B, igrad_B,
...@@ -201,9 +196,7 @@ class MatMul8bit(torch.autograd.Function): ...@@ -201,9 +196,7 @@ class MatMul8bit(torch.autograd.Function):
with torch.no_grad(): with torch.no_grad():
grad_A = torch.matmul(grad_output, B.permute(permute_dim)) grad_A = torch.matmul(grad_output, B.permute(permute_dim))
else: else:
qgrad_output, S1 = F.vectorwise_quant( qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
grad_output, dim=dims, quant_type=quant_type
)
qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type) qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim)) igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
grad_A = F.vectorwise_mm_dequant( grad_A = F.vectorwise_mm_dequant(
...@@ -227,7 +220,7 @@ def supports_igemmlt(device: torch.device) -> bool: ...@@ -227,7 +220,7 @@ def supports_igemmlt(device: torch.device) -> bool:
if torch.cuda.get_device_capability(device=device) < (7, 5): if torch.cuda.get_device_capability(device=device) < (7, 5):
return False return False
device_name = torch.cuda.get_device_name(device=device) device_name = torch.cuda.get_device_name(device=device)
nvidia16_models = ('GTX 1630', 'GTX 1650', 'GTX 1660') # https://en.wikipedia.org/wiki/GeForce_16_series nvidia16_models = ("GTX 1630", "GTX 1650", "GTX 1660") # https://en.wikipedia.org/wiki/GeForce_16_series
if any(model_name in device_name for model_name in nvidia16_models): if any(model_name in device_name for model_name in nvidia16_models):
return False # these devices are technically cuda 7.5-capable, but they lack tensor cores return False # these devices are technically cuda 7.5-capable, but they lack tensor cores
return True return True
...@@ -246,6 +239,7 @@ def get_tile_inds(format, device): ...@@ -246,6 +239,7 @@ def get_tile_inds(format, device):
with torch.no_grad(): with torch.no_grad():
return get_inverse_transform_indices(transform, _get_tile_size(format)).to(device) return get_inverse_transform_indices(transform, _get_tile_size(format)).to(device)
@dataclass @dataclass
class MatmulLtState: class MatmulLtState:
_tile_indices: Optional[torch.Tensor] = None _tile_indices: Optional[torch.Tensor] = None
...@@ -510,7 +504,6 @@ class MatMul4Bit(torch.autograd.Function): ...@@ -510,7 +504,6 @@ class MatMul4Bit(torch.autograd.Function):
else: else:
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device) return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
# 1. Dequantize # 1. Dequantize
# 2. MatmulnN # 2. MatmulnN
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias) output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
...@@ -532,7 +525,7 @@ class MatMul4Bit(torch.autograd.Function): ...@@ -532,7 +525,7 @@ class MatMul4Bit(torch.autograd.Function):
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, _, _, req_gradBias, _= ctx.needs_input_grad req_gradA, _, _, req_gradBias, _ = ctx.needs_input_grad
A, B = ctx.tensors A, B = ctx.tensors
grad_A, grad_B, grad_bias = None, None, None grad_A, grad_B, grad_bias = None, None, None
...@@ -542,8 +535,9 @@ class MatMul4Bit(torch.autograd.Function): ...@@ -542,8 +535,9 @@ class MatMul4Bit(torch.autograd.Function):
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias) grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)
# not supported by PyTorch. TODO: create work-around # not supported by PyTorch. TODO: create work-around
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A) # if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t()) if req_gradA:
grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t())
return grad_A, grad_B, None, grad_bias, None return grad_A, grad_B, None, grad_bias, None
...@@ -554,7 +548,7 @@ def matmul( ...@@ -554,7 +548,7 @@ def matmul(
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
state: Optional[MatmulLtState] = None, state: Optional[MatmulLtState] = None,
threshold=0.0, threshold=0.0,
bias=None bias=None,
): ):
state = state or MatmulLtState() state = state or MatmulLtState()
if threshold > 0.0: if threshold > 0.0:
...@@ -562,11 +556,19 @@ def matmul( ...@@ -562,11 +556,19 @@ def matmul(
return MatMul8bitLt.apply(A, B, out, bias, state) return MatMul8bitLt.apply(A, B, out, bias, state)
def matmul_4bit(A: torch.Tensor, B: torch.Tensor, quant_state: F.QuantState, out: Optional[torch.Tensor] = None, bias=None): def matmul_4bit(
A: torch.Tensor,
B: torch.Tensor,
quant_state: F.QuantState,
out: Optional[torch.Tensor] = None,
bias=None,
):
assert quant_state is not None assert quant_state is not None
if A.numel() == A.shape[-1] and A.requires_grad == False: if A.numel() == A.shape[-1] and A.requires_grad == False:
if A.shape[-1] % quant_state.blocksize != 0: if A.shape[-1] % quant_state.blocksize != 0:
warn(f'Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}') warn(
f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",
)
return MatMul4Bit.apply(A, B, out, bias, quant_state) return MatMul4Bit.apply(A, B, out, bias, quant_state)
else: else:
out = F.gemv_4bit(A, B.t(), out, state=quant_state) out = F.gemv_4bit(A, B.t(), out, state=quant_state)
......
...@@ -56,7 +56,7 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: ...@@ -56,7 +56,7 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
"This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n" "This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n"
"If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n" "If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n"
"If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n" "If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n"
"For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<path_to_cuda_dir/lib64\n" "For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<path_to_cuda_dir/lib64\n",
) )
return PACKAGE_DIR / library_name return PACKAGE_DIR / library_name
...@@ -100,7 +100,7 @@ def get_native_library() -> BNBNativeLibrary: ...@@ -100,7 +100,7 @@ def get_native_library() -> BNBNativeLibrary:
logger.warning( logger.warning(
"The installed version of bitsandbytes was compiled without GPU support. " "The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable." "8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.",
) )
return BNBNativeLibrary(dll) return BNBNativeLibrary(dll)
...@@ -120,5 +120,5 @@ python -m bitsandbytes ...@@ -120,5 +120,5 @@ python -m bitsandbytes
Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them
to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes
and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues
""" """,
) )
...@@ -120,7 +120,7 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: ...@@ -120,7 +120,7 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
The CUDA version for the compile might depend on your conda install, if using conda. The CUDA version for the compile might depend on your conda install, if using conda.
Inspect CUDA version via `conda list | grep cuda`. Inspect CUDA version via `conda list | grep cuda`.
""" """,
) )
cuda_major, cuda_minor = cuda_specs.cuda_version_tuple cuda_major, cuda_minor = cuda_specs.cuda_version_tuple
...@@ -129,7 +129,7 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: ...@@ -129,7 +129,7 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
""" """
WARNING: CUDA versions lower than 11 are currently not supported for LLM.int8(). WARNING: CUDA versions lower than 11 are currently not supported for LLM.int8().
You will be only to use 8-bit optimizers and quantization routines! You will be only to use 8-bit optimizers and quantization routines!
""" """,
) )
print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}") print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}")
...@@ -170,7 +170,7 @@ def print_cuda_runtime_diagnostics() -> None: ...@@ -170,7 +170,7 @@ def print_cuda_runtime_diagnostics() -> None:
In the case of a manual override, make sure you set LD_LIBRARY_PATH, e.g. In the case of a manual override, make sure you set LD_LIBRARY_PATH, e.g.
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2, export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2,
""" """,
) )
for pth in cudart_paths: for pth in cudart_paths:
print(f"* Found CUDA runtime at: {pth}") print(f"* Found CUDA runtime at: {pth}")
...@@ -25,7 +25,7 @@ def sanity_check(): ...@@ -25,7 +25,7 @@ def sanity_check():
See the documentation for more details if needed. See the documentation for more details if needed.
Trying a simple check anyway, but this will likely fail... Trying a simple check anyway, but this will likely fail...
""" """,
) )
from bitsandbytes.optim import Adam from bitsandbytes.optim import Adam
...@@ -71,7 +71,7 @@ def main(): ...@@ -71,7 +71,7 @@ def main():
print( print(
f"WARNING: {__package__} is currently running as CPU-only!\n" f"WARNING: {__package__} is currently running as CPU-only!\n"
"Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n"
f"If you think that this is so erroneously,\nplease report an issue!" f"If you think that this is so erroneously,\nplease report an issue!",
) )
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
...@@ -80,6 +80,6 @@ def main(): ...@@ -80,6 +80,6 @@ def main():
Above we output some debug information. Above we output some debug information.
Please provide this info when creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose Please provide this info when creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose
WARNING: Please be sure to sanitize sensitive info from the output before posting it. WARNING: Please be sure to sanitize sensitive info from the output before posting it.
""" """,
) )
sys.exit(1) sys.exit(1)
This diff is collapsed.
...@@ -44,6 +44,7 @@ class StableEmbedding(torch.nn.Embedding): ...@@ -44,6 +44,7 @@ class StableEmbedding(torch.nn.Embedding):
reset_parameters(): Reset embedding parameters using Xavier uniform initialization. reset_parameters(): Reset embedding parameters using Xavier uniform initialization.
forward(input: Tensor) -> Tensor: Forward pass through the stable embedding layer. forward(input: Tensor) -> Tensor: Forward pass through the stable embedding layer.
""" """
def __init__( def __init__(
self, self,
num_embeddings: int, num_embeddings: int,
...@@ -89,9 +90,7 @@ class StableEmbedding(torch.nn.Embedding): ...@@ -89,9 +90,7 @@ class StableEmbedding(torch.nn.Embedding):
dtype, dtype,
) )
self.norm = torch.nn.LayerNorm(embedding_dim, device=device) self.norm = torch.nn.LayerNorm(embedding_dim, device=device)
GlobalOptimManager.get_instance().register_module_override( GlobalOptimManager.get_instance().register_module_override(self, "weight", {"optim_bits": 32})
self, "weight", {"optim_bits": 32}
)
def reset_parameters(self) -> None: def reset_parameters(self) -> None:
torch.nn.init.xavier_uniform_(self.weight) torch.nn.init.xavier_uniform_(self.weight)
...@@ -130,6 +129,7 @@ class Embedding(torch.nn.Embedding): ...@@ -130,6 +129,7 @@ class Embedding(torch.nn.Embedding):
""" """
Embedding class to store and retrieve word embeddings from their indices. Embedding class to store and retrieve word embeddings from their indices.
""" """
def __init__( def __init__(
self, self,
num_embeddings: int, num_embeddings: int,
...@@ -170,11 +170,9 @@ class Embedding(torch.nn.Embedding): ...@@ -170,11 +170,9 @@ class Embedding(torch.nn.Embedding):
scale_grad_by_freq, scale_grad_by_freq,
sparse, sparse,
_weight, _weight,
device=device device=device,
)
GlobalOptimManager.get_instance().register_module_override(
self, "weight", {"optim_bits": 32}
) )
GlobalOptimManager.get_instance().register_module_override(self, "weight", {"optim_bits": 32})
def reset_parameters(self) -> None: def reset_parameters(self) -> None:
torch.nn.init.xavier_uniform_(self.weight) torch.nn.init.xavier_uniform_(self.weight)
...@@ -214,10 +212,10 @@ class Params4bit(torch.nn.Parameter): ...@@ -214,10 +212,10 @@ class Params4bit(torch.nn.Parameter):
quant_state: Optional[QuantState] = None, quant_state: Optional[QuantState] = None,
blocksize: int = 64, blocksize: int = 64,
compress_statistics: bool = True, compress_statistics: bool = True,
quant_type: str = 'fp4', quant_type: str = "fp4",
quant_storage: torch.dtype = torch.uint8, quant_storage: torch.dtype = torch.uint8,
module: Optional["Linear4bit"] = None, module: Optional["Linear4bit"] = None,
bnb_quantized: bool = False bnb_quantized: bool = False,
) -> "Params4bit": ) -> "Params4bit":
if data is None: if data is None:
data = torch.empty(0) data = torch.empty(0)
...@@ -250,7 +248,7 @@ class Params4bit(torch.nn.Parameter): ...@@ -250,7 +248,7 @@ class Params4bit(torch.nn.Parameter):
self.bnb_quantized = state["bnb_quantized"] self.bnb_quantized = state["bnb_quantized"]
self.module = state["module"] self.module = state["module"]
def __deepcopy__(self,memo): def __deepcopy__(self, memo):
new_instance = type(self).__new__(type(self)) new_instance = type(self).__new__(type(self))
state = self.__getstate__() state = self.__getstate__()
new_instance.__setstate__(state) new_instance.__setstate__(state)
...@@ -265,7 +263,14 @@ class Params4bit(torch.nn.Parameter): ...@@ -265,7 +263,14 @@ class Params4bit(torch.nn.Parameter):
return new_instance return new_instance
@classmethod @classmethod
def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], requires_grad: bool = False, device='cuda', **kwargs) -> "Params4bit": def from_prequantized(
cls,
data: torch.Tensor,
quantized_stats: Dict[str, Any],
requires_grad: bool = False,
device="cuda",
**kwargs,
) -> "Params4bit":
self = torch.Tensor._make_subclass(cls, data.to(device)) self = torch.Tensor._make_subclass(cls, data.to(device))
self.requires_grad = requires_grad self.requires_grad = requires_grad
self.quant_state = QuantState.from_dict(qs_dict=quantized_stats, device=device) self.quant_state = QuantState.from_dict(qs_dict=quantized_stats, device=device)
...@@ -292,33 +297,39 @@ class Params4bit(torch.nn.Parameter): ...@@ -292,33 +297,39 @@ class Params4bit(torch.nn.Parameter):
return self return self
def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False): def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
return self.to(device='cuda' if device is None else device, non_blocking=non_blocking) return self.to(device="cuda" if device is None else device, non_blocking=non_blocking)
@overload @overload
def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ...,) -> T: def to(
... self: T,
device: Optional[Union[int, device]] = ...,
dtype: Optional[Union[dtype, str]] = ...,
non_blocking: bool = ...,
) -> T: ...
@overload @overload
def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: ...
...
@overload @overload
def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
...
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if (device is not None and device.type == "cuda" and not self.bnb_quantized): if device is not None and device.type == "cuda" and not self.bnb_quantized:
return self._quantize(device) return self._quantize(device)
else: else:
if self.quant_state is not None: if self.quant_state is not None:
self.quant_state.to(device) self.quant_state.to(device)
new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking), new_param = Params4bit(
requires_grad=self.requires_grad, quant_state=self.quant_state, super().to(device=device, dtype=dtype, non_blocking=non_blocking),
blocksize=self.blocksize, compress_statistics=self.compress_statistics, requires_grad=self.requires_grad,
quant_type=self.quant_type) quant_state=self.quant_state,
blocksize=self.blocksize,
compress_statistics=self.compress_statistics,
quant_type=self.quant_type,
)
return new_param return new_param
...@@ -355,7 +366,18 @@ class Linear4bit(nn.Linear): ...@@ -355,7 +366,18 @@ class Linear4bit(nn.Linear):
quantized_model = quantized_model.to(0) # Quantization happens here quantized_model = quantized_model.to(0) # Quantization happens here
``` ```
""" """
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', quant_storage=torch.uint8, device=None):
def __init__(
self,
input_features,
output_features,
bias=True,
compute_dtype=None,
compress_statistics=True,
quant_type="fp4",
quant_storage=torch.uint8,
device=None,
):
""" """
Initialize Linear4bit class. Initialize Linear4bit class.
...@@ -368,7 +390,14 @@ class Linear4bit(nn.Linear): ...@@ -368,7 +390,14 @@ class Linear4bit(nn.Linear):
Whether the linear class uses the bias term as well. Whether the linear class uses the bias term as well.
""" """
super().__init__(input_features, output_features, bias, device) super().__init__(input_features, output_features, bias, device)
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type, quant_storage=quant_storage, module=self) self.weight = Params4bit(
self.weight.data,
requires_grad=False,
compress_statistics=compress_statistics,
quant_type=quant_type,
quant_storage=quant_storage,
module=self,
)
# self.persistent_buffers = [] # TODO consider as way to save quant state # self.persistent_buffers = [] # TODO consider as way to save quant state
self.compute_dtype = compute_dtype self.compute_dtype = compute_dtype
self.compute_type_is_set = False self.compute_type_is_set = False
...@@ -385,11 +414,15 @@ class Linear4bit(nn.Linear): ...@@ -385,11 +414,15 @@ class Linear4bit(nn.Linear):
if self.compute_dtype == torch.float32 and (x.numel() == x.shape[-1]): if self.compute_dtype == torch.float32 and (x.numel() == x.shape[-1]):
# single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast # single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast
# warn the user about this # warn the user about this
warnings.warn('Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.') warnings.warn(
warnings.filterwarnings('ignore', message='.*inference.') "Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.",
)
warnings.filterwarnings("ignore", message=".*inference.")
if self.compute_dtype == torch.float32 and (x.numel() != x.shape[-1]): if self.compute_dtype == torch.float32 and (x.numel() != x.shape[-1]):
warnings.warn('Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.') warnings.warn(
warnings.filterwarnings('ignore', message='.*inference or training') "Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.",
)
warnings.filterwarnings("ignore", message=".*inference or training")
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_state_dict(self, destination, prefix, keep_vars):
""" """
...@@ -407,8 +440,8 @@ class Linear4bit(nn.Linear): ...@@ -407,8 +440,8 @@ class Linear4bit(nn.Linear):
if self.bias is not None and self.bias.dtype != x.dtype: if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype) self.bias.data = self.bias.data.to(x.dtype)
if getattr(self.weight, 'quant_state', None) is None: if getattr(self.weight, "quant_state", None) is None:
if getattr(self, 'quant_state', None) is not None: if getattr(self, "quant_state", None) is not None:
# the quant state got lost when the parameter got converted. This happens for example for fsdp # the quant state got lost when the parameter got converted. This happens for example for fsdp
# since we registered the module, we can recover the state here # since we registered the module, we can recover the state here
assert self.weight.shape[1] == 1 assert self.weight.shape[1] == 1
...@@ -416,7 +449,9 @@ class Linear4bit(nn.Linear): ...@@ -416,7 +449,9 @@ class Linear4bit(nn.Linear):
self.weight = Params4bit(self.weight, quant_storage=self.quant_storage) self.weight = Params4bit(self.weight, quant_storage=self.quant_storage)
self.weight.quant_state = self.quant_state self.weight.quant_state = self.quant_state
else: else:
print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') print(
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.",
)
if not self.compute_type_is_set: if not self.compute_type_is_set:
self.set_compute_type(x) self.set_compute_type(x)
self.compute_type_is_set = True self.compute_type_is_set = True
...@@ -437,7 +472,17 @@ class LinearFP4(Linear4bit): ...@@ -437,7 +472,17 @@ class LinearFP4(Linear4bit):
""" """
Implements the FP4 data type. Implements the FP4 data type.
""" """
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None):
def __init__(
self,
input_features,
output_features,
bias=True,
compute_dtype=None,
compress_statistics=True,
quant_storage=torch.uint8,
device=None,
):
""" """
Args: Args:
input_features (`str`): input_features (`str`):
...@@ -447,11 +492,20 @@ class LinearFP4(Linear4bit): ...@@ -447,11 +492,20 @@ class LinearFP4(Linear4bit):
bias (`bool`, defaults to `True`): bias (`bool`, defaults to `True`):
Whether the linear class uses the bias term as well. Whether the linear class uses the bias term as well.
""" """
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', quant_storage, device) super().__init__(
input_features,
output_features,
bias,
compute_dtype,
compress_statistics,
"fp4",
quant_storage,
device,
)
class LinearNF4(Linear4bit): class LinearNF4(Linear4bit):
''' Implements the NF4 data type. """Implements the NF4 data type.
Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that
is normalized into the range [-1, 1]. is normalized into the range [-1, 1].
...@@ -460,8 +514,18 @@ class LinearNF4(Linear4bit): ...@@ -460,8 +514,18 @@ class LinearNF4(Linear4bit):
Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236. the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
''' """
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None):
def __init__(
self,
input_features,
output_features,
bias=True,
compute_dtype=None,
compress_statistics=True,
quant_storage=torch.uint8,
device=None,
):
""" """
Args: Args:
input_features (`str`): input_features (`str`):
...@@ -471,7 +535,16 @@ class LinearNF4(Linear4bit): ...@@ -471,7 +535,16 @@ class LinearNF4(Linear4bit):
bias (`bool`, defaults to `True`): bias (`bool`, defaults to `True`):
Whether the linear class uses the bias term as well. Whether the linear class uses the bias term as well.
""" """
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', quant_storage, device) super().__init__(
input_features,
output_features,
bias,
compute_dtype,
compress_statistics,
"nf4",
quant_storage,
device,
)
class Int8Params(torch.nn.Parameter): class Int8Params(torch.nn.Parameter):
...@@ -514,33 +587,22 @@ class Int8Params(torch.nn.Parameter): ...@@ -514,33 +587,22 @@ class Int8Params(torch.nn.Parameter):
device: Optional[Union[int, device]] = ..., device: Optional[Union[int, device]] = ...,
dtype: Optional[Union[dtype, str]] = ..., dtype: Optional[Union[dtype, str]] = ...,
non_blocking: bool = ..., non_blocking: bool = ...,
) -> T: ) -> T: ...
...
@overload @overload
def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: ...
...
@overload @overload
def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
...
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
*args, **kwargs
)
if ( if device is not None and device.type == "cuda" and self.data.device.type == "cpu":
device is not None
and device.type == "cuda"
and self.data.device.type == "cpu"
):
return self.cuda(device) return self.cuda(device)
else: else:
new_param = Int8Params( new_param = Int8Params(
super().to( super().to(device=device, dtype=dtype, non_blocking=non_blocking),
device=device, dtype=dtype, non_blocking=non_blocking
),
requires_grad=self.requires_grad, requires_grad=self.requires_grad,
has_fp16_weights=self.has_fp16_weights, has_fp16_weights=self.has_fp16_weights,
) )
...@@ -593,8 +655,18 @@ class Linear8bitLt(nn.Linear): ...@@ -593,8 +655,18 @@ class Linear8bitLt(nn.Linear):
int8_model = int8_model.to(0) # Quantization happens here int8_model = int8_model.to(0) # Quantization happens here
``` ```
""" """
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
memory_efficient_backward=False, threshold=0.0, index=None, device=None): def __init__(
self,
input_features,
output_features,
bias=True,
has_fp16_weights=True,
memory_efficient_backward=False,
threshold=0.0,
index=None,
device=None,
):
""" """
Initialize Linear8bitLt class. Initialize Linear8bitLt class.
...@@ -647,19 +719,36 @@ class Linear8bitLt(nn.Linear): ...@@ -647,19 +719,36 @@ class Linear8bitLt(nn.Linear):
destination[key_name] = param_from_state if keep_vars else param_from_state.detach() destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
destination[format_name] = self.state.formatB destination[format_name] = self.state.formatB
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, def _load_from_state_dict(
missing_keys, unexpected_keys, error_msgs): self,
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, state_dict,
error_msgs) prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
unexpected_copy = list(unexpected_keys) unexpected_copy = list(unexpected_keys)
for key in unexpected_copy: for key in unexpected_copy:
input_name = key[len(prefix):] input_name = key[len(prefix) :]
if input_name == "SCB": if input_name == "SCB":
if self.weight.SCB is None: if self.weight.SCB is None:
# buffers not yet initialized, can't access them directly without quantizing first # buffers not yet initialized, can't access them directly without quantizing first
raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is " raise RuntimeError(
"not supported. Please call module.cuda() before module.load_state_dict()") "Loading a quantized checkpoint into non-quantized Linear8bitLt is "
"not supported. Please call module.cuda() before module.load_state_dict()",
)
input_param = state_dict[key] input_param = state_dict[key]
self.weight.SCB.copy_(input_param) self.weight.SCB.copy_(input_param)
...@@ -702,18 +791,18 @@ class OutlierAwareLinear(nn.Linear): ...@@ -702,18 +791,18 @@ class OutlierAwareLinear(nn.Linear):
self.is_quantized = False self.is_quantized = False
def forward_with_outliers(self, x, outlier_idx): def forward_with_outliers(self, x, outlier_idx):
raise NotImplementedError('Please override the `forward_with_outliers(self, x, outlier_idx)` function') raise NotImplementedError("Please override the `forward_with_outliers(self, x, outlier_idx)` function")
def quantize_weight(self, w, outlier_idx): def quantize_weight(self, w, outlier_idx):
raise NotImplementedError('Please override the `quantize_weights(self, w, outlier_idx)` function') raise NotImplementedError("Please override the `quantize_weights(self, w, outlier_idx)` function")
def forward(self, x): def forward(self, x):
if self.outlier_dim is None: if self.outlier_dim is None:
tracer = OutlierTracer.get_instance() tracer = OutlierTracer.get_instance()
if not tracer.is_initialized(): if not tracer.is_initialized():
print('Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer') print("Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer")
outlier_idx = tracer.get_outliers(self.weight) outlier_idx = tracer.get_outliers(self.weight)
#print(outlier_idx, tracer.get_hvalue(self.weight)) # print(outlier_idx, tracer.get_hvalue(self.weight))
self.outlier_dim = outlier_idx self.outlier_dim = outlier_idx
if not self.is_quantized: if not self.is_quantized:
...@@ -721,6 +810,7 @@ class OutlierAwareLinear(nn.Linear): ...@@ -721,6 +810,7 @@ class OutlierAwareLinear(nn.Linear):
self.weight.data.copy_(w) self.weight.data.copy_(w)
self.is_quantized = True self.is_quantized = True
class SwitchBackLinearBnb(nn.Linear): class SwitchBackLinearBnb(nn.Linear):
def __init__( def __init__(
self, self,
...@@ -731,11 +821,9 @@ class SwitchBackLinearBnb(nn.Linear): ...@@ -731,11 +821,9 @@ class SwitchBackLinearBnb(nn.Linear):
memory_efficient_backward=False, memory_efficient_backward=False,
threshold=0.0, threshold=0.0,
index=None, index=None,
device=None device=None,
): ):
super().__init__( super().__init__(input_features, output_features, bias, device)
input_features, output_features, bias, device
)
self.state = bnb.MatmulLtState() self.state = bnb.MatmulLtState()
self.index = index self.index = index
...@@ -745,9 +833,7 @@ class SwitchBackLinearBnb(nn.Linear): ...@@ -745,9 +833,7 @@ class SwitchBackLinearBnb(nn.Linear):
if threshold > 0.0 and not has_fp16_weights: if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True self.state.use_pool = True
self.weight = Int8Params( self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights
)
def init_8bit_state(self): def init_8bit_state(self):
self.state.CB = self.weight.CB self.state.CB = self.weight.CB
......
...@@ -22,7 +22,6 @@ from bitsandbytes.triton.triton_utils import is_triton_available ...@@ -22,7 +22,6 @@ from bitsandbytes.triton.triton_utils import is_triton_available
class _switchback_global(torch.autograd.Function): class _switchback_global(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, X_3D, W, bias): def forward(ctx, X_3D, W, bias):
# reshape input to [N * L, D] # reshape input to [N * L, D]
...@@ -37,9 +36,7 @@ class _switchback_global(torch.autograd.Function): ...@@ -37,9 +36,7 @@ class _switchback_global(torch.autograd.Function):
# matmult, fused dequant and add bias # matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized # call "mixed" because we are mixing rowwise quantized and global quantized
return int8_matmul_mixed_dequantize( return int8_matmul_mixed_dequantize(X_int8, W_int8.t(), state_X, state_W, bias).view(*X_3D.size()[:-1], -1)
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D.size()[:-1], -1)
@staticmethod @staticmethod
def backward(ctx, G_3D): def backward(ctx, G_3D):
...@@ -56,7 +53,8 @@ class _switchback_global(torch.autograd.Function): ...@@ -56,7 +53,8 @@ class _switchback_global(torch.autograd.Function):
G_int8, state_G = quantize_rowwise(G) G_int8, state_G = quantize_rowwise(G)
W_int8, state_W = quantize_global_transpose(W) W_int8, state_W = quantize_global_transpose(W)
grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view( grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D.size()[:-1], -1 *G_3D.size()[:-1],
-1,
) )
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
# backward pass uses standard weight grad # backward pass uses standard weight grad
...@@ -66,8 +64,8 @@ class _switchback_global(torch.autograd.Function): ...@@ -66,8 +64,8 @@ class _switchback_global(torch.autograd.Function):
return grad_X, grad_W, grad_bias return grad_X, grad_W, grad_bias
class _switchback_vectorrize(torch.autograd.Function):
class _switchback_vectorrize(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, X_3D, W, bias): def forward(ctx, X_3D, W, bias):
# reshape input to [N * L, D] # reshape input to [N * L, D]
...@@ -81,9 +79,7 @@ class _switchback_vectorrize(torch.autograd.Function): ...@@ -81,9 +79,7 @@ class _switchback_vectorrize(torch.autograd.Function):
# matmult, fused dequant and add bias # matmult, fused dequant and add bias
# call kernel which expects rowwise quantized X and W # call kernel which expects rowwise quantized X and W
return int8_matmul_rowwise_dequantize( return int8_matmul_rowwise_dequantize(X_int8, W_int8.t(), state_X, state_W, bias).view(*X_3D.size()[:-1], -1)
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D.size()[:-1], -1)
@staticmethod @staticmethod
def backward(ctx, G_3D): def backward(ctx, G_3D):
...@@ -99,7 +95,8 @@ class _switchback_vectorrize(torch.autograd.Function): ...@@ -99,7 +95,8 @@ class _switchback_vectorrize(torch.autograd.Function):
G_int8, state_G = quantize_rowwise(G) G_int8, state_G = quantize_rowwise(G)
W_int8, state_W = quantize_columnwise_and_transpose(W) W_int8, state_W = quantize_columnwise_and_transpose(W)
grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view( grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D.size()[:-1], -1 *G_3D.size()[:-1],
-1,
) )
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
# backward pass uses standard weight grad # backward pass uses standard weight grad
...@@ -109,8 +106,8 @@ class _switchback_vectorrize(torch.autograd.Function): ...@@ -109,8 +106,8 @@ class _switchback_vectorrize(torch.autograd.Function):
return grad_X, grad_W, grad_bias return grad_X, grad_W, grad_bias
class _switchback_global_mem_efficient(torch.autograd.Function):
class _switchback_global_mem_efficient(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, X_3D, W, bias): def forward(ctx, X_3D, W, bias):
# reshape input to [N * L, D] # reshape input to [N * L, D]
...@@ -127,9 +124,7 @@ class _switchback_global_mem_efficient(torch.autograd.Function): ...@@ -127,9 +124,7 @@ class _switchback_global_mem_efficient(torch.autograd.Function):
# matmult, fused dequant and add bias # matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized # call "mixed" because we are mixing rowwise quantized and global quantized
return int8_matmul_mixed_dequantize( return int8_matmul_mixed_dequantize(X_int8, W_int8.t(), state_X, state_W, bias).view(*X_3D_sz[:-1], -1)
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D_sz[:-1], -1)
@staticmethod @staticmethod
def backward(ctx, G_3D): def backward(ctx, G_3D):
...@@ -151,12 +146,11 @@ class _switchback_global_mem_efficient(torch.autograd.Function): ...@@ -151,12 +146,11 @@ class _switchback_global_mem_efficient(torch.autograd.Function):
G_int8, state_G = quantize_rowwise(G) G_int8, state_G = quantize_rowwise(G)
del G del G
W_int8 = W_int8.t().contiguous() W_int8 = W_int8.t().contiguous()
grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view( grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(*G_3D_sz[:-1], -1)
*G_3D_sz[:-1], -1
)
return grad_X, grad_W, grad_bias return grad_X, grad_W, grad_bias
class SwitchBackLinear(nn.Linear): class SwitchBackLinear(nn.Linear):
def __init__( def __init__(
self, self,
...@@ -166,20 +160,20 @@ class SwitchBackLinear(nn.Linear): ...@@ -166,20 +160,20 @@ class SwitchBackLinear(nn.Linear):
device=None, device=None,
dtype=None, dtype=None,
vector_wise_quantization: bool = False, vector_wise_quantization: bool = False,
mem_efficient : bool = False, mem_efficient: bool = False,
): ):
super().__init__(in_features, out_features, bias, device, dtype) super().__init__(in_features, out_features, bias, device, dtype)
if not is_triton_available(): if not is_triton_available():
raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear. raise ImportError("""Could not import triton. Please install triton to use SwitchBackLinear.
Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''') Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower""")
# By default, we use the global quantization. # By default, we use the global quantization.
self.vector_wise_quantization = vector_wise_quantization self.vector_wise_quantization = vector_wise_quantization
if self.vector_wise_quantization: if self.vector_wise_quantization:
self._fn = _switchback_vectorrize self._fn = _switchback_vectorrize
if mem_efficient: if mem_efficient:
print('mem efficient is not supported for vector-wise quantization.') print("mem efficient is not supported for vector-wise quantization.")
exit(1) exit(1)
else: else:
if mem_efficient: if mem_efficient:
...@@ -195,7 +189,7 @@ class SwitchBackLinear(nn.Linear): ...@@ -195,7 +189,7 @@ class SwitchBackLinear(nn.Linear):
# if hasattr(m, "prepare_for_eval"): # if hasattr(m, "prepare_for_eval"):
# m.prepare_for_eval() # m.prepare_for_eval()
# model.apply(cond_prepare) # model.apply(cond_prepare)
print('=> preparing for eval.') print("=> preparing for eval.")
if self.vector_wise_quantization: if self.vector_wise_quantization:
W_int8, state_W = quantize_rowwise(self.weight) W_int8, state_W = quantize_rowwise(self.weight)
else: else:
...@@ -219,18 +213,22 @@ class SwitchBackLinear(nn.Linear): ...@@ -219,18 +213,22 @@ class SwitchBackLinear(nn.Linear):
X_int8, state_X = quantize_rowwise(X) X_int8, state_X = quantize_rowwise(X)
if self.vector_wise_quantization: if self.vector_wise_quantization:
return int8_matmul_rowwise_dequantize( return int8_matmul_rowwise_dequantize(X_int8, self.W_int8.t(), state_X, self.state_W, self.bias).view(
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias *x.size()[:-1],
).view(*x.size()[:-1], -1) -1,
)
else: else:
return int8_matmul_mixed_dequantize( return int8_matmul_mixed_dequantize(X_int8, self.W_int8.t(), state_X, self.state_W, self.bias).view(
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias *x.size()[:-1],
).view(*x.size()[:-1], -1) -1,
)
SwitchBackLinearGlobal = partial(SwitchBackLinear, vector_wise_quantization=False) SwitchBackLinearGlobal = partial(SwitchBackLinear, vector_wise_quantization=False)
SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vector_wise_quantization=False, mem_efficient=True) SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vector_wise_quantization=False, mem_efficient=True)
SwitchBackLinearVectorwise = partial(SwitchBackLinear, vector_wise_quantization=True) SwitchBackLinearVectorwise = partial(SwitchBackLinear, vector_wise_quantization=True)
# This is just the standard linear function. # This is just the standard linear function.
class StandardLinearFunction(torch.autograd.Function): class StandardLinearFunction(torch.autograd.Function):
@staticmethod @staticmethod
...@@ -260,7 +258,7 @@ class StandardLinearFunction(torch.autograd.Function): ...@@ -260,7 +258,7 @@ class StandardLinearFunction(torch.autograd.Function):
return grad_input, grad_weight, grad_bias return grad_input, grad_weight, grad_bias
class StandardLinear(nn.Linear):
class StandardLinear(nn.Linear):
def forward(self, x): def forward(self, x):
return StandardLinearFunction.apply(x, self.weight, self.bias) return StandardLinearFunction.apply(x, self.weight, self.bias)
...@@ -50,9 +50,7 @@ class Adagrad(Optimizer1State): ...@@ -50,9 +50,7 @@ class Adagrad(Optimizer1State):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}") raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError( raise ValueError(f"Invalid weight_decay value: {weight_decay}")
f"Invalid weight_decay value: {weight_decay}"
)
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}") raise ValueError(f"Invalid epsilon value: {eps}")
if initial_accumulator_value != 0.0: if initial_accumulator_value != 0.0:
...@@ -119,9 +117,7 @@ class Adagrad8bit(Optimizer1State): ...@@ -119,9 +117,7 @@ class Adagrad8bit(Optimizer1State):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}") raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError( raise ValueError(f"Invalid weight_decay value: {weight_decay}")
f"Invalid weight_decay value: {weight_decay}"
)
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}") raise ValueError(f"Invalid epsilon value: {eps}")
if initial_accumulator_value != 0.0: if initial_accumulator_value != 0.0:
...@@ -189,9 +185,7 @@ class Adagrad32bit(Optimizer1State): ...@@ -189,9 +185,7 @@ class Adagrad32bit(Optimizer1State):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}") raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError( raise ValueError(f"Invalid weight_decay value: {weight_decay}")
f"Invalid weight_decay value: {weight_decay}"
)
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}") raise ValueError(f"Invalid epsilon value: {eps}")
if initial_accumulator_value != 0.0: if initial_accumulator_value != 0.0:
......
...@@ -14,8 +14,21 @@ from bitsandbytes.optim.optimizer import Optimizer2State ...@@ -14,8 +14,21 @@ from bitsandbytes.optim.optimizer import Optimizer2State
class Adam(Optimizer2State): class Adam(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, def __init__(
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
""" """
Base Adam optimizer. Base Adam optimizer.
...@@ -45,11 +58,38 @@ class Adam(Optimizer2State): ...@@ -45,11 +58,38 @@ class Adam(Optimizer2State):
is_paged (`bool`, defaults to `False`): is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not. Whether the optimizer is a paged optimizer or not.
""" """
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class Adam8bit(Optimizer2State): class Adam8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, def __init__(
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
""" """
8-bit Adam optimizer. 8-bit Adam optimizer.
...@@ -79,11 +119,38 @@ class Adam8bit(Optimizer2State): ...@@ -79,11 +119,38 @@ class Adam8bit(Optimizer2State):
is_paged (`bool`, defaults to `False`): is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not. Whether the optimizer is a paged optimizer or not.
""" """
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class Adam32bit(Optimizer2State): class Adam32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, def __init__(
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
""" """
32-bit Adam optimizer. 32-bit Adam optimizer.
...@@ -113,11 +180,38 @@ class Adam32bit(Optimizer2State): ...@@ -113,11 +180,38 @@ class Adam32bit(Optimizer2State):
is_paged (`bool`, defaults to `False`): is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not. Whether the optimizer is a paged optimizer or not.
""" """
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class PagedAdam(Optimizer2State): class PagedAdam(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, def __init__(
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
""" """
Paged Adam optimizer. Paged Adam optimizer.
...@@ -147,11 +241,38 @@ class PagedAdam(Optimizer2State): ...@@ -147,11 +241,38 @@ class PagedAdam(Optimizer2State):
is_paged (`bool`, defaults to `False`): is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not. Whether the optimizer is a paged optimizer or not.
""" """
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
class PagedAdam8bit(Optimizer2State): class PagedAdam8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, def __init__(
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
""" """
8-bit paged Adam optimizer. 8-bit paged Adam optimizer.
...@@ -181,11 +302,38 @@ class PagedAdam8bit(Optimizer2State): ...@@ -181,11 +302,38 @@ class PagedAdam8bit(Optimizer2State):
is_paged (`bool`, defaults to `False`): is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not. Whether the optimizer is a paged optimizer or not.
""" """
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
class PagedAdam32bit(Optimizer2State): class PagedAdam32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, def __init__(
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
""" """
Paged 32-bit Adam optimizer. Paged 32-bit Adam optimizer.
...@@ -215,7 +363,21 @@ class PagedAdam32bit(Optimizer2State): ...@@ -215,7 +363,21 @@ class PagedAdam32bit(Optimizer2State):
is_paged (`bool`, defaults to `False`): is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not. Whether the optimizer is a paged optimizer or not.
""" """
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
class AnalysisAdam(torch.optim.Optimizer): class AnalysisAdam(torch.optim.Optimizer):
"""Adam that performs 8-bit vs 32-bit error analysis. """Adam that performs 8-bit vs 32-bit error analysis.
...@@ -293,9 +455,7 @@ class AnalysisAdam(torch.optim.Optimizer): ...@@ -293,9 +455,7 @@ class AnalysisAdam(torch.optim.Optimizer):
if grad.dtype in {torch.float16, torch.bfloat16}: if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float() grad = grad.float()
if grad.is_sparse: if grad.is_sparse:
raise RuntimeError( raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
"Adam does not support sparse gradients, please consider SparseAdam instead"
)
amsgrad = group.get("amsgrad", False) amsgrad = group.get("amsgrad", False)
assert not amsgrad assert not amsgrad
...@@ -312,15 +472,9 @@ class AnalysisAdam(torch.optim.Optimizer): ...@@ -312,15 +472,9 @@ class AnalysisAdam(torch.optim.Optimizer):
state["exp_avg"] = torch.zeros_like(p_data_fp32) state["exp_avg"] = torch.zeros_like(p_data_fp32)
# Exponential moving average of squared gradient values # Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
state["abserrors"] = torch.zeros( state["abserrors"] = torch.zeros((256, 256), device=p_data_fp32.device)
(256, 256), device=p_data_fp32.device state["relerrors"] = torch.zeros((256, 256), device=p_data_fp32.device)
) state["counts"] = torch.zeros((256, 256), device=p_data_fp32.device)
state["relerrors"] = torch.zeros(
(256, 256), device=p_data_fp32.device
)
state["counts"] = torch.zeros(
(256, 256), device=p_data_fp32.device
)
if amsgrad: if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values # Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32) state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32)
...@@ -328,25 +482,19 @@ class AnalysisAdam(torch.optim.Optimizer): ...@@ -328,25 +482,19 @@ class AnalysisAdam(torch.optim.Optimizer):
state["exp_avg"] = state["exp_avg"].to(p_data_fp32) state["exp_avg"] = state["exp_avg"].to(p_data_fp32)
state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32) state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32)
if amsgrad: if amsgrad:
state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to( state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to(p_data_fp32)
p_data_fp32
)
state["step"] += 1 state["step"] += 1
beta1, beta2 = group["betas"] beta1, beta2 = group["betas"]
bias_correction1 = 1 - beta1 ** state["step"] bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"]
step_size = ( step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
group["lr"] * math.sqrt(bias_correction2) / bias_correction1
)
e = state["abserrors"] e = state["abserrors"]
rele = state["relerrors"] rele = state["relerrors"]
counts = state["counts"] counts = state["counts"]
if group["weight_decay"] != 0: if group["weight_decay"] != 0:
p_data_fp32.add_( p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"])
p_data_fp32, alpha=-group["weight_decay"] * group["lr"]
)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
if amsgrad: if amsgrad:
...@@ -359,10 +507,7 @@ class AnalysisAdam(torch.optim.Optimizer): ...@@ -359,10 +507,7 @@ class AnalysisAdam(torch.optim.Optimizer):
denom = exp_avg_sq.sqrt().add_(group["eps"]) denom = exp_avg_sq.sqrt().add_(group["eps"])
update_fp32 = exp_avg / denom update_fp32 = exp_avg / denom
if ( if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000 * 1000:
p_data_fp32.numel() <= 8192
or p_data_fp32.numel() > 50000 * 1000
):
# embedding layer or too small # embedding layer or too small
p_data_fp32 += -step_size * update_fp32 p_data_fp32 += -step_size * update_fp32
else: else:
...@@ -401,9 +546,7 @@ class AnalysisAdam(torch.optim.Optimizer): ...@@ -401,9 +546,7 @@ class AnalysisAdam(torch.optim.Optimizer):
# 3. dequantize # 3. dequantize
# Error will be calculated automatically! # Error will be calculated automatically!
else: else:
raise ValueError( raise ValueError(f"Invalid analysis value: {self.analysis}!")
f"Invalid analysis value: {self.analysis}!"
)
denom = state2.sqrt().add_(group["eps"]) denom = state2.sqrt().add_(group["eps"])
update_8bit = state1 / denom update_8bit = state1 / denom
...@@ -415,9 +558,7 @@ class AnalysisAdam(torch.optim.Optimizer): ...@@ -415,9 +558,7 @@ class AnalysisAdam(torch.optim.Optimizer):
F.histogram_scatter_add_2d(e, C1.int(), C2.int(), abserr) F.histogram_scatter_add_2d(e, C1.int(), C2.int(), abserr)
F.histogram_scatter_add_2d(rele, C1.int(), C2.int(), relerr) F.histogram_scatter_add_2d(rele, C1.int(), C2.int(), relerr)
F.histogram_scatter_add_2d( F.histogram_scatter_add_2d(counts, C1.int(), C2.int(), torch.ones_like(abserr))
counts, C1.int(), C2.int(), torch.ones_like(abserr)
)
p_data_fp32 += -step_size * update_fp32 p_data_fp32 += -step_size * update_fp32
...@@ -425,18 +566,10 @@ class AnalysisAdam(torch.optim.Optimizer): ...@@ -425,18 +566,10 @@ class AnalysisAdam(torch.optim.Optimizer):
if self.savedir != "" and state["step"] % 100 == 0: if self.savedir != "" and state["step"] % 100 == 0:
if not os.path.exists(self.savedir): if not os.path.exists(self.savedir):
os.makedirs(self.savedir) os.makedirs(self.savedir)
shapestr = "_".join( shapestr = "_".join([str(dim) for dim in p_data_fp32.shape])
[str(dim) for dim in p_data_fp32.shape] pathe = os.path.join(self.savedir, f"{p_id}_{shapestr}_abserr.pkl")
) pathrele = os.path.join(self.savedir, f"{p_id}_{shapestr}_relerr.pkl")
pathe = os.path.join( pathcounts = os.path.join(self.savedir, f"{p_id}_{shapestr}_counts.pkl")
self.savedir, f"{p_id}_{shapestr}_abserr.pkl"
)
pathrele = os.path.join(
self.savedir, f"{p_id}_{shapestr}_relerr.pkl"
)
pathcounts = os.path.join(
self.savedir, f"{p_id}_{shapestr}_counts.pkl"
)
torch.save(e, pathe) torch.save(e, pathe)
torch.save(rele, pathrele) torch.save(rele, pathrele)
torch.save(counts, pathcounts) torch.save(counts, pathcounts)
......
...@@ -6,8 +6,21 @@ from bitsandbytes.optim.optimizer import Optimizer2State ...@@ -6,8 +6,21 @@ from bitsandbytes.optim.optimizer import Optimizer2State
class AdamW(Optimizer2State): class AdamW(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, def __init__(
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
""" """
Base AdamW optimizer. Base AdamW optimizer.
...@@ -37,11 +50,38 @@ class AdamW(Optimizer2State): ...@@ -37,11 +50,38 @@ class AdamW(Optimizer2State):
is_paged (`bool`, defaults to `False`): is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not. Whether the optimizer is a paged optimizer or not.
""" """
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged ) super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class AdamW8bit(Optimizer2State): class AdamW8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, def __init__(
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
""" """
8-bit AdamW optimizer. 8-bit AdamW optimizer.
...@@ -71,11 +111,38 @@ class AdamW8bit(Optimizer2State): ...@@ -71,11 +111,38 @@ class AdamW8bit(Optimizer2State):
is_paged (`bool`, defaults to `False`): is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not. Whether the optimizer is a paged optimizer or not.
""" """
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged ) super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class AdamW32bit(Optimizer2State): class AdamW32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, def __init__(
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
""" """
32-bit AdamW optimizer. 32-bit AdamW optimizer.
...@@ -105,12 +172,37 @@ class AdamW32bit(Optimizer2State): ...@@ -105,12 +172,37 @@ class AdamW32bit(Optimizer2State):
is_paged (`bool`, defaults to `False`): is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not. Whether the optimizer is a paged optimizer or not.
""" """
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class PagedAdamW(Optimizer2State): class PagedAdamW(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, def __init__(
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
""" """
Paged AdamW optimizer. Paged AdamW optimizer.
...@@ -140,11 +232,37 @@ class PagedAdamW(Optimizer2State): ...@@ -140,11 +232,37 @@ class PagedAdamW(Optimizer2State):
is_paged (`bool`, defaults to `False`): is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not. Whether the optimizer is a paged optimizer or not.
""" """
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
class PagedAdamW8bit(Optimizer2State): class PagedAdamW8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, def __init__(
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
""" """
Paged 8-bit AdamW optimizer. Paged 8-bit AdamW optimizer.
...@@ -174,11 +292,37 @@ class PagedAdamW8bit(Optimizer2State): ...@@ -174,11 +292,37 @@ class PagedAdamW8bit(Optimizer2State):
is_paged (`bool`, defaults to `False`): is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not. Whether the optimizer is a paged optimizer or not.
""" """
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
class PagedAdamW32bit(Optimizer2State): class PagedAdamW32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, def __init__(
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
""" """
Paged 32-bit AdamW optimizer. Paged 32-bit AdamW optimizer.
...@@ -208,4 +352,17 @@ class PagedAdamW32bit(Optimizer2State): ...@@ -208,4 +352,17 @@ class PagedAdamW32bit(Optimizer2State):
is_paged (`bool`, defaults to `False`): is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not. Whether the optimizer is a paged optimizer or not.
""" """
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
...@@ -51,9 +51,7 @@ class LARS(Optimizer1State): ...@@ -51,9 +51,7 @@ class LARS(Optimizer1State):
The maximum gradient norm. The maximum gradient norm.
""" """
if momentum == 0: if momentum == 0:
raise NotImplementedError( raise NotImplementedError("LARS without momentum is not supported!")
"LARS without momentum is not supported!"
)
super().__init__( super().__init__(
"lars", "lars",
params, params,
...@@ -110,9 +108,7 @@ class LARS8bit(Optimizer1State): ...@@ -110,9 +108,7 @@ class LARS8bit(Optimizer1State):
The maximum gradient norm. The maximum gradient norm.
""" """
if momentum == 0: if momentum == 0:
raise NotImplementedError( raise NotImplementedError("LARS without momentum is not supported!")
"LARS without momentum is not supported!"
)
super().__init__( super().__init__(
"lars", "lars",
params, params,
...@@ -169,9 +165,7 @@ class LARS32bit(Optimizer1State): ...@@ -169,9 +165,7 @@ class LARS32bit(Optimizer1State):
The maximum gradient norm. The maximum gradient norm.
""" """
if momentum == 0: if momentum == 0:
raise NotImplementedError( raise NotImplementedError("LARS without momentum is not supported!")
"LARS without momentum is not supported!"
)
super().__init__( super().__init__(
"lars", "lars",
params, params,
...@@ -204,9 +198,7 @@ class PytorchLARS(Optimizer): ...@@ -204,9 +198,7 @@ class PytorchLARS(Optimizer):
if momentum < 0.0: if momentum < 0.0:
raise ValueError(f"Invalid momentum value: {momentum}") raise ValueError(f"Invalid momentum value: {momentum}")
if weight_decay < 0.0: if weight_decay < 0.0:
raise ValueError( raise ValueError(f"Invalid weight_decay value: {weight_decay}")
f"Invalid weight_decay value: {weight_decay}"
)
defaults = dict( defaults = dict(
lr=lr, lr=lr,
...@@ -217,9 +209,7 @@ class PytorchLARS(Optimizer): ...@@ -217,9 +209,7 @@ class PytorchLARS(Optimizer):
max_unorm=max_unorm, max_unorm=max_unorm,
) )
if nesterov and (momentum <= 0 or dampening != 0): if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError( raise ValueError("Nesterov momentum requires a momentum and zero dampening")
"Nesterov momentum requires a momentum and zero dampening"
)
super().__init__(params, defaults) super().__init__(params, defaults)
def __setstate__(self, state): def __setstate__(self, state):
......
...@@ -6,7 +6,19 @@ from bitsandbytes.optim.optimizer import Optimizer1State ...@@ -6,7 +6,19 @@ from bitsandbytes.optim.optimizer import Optimizer1State
class Lion(Optimizer1State): class Lion(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
""" """
Base Lion optimizer. Base Lion optimizer.
...@@ -32,10 +44,35 @@ class Lion(Optimizer1State): ...@@ -32,10 +44,35 @@ class Lion(Optimizer1State):
is_paged (`bool`, defaults to `False`): is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not. Whether the optimizer is a paged optimizer or not.
""" """
super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) super().__init__(
"lion",
params,
lr,
betas,
0.0,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class Lion8bit(Optimizer1State): class Lion8bit(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
""" """
8-bit Lion optimizer. 8-bit Lion optimizer.
...@@ -59,10 +96,35 @@ class Lion8bit(Optimizer1State): ...@@ -59,10 +96,35 @@ class Lion8bit(Optimizer1State):
is_paged (`bool`, defaults to `False`): is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not. Whether the optimizer is a paged optimizer or not.
""" """
super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) super().__init__(
"lion",
params,
lr,
betas,
0.0,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class Lion32bit(Optimizer1State): class Lion32bit(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
""" """
32-bit Lion optimizer. 32-bit Lion optimizer.
...@@ -86,11 +148,35 @@ class Lion32bit(Optimizer1State): ...@@ -86,11 +148,35 @@ class Lion32bit(Optimizer1State):
is_paged (`bool`, defaults to `False`): is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not. Whether the optimizer is a paged optimizer or not.
""" """
super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) super().__init__(
"lion",
params,
lr,
betas,
0.0,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class PagedLion(Optimizer1State): class PagedLion(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
""" """
Paged Lion optimizer. Paged Lion optimizer.
...@@ -114,10 +200,34 @@ class PagedLion(Optimizer1State): ...@@ -114,10 +200,34 @@ class PagedLion(Optimizer1State):
block_wise (`bool`, defaults to `True`): block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
""" """
super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) super().__init__(
"lion",
params,
lr,
betas,
0.0,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
class PagedLion8bit(Optimizer1State): class PagedLion8bit(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
""" """
Paged 8-bit Lion optimizer. Paged 8-bit Lion optimizer.
...@@ -141,10 +251,34 @@ class PagedLion8bit(Optimizer1State): ...@@ -141,10 +251,34 @@ class PagedLion8bit(Optimizer1State):
block_wise (`bool`, defaults to `True`): block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
""" """
super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) super().__init__(
"lion",
params,
lr,
betas,
0.0,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
class PagedLion32bit(Optimizer1State): class PagedLion32bit(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
""" """
Paged 32-bit Lion optimizer. Paged 32-bit Lion optimizer.
...@@ -168,4 +302,17 @@ class PagedLion32bit(Optimizer1State): ...@@ -168,4 +302,17 @@ class PagedLion32bit(Optimizer1State):
block_wise (`bool`, defaults to `True`): block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
""" """
super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) super().__init__(
"lion",
params,
lr,
betas,
0.0,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
...@@ -21,6 +21,7 @@ class GlobalOptimManager: ...@@ -21,6 +21,7 @@ class GlobalOptimManager:
""" """
A global optimizer manager for enabling custom optimizer configs. A global optimizer manager for enabling custom optimizer configs.
""" """
_instance = None _instance = None
def __init__(self): def __init__(self):
...@@ -48,13 +49,9 @@ class GlobalOptimManager: ...@@ -48,13 +49,9 @@ class GlobalOptimManager:
for group_index, group in enumerate(param_groups): for group_index, group in enumerate(param_groups):
for p_index, p in enumerate(group["params"]): for p_index, p in enumerate(group["params"]):
if id(p) in self.pid2config: if id(p) in self.pid2config:
self.index2config[(group_index, p_index)] = self.pid2config[ self.index2config[(group_index, p_index)] = self.pid2config[id(p)]
id(p)
]
def override_config( def override_config(self, parameters, key=None, value=None, key_value_dict=None):
self, parameters, key=None, value=None, key_value_dict=None
):
""" """
Override initial optimizer config with specific hyperparameters. Override initial optimizer config with specific hyperparameters.
...@@ -170,16 +167,12 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -170,16 +167,12 @@ class Optimizer8bit(torch.optim.Optimizer):
saved_groups = state_dict["param_groups"] saved_groups = state_dict["param_groups"]
if len(groups) != len(saved_groups): if len(groups) != len(saved_groups):
raise ValueError( raise ValueError("loaded state dict has a different number of parameter groups")
"loaded state dict has a different number of "
"parameter groups"
)
param_lens = (len(g["params"]) for g in groups) param_lens = (len(g["params"]) for g in groups)
saved_lens = (len(g["params"]) for g in saved_groups) saved_lens = (len(g["params"]) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError( raise ValueError(
"loaded state dict contains a parameter group " "loaded state dict contains a parameter group that doesn't match the size of optimizer's group",
"that doesn't match the size of optimizer's group"
) )
# Update the state # Update the state
...@@ -228,9 +221,7 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -228,9 +221,7 @@ class Optimizer8bit(torch.optim.Optimizer):
new_group["params"] = group["params"] new_group["params"] = group["params"]
return new_group return new_group
param_groups = [ param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
update_group(g, ng) for g, ng in zip(groups, saved_groups)
]
self.__setstate__({"state": state, "param_groups": param_groups}) self.__setstate__({"state": state, "param_groups": param_groups})
def to_gpu(self): def to_gpu(self):
...@@ -240,7 +231,7 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -240,7 +231,7 @@ class Optimizer8bit(torch.optim.Optimizer):
values = self.state[p] values = self.state[p]
for k, v in values.items(): for k, v in values.items():
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
is_paged = getattr(v, 'is_paged', False) is_paged = getattr(v, "is_paged", False)
if not is_paged: if not is_paged:
self.state[p][k] = v.to(p.device) self.state[p][k] = v.to(p.device)
...@@ -248,9 +239,7 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -248,9 +239,7 @@ class Optimizer8bit(torch.optim.Optimizer):
for module, attr, config in self.mng.module_weight_config_triple: for module, attr, config in self.mng.module_weight_config_triple:
pmodule = getattr(module, attr) pmodule = getattr(module, attr)
assert pmodule is not None assert pmodule is not None
assert isinstance(pmodule, torch.Tensor) or isinstance( assert isinstance(pmodule, torch.Tensor) or isinstance(pmodule, torch.Parameter)
pmodule, torch.Parameter
)
found = False found = False
for gindex, group in enumerate(self.param_groups): for gindex, group in enumerate(self.param_groups):
if found: if found:
...@@ -262,9 +251,7 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -262,9 +251,7 @@ class Optimizer8bit(torch.optim.Optimizer):
# found the matching parameter # found the matching parameter
# init override # init override
self.mng.pid2config[id(p)] = config self.mng.pid2config[id(p)] = config
self.mng.index2config[ self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[id(p)]
(gindex, pindex)
] = self.mng.pid2config[id(p)]
found = True found = True
@torch.no_grad() @torch.no_grad()
...@@ -287,7 +274,7 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -287,7 +274,7 @@ class Optimizer8bit(torch.optim.Optimizer):
self.to_gpu() # needed for fairseq pure fp16 training self.to_gpu() # needed for fairseq pure fp16 training
self.initialized = True self.initialized = True
#if self.is_paged: self.page_mng.prefetch_all() # if self.is_paged: self.page_mng.prefetch_all()
for gindex, group in enumerate(self.param_groups): for gindex, group in enumerate(self.param_groups):
for pindex, p in enumerate(group["params"]): for pindex, p in enumerate(group["params"]):
if p.grad is None: if p.grad is None:
...@@ -304,7 +291,6 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -304,7 +291,6 @@ class Optimizer8bit(torch.optim.Optimizer):
# to sync to make sure all tensors are in the right state # to sync to make sure all tensors are in the right state
torch.cuda.synchronize() torch.cuda.synchronize()
return loss return loss
def get_config(self, gindex, pindex, group): def get_config(self, gindex, pindex, group):
...@@ -328,9 +314,7 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -328,9 +314,7 @@ class Optimizer8bit(torch.optim.Optimizer):
raise NotImplementedError("init_state method needs to be overridden") raise NotImplementedError("init_state method needs to be overridden")
def update_step(self, group, p, gindex, pindex): def update_step(self, group, p, gindex, pindex):
raise NotImplementedError( raise NotImplementedError("The update_step method needs to be overridden")
"The update_step method needs to be overridden"
)
def get_state_buffer(self, p, dtype=torch.float32): def get_state_buffer(self, p, dtype=torch.float32):
if not self.is_paged or p.numel() < 1e5: if not self.is_paged or p.numel() < 1e5:
...@@ -345,12 +329,12 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -345,12 +329,12 @@ class Optimizer8bit(torch.optim.Optimizer):
def prefetch_state(self, p): def prefetch_state(self, p):
if self.is_paged: if self.is_paged:
state = self.state[p] state = self.state[p]
s1 = state['state1'] s1 = state["state1"]
is_paged = getattr(s1, 'is_paged', False) is_paged = getattr(s1, "is_paged", False)
if is_paged: if is_paged:
F.prefetch_tensor(state['state1']) F.prefetch_tensor(state["state1"])
if 'state2' in state: if "state2" in state:
F.prefetch_tensor(state['state2']) F.prefetch_tensor(state["state2"])
class Optimizer2State(Optimizer8bit): class Optimizer2State(Optimizer8bit):
...@@ -369,7 +353,7 @@ class Optimizer2State(Optimizer8bit): ...@@ -369,7 +353,7 @@ class Optimizer2State(Optimizer8bit):
block_wise=True, block_wise=True,
max_unorm=0.0, max_unorm=0.0,
skip_zeros=False, skip_zeros=False,
is_paged=False is_paged=False,
): ):
""" """
Base 2-state update optimizer class. Base 2-state update optimizer class.
...@@ -414,13 +398,9 @@ class Optimizer2State(Optimizer8bit): ...@@ -414,13 +398,9 @@ class Optimizer2State(Optimizer8bit):
betas = [float(b) for b in betas] betas = [float(b) for b in betas]
for i in range(len(betas)): for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0: if not 0.0 <= betas[i] < 1.0:
raise ValueError( raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
f"Invalid beta parameter at index {i}: {betas[i]}"
)
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError( raise ValueError(f"Invalid weight_decay value: {weight_decay}")
f"Invalid weight_decay value: {weight_decay}"
)
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults, optim_bits, is_paged) super().__init__(params, defaults, optim_bits, is_paged)
...@@ -449,9 +429,7 @@ class Optimizer2State(Optimizer8bit): ...@@ -449,9 +429,7 @@ class Optimizer2State(Optimizer8bit):
elif config["optim_bits"] == 8: elif config["optim_bits"] == 8:
dtype = torch.uint8 dtype = torch.uint8
else: else:
raise NotImplementedError( raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')
f'Amount of optimizer bits not supported: {config["optim_bits"]}'
)
if p.numel() < config["min_8bit_size"]: if p.numel() < config["min_8bit_size"]:
dtype = torch.float32 dtype = torch.float32
...@@ -459,21 +437,15 @@ class Optimizer2State(Optimizer8bit): ...@@ -459,21 +437,15 @@ class Optimizer2State(Optimizer8bit):
state = self.state[p] state = self.state[p]
state["step"] = 0 state["step"] = 0
if dtype == torch.float32 or ( if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
dtype == torch.uint8 and p.numel() < 4096
):
state["state1"] = self.get_state_buffer(p, dtype=torch.float32) state["state1"] = self.get_state_buffer(p, dtype=torch.float32)
state["state2"] = self.get_state_buffer(p, dtype=torch.float32) state["state2"] = self.get_state_buffer(p, dtype=torch.float32)
elif dtype == torch.uint8: elif dtype == torch.uint8:
if state["step"] == 0: if state["step"] == 0:
if "dynamic" not in self.name2qmap: if "dynamic" not in self.name2qmap:
self.fill_qmap() self.fill_qmap()
self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to( self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device)
p.device self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(p.device)
)
self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(
p.device
)
state["state1"] = self.get_state_buffer(p, dtype=torch.uint8) state["state1"] = self.get_state_buffer(p, dtype=torch.uint8)
state["qmap1"] = self.name2qmap["dynamic"] state["qmap1"] = self.name2qmap["dynamic"]
...@@ -486,25 +458,13 @@ class Optimizer2State(Optimizer8bit): ...@@ -486,25 +458,13 @@ class Optimizer2State(Optimizer8bit):
blocks = n // 2048 blocks = n // 2048
blocks += 1 if n % 2048 > 0 else 0 blocks += 1 if n % 2048 > 0 else 0
state["absmax1"] = torch.zeros( state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
(blocks,), dtype=torch.float32, device=p.device state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
)
state["absmax2"] = torch.zeros(
(blocks,), dtype=torch.float32, device=p.device
)
else: else:
state["max1"] = torch.zeros( state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
(1,), dtype=torch.float32, device=p.device state["new_max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
) state["max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
state["new_max1"] = torch.zeros( state["new_max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
(1,), dtype=torch.float32, device=p.device
)
state["max2"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
state["new_max2"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
if config["percentile_clipping"] < 100: if config["percentile_clipping"] < 100:
state["gnorm_vec"] = torch.zeros((100,), device=p.device) state["gnorm_vec"] = torch.zeros((100,), device=p.device)
...@@ -524,7 +484,10 @@ class Optimizer2State(Optimizer8bit): ...@@ -524,7 +484,10 @@ class Optimizer2State(Optimizer8bit):
if config["percentile_clipping"] < 100: if config["percentile_clipping"] < 100:
current_gnorm, clip_value, gnorm_scale = F.percentile_clipping( current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(
grad, state["gnorm_vec"], step, config["percentile_clipping"] grad,
state["gnorm_vec"],
step,
config["percentile_clipping"],
) )
else: else:
gnorm_scale = 1.0 gnorm_scale = 1.0
...@@ -568,9 +531,7 @@ class Optimizer2State(Optimizer8bit): ...@@ -568,9 +531,7 @@ class Optimizer2State(Optimizer8bit):
state["new_max2"], state["new_max2"],
config["weight_decay"], config["weight_decay"],
gnorm_scale=gnorm_scale, gnorm_scale=gnorm_scale,
unorm_vec=state["unorm_vec"] unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
if config["max_unorm"] > 0.0
else None,
max_unorm=config["max_unorm"], max_unorm=config["max_unorm"],
) )
...@@ -615,7 +576,7 @@ class Optimizer1State(Optimizer8bit): ...@@ -615,7 +576,7 @@ class Optimizer1State(Optimizer8bit):
block_wise=True, block_wise=True,
max_unorm=0.0, max_unorm=0.0,
skip_zeros=False, skip_zeros=False,
is_paged=False is_paged=False,
): ):
""" """
Base 1-state update optimizer class. Base 1-state update optimizer class.
...@@ -656,13 +617,9 @@ class Optimizer1State(Optimizer8bit): ...@@ -656,13 +617,9 @@ class Optimizer1State(Optimizer8bit):
raise ValueError(f"Invalid epsilon value: {eps}") raise ValueError(f"Invalid epsilon value: {eps}")
for i in range(len(betas)): for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0: if not 0.0 <= betas[i] < 1.0:
raise ValueError( raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
f"Invalid beta parameter at index {i}: {betas[i]}"
)
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError( raise ValueError(f"Invalid weight_decay value: {weight_decay}")
f"Invalid weight_decay value: {weight_decay}"
)
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults, optim_bits, is_paged) super().__init__(params, defaults, optim_bits, is_paged)
...@@ -691,9 +648,7 @@ class Optimizer1State(Optimizer8bit): ...@@ -691,9 +648,7 @@ class Optimizer1State(Optimizer8bit):
elif config["optim_bits"] == 8: elif config["optim_bits"] == 8:
dtype = torch.uint8 dtype = torch.uint8
else: else:
raise NotImplementedError( raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')
f'Amount of optimizer bits not supported: {config["optim_bits"]}'
)
if p.numel() < config["min_8bit_size"]: if p.numel() < config["min_8bit_size"]:
dtype = torch.float32 dtype = torch.float32
...@@ -701,17 +656,13 @@ class Optimizer1State(Optimizer8bit): ...@@ -701,17 +656,13 @@ class Optimizer1State(Optimizer8bit):
state = self.state[p] state = self.state[p]
state["step"] = 0 state["step"] = 0
if dtype == torch.float32 or ( if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
dtype == torch.uint8 and p.numel() < 4096
):
state["state1"] = self.get_state_buffer(p, dtype=torch.float32) state["state1"] = self.get_state_buffer(p, dtype=torch.float32)
elif dtype == torch.uint8: elif dtype == torch.uint8:
if state["step"] == 0: if state["step"] == 0:
if "dynamic" not in self.name2qmap: if "dynamic" not in self.name2qmap:
self.fill_qmap() self.fill_qmap()
self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to( self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device)
p.device
)
state["state1"] = self.get_state_buffer(p, dtype=torch.uint8) state["state1"] = self.get_state_buffer(p, dtype=torch.uint8)
state["qmap1"] = self.name2qmap["dynamic"] state["qmap1"] = self.name2qmap["dynamic"]
...@@ -721,16 +672,10 @@ class Optimizer1State(Optimizer8bit): ...@@ -721,16 +672,10 @@ class Optimizer1State(Optimizer8bit):
blocks = n // 2048 blocks = n // 2048
blocks += 1 if n % 2048 > 0 else 0 blocks += 1 if n % 2048 > 0 else 0
state["absmax1"] = torch.zeros( state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
(blocks,), dtype=torch.float32, device=p.device
)
else: else:
state["max1"] = torch.zeros( state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
(1,), dtype=torch.float32, device=p.device state["new_max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
)
state["new_max1"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
if config["percentile_clipping"] < 100: if config["percentile_clipping"] < 100:
state["gnorm_vec"] = torch.zeros((100,), device=p.device) state["gnorm_vec"] = torch.zeros((100,), device=p.device)
...@@ -750,7 +695,10 @@ class Optimizer1State(Optimizer8bit): ...@@ -750,7 +695,10 @@ class Optimizer1State(Optimizer8bit):
if config["percentile_clipping"] < 100: if config["percentile_clipping"] < 100:
current_gnorm, clip_value, gnorm_scale = F.percentile_clipping( current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(
grad, state["gnorm_vec"], step, config["percentile_clipping"] grad,
state["gnorm_vec"],
step,
config["percentile_clipping"],
) )
else: else:
gnorm_scale = 1.0 gnorm_scale = 1.0
...@@ -766,7 +714,7 @@ class Optimizer1State(Optimizer8bit): ...@@ -766,7 +714,7 @@ class Optimizer1State(Optimizer8bit):
step, step,
config["lr"], config["lr"],
None, None,
config['betas'][1], config["betas"][1],
config["weight_decay"], config["weight_decay"],
gnorm_scale, gnorm_scale,
state["unorm_vec"] if config["max_unorm"] > 0.0 else None, state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
......
...@@ -51,9 +51,7 @@ class RMSprop(Optimizer1State): ...@@ -51,9 +51,7 @@ class RMSprop(Optimizer1State):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
""" """
if alpha == 0: if alpha == 0:
raise NotImplementedError( raise NotImplementedError("RMSprop with alpha==0.0 is not supported!")
"RMSprop with alpha==0.0 is not supported!"
)
if centered: if centered:
raise NotImplementedError("Centered RMSprop is not supported!") raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__( super().__init__(
...@@ -116,9 +114,7 @@ class RMSprop8bit(Optimizer1State): ...@@ -116,9 +114,7 @@ class RMSprop8bit(Optimizer1State):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
""" """
if alpha == 0: if alpha == 0:
raise NotImplementedError( raise NotImplementedError("RMSprop with alpha==0.0 is not supported!")
"RMSprop with alpha==0.0 is not supported!"
)
if centered: if centered:
raise NotImplementedError("Centered RMSprop is not supported!") raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__( super().__init__(
...@@ -182,9 +178,7 @@ class RMSprop32bit(Optimizer1State): ...@@ -182,9 +178,7 @@ class RMSprop32bit(Optimizer1State):
""" """
if alpha == 0: if alpha == 0:
raise NotImplementedError( raise NotImplementedError("RMSprop with alpha==0.0 is not supported!")
"RMSprop with alpha==0.0 is not supported!"
)
if centered: if centered:
raise NotImplementedError("Centered RMSprop is not supported!") raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__( super().__init__(
......
...@@ -195,9 +195,9 @@ class SwitchBackBnb(torch.autograd.Function): ...@@ -195,9 +195,9 @@ class SwitchBackBnb(torch.autograd.Function):
ctx.B = B ctx.B = B
ctx.bias = bias ctx.bias = bias
if A.shape[-1] == B.shape[0]: if A.shape[-1] == B.shape[0]:
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device) return torch.empty(A.shape[:-1] + B.shape[1:], dtype=A.dtype, device=A.device)
else: else:
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device) return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=A.device)
# 1. Quantize A # 1. Quantize A
# 2. Quantize B # 2. Quantize B
...@@ -216,9 +216,7 @@ class SwitchBackBnb(torch.autograd.Function): ...@@ -216,9 +216,7 @@ class SwitchBackBnb(torch.autograd.Function):
# 1. Quantize A # 1. Quantize A
if len(A.shape) == 3: if len(A.shape) == 3:
A = A.view(-1, A.shape[-1]).contiguous() A = A.view(-1, A.shape[-1]).contiguous()
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant( CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold)
A.to(torch.float16), threshold=state.threshold
)
if state.threshold > 0.0 and coo_tensorA is not None: if state.threshold > 0.0 and coo_tensorA is not None:
if state.has_fp16_weights: if state.has_fp16_weights:
...@@ -234,14 +232,14 @@ class SwitchBackBnb(torch.autograd.Function): ...@@ -234,14 +232,14 @@ class SwitchBackBnb(torch.autograd.Function):
# we also need to convert it to the turing/ampere format # we also need to convert it to the turing/ampere format
state.CxB, state.SB = F.transform(state.CB, to_order=formatB) state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
else: else:
#print('A shape', A.shape) # print('A shape', A.shape)
if not state.has_fp16_weights and state.CxB is None: if not state.has_fp16_weights and state.CxB is None:
state.CxB, state.SB = F.transform(state.CB, to_order=formatB) state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
subA = None subA = None
# 2. Quantize B # 2. Quantize B
if state.has_fp16_weights: if state.has_fp16_weights:
#print('B shape', B.shape) # print('B shape', B.shape)
has_grad = True if (getattr(B, "grad", None) is not None) else False has_grad = True if (getattr(B, "grad", None) is not None) else False
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1) is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
if is_transposed: if is_transposed:
...@@ -272,12 +270,7 @@ class SwitchBackBnb(torch.autograd.Function): ...@@ -272,12 +270,7 @@ class SwitchBackBnb(torch.autograd.Function):
# else: # else:
# state.idx = outlier_idx # state.idx = outlier_idx
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
state.subB = ( state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype)
(outliers * state.SCB.view(-1, 1) / 127.0)
.t()
.contiguous()
.to(A.dtype)
)
CA[:, state.idx.long()] = 0 CA[:, state.idx.long()] = 0
CAt[:, state.idx.long()] = 0 CAt[:, state.idx.long()] = 0
subA = A[:, state.idx.long()] subA = A[:, state.idx.long()]
...@@ -320,14 +313,13 @@ class SwitchBackBnb(torch.autograd.Function): ...@@ -320,14 +313,13 @@ class SwitchBackBnb(torch.autograd.Function):
ctx.tensor_states = (None, None) ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None) ctx.save_for_backward(None, None)
clone_func = torch.clone if len(output_shape) == 3 else lambda x: x
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
return clone_func(output.view(output_shape)) return clone_func(output.view(output_shape))
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
if ctx.is_empty: if ctx.is_empty:
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias)) bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
CAt, subA, A = ctx.tensors CAt, subA, A = ctx.tensors
...@@ -342,9 +334,7 @@ class SwitchBackBnb(torch.autograd.Function): ...@@ -342,9 +334,7 @@ class SwitchBackBnb(torch.autograd.Function):
# Cast grad_output to fp16 # Cast grad_output to fp16
if len(grad_output.shape) == 3: if len(grad_output.shape) == 3:
grad_output = grad_output.reshape( grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
-1, grad_output.shape[-1]
).contiguous()
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
...@@ -357,25 +347,24 @@ class SwitchBackBnb(torch.autograd.Function): ...@@ -357,25 +347,24 @@ class SwitchBackBnb(torch.autograd.Function):
if state.CBt is not None: if state.CBt is not None:
C32grad, Sgrad = F.transform(Cgrad, "col32") C32grad, Sgrad = F.transform(Cgrad, "col32")
if state.CxBt is None: if state.CxBt is None:
state.CxBt, state.SBt = F.transform( state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
state.CBt, to_order=formatB, transpose=True
)
# print('back B shape', state.CxBt.shape) # print('back B shape', state.CxBt.shape)
# print('back grad shape', C32grad.shape) # print('back grad shape', C32grad.shape)
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
elif state.CB is not None: elif state.CB is not None:
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0)) CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
else: else:
raise Exception('State must contain either CBt or CB matrix for backward') raise Exception("State must contain either CBt or CB matrix for backward")
return grad_A, grad_B, None, grad_bias, None return grad_A, grad_B, None, grad_bias, None
def get_block_sizes(input_matrix, weight_matrix): def get_block_sizes(input_matrix, weight_matrix):
input_features = input_matrix.shape[-1] input_features = input_matrix.shape[-1]
output_features = (weight_matrix.shape[0] if weight_matrix.shape[1] == input_features else weight_matrix.shape[1]) output_features = weight_matrix.shape[0] if weight_matrix.shape[1] == input_features else weight_matrix.shape[1]
array = [4096, 2048, 1024, 512, 256, 128, 64, 0] array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
bsz, bsz2 = 1024, 1024 bsz, bsz2 = 1024, 1024
for i, k in enumerate(array): for i, k in enumerate(array):
...@@ -399,7 +388,8 @@ def matmul_fp8_global( ...@@ -399,7 +388,8 @@ def matmul_fp8_global(
bsz: int = -1, bsz: int = -1,
bsz2: int = -1, bsz2: int = -1,
): ):
if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B) if bsz == -1 or bsz2 == -1:
bsz, bsz2 = get_block_sizes(A, B)
return MatMulFP8Global.apply(A, B, out, fw_code, bw_code, bsz, bsz2) return MatMulFP8Global.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
...@@ -412,7 +402,8 @@ def matmul_fp8_mixed( ...@@ -412,7 +402,8 @@ def matmul_fp8_mixed(
bsz: int = -1, bsz: int = -1,
bsz2: int = -1, bsz2: int = -1,
): ):
if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B) if bsz == -1 or bsz2 == -1:
bsz, bsz2 = get_block_sizes(A, B)
return MatMulFP8Mixed.apply(A, B, out, fw_code, bw_code, bsz, bsz2) return MatMulFP8Mixed.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
...@@ -422,7 +413,7 @@ def switchback_bnb( ...@@ -422,7 +413,7 @@ def switchback_bnb(
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
state: Optional[MatmulLtState] = None, state: Optional[MatmulLtState] = None,
threshold=0.0, threshold=0.0,
bias=None bias=None,
): ):
state = state or MatmulLtState() state = state or MatmulLtState()
if threshold > 0.0: if threshold > 0.0:
......
...@@ -28,12 +28,20 @@ class LinearFP8Mixed(nn.Linear): ...@@ -28,12 +28,20 @@ class LinearFP8Mixed(nn.Linear):
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device) self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device) self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
out = bnb.research.matmul_fp8_mixed(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2) out = bnb.research.matmul_fp8_mixed(
x,
self.weight.t(),
fw_code=self.fw_code,
bw_code=self.bw_code,
bsz=self.bsz,
bsz2=self.bsz2,
)
if self.bias is not None: if self.bias is not None:
out += self.bias out += self.bias
return out return out
class LinearFP8Global(nn.Linear): class LinearFP8Global(nn.Linear):
def __init__(self, input_features, output_features, bias=True): def __init__(self, input_features, output_features, bias=True):
super().__init__(input_features, output_features, bias) super().__init__(input_features, output_features, bias)
...@@ -54,7 +62,14 @@ class LinearFP8Global(nn.Linear): ...@@ -54,7 +62,14 @@ class LinearFP8Global(nn.Linear):
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device) self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device) self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
out = bnb.matmul_fp8_global(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2) out = bnb.matmul_fp8_global(
x,
self.weight.t(),
fw_code=self.fw_code,
bw_code=self.bw_code,
bsz=self.bsz,
bsz2=self.bsz2,
)
if self.bias is not None: if self.bias is not None:
out += self.bias out += self.bias
......
...@@ -5,9 +5,10 @@ import torch ...@@ -5,9 +5,10 @@ import torch
from bitsandbytes.triton.triton_utils import is_triton_available from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available(): if not is_triton_available():
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): return None
else:
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
return None
else:
import triton import triton
import triton.language as tl import triton.language as tl
...@@ -29,7 +30,7 @@ else: ...@@ -29,7 +30,7 @@ else:
triton.Config({}, num_warps=4), triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8), triton.Config({}, num_warps=8),
], ],
key=['n_elements'] key=["n_elements"],
) )
@triton.jit @triton.jit
def _dequantize_rowwise( def _dequantize_rowwise(
...@@ -51,7 +52,6 @@ else: ...@@ -51,7 +52,6 @@ else:
output = max_val * x * inv_127 output = max_val * x * inv_127
tl.store(output_ptr + offsets, output, mask=row_mask) tl.store(output_ptr + offsets, output, mask=row_mask)
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
output = torch.empty(*x.shape, device=x.device, dtype=torch.float16) output = torch.empty(*x.shape, device=x.device, dtype=torch.float16)
...@@ -60,5 +60,5 @@ else: ...@@ -60,5 +60,5 @@ else:
assert x.is_cuda and output.is_cuda assert x.is_cuda and output.is_cuda
n_elements = output.numel() n_elements = output.numel()
grid = lambda meta: (x.shape[0],) grid = lambda meta: (x.shape[0],)
_dequantize_rowwise[grid](x, state_x, output, 1./127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) _dequantize_rowwise[grid](x, state_x, output, 1.0 / 127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
return output return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment