Commit 5a4263f4 authored by Ruff's avatar Ruff Committed by Aarni Koskela
Browse files

Reformat with ruff-format

parent 02e30ca6
......@@ -7,9 +7,7 @@ def get_platform_tag(architecture):
system = platform.system()
if system == "Linux":
tag = (
"manylinux_2_24_x86_64" if architecture == "x86_64" else "manylinux_2_24_aarch64"
)
tag = "manylinux_2_24_x86_64" if architecture == "x86_64" else "manylinux_2_24_aarch64"
elif system == "Darwin":
tag = "macosx_13_1_x86_64" if architecture == "x86_64" else "macosx_13_1_arm64"
elif system == "Windows":
......
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import pandas as pd
cmap=plt.get_cmap('cool')
if __name__ == '__main__':
cmap = plt.get_cmap("cool")
fig = plt.figure(tight_layout=True, figsize=(12,3.5))
if __name__ == "__main__":
fig = plt.figure(tight_layout=True, figsize=(12, 3.5))
gs = gridspec.GridSpec(1, 2)
dims_to_consider = [1024, 1280, 1408, 1664, 2048, 4096]
......@@ -19,25 +17,28 @@ if __name__ == '__main__':
ax = fig.add_subplot(gs[0, 0])
# TODO: change this to what you want.
rdf = pd.read_json('speed_benchmark/info_a100_py2.jsonl', lines=True)
rdf = pd.read_json("speed_benchmark/info_a100_py2.jsonl", lines=True)
df = rdf[rdf.batch_size == batch_size_for_plot1]
# first plot the time occupied by different operations
for k, marker, ls, color, name in [
('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (sum of parts)'),
('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (sum of parts)'),
('standard_fwd', '^', '--', 'C2', 'Matmul XW (standard)'),
('standard_gw', '^', '-.', 'C2', 'Matmul GW (standard)'),
('standard_gx', '^', ':', 'gray', 'Matmul GX (both)'),
('global_fwd', '^', '--', 'C4', 'Int8 Matmul XW (switchback)'),
('global_bwd', '^', '-.', 'C4', 'Int8 Matmul GW (switchback)'),
('x_quantize_rowwise', 'P', '--', 'C4', 'Quantize rowwise X (switchback)'),
('g_quantize_rowwise', 'P', '-.', 'C4', 'Quantize rowwise G (switchback)'),
('w_quantize_global', '.', '--', 'C4', 'Quantize global W (switchback)'),
('w_quantize_global_transpose', '.', '-.', 'C4', 'Quantize global and\ntranspose W (switchback)'),
("standard_gx+standard_gw+standard_fwd", "s", "-", "C2", "Standard fp16 (sum of parts)"),
(
"x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd",
"o",
"-",
"C4",
"SwitchBack int8 (sum of parts)",
),
("standard_fwd", "^", "--", "C2", "Matmul XW (standard)"),
("standard_gw", "^", "-.", "C2", "Matmul GW (standard)"),
("standard_gx", "^", ":", "gray", "Matmul GX (both)"),
("global_fwd", "^", "--", "C4", "Int8 Matmul XW (switchback)"),
("global_bwd", "^", "-.", "C4", "Int8 Matmul GW (switchback)"),
("x_quantize_rowwise", "P", "--", "C4", "Quantize rowwise X (switchback)"),
("g_quantize_rowwise", "P", "-.", "C4", "Quantize rowwise G (switchback)"),
("w_quantize_global", ".", "--", "C4", "Quantize global W (switchback)"),
("w_quantize_global_transpose", ".", "-.", "C4", "Quantize global and\ntranspose W (switchback)"),
]:
xs = []
ys = []
......@@ -47,40 +48,46 @@ if __name__ == '__main__':
df_ = df_[df_.dim_out == embed_dim * 4]
xs.append(embed_dim)
y_ = 0
for k_ in k.split('+'):
for k_ in k.split("+"):
y_ += df_[k_].values[0]
df_ = df[df.dim_in == embed_dim * 4]
df_ = df_[df_.dim_out == embed_dim]
for k_ in k.split('+'):
for k_ in k.split("+"):
y_ += df_[k_].values[0]
ys.append(y_ * 0.5)
ax.plot(
xs,
ys,
color=color,
label=name,
marker=marker,
markersize=5 if marker == "s" else 5,
linestyle=ls,
linewidth=2 if "+" in k else 1.0,
)
ax.plot(xs, ys, color=color, label=name, marker=marker, markersize=5 if marker=='s' else 5, linestyle=ls, linewidth=2 if '+' in k else 1.)
ax.set_xlabel('dim', fontsize=13)
ax.set_ylabel('time (ms)', fontsize=13)
ax.set_xlabel("dim", fontsize=13)
ax.set_ylabel("time (ms)", fontsize=13)
ax.grid()
ax.set_xscale('log')
ax.set_xscale("log")
if logscale_plot1:
ax.set_yscale('log')
ax.set_yscale("log")
ax.tick_params(axis='x', labelsize=11)
ax.tick_params(axis='y', labelsize=11)
ax.tick_params(axis="x", labelsize=11)
ax.tick_params(axis="y", labelsize=11)
ax.set_xticks(dims_to_xtick)
ax.set_xticklabels(dims_to_xtick)
ax.set_xticks([], minor=True)
leg = ax.legend(loc='upper center', bbox_to_anchor=(-0.64, 1.), ncol=1, fontsize=10)
leg.get_texts()[0].set_fontweight('bold')
leg.get_texts()[1].set_fontweight('bold')
leg = ax.legend(loc="upper center", bbox_to_anchor=(-0.64, 1.0), ncol=1, fontsize=10)
leg.get_texts()[0].set_fontweight("bold")
leg.get_texts()[1].set_fontweight("bold")
plt.subplots_adjust(left=0.1)
ax.set_title(' Linear layer, batch * sequence length = 32k', fontsize=10, loc='left', y=1.05, pad=-20)
ax.set_title(" Linear layer, batch * sequence length = 32k", fontsize=10, loc="left", y=1.05, pad=-20)
ax = fig.add_subplot(gs[0, 1])
......@@ -88,10 +95,15 @@ if __name__ == '__main__':
for j, batch_size in enumerate(batch_sizes_for_plot2):
all_xs, all_ys = [], []
for k, marker, ls, color, name in [
('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (total time)'),
('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (total time)'),
("standard_gx+standard_gw+standard_fwd", "s", "-", "C2", "Standard fp16 (total time)"),
(
"x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd",
"o",
"-",
"C4",
"SwitchBack int8 (total time)",
),
]:
xs, ys = [], []
df = rdf[rdf.batch_size == batch_size]
for embed_dim in dims_to_consider:
......@@ -99,11 +111,11 @@ if __name__ == '__main__':
df_ = df_[df_.dim_out == embed_dim * 4]
xs.append(embed_dim)
y_ = 0
for k_ in k.split('+'):
for k_ in k.split("+"):
y_ += df_[k_].values[0]
df_ = df[df.dim_in == embed_dim * 4]
df_ = df_[df_.dim_out == embed_dim]
for k_ in k.split('+'):
for k_ in k.split("+"):
y_ += df_[k_].values[0]
ys.append(y_ * 0.5)
all_xs.append(xs)
......@@ -111,25 +123,29 @@ if __name__ == '__main__':
color = cmap(j * 0.25)
real_ys = [-((all_ys[1][i] - all_ys[0][i]) / all_ys[0][i]) * 100 for i in range(len(all_ys[0]))]
markers = ['^', 'v', 'P', 'o']
ax.plot(all_xs[0], real_ys, color=color, label=f'batch * sequence length = {batch_size}', marker=markers[j], markersize=5 if marker=='s' else 5)
markers = ["^", "v", "P", "o"]
ax.plot(
all_xs[0],
real_ys,
color=color,
label=f"batch * sequence length = {batch_size}",
marker=markers[j],
markersize=5 if marker == "s" else 5,
)
ax.legend()
ax.set_xlabel('dim', fontsize=13)
ax.set_xscale('log')
ax.set_xlabel("dim", fontsize=13)
ax.set_xscale("log")
ax.grid()
ax.set_ylabel(r'% speedup', fontsize=13)
ax.set_ylabel(r"% speedup", fontsize=13)
ax.tick_params(axis='x', labelsize=11)
ax.tick_params(axis='y', labelsize=11)
ax.tick_params(axis="x", labelsize=11)
ax.tick_params(axis="y", labelsize=11)
ax.set_xticks(dims_to_xtick)
ax.set_xticklabels(dims_to_xtick)
ax.set_xticks([], minor=True)
ax.set_title(' Linear layer summary, varying dimensions', fontsize=10, loc='left', y=1.05, pad=-20)
ax.set_title(" Linear layer summary, varying dimensions", fontsize=10, loc="left", y=1.05, pad=-20)
plt.savefig('speed_benchmark/plot_with_info.pdf', bbox_inches='tight')
plt.savefig("speed_benchmark/plot_with_info.pdf", bbox_inches="tight")
......@@ -20,8 +20,8 @@ from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large.
def get_time(k, fn, info_dict):
def get_time(k, fn, info_dict):
for _ in range(repeat // 2):
fn()
......@@ -36,16 +36,15 @@ def get_time(k, fn, info_dict):
print(f"time {k}: {ms:.3f} ms")
info_dict[k] = ms
if __name__ == '__main__':
if __name__ == "__main__":
torch.manual_seed(0)
wm = 4
for dim in [1024, 1280, 1408, 1664, 2048, 4096]:
# note "batch_size" is actually "batch_size * embed_dim", which is why it's large
for batch_size in [256*32, 256*64, 256*128, 256*256, 256*512]:
for batch_size in [256 * 32, 256 * 64, 256 * 128, 256 * 256, 256 * 512]:
# switch switches dim_in and dim_out
for switch in [False, True]:
# hparams
repeat = 64
batch_size = batch_size
......@@ -73,35 +72,86 @@ if __name__ == '__main__':
state_w_rowwise = w.max(dim=1)[0]
state_w_global = w.max()
info = {'repeat' : repeat, 'batch_size' : batch_size, 'dim_out' : dim_out, 'dim_in' : dim_in, 'wm' : wm, 'switch' : switch}
get_time('standard_fwd', lambda : x.matmul(w.t()), info)
get_time('standard_gw', lambda : g.t().matmul(x), info)
get_time('standard_gx', lambda : g.matmul(w), info)
get_time('rowwise_fwd', lambda : int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise, None), info)
get_time('rowwise_bwd', lambda : int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise, None), info)
get_time('global_fwd', lambda : int8_matmul_mixed_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('global_bwd', lambda : int8_matmul_mixed_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('x_quantize_rowwise', lambda : quantize_rowwise(x), info)
get_time('g_quantize_rowwise', lambda : quantize_rowwise(g), info)
get_time('w_quantize_rowwise', lambda : quantize_rowwise(w), info)
get_time('w_quantize_colwise_transpose', lambda : quantize_columnwise_and_transpose(w), info)
get_time('w_quantize_global', lambda : quantize_global(w), info)
get_time('w_quantize_global_transpose', lambda : quantize_global_transpose(w), info)
time_standard = info['standard_fwd'] + info['standard_gx'] + info['standard_gw']
time_rowwise = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_colwise_transpose'] + info['w_quantize_rowwise'] + info['standard_gw'] + info['rowwise_fwd'] + info['rowwise_bwd']
time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd']
print('TOTAL STANDARD', time_standard)
print('TOTAL ROWWISE', time_rowwise)
print('TOTAL GLOBAL', time_global)
print('speedup', -100*(time_global - time_standard)/time_standard)
info['time_standard'] = time_standard
info['time_rowwise'] = time_rowwise
info['time_global'] = time_global
info = {
"repeat": repeat,
"batch_size": batch_size,
"dim_out": dim_out,
"dim_in": dim_in,
"wm": wm,
"switch": switch,
}
get_time("standard_fwd", lambda: x.matmul(w.t()), info)
get_time("standard_gw", lambda: g.t().matmul(x), info)
get_time("standard_gx", lambda: g.matmul(w), info)
get_time(
"rowwise_fwd",
lambda: int8_matmul_rowwise_dequantize(
x_int8,
w_int8.t(),
state_x_rowwise,
state_w_columnwise,
None,
),
info,
)
get_time(
"rowwise_bwd",
lambda: int8_matmul_rowwise_dequantize(
g_int8,
wt_int8.t(),
state_x_rowwise,
state_w_rowwise,
None,
),
info,
)
get_time(
"global_fwd",
lambda: int8_matmul_mixed_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None),
info,
)
get_time(
"global_bwd",
lambda: int8_matmul_mixed_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None),
info,
)
get_time("x_quantize_rowwise", lambda: quantize_rowwise(x), info)
get_time("g_quantize_rowwise", lambda: quantize_rowwise(g), info)
get_time("w_quantize_rowwise", lambda: quantize_rowwise(w), info)
get_time("w_quantize_colwise_transpose", lambda: quantize_columnwise_and_transpose(w), info)
get_time("w_quantize_global", lambda: quantize_global(w), info)
get_time("w_quantize_global_transpose", lambda: quantize_global_transpose(w), info)
time_standard = info["standard_fwd"] + info["standard_gx"] + info["standard_gw"]
time_rowwise = (
info["x_quantize_rowwise"]
+ info["g_quantize_rowwise"]
+ info["w_quantize_colwise_transpose"]
+ info["w_quantize_rowwise"]
+ info["standard_gw"]
+ info["rowwise_fwd"]
+ info["rowwise_bwd"]
)
time_global = (
info["x_quantize_rowwise"]
+ info["g_quantize_rowwise"]
+ info["w_quantize_global"]
+ info["w_quantize_global_transpose"]
+ info["standard_gw"]
+ info["global_fwd"]
+ info["global_bwd"]
)
print("TOTAL STANDARD", time_standard)
print("TOTAL ROWWISE", time_rowwise)
print("TOTAL GLOBAL", time_global)
print("speedup", -100 * (time_global - time_standard) / time_standard)
info["time_standard"] = time_standard
info["time_rowwise"] = time_rowwise
info["time_global"] = time_global
info_json = json.dumps(info)
......
......@@ -14,16 +14,18 @@ import bitsandbytes.functional as F
def prod(iterable):
return reduce(operator.mul, iterable, 1)
# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
"""
This class pools outlier dimensions across layers.
This is particularly important for small models where outlier features
are less systematic and occur with low frequency.
"""
class GlobalOutlierPooler:
_instance = None
......@@ -83,6 +85,7 @@ def get_inverse_transform_indices(
break # if all indices fit in i bytes, stop early
return permuted_tile_indices
def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor:
"""
Undo a tiled permutation such as turing or ampere layout
......@@ -159,20 +162,12 @@ class MatMul8bit(torch.autograd.Function):
)
if not A.is_contiguous():
A = A.contiguous()
qA, S2 = F.vectorwise_quant(
A.view(-1, A.shape[2]), dim=0, quant_type=quant_type
)
qA, S2 = F.vectorwise_quant(A.view(-1, A.shape[2]), dim=0, quant_type=quant_type)
igrad_B = F.igemm(qA.t(), qgrad_output)
grad_B = F.vectorwise_mm_dequant(
igrad_B, S2.t(), S1, grad_output.dtype, quant_type
)
grad_B = F.vectorwise_mm_dequant(igrad_B, S2.t(), S1, grad_output.dtype, quant_type)
else:
qgrad_output, S1 = F.vectorwise_quant(
grad_output, dim=dims, quant_type=quant_type
)
qA, S2 = F.vectorwise_quant(
A, dim=dims, quant_type=quant_type
)
qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type)
igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
grad_B = F.vectorwise_mm_dequant(
igrad_B,
......@@ -201,9 +196,7 @@ class MatMul8bit(torch.autograd.Function):
with torch.no_grad():
grad_A = torch.matmul(grad_output, B.permute(permute_dim))
else:
qgrad_output, S1 = F.vectorwise_quant(
grad_output, dim=dims, quant_type=quant_type
)
qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
grad_A = F.vectorwise_mm_dequant(
......@@ -227,7 +220,7 @@ def supports_igemmlt(device: torch.device) -> bool:
if torch.cuda.get_device_capability(device=device) < (7, 5):
return False
device_name = torch.cuda.get_device_name(device=device)
nvidia16_models = ('GTX 1630', 'GTX 1650', 'GTX 1660') # https://en.wikipedia.org/wiki/GeForce_16_series
nvidia16_models = ("GTX 1630", "GTX 1650", "GTX 1660") # https://en.wikipedia.org/wiki/GeForce_16_series
if any(model_name in device_name for model_name in nvidia16_models):
return False # these devices are technically cuda 7.5-capable, but they lack tensor cores
return True
......@@ -246,6 +239,7 @@ def get_tile_inds(format, device):
with torch.no_grad():
return get_inverse_transform_indices(transform, _get_tile_size(format)).to(device)
@dataclass
class MatmulLtState:
_tile_indices: Optional[torch.Tensor] = None
......@@ -510,7 +504,6 @@ class MatMul4Bit(torch.autograd.Function):
else:
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
# 1. Dequantize
# 2. MatmulnN
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
......@@ -532,7 +525,7 @@ class MatMul4Bit(torch.autograd.Function):
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, _, _, req_gradBias, _= ctx.needs_input_grad
req_gradA, _, _, req_gradBias, _ = ctx.needs_input_grad
A, B = ctx.tensors
grad_A, grad_B, grad_bias = None, None, None
......@@ -542,8 +535,9 @@ class MatMul4Bit(torch.autograd.Function):
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)
# not supported by PyTorch. TODO: create work-around
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t())
# if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
if req_gradA:
grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t())
return grad_A, grad_B, None, grad_bias, None
......@@ -554,7 +548,7 @@ def matmul(
out: Optional[torch.Tensor] = None,
state: Optional[MatmulLtState] = None,
threshold=0.0,
bias=None
bias=None,
):
state = state or MatmulLtState()
if threshold > 0.0:
......@@ -562,11 +556,19 @@ def matmul(
return MatMul8bitLt.apply(A, B, out, bias, state)
def matmul_4bit(A: torch.Tensor, B: torch.Tensor, quant_state: F.QuantState, out: Optional[torch.Tensor] = None, bias=None):
def matmul_4bit(
A: torch.Tensor,
B: torch.Tensor,
quant_state: F.QuantState,
out: Optional[torch.Tensor] = None,
bias=None,
):
assert quant_state is not None
if A.numel() == A.shape[-1] and A.requires_grad == False:
if A.shape[-1] % quant_state.blocksize != 0:
warn(f'Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}')
warn(
f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",
)
return MatMul4Bit.apply(A, B, out, bias, quant_state)
else:
out = F.gemv_4bit(A, B.t(), out, state=quant_state)
......
......@@ -56,7 +56,7 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
"This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n"
"If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n"
"If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n"
"For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<path_to_cuda_dir/lib64\n"
"For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<path_to_cuda_dir/lib64\n",
)
return PACKAGE_DIR / library_name
......@@ -100,7 +100,7 @@ def get_native_library() -> BNBNativeLibrary:
logger.warning(
"The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable."
"8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.",
)
return BNBNativeLibrary(dll)
......@@ -120,5 +120,5 @@ python -m bitsandbytes
Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them
to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes
and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues
"""
""",
)
......@@ -120,7 +120,7 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
The CUDA version for the compile might depend on your conda install, if using conda.
Inspect CUDA version via `conda list | grep cuda`.
"""
""",
)
cuda_major, cuda_minor = cuda_specs.cuda_version_tuple
......@@ -129,7 +129,7 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
"""
WARNING: CUDA versions lower than 11 are currently not supported for LLM.int8().
You will be only to use 8-bit optimizers and quantization routines!
"""
""",
)
print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}")
......@@ -170,7 +170,7 @@ def print_cuda_runtime_diagnostics() -> None:
In the case of a manual override, make sure you set LD_LIBRARY_PATH, e.g.
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2,
"""
""",
)
for pth in cudart_paths:
print(f"* Found CUDA runtime at: {pth}")
......@@ -25,7 +25,7 @@ def sanity_check():
See the documentation for more details if needed.
Trying a simple check anyway, but this will likely fail...
"""
""",
)
from bitsandbytes.optim import Adam
......@@ -71,7 +71,7 @@ def main():
print(
f"WARNING: {__package__} is currently running as CPU-only!\n"
"Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n"
f"If you think that this is so erroneously,\nplease report an issue!"
f"If you think that this is so erroneously,\nplease report an issue!",
)
except Exception:
traceback.print_exc()
......@@ -80,6 +80,6 @@ def main():
Above we output some debug information.
Please provide this info when creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose
WARNING: Please be sure to sanitize sensitive info from the output before posting it.
"""
""",
)
sys.exit(1)
This diff is collapsed.
......@@ -44,6 +44,7 @@ class StableEmbedding(torch.nn.Embedding):
reset_parameters(): Reset embedding parameters using Xavier uniform initialization.
forward(input: Tensor) -> Tensor: Forward pass through the stable embedding layer.
"""
def __init__(
self,
num_embeddings: int,
......@@ -89,9 +90,7 @@ class StableEmbedding(torch.nn.Embedding):
dtype,
)
self.norm = torch.nn.LayerNorm(embedding_dim, device=device)
GlobalOptimManager.get_instance().register_module_override(
self, "weight", {"optim_bits": 32}
)
GlobalOptimManager.get_instance().register_module_override(self, "weight", {"optim_bits": 32})
def reset_parameters(self) -> None:
torch.nn.init.xavier_uniform_(self.weight)
......@@ -130,6 +129,7 @@ class Embedding(torch.nn.Embedding):
"""
Embedding class to store and retrieve word embeddings from their indices.
"""
def __init__(
self,
num_embeddings: int,
......@@ -170,11 +170,9 @@ class Embedding(torch.nn.Embedding):
scale_grad_by_freq,
sparse,
_weight,
device=device
)
GlobalOptimManager.get_instance().register_module_override(
self, "weight", {"optim_bits": 32}
device=device,
)
GlobalOptimManager.get_instance().register_module_override(self, "weight", {"optim_bits": 32})
def reset_parameters(self) -> None:
torch.nn.init.xavier_uniform_(self.weight)
......@@ -214,10 +212,10 @@ class Params4bit(torch.nn.Parameter):
quant_state: Optional[QuantState] = None,
blocksize: int = 64,
compress_statistics: bool = True,
quant_type: str = 'fp4',
quant_type: str = "fp4",
quant_storage: torch.dtype = torch.uint8,
module: Optional["Linear4bit"] = None,
bnb_quantized: bool = False
bnb_quantized: bool = False,
) -> "Params4bit":
if data is None:
data = torch.empty(0)
......@@ -250,7 +248,7 @@ class Params4bit(torch.nn.Parameter):
self.bnb_quantized = state["bnb_quantized"]
self.module = state["module"]
def __deepcopy__(self,memo):
def __deepcopy__(self, memo):
new_instance = type(self).__new__(type(self))
state = self.__getstate__()
new_instance.__setstate__(state)
......@@ -265,7 +263,14 @@ class Params4bit(torch.nn.Parameter):
return new_instance
@classmethod
def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], requires_grad: bool = False, device='cuda', **kwargs) -> "Params4bit":
def from_prequantized(
cls,
data: torch.Tensor,
quantized_stats: Dict[str, Any],
requires_grad: bool = False,
device="cuda",
**kwargs,
) -> "Params4bit":
self = torch.Tensor._make_subclass(cls, data.to(device))
self.requires_grad = requires_grad
self.quant_state = QuantState.from_dict(qs_dict=quantized_stats, device=device)
......@@ -292,33 +297,39 @@ class Params4bit(torch.nn.Parameter):
return self
def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
return self.to(device='cuda' if device is None else device, non_blocking=non_blocking)
return self.to(device="cuda" if device is None else device, non_blocking=non_blocking)
@overload
def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ...,) -> T:
...
def to(
self: T,
device: Optional[Union[int, device]] = ...,
dtype: Optional[Union[dtype, str]] = ...,
non_blocking: bool = ...,
) -> T: ...
@overload
def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T:
...
def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: ...
@overload
def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T:
...
def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if (device is not None and device.type == "cuda" and not self.bnb_quantized):
if device is not None and device.type == "cuda" and not self.bnb_quantized:
return self._quantize(device)
else:
if self.quant_state is not None:
self.quant_state.to(device)
new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking),
requires_grad=self.requires_grad, quant_state=self.quant_state,
blocksize=self.blocksize, compress_statistics=self.compress_statistics,
quant_type=self.quant_type)
new_param = Params4bit(
super().to(device=device, dtype=dtype, non_blocking=non_blocking),
requires_grad=self.requires_grad,
quant_state=self.quant_state,
blocksize=self.blocksize,
compress_statistics=self.compress_statistics,
quant_type=self.quant_type,
)
return new_param
......@@ -355,7 +366,18 @@ class Linear4bit(nn.Linear):
quantized_model = quantized_model.to(0) # Quantization happens here
```
"""
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', quant_storage=torch.uint8, device=None):
def __init__(
self,
input_features,
output_features,
bias=True,
compute_dtype=None,
compress_statistics=True,
quant_type="fp4",
quant_storage=torch.uint8,
device=None,
):
"""
Initialize Linear4bit class.
......@@ -368,7 +390,14 @@ class Linear4bit(nn.Linear):
Whether the linear class uses the bias term as well.
"""
super().__init__(input_features, output_features, bias, device)
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type, quant_storage=quant_storage, module=self)
self.weight = Params4bit(
self.weight.data,
requires_grad=False,
compress_statistics=compress_statistics,
quant_type=quant_type,
quant_storage=quant_storage,
module=self,
)
# self.persistent_buffers = [] # TODO consider as way to save quant state
self.compute_dtype = compute_dtype
self.compute_type_is_set = False
......@@ -385,11 +414,15 @@ class Linear4bit(nn.Linear):
if self.compute_dtype == torch.float32 and (x.numel() == x.shape[-1]):
# single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast
# warn the user about this
warnings.warn('Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.')
warnings.filterwarnings('ignore', message='.*inference.')
warnings.warn(
"Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.",
)
warnings.filterwarnings("ignore", message=".*inference.")
if self.compute_dtype == torch.float32 and (x.numel() != x.shape[-1]):
warnings.warn('Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.')
warnings.filterwarnings('ignore', message='.*inference or training')
warnings.warn(
"Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.",
)
warnings.filterwarnings("ignore", message=".*inference or training")
def _save_to_state_dict(self, destination, prefix, keep_vars):
"""
......@@ -407,8 +440,8 @@ class Linear4bit(nn.Linear):
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
if getattr(self.weight, 'quant_state', None) is None:
if getattr(self, 'quant_state', None) is not None:
if getattr(self.weight, "quant_state", None) is None:
if getattr(self, "quant_state", None) is not None:
# the quant state got lost when the parameter got converted. This happens for example for fsdp
# since we registered the module, we can recover the state here
assert self.weight.shape[1] == 1
......@@ -416,7 +449,9 @@ class Linear4bit(nn.Linear):
self.weight = Params4bit(self.weight, quant_storage=self.quant_storage)
self.weight.quant_state = self.quant_state
else:
print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
print(
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.",
)
if not self.compute_type_is_set:
self.set_compute_type(x)
self.compute_type_is_set = True
......@@ -437,7 +472,17 @@ class LinearFP4(Linear4bit):
"""
Implements the FP4 data type.
"""
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None):
def __init__(
self,
input_features,
output_features,
bias=True,
compute_dtype=None,
compress_statistics=True,
quant_storage=torch.uint8,
device=None,
):
"""
Args:
input_features (`str`):
......@@ -447,11 +492,20 @@ class LinearFP4(Linear4bit):
bias (`bool`, defaults to `True`):
Whether the linear class uses the bias term as well.
"""
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', quant_storage, device)
super().__init__(
input_features,
output_features,
bias,
compute_dtype,
compress_statistics,
"fp4",
quant_storage,
device,
)
class LinearNF4(Linear4bit):
''' Implements the NF4 data type.
"""Implements the NF4 data type.
Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that
is normalized into the range [-1, 1].
......@@ -460,8 +514,18 @@ class LinearNF4(Linear4bit):
Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
'''
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None):
"""
def __init__(
self,
input_features,
output_features,
bias=True,
compute_dtype=None,
compress_statistics=True,
quant_storage=torch.uint8,
device=None,
):
"""
Args:
input_features (`str`):
......@@ -471,7 +535,16 @@ class LinearNF4(Linear4bit):
bias (`bool`, defaults to `True`):
Whether the linear class uses the bias term as well.
"""
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', quant_storage, device)
super().__init__(
input_features,
output_features,
bias,
compute_dtype,
compress_statistics,
"nf4",
quant_storage,
device,
)
class Int8Params(torch.nn.Parameter):
......@@ -514,33 +587,22 @@ class Int8Params(torch.nn.Parameter):
device: Optional[Union[int, device]] = ...,
dtype: Optional[Union[dtype, str]] = ...,
non_blocking: bool = ...,
) -> T:
...
) -> T: ...
@overload
def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T:
...
def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: ...
@overload
def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T:
...
def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
*args, **kwargs
)
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if (
device is not None
and device.type == "cuda"
and self.data.device.type == "cpu"
):
if device is not None and device.type == "cuda" and self.data.device.type == "cpu":
return self.cuda(device)
else:
new_param = Int8Params(
super().to(
device=device, dtype=dtype, non_blocking=non_blocking
),
super().to(device=device, dtype=dtype, non_blocking=non_blocking),
requires_grad=self.requires_grad,
has_fp16_weights=self.has_fp16_weights,
)
......@@ -593,8 +655,18 @@ class Linear8bitLt(nn.Linear):
int8_model = int8_model.to(0) # Quantization happens here
```
"""
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
memory_efficient_backward=False, threshold=0.0, index=None, device=None):
def __init__(
self,
input_features,
output_features,
bias=True,
has_fp16_weights=True,
memory_efficient_backward=False,
threshold=0.0,
index=None,
device=None,
):
"""
Initialize Linear8bitLt class.
......@@ -647,19 +719,36 @@ class Linear8bitLt(nn.Linear):
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
destination[format_name] = self.state.formatB
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs)
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
unexpected_copy = list(unexpected_keys)
for key in unexpected_copy:
input_name = key[len(prefix):]
input_name = key[len(prefix) :]
if input_name == "SCB":
if self.weight.SCB is None:
# buffers not yet initialized, can't access them directly without quantizing first
raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is "
"not supported. Please call module.cuda() before module.load_state_dict()")
raise RuntimeError(
"Loading a quantized checkpoint into non-quantized Linear8bitLt is "
"not supported. Please call module.cuda() before module.load_state_dict()",
)
input_param = state_dict[key]
self.weight.SCB.copy_(input_param)
......@@ -702,18 +791,18 @@ class OutlierAwareLinear(nn.Linear):
self.is_quantized = False
def forward_with_outliers(self, x, outlier_idx):
raise NotImplementedError('Please override the `forward_with_outliers(self, x, outlier_idx)` function')
raise NotImplementedError("Please override the `forward_with_outliers(self, x, outlier_idx)` function")
def quantize_weight(self, w, outlier_idx):
raise NotImplementedError('Please override the `quantize_weights(self, w, outlier_idx)` function')
raise NotImplementedError("Please override the `quantize_weights(self, w, outlier_idx)` function")
def forward(self, x):
if self.outlier_dim is None:
tracer = OutlierTracer.get_instance()
if not tracer.is_initialized():
print('Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer')
print("Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer")
outlier_idx = tracer.get_outliers(self.weight)
#print(outlier_idx, tracer.get_hvalue(self.weight))
# print(outlier_idx, tracer.get_hvalue(self.weight))
self.outlier_dim = outlier_idx
if not self.is_quantized:
......@@ -721,6 +810,7 @@ class OutlierAwareLinear(nn.Linear):
self.weight.data.copy_(w)
self.is_quantized = True
class SwitchBackLinearBnb(nn.Linear):
def __init__(
self,
......@@ -731,11 +821,9 @@ class SwitchBackLinearBnb(nn.Linear):
memory_efficient_backward=False,
threshold=0.0,
index=None,
device=None
device=None,
):
super().__init__(
input_features, output_features, bias, device
)
super().__init__(input_features, output_features, bias, device)
self.state = bnb.MatmulLtState()
self.index = index
......@@ -745,9 +833,7 @@ class SwitchBackLinearBnb(nn.Linear):
if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True
self.weight = Int8Params(
self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights
)
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
def init_8bit_state(self):
self.state.CB = self.weight.CB
......
......@@ -22,7 +22,6 @@ from bitsandbytes.triton.triton_utils import is_triton_available
class _switchback_global(torch.autograd.Function):
@staticmethod
def forward(ctx, X_3D, W, bias):
# reshape input to [N * L, D]
......@@ -37,9 +36,7 @@ class _switchback_global(torch.autograd.Function):
# matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized
return int8_matmul_mixed_dequantize(
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D.size()[:-1], -1)
return int8_matmul_mixed_dequantize(X_int8, W_int8.t(), state_X, state_W, bias).view(*X_3D.size()[:-1], -1)
@staticmethod
def backward(ctx, G_3D):
......@@ -56,7 +53,8 @@ class _switchback_global(torch.autograd.Function):
G_int8, state_G = quantize_rowwise(G)
W_int8, state_W = quantize_global_transpose(W)
grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D.size()[:-1], -1
*G_3D.size()[:-1],
-1,
)
if ctx.needs_input_grad[1]:
# backward pass uses standard weight grad
......@@ -66,8 +64,8 @@ class _switchback_global(torch.autograd.Function):
return grad_X, grad_W, grad_bias
class _switchback_vectorrize(torch.autograd.Function):
class _switchback_vectorrize(torch.autograd.Function):
@staticmethod
def forward(ctx, X_3D, W, bias):
# reshape input to [N * L, D]
......@@ -81,9 +79,7 @@ class _switchback_vectorrize(torch.autograd.Function):
# matmult, fused dequant and add bias
# call kernel which expects rowwise quantized X and W
return int8_matmul_rowwise_dequantize(
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D.size()[:-1], -1)
return int8_matmul_rowwise_dequantize(X_int8, W_int8.t(), state_X, state_W, bias).view(*X_3D.size()[:-1], -1)
@staticmethod
def backward(ctx, G_3D):
......@@ -99,7 +95,8 @@ class _switchback_vectorrize(torch.autograd.Function):
G_int8, state_G = quantize_rowwise(G)
W_int8, state_W = quantize_columnwise_and_transpose(W)
grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D.size()[:-1], -1
*G_3D.size()[:-1],
-1,
)
if ctx.needs_input_grad[1]:
# backward pass uses standard weight grad
......@@ -109,8 +106,8 @@ class _switchback_vectorrize(torch.autograd.Function):
return grad_X, grad_W, grad_bias
class _switchback_global_mem_efficient(torch.autograd.Function):
class _switchback_global_mem_efficient(torch.autograd.Function):
@staticmethod
def forward(ctx, X_3D, W, bias):
# reshape input to [N * L, D]
......@@ -127,9 +124,7 @@ class _switchback_global_mem_efficient(torch.autograd.Function):
# matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized
return int8_matmul_mixed_dequantize(
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D_sz[:-1], -1)
return int8_matmul_mixed_dequantize(X_int8, W_int8.t(), state_X, state_W, bias).view(*X_3D_sz[:-1], -1)
@staticmethod
def backward(ctx, G_3D):
......@@ -151,12 +146,11 @@ class _switchback_global_mem_efficient(torch.autograd.Function):
G_int8, state_G = quantize_rowwise(G)
del G
W_int8 = W_int8.t().contiguous()
grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D_sz[:-1], -1
)
grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(*G_3D_sz[:-1], -1)
return grad_X, grad_W, grad_bias
class SwitchBackLinear(nn.Linear):
def __init__(
self,
......@@ -166,20 +160,20 @@ class SwitchBackLinear(nn.Linear):
device=None,
dtype=None,
vector_wise_quantization: bool = False,
mem_efficient : bool = False,
mem_efficient: bool = False,
):
super().__init__(in_features, out_features, bias, device, dtype)
if not is_triton_available():
raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear.
Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''')
raise ImportError("""Could not import triton. Please install triton to use SwitchBackLinear.
Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower""")
# By default, we use the global quantization.
self.vector_wise_quantization = vector_wise_quantization
if self.vector_wise_quantization:
self._fn = _switchback_vectorrize
if mem_efficient:
print('mem efficient is not supported for vector-wise quantization.')
print("mem efficient is not supported for vector-wise quantization.")
exit(1)
else:
if mem_efficient:
......@@ -195,7 +189,7 @@ class SwitchBackLinear(nn.Linear):
# if hasattr(m, "prepare_for_eval"):
# m.prepare_for_eval()
# model.apply(cond_prepare)
print('=> preparing for eval.')
print("=> preparing for eval.")
if self.vector_wise_quantization:
W_int8, state_W = quantize_rowwise(self.weight)
else:
......@@ -219,18 +213,22 @@ class SwitchBackLinear(nn.Linear):
X_int8, state_X = quantize_rowwise(X)
if self.vector_wise_quantization:
return int8_matmul_rowwise_dequantize(
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
).view(*x.size()[:-1], -1)
return int8_matmul_rowwise_dequantize(X_int8, self.W_int8.t(), state_X, self.state_W, self.bias).view(
*x.size()[:-1],
-1,
)
else:
return int8_matmul_mixed_dequantize(
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
).view(*x.size()[:-1], -1)
return int8_matmul_mixed_dequantize(X_int8, self.W_int8.t(), state_X, self.state_W, self.bias).view(
*x.size()[:-1],
-1,
)
SwitchBackLinearGlobal = partial(SwitchBackLinear, vector_wise_quantization=False)
SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vector_wise_quantization=False, mem_efficient=True)
SwitchBackLinearVectorwise = partial(SwitchBackLinear, vector_wise_quantization=True)
# This is just the standard linear function.
class StandardLinearFunction(torch.autograd.Function):
@staticmethod
......@@ -260,7 +258,7 @@ class StandardLinearFunction(torch.autograd.Function):
return grad_input, grad_weight, grad_bias
class StandardLinear(nn.Linear):
class StandardLinear(nn.Linear):
def forward(self, x):
return StandardLinearFunction.apply(x, self.weight, self.bias)
......@@ -50,9 +50,7 @@ class Adagrad(Optimizer1State):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay:
raise ValueError(
f"Invalid weight_decay value: {weight_decay}"
)
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if initial_accumulator_value != 0.0:
......@@ -119,9 +117,7 @@ class Adagrad8bit(Optimizer1State):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay:
raise ValueError(
f"Invalid weight_decay value: {weight_decay}"
)
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if initial_accumulator_value != 0.0:
......@@ -189,9 +185,7 @@ class Adagrad32bit(Optimizer1State):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay:
raise ValueError(
f"Invalid weight_decay value: {weight_decay}"
)
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if initial_accumulator_value != 0.0:
......
......@@ -14,8 +14,21 @@ from bitsandbytes.optim.optimizer import Optimizer2State
class Adam(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
"""
Base Adam optimizer.
......@@ -45,11 +58,38 @@ class Adam(Optimizer2State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class Adam8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
"""
8-bit Adam optimizer.
......@@ -79,11 +119,38 @@ class Adam8bit(Optimizer2State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class Adam32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
"""
32-bit Adam optimizer.
......@@ -113,11 +180,38 @@ class Adam32bit(Optimizer2State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class PagedAdam(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
"""
Paged Adam optimizer.
......@@ -147,11 +241,38 @@ class PagedAdam(Optimizer2State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
class PagedAdam8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
"""
8-bit paged Adam optimizer.
......@@ -181,11 +302,38 @@ class PagedAdam8bit(Optimizer2State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
class PagedAdam32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
"""
Paged 32-bit Adam optimizer.
......@@ -215,7 +363,21 @@ class PagedAdam32bit(Optimizer2State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
class AnalysisAdam(torch.optim.Optimizer):
"""Adam that performs 8-bit vs 32-bit error analysis.
......@@ -293,9 +455,7 @@ class AnalysisAdam(torch.optim.Optimizer):
if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float()
if grad.is_sparse:
raise RuntimeError(
"Adam does not support sparse gradients, please consider SparseAdam instead"
)
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
amsgrad = group.get("amsgrad", False)
assert not amsgrad
......@@ -312,15 +472,9 @@ class AnalysisAdam(torch.optim.Optimizer):
state["exp_avg"] = torch.zeros_like(p_data_fp32)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
state["abserrors"] = torch.zeros(
(256, 256), device=p_data_fp32.device
)
state["relerrors"] = torch.zeros(
(256, 256), device=p_data_fp32.device
)
state["counts"] = torch.zeros(
(256, 256), device=p_data_fp32.device
)
state["abserrors"] = torch.zeros((256, 256), device=p_data_fp32.device)
state["relerrors"] = torch.zeros((256, 256), device=p_data_fp32.device)
state["counts"] = torch.zeros((256, 256), device=p_data_fp32.device)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32)
......@@ -328,25 +482,19 @@ class AnalysisAdam(torch.optim.Optimizer):
state["exp_avg"] = state["exp_avg"].to(p_data_fp32)
state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32)
if amsgrad:
state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to(
p_data_fp32
)
state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to(p_data_fp32)
state["step"] += 1
beta1, beta2 = group["betas"]
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
step_size = (
group["lr"] * math.sqrt(bias_correction2) / bias_correction1
)
step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
e = state["abserrors"]
rele = state["relerrors"]
counts = state["counts"]
if group["weight_decay"] != 0:
p_data_fp32.add_(
p_data_fp32, alpha=-group["weight_decay"] * group["lr"]
)
p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"])
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
if amsgrad:
......@@ -359,10 +507,7 @@ class AnalysisAdam(torch.optim.Optimizer):
denom = exp_avg_sq.sqrt().add_(group["eps"])
update_fp32 = exp_avg / denom
if (
p_data_fp32.numel() <= 8192
or p_data_fp32.numel() > 50000 * 1000
):
if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000 * 1000:
# embedding layer or too small
p_data_fp32 += -step_size * update_fp32
else:
......@@ -401,9 +546,7 @@ class AnalysisAdam(torch.optim.Optimizer):
# 3. dequantize
# Error will be calculated automatically!
else:
raise ValueError(
f"Invalid analysis value: {self.analysis}!"
)
raise ValueError(f"Invalid analysis value: {self.analysis}!")
denom = state2.sqrt().add_(group["eps"])
update_8bit = state1 / denom
......@@ -415,9 +558,7 @@ class AnalysisAdam(torch.optim.Optimizer):
F.histogram_scatter_add_2d(e, C1.int(), C2.int(), abserr)
F.histogram_scatter_add_2d(rele, C1.int(), C2.int(), relerr)
F.histogram_scatter_add_2d(
counts, C1.int(), C2.int(), torch.ones_like(abserr)
)
F.histogram_scatter_add_2d(counts, C1.int(), C2.int(), torch.ones_like(abserr))
p_data_fp32 += -step_size * update_fp32
......@@ -425,18 +566,10 @@ class AnalysisAdam(torch.optim.Optimizer):
if self.savedir != "" and state["step"] % 100 == 0:
if not os.path.exists(self.savedir):
os.makedirs(self.savedir)
shapestr = "_".join(
[str(dim) for dim in p_data_fp32.shape]
)
pathe = os.path.join(
self.savedir, f"{p_id}_{shapestr}_abserr.pkl"
)
pathrele = os.path.join(
self.savedir, f"{p_id}_{shapestr}_relerr.pkl"
)
pathcounts = os.path.join(
self.savedir, f"{p_id}_{shapestr}_counts.pkl"
)
shapestr = "_".join([str(dim) for dim in p_data_fp32.shape])
pathe = os.path.join(self.savedir, f"{p_id}_{shapestr}_abserr.pkl")
pathrele = os.path.join(self.savedir, f"{p_id}_{shapestr}_relerr.pkl")
pathcounts = os.path.join(self.savedir, f"{p_id}_{shapestr}_counts.pkl")
torch.save(e, pathe)
torch.save(rele, pathrele)
torch.save(counts, pathcounts)
......
......@@ -6,8 +6,21 @@ from bitsandbytes.optim.optimizer import Optimizer2State
class AdamW(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
"""
Base AdamW optimizer.
......@@ -37,11 +50,38 @@ class AdamW(Optimizer2State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged )
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class AdamW8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
"""
8-bit AdamW optimizer.
......@@ -71,11 +111,38 @@ class AdamW8bit(Optimizer2State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged )
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class AdamW32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
"""
32-bit AdamW optimizer.
......@@ -105,12 +172,37 @@ class AdamW32bit(Optimizer2State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class PagedAdamW(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
"""
Paged AdamW optimizer.
......@@ -140,11 +232,37 @@ class PagedAdamW(Optimizer2State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
class PagedAdamW8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
"""
Paged 8-bit AdamW optimizer.
......@@ -174,11 +292,37 @@ class PagedAdamW8bit(Optimizer2State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
class PagedAdamW32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
"""
Paged 32-bit AdamW optimizer.
......@@ -208,4 +352,17 @@ class PagedAdamW32bit(Optimizer2State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
......@@ -51,9 +51,7 @@ class LARS(Optimizer1State):
The maximum gradient norm.
"""
if momentum == 0:
raise NotImplementedError(
"LARS without momentum is not supported!"
)
raise NotImplementedError("LARS without momentum is not supported!")
super().__init__(
"lars",
params,
......@@ -110,9 +108,7 @@ class LARS8bit(Optimizer1State):
The maximum gradient norm.
"""
if momentum == 0:
raise NotImplementedError(
"LARS without momentum is not supported!"
)
raise NotImplementedError("LARS without momentum is not supported!")
super().__init__(
"lars",
params,
......@@ -169,9 +165,7 @@ class LARS32bit(Optimizer1State):
The maximum gradient norm.
"""
if momentum == 0:
raise NotImplementedError(
"LARS without momentum is not supported!"
)
raise NotImplementedError("LARS without momentum is not supported!")
super().__init__(
"lars",
params,
......@@ -204,9 +198,7 @@ class PytorchLARS(Optimizer):
if momentum < 0.0:
raise ValueError(f"Invalid momentum value: {momentum}")
if weight_decay < 0.0:
raise ValueError(
f"Invalid weight_decay value: {weight_decay}"
)
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(
lr=lr,
......@@ -217,9 +209,7 @@ class PytorchLARS(Optimizer):
max_unorm=max_unorm,
)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError(
"Nesterov momentum requires a momentum and zero dampening"
)
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super().__init__(params, defaults)
def __setstate__(self, state):
......
......@@ -6,7 +6,19 @@ from bitsandbytes.optim.optimizer import Optimizer1State
class Lion(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
"""
Base Lion optimizer.
......@@ -32,10 +44,35 @@ class Lion(Optimizer1State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
super().__init__(
"lion",
params,
lr,
betas,
0.0,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class Lion8bit(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
"""
8-bit Lion optimizer.
......@@ -59,10 +96,35 @@ class Lion8bit(Optimizer1State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
super().__init__(
"lion",
params,
lr,
betas,
0.0,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class Lion32bit(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
"""
32-bit Lion optimizer.
......@@ -86,11 +148,35 @@ class Lion32bit(Optimizer1State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
super().__init__(
"lion",
params,
lr,
betas,
0.0,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
class PagedLion(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
"""
Paged Lion optimizer.
......@@ -114,10 +200,34 @@ class PagedLion(Optimizer1State):
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
super().__init__(
"lion",
params,
lr,
betas,
0.0,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
class PagedLion8bit(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
"""
Paged 8-bit Lion optimizer.
......@@ -141,10 +251,34 @@ class PagedLion8bit(Optimizer1State):
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
super().__init__(
"lion",
params,
lr,
betas,
0.0,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
class PagedLion32bit(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
"""
Paged 32-bit Lion optimizer.
......@@ -168,4 +302,17 @@ class PagedLion32bit(Optimizer1State):
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
super().__init__(
"lion",
params,
lr,
betas,
0.0,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=True,
)
......@@ -21,6 +21,7 @@ class GlobalOptimManager:
"""
A global optimizer manager for enabling custom optimizer configs.
"""
_instance = None
def __init__(self):
......@@ -48,13 +49,9 @@ class GlobalOptimManager:
for group_index, group in enumerate(param_groups):
for p_index, p in enumerate(group["params"]):
if id(p) in self.pid2config:
self.index2config[(group_index, p_index)] = self.pid2config[
id(p)
]
self.index2config[(group_index, p_index)] = self.pid2config[id(p)]
def override_config(
self, parameters, key=None, value=None, key_value_dict=None
):
def override_config(self, parameters, key=None, value=None, key_value_dict=None):
"""
Override initial optimizer config with specific hyperparameters.
......@@ -170,16 +167,12 @@ class Optimizer8bit(torch.optim.Optimizer):
saved_groups = state_dict["param_groups"]
if len(groups) != len(saved_groups):
raise ValueError(
"loaded state dict has a different number of "
"parameter groups"
)
raise ValueError("loaded state dict has a different number of parameter groups")
param_lens = (len(g["params"]) for g in groups)
saved_lens = (len(g["params"]) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError(
"loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group"
"loaded state dict contains a parameter group that doesn't match the size of optimizer's group",
)
# Update the state
......@@ -228,9 +221,7 @@ class Optimizer8bit(torch.optim.Optimizer):
new_group["params"] = group["params"]
return new_group
param_groups = [
update_group(g, ng) for g, ng in zip(groups, saved_groups)
]
param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({"state": state, "param_groups": param_groups})
def to_gpu(self):
......@@ -240,7 +231,7 @@ class Optimizer8bit(torch.optim.Optimizer):
values = self.state[p]
for k, v in values.items():
if isinstance(v, torch.Tensor):
is_paged = getattr(v, 'is_paged', False)
is_paged = getattr(v, "is_paged", False)
if not is_paged:
self.state[p][k] = v.to(p.device)
......@@ -248,9 +239,7 @@ class Optimizer8bit(torch.optim.Optimizer):
for module, attr, config in self.mng.module_weight_config_triple:
pmodule = getattr(module, attr)
assert pmodule is not None
assert isinstance(pmodule, torch.Tensor) or isinstance(
pmodule, torch.Parameter
)
assert isinstance(pmodule, torch.Tensor) or isinstance(pmodule, torch.Parameter)
found = False
for gindex, group in enumerate(self.param_groups):
if found:
......@@ -262,9 +251,7 @@ class Optimizer8bit(torch.optim.Optimizer):
# found the matching parameter
# init override
self.mng.pid2config[id(p)] = config
self.mng.index2config[
(gindex, pindex)
] = self.mng.pid2config[id(p)]
self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[id(p)]
found = True
@torch.no_grad()
......@@ -287,7 +274,7 @@ class Optimizer8bit(torch.optim.Optimizer):
self.to_gpu() # needed for fairseq pure fp16 training
self.initialized = True
#if self.is_paged: self.page_mng.prefetch_all()
# if self.is_paged: self.page_mng.prefetch_all()
for gindex, group in enumerate(self.param_groups):
for pindex, p in enumerate(group["params"]):
if p.grad is None:
......@@ -304,7 +291,6 @@ class Optimizer8bit(torch.optim.Optimizer):
# to sync to make sure all tensors are in the right state
torch.cuda.synchronize()
return loss
def get_config(self, gindex, pindex, group):
......@@ -328,9 +314,7 @@ class Optimizer8bit(torch.optim.Optimizer):
raise NotImplementedError("init_state method needs to be overridden")
def update_step(self, group, p, gindex, pindex):
raise NotImplementedError(
"The update_step method needs to be overridden"
)
raise NotImplementedError("The update_step method needs to be overridden")
def get_state_buffer(self, p, dtype=torch.float32):
if not self.is_paged or p.numel() < 1e5:
......@@ -345,12 +329,12 @@ class Optimizer8bit(torch.optim.Optimizer):
def prefetch_state(self, p):
if self.is_paged:
state = self.state[p]
s1 = state['state1']
is_paged = getattr(s1, 'is_paged', False)
s1 = state["state1"]
is_paged = getattr(s1, "is_paged", False)
if is_paged:
F.prefetch_tensor(state['state1'])
if 'state2' in state:
F.prefetch_tensor(state['state2'])
F.prefetch_tensor(state["state1"])
if "state2" in state:
F.prefetch_tensor(state["state2"])
class Optimizer2State(Optimizer8bit):
......@@ -369,7 +353,7 @@ class Optimizer2State(Optimizer8bit):
block_wise=True,
max_unorm=0.0,
skip_zeros=False,
is_paged=False
is_paged=False,
):
"""
Base 2-state update optimizer class.
......@@ -414,13 +398,9 @@ class Optimizer2State(Optimizer8bit):
betas = [float(b) for b in betas]
for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0:
raise ValueError(
f"Invalid beta parameter at index {i}: {betas[i]}"
)
raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
if not 0.0 <= weight_decay:
raise ValueError(
f"Invalid weight_decay value: {weight_decay}"
)
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults, optim_bits, is_paged)
......@@ -449,9 +429,7 @@ class Optimizer2State(Optimizer8bit):
elif config["optim_bits"] == 8:
dtype = torch.uint8
else:
raise NotImplementedError(
f'Amount of optimizer bits not supported: {config["optim_bits"]}'
)
raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')
if p.numel() < config["min_8bit_size"]:
dtype = torch.float32
......@@ -459,21 +437,15 @@ class Optimizer2State(Optimizer8bit):
state = self.state[p]
state["step"] = 0
if dtype == torch.float32 or (
dtype == torch.uint8 and p.numel() < 4096
):
if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
state["state1"] = self.get_state_buffer(p, dtype=torch.float32)
state["state2"] = self.get_state_buffer(p, dtype=torch.float32)
elif dtype == torch.uint8:
if state["step"] == 0:
if "dynamic" not in self.name2qmap:
self.fill_qmap()
self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(
p.device
)
self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(
p.device
)
self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device)
self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(p.device)
state["state1"] = self.get_state_buffer(p, dtype=torch.uint8)
state["qmap1"] = self.name2qmap["dynamic"]
......@@ -486,25 +458,13 @@ class Optimizer2State(Optimizer8bit):
blocks = n // 2048
blocks += 1 if n % 2048 > 0 else 0
state["absmax1"] = torch.zeros(
(blocks,), dtype=torch.float32, device=p.device
)
state["absmax2"] = torch.zeros(
(blocks,), dtype=torch.float32, device=p.device
)
state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
else:
state["max1"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
state["new_max1"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
state["max2"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
state["new_max2"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
state["new_max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
state["max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
state["new_max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
if config["percentile_clipping"] < 100:
state["gnorm_vec"] = torch.zeros((100,), device=p.device)
......@@ -524,7 +484,10 @@ class Optimizer2State(Optimizer8bit):
if config["percentile_clipping"] < 100:
current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(
grad, state["gnorm_vec"], step, config["percentile_clipping"]
grad,
state["gnorm_vec"],
step,
config["percentile_clipping"],
)
else:
gnorm_scale = 1.0
......@@ -568,9 +531,7 @@ class Optimizer2State(Optimizer8bit):
state["new_max2"],
config["weight_decay"],
gnorm_scale=gnorm_scale,
unorm_vec=state["unorm_vec"]
if config["max_unorm"] > 0.0
else None,
unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
max_unorm=config["max_unorm"],
)
......@@ -615,7 +576,7 @@ class Optimizer1State(Optimizer8bit):
block_wise=True,
max_unorm=0.0,
skip_zeros=False,
is_paged=False
is_paged=False,
):
"""
Base 1-state update optimizer class.
......@@ -656,13 +617,9 @@ class Optimizer1State(Optimizer8bit):
raise ValueError(f"Invalid epsilon value: {eps}")
for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0:
raise ValueError(
f"Invalid beta parameter at index {i}: {betas[i]}"
)
raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
if not 0.0 <= weight_decay:
raise ValueError(
f"Invalid weight_decay value: {weight_decay}"
)
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults, optim_bits, is_paged)
......@@ -691,9 +648,7 @@ class Optimizer1State(Optimizer8bit):
elif config["optim_bits"] == 8:
dtype = torch.uint8
else:
raise NotImplementedError(
f'Amount of optimizer bits not supported: {config["optim_bits"]}'
)
raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')
if p.numel() < config["min_8bit_size"]:
dtype = torch.float32
......@@ -701,17 +656,13 @@ class Optimizer1State(Optimizer8bit):
state = self.state[p]
state["step"] = 0
if dtype == torch.float32 or (
dtype == torch.uint8 and p.numel() < 4096
):
if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
state["state1"] = self.get_state_buffer(p, dtype=torch.float32)
elif dtype == torch.uint8:
if state["step"] == 0:
if "dynamic" not in self.name2qmap:
self.fill_qmap()
self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(
p.device
)
self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device)
state["state1"] = self.get_state_buffer(p, dtype=torch.uint8)
state["qmap1"] = self.name2qmap["dynamic"]
......@@ -721,16 +672,10 @@ class Optimizer1State(Optimizer8bit):
blocks = n // 2048
blocks += 1 if n % 2048 > 0 else 0
state["absmax1"] = torch.zeros(
(blocks,), dtype=torch.float32, device=p.device
)
state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
else:
state["max1"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
state["new_max1"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
state["new_max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
if config["percentile_clipping"] < 100:
state["gnorm_vec"] = torch.zeros((100,), device=p.device)
......@@ -750,7 +695,10 @@ class Optimizer1State(Optimizer8bit):
if config["percentile_clipping"] < 100:
current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(
grad, state["gnorm_vec"], step, config["percentile_clipping"]
grad,
state["gnorm_vec"],
step,
config["percentile_clipping"],
)
else:
gnorm_scale = 1.0
......@@ -766,7 +714,7 @@ class Optimizer1State(Optimizer8bit):
step,
config["lr"],
None,
config['betas'][1],
config["betas"][1],
config["weight_decay"],
gnorm_scale,
state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
......
......@@ -51,9 +51,7 @@ class RMSprop(Optimizer1State):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
if alpha == 0:
raise NotImplementedError(
"RMSprop with alpha==0.0 is not supported!"
)
raise NotImplementedError("RMSprop with alpha==0.0 is not supported!")
if centered:
raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__(
......@@ -116,9 +114,7 @@ class RMSprop8bit(Optimizer1State):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
if alpha == 0:
raise NotImplementedError(
"RMSprop with alpha==0.0 is not supported!"
)
raise NotImplementedError("RMSprop with alpha==0.0 is not supported!")
if centered:
raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__(
......@@ -182,9 +178,7 @@ class RMSprop32bit(Optimizer1State):
"""
if alpha == 0:
raise NotImplementedError(
"RMSprop with alpha==0.0 is not supported!"
)
raise NotImplementedError("RMSprop with alpha==0.0 is not supported!")
if centered:
raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__(
......
......@@ -195,9 +195,9 @@ class SwitchBackBnb(torch.autograd.Function):
ctx.B = B
ctx.bias = bias
if A.shape[-1] == B.shape[0]:
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device)
return torch.empty(A.shape[:-1] + B.shape[1:], dtype=A.dtype, device=A.device)
else:
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device)
return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=A.device)
# 1. Quantize A
# 2. Quantize B
......@@ -216,9 +216,7 @@ class SwitchBackBnb(torch.autograd.Function):
# 1. Quantize A
if len(A.shape) == 3:
A = A.view(-1, A.shape[-1]).contiguous()
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(
A.to(torch.float16), threshold=state.threshold
)
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold)
if state.threshold > 0.0 and coo_tensorA is not None:
if state.has_fp16_weights:
......@@ -234,14 +232,14 @@ class SwitchBackBnb(torch.autograd.Function):
# we also need to convert it to the turing/ampere format
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
else:
#print('A shape', A.shape)
# print('A shape', A.shape)
if not state.has_fp16_weights and state.CxB is None:
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
subA = None
# 2. Quantize B
if state.has_fp16_weights:
#print('B shape', B.shape)
# print('B shape', B.shape)
has_grad = True if (getattr(B, "grad", None) is not None) else False
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
if is_transposed:
......@@ -272,12 +270,7 @@ class SwitchBackBnb(torch.autograd.Function):
# else:
# state.idx = outlier_idx
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
state.subB = (
(outliers * state.SCB.view(-1, 1) / 127.0)
.t()
.contiguous()
.to(A.dtype)
)
state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype)
CA[:, state.idx.long()] = 0
CAt[:, state.idx.long()] = 0
subA = A[:, state.idx.long()]
......@@ -320,14 +313,13 @@ class SwitchBackBnb(torch.autograd.Function):
ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None)
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
clone_func = torch.clone if len(output_shape) == 3 else lambda x: x
return clone_func(output.view(output_shape))
@staticmethod
def backward(ctx, grad_output):
if ctx.is_empty:
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
CAt, subA, A = ctx.tensors
......@@ -342,9 +334,7 @@ class SwitchBackBnb(torch.autograd.Function):
# Cast grad_output to fp16
if len(grad_output.shape) == 3:
grad_output = grad_output.reshape(
-1, grad_output.shape[-1]
).contiguous()
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
......@@ -357,25 +347,24 @@ class SwitchBackBnb(torch.autograd.Function):
if state.CBt is not None:
C32grad, Sgrad = F.transform(Cgrad, "col32")
if state.CxBt is None:
state.CxBt, state.SBt = F.transform(
state.CBt, to_order=formatB, transpose=True
)
state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
# print('back B shape', state.CxBt.shape)
# print('back grad shape', C32grad.shape)
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
elif state.CB is not None:
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0))
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
else:
raise Exception('State must contain either CBt or CB matrix for backward')
raise Exception("State must contain either CBt or CB matrix for backward")
return grad_A, grad_B, None, grad_bias, None
def get_block_sizes(input_matrix, weight_matrix):
input_features = input_matrix.shape[-1]
output_features = (weight_matrix.shape[0] if weight_matrix.shape[1] == input_features else weight_matrix.shape[1])
output_features = weight_matrix.shape[0] if weight_matrix.shape[1] == input_features else weight_matrix.shape[1]
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
bsz, bsz2 = 1024, 1024
for i, k in enumerate(array):
......@@ -399,7 +388,8 @@ def matmul_fp8_global(
bsz: int = -1,
bsz2: int = -1,
):
if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B)
if bsz == -1 or bsz2 == -1:
bsz, bsz2 = get_block_sizes(A, B)
return MatMulFP8Global.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
......@@ -412,7 +402,8 @@ def matmul_fp8_mixed(
bsz: int = -1,
bsz2: int = -1,
):
if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B)
if bsz == -1 or bsz2 == -1:
bsz, bsz2 = get_block_sizes(A, B)
return MatMulFP8Mixed.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
......@@ -422,7 +413,7 @@ def switchback_bnb(
out: Optional[torch.Tensor] = None,
state: Optional[MatmulLtState] = None,
threshold=0.0,
bias=None
bias=None,
):
state = state or MatmulLtState()
if threshold > 0.0:
......
......@@ -28,12 +28,20 @@ class LinearFP8Mixed(nn.Linear):
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
out = bnb.research.matmul_fp8_mixed(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2)
out = bnb.research.matmul_fp8_mixed(
x,
self.weight.t(),
fw_code=self.fw_code,
bw_code=self.bw_code,
bsz=self.bsz,
bsz2=self.bsz2,
)
if self.bias is not None:
out += self.bias
return out
class LinearFP8Global(nn.Linear):
def __init__(self, input_features, output_features, bias=True):
super().__init__(input_features, output_features, bias)
......@@ -54,7 +62,14 @@ class LinearFP8Global(nn.Linear):
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
out = bnb.matmul_fp8_global(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2)
out = bnb.matmul_fp8_global(
x,
self.weight.t(),
fw_code=self.fw_code,
bw_code=self.bw_code,
bsz=self.bsz,
bsz2=self.bsz2,
)
if self.bias is not None:
out += self.bias
......
......@@ -5,9 +5,10 @@ import torch
from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available():
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): return None
else:
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
return None
else:
import triton
import triton.language as tl
......@@ -29,7 +30,7 @@ else:
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=['n_elements']
key=["n_elements"],
)
@triton.jit
def _dequantize_rowwise(
......@@ -51,7 +52,6 @@ else:
output = max_val * x * inv_127
tl.store(output_ptr + offsets, output, mask=row_mask)
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
output = torch.empty(*x.shape, device=x.device, dtype=torch.float16)
......@@ -60,5 +60,5 @@ else:
assert x.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (x.shape[0],)
_dequantize_rowwise[grid](x, state_x, output, 1./127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
_dequantize_rowwise[grid](x, state_x, output, 1.0 / 127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment