Commit a24aae30 authored by Jeongseok Kang's avatar Jeongseok Kang
Browse files

Merge branch 'main' into fix/libcuda-to-torch

parents 2b4cc256 4395d68c
...@@ -221,3 +221,29 @@ Improvements: ...@@ -221,3 +221,29 @@ Improvements:
Deprecated: Deprecated:
- Devices with compute capability 3.0 (GTX 700s, K10) and 3.2 (Tegra K1, Jetson TK1) are now deprecated and support will be removed in 0.39.0. - Devices with compute capability 3.0 (GTX 700s, K10) and 3.2 (Tegra K1, Jetson TK1) are now deprecated and support will be removed in 0.39.0.
- Support for CUDA 10.0 and 10.2 will be removed in bitsandbytes 0.39.0 - Support for CUDA 10.0 and 10.2 will be removed in bitsandbytes 0.39.0
### 0.38.1
Features:
- Added Int8 SwitchBack layers
- Added Fake FP8 layers for research purposes (available under `bnb.research.nn. ...`)
### 0.39.0
Features:
- 4-bit matrix multiplication for Float4 and NormalFloat4 data types.
- Added 4-bit quantization routines
- Doubled quantization routines for 4-bit quantization
- Paged optimizers for Adam and Lion.
- bfloat16 gradient / weight support for Adam and Lion with 8 or 32-bit states.
Bug fixes:
- Fixed a bug where 8-bit models consumed twice the memory as expected after serialization
Deprecated:
- Kepler binaries (GTX 700s and Tesla K40/K80) are not longer provided via pip and need to be compiled from source. Kepler support might be fully removed in the future.
...@@ -2,6 +2,7 @@ MKFILE_PATH := $(abspath $(lastword $(MAKEFILE_LIST))) ...@@ -2,6 +2,7 @@ MKFILE_PATH := $(abspath $(lastword $(MAKEFILE_LIST)))
ROOT_DIR := $(patsubst %/,%,$(dir $(MKFILE_PATH))) ROOT_DIR := $(patsubst %/,%,$(dir $(MKFILE_PATH)))
GPP:= /usr/bin/g++ GPP:= /usr/bin/g++
#GPP:= /sw/gcc/11.2.0/bin/g++
ifeq ($(CUDA_HOME),) ifeq ($(CUDA_HOME),)
CUDA_HOME:= $(shell which nvcc | rev | cut -d'/' -f3- | rev) CUDA_HOME:= $(shell which nvcc | rev | cut -d'/' -f3- | rev)
endif endif
...@@ -12,6 +13,7 @@ CUDA_VERSION:= ...@@ -12,6 +13,7 @@ CUDA_VERSION:=
endif endif
NVCC := $(CUDA_HOME)/bin/nvcc NVCC := $(CUDA_HOME)/bin/nvcc
########################################### ###########################################
...@@ -23,8 +25,7 @@ FILES_CUDA := $(CSRC)/ops.cu $(CSRC)/kernels.cu ...@@ -23,8 +25,7 @@ FILES_CUDA := $(CSRC)/ops.cu $(CSRC)/kernels.cu
FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c
INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/include INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/include
INCLUDE_10x := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/dependencies/cub -I $(ROOT_DIR)/include LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcusparse -L $(CONDA_PREFIX)/lib
LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcurand -lcusparse -L $(CONDA_PREFIX)/lib
# NVIDIA NVCC compilation flags # NVIDIA NVCC compilation flags
COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell
...@@ -32,17 +33,11 @@ COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell ...@@ -32,17 +33,11 @@ COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell
COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal
COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal
COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta
COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
CC_KEPLER := -gencode arch=compute_35,code=sm_35 # Kepler CC_KEPLER := -gencode arch=compute_35,code=sm_35 # Kepler
CC_KEPLER += -gencode arch=compute_37,code=sm_37 # Kepler CC_KEPLER += -gencode arch=compute_37,code=sm_37 # Kepler
# Later versions of CUDA support the new architectures # Later versions of CUDA support the new architectures
CC_CUDA10x += -gencode arch=compute_75,code=sm_75
CC_CUDA110 := -gencode arch=compute_75,code=sm_75
CC_CUDA110 += -gencode arch=compute_80,code=sm_80
CC_CUDA11x := -gencode arch=compute_75,code=sm_75 CC_CUDA11x := -gencode arch=compute_75,code=sm_75
CC_CUDA11x += -gencode arch=compute_80,code=sm_80 CC_CUDA11x += -gencode arch=compute_80,code=sm_80
CC_CUDA11x += -gencode arch=compute_86,code=sm_86 CC_CUDA11x += -gencode arch=compute_86,code=sm_86
...@@ -59,29 +54,30 @@ CC_ADA_HOPPER := -gencode arch=compute_89,code=sm_89 ...@@ -59,29 +54,30 @@ CC_ADA_HOPPER := -gencode arch=compute_89,code=sm_89
CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90 CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90
all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env all: $(BUILD_DIR) env
$(NVCC) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)
cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env cuda110_nomatmul_kepler: $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB)
cuda10x_nomatmul: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env cuda11x_nomatmul_kepler: $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE_10x) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB)
cuda110_nomatmul: $(BUILD_DIR) env cuda110_nomatmul: $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB)
cuda11x_nomatmul: $(BUILD_DIR) env cuda11x_nomatmul: $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB)
cuda12x_nomatmul: $(BUILD_DIR) env cuda12x_nomatmul: $(BUILD_DIR) env
......
Steps:
1. Run `python speed_benchmark/speed_benchmark.py` which times operations and writes their time to `speed_benchmark/info_a100_py2.jsonl` (change the name of the jsonl to a different name for your profiling).
2. Run `python speed_benchmark/make_plot_with_jsonl.py`, which produces the `speed_benchmark/plot_with_info.pdf`. Again make sure you change the jsonl which is being processed.
\ No newline at end of file
This diff is collapsed.
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
import matplotlib.gridspec as gridspec
cmap=plt.get_cmap('cool')
if __name__ == '__main__':
fig = plt.figure(tight_layout=True, figsize=(12,3.5))
gs = gridspec.GridSpec(1, 2)
dims_to_consider = [1024, 1280, 1408, 1664, 2048, 4096]
batch_size_for_plot1 = 32768
batch_sizes_for_plot2 = [2**14, 2**15, 2**16, 2**17]
dims_to_xtick = [1024, 2048, 4096]
logscale_plot1 = True
ax = fig.add_subplot(gs[0, 0])
# TODO: change this to what you want.
rdf = pd.read_json('speed_benchmark/info_a100_py2.jsonl', lines=True)
df = rdf[rdf.batch_size == batch_size_for_plot1]
# first plot the time occupied by different operations
for k, marker, ls, color, name in [
('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (sum of parts)'),
('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (sum of parts)'),
('standard_fwd', '^', '--', 'C2', 'Matmul XW (standard)'),
('standard_gw', '^', '-.', 'C2', 'Matmul GW (standard)'),
('standard_gx', '^', ':', 'gray', 'Matmul GX (both)'),
('global_fwd', '^', '--', 'C4', 'Int8 Matmul XW (switchback)'),
('global_bwd', '^', '-.', 'C4', 'Int8 Matmul GW (switchback)'),
('x_quantize_rowwise', 'P', '--', 'C4', 'Quantize rowwise X (switchback)'),
('g_quantize_rowwise', 'P', '-.', 'C4', 'Quantize rowwise G (switchback)'),
('w_quantize_global', '.', '--', 'C4', 'Quatnize global W (switchback)'),
('w_quantize_global_transpose', '.', '-.', 'C4', 'Quantize gloabl and\ntranspose W (switchback)'),
]:
xs = []
ys = []
for embed_dim in dims_to_consider:
# average over dim -> 4*dim and 4*dim -> dim
df_ = df[df.dim_in == embed_dim]
df_ = df_[df_.dim_out == embed_dim * 4]
xs.append(embed_dim)
y_ = 0
for k_ in k.split('+'):
y_ += df_[k_].values[0]
df_ = df[df.dim_in == embed_dim * 4]
df_ = df_[df_.dim_out == embed_dim]
for k_ in k.split('+'):
y_ += df_[k_].values[0]
ys.append(y_ * 0.5)
ax.plot(xs, ys, color=color, label=name, marker=marker, markersize=5 if marker=='s' else 5, linestyle=ls, linewidth=2 if '+' in k else 1.)
ax.set_xlabel('dim', fontsize=13)
ax.set_ylabel('time (ms)', fontsize=13)
ax.grid()
ax.set_xscale('log')
if logscale_plot1:
ax.set_yscale('log')
ax.tick_params(axis='x', labelsize=11)
ax.tick_params(axis='y', labelsize=11)
ax.set_xticks(dims_to_xtick)
ax.set_xticklabels(dims_to_xtick)
ax.set_xticks([], minor=True)
leg = ax.legend(loc='upper center', bbox_to_anchor=(-0.64, 1.), ncol=1, fontsize=10)
leg.get_texts()[0].set_fontweight('bold')
leg.get_texts()[1].set_fontweight('bold')
plt.subplots_adjust(left=0.1)
ax.set_title(' Linear layer, batch * sequence length = 32k', fontsize=10, loc='left', y=1.05, pad=-20)
ax = fig.add_subplot(gs[0, 1])
# now plot the % speedup for different batch sizes
for j, batch_size in enumerate(batch_sizes_for_plot2):
all_xs, all_ys = [], []
for k, marker, ls, color, name in [
('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (total time)'),
('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (total time)'),
]:
xs, ys = [], []
df = rdf[rdf.batch_size == batch_size]
for embed_dim in dims_to_consider:
df_ = df[df.dim_in == embed_dim]
df_ = df_[df_.dim_out == embed_dim * 4]
xs.append(embed_dim)
y_ = 0
for k_ in k.split('+'):
y_ += df_[k_].values[0]
df_ = df[df.dim_in == embed_dim * 4]
df_ = df_[df_.dim_out == embed_dim]
for k_ in k.split('+'):
y_ += df_[k_].values[0]
ys.append(y_ * 0.5)
all_xs.append(xs)
all_ys.append(ys)
color = cmap(j * 0.25)
real_ys = [-((all_ys[1][i] - all_ys[0][i]) / all_ys[0][i]) * 100 for i in range(len(all_ys[0]))]
markers = ['^', 'v', 'P', 'o']
ax.plot(all_xs[0], real_ys, color=color, label=f'batch * sequence length = {batch_size}', marker=markers[j], markersize=5 if marker=='s' else 5)
ax.legend()
ax.set_xlabel('dim', fontsize=13)
ax.set_xscale('log')
ax.grid()
ax.set_ylabel(r'% speedup', fontsize=13)
ax.tick_params(axis='x', labelsize=11)
ax.tick_params(axis='y', labelsize=11)
ax.set_xticks(dims_to_xtick)
ax.set_xticklabels(dims_to_xtick)
ax.set_xticks([], minor=True)
ax.set_title(' Linear layer summary, varying dimensions', fontsize=10, loc='left', y=1.05, pad=-20)
plt.savefig('speed_benchmark/plot_with_info.pdf', bbox_inches='tight')
import json
import time
import torch
import torch.nn as nn
from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose
from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze
# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large.
def get_time(k, fn, info_dict):
for _ in range(repeat // 2):
fn()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
fn()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info_dict[k] = ms
if __name__ == '__main__':
torch.manual_seed(0)
wm = 4
for dim in [1024, 1280, 1408, 1664, 2048, 4096]:
# note "batch_size" is actually "batch_size * embed_dim", which is why it's large
for batch_size in [256*32, 256*64, 256*128, 256*256, 256*512]:
# switch switches dim_in and dim_out
for switch in [False, True]:
# hparams
repeat = 64
batch_size = batch_size
dim_out = dim * wm
dim_in = dim
if switch:
dim_out = dim
dim_in = wm * dim
dim_in = round(dim_in)
dim_out = round(dim_out)
# simulate forward pass
x = torch.randn(batch_size, dim_in, dtype=torch.float16).cuda()
g = torch.randn(batch_size, dim_out, dtype=torch.float16).cuda()
w = torch.randn(dim_out, dim_in, dtype=torch.float16).cuda()
x_int8 = x.clone().to(torch.int8)
g_int8 = g.clone().to(torch.int8)
w_int8 = w.clone().to(torch.int8)
wt_int8 = w.t().contiguous().clone().to(torch.int8)
state_x_rowwise = x.max(dim=1)[0]
state_g_rowwise = g.max(dim=1)[0]
state_w_columnwise = w.max(dim=0)[0]
state_w_rowwise = w.max(dim=1)[0]
state_w_global = w.max()
info = {'repeat' : repeat, 'batch_size' : batch_size, 'dim_out' : dim_out, 'dim_in' : dim_in, 'wm' : wm, 'switch' : switch}
get_time('standard_fwd', lambda : x.matmul(w.t()), info)
get_time('standard_gw', lambda : g.t().matmul(x), info)
get_time('standard_gx', lambda : g.matmul(w), info)
get_time('rowwise_fwd', lambda : int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise, None), info)
get_time('rowwise_bwd', lambda : int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise, None), info)
get_time('global_fwd', lambda : int8_matmul_mixed_dequanitze(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('global_bwd', lambda : int8_matmul_mixed_dequanitze(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('x_quantize_rowwise', lambda : quantize_rowwise(x), info)
get_time('g_quantize_rowwise', lambda : quantize_rowwise(g), info)
get_time('w_quantize_rowwise', lambda : quantize_rowwise(w), info)
get_time('w_quantize_colwise_transpose', lambda : quantize_columnwise_and_transpose(w), info)
get_time('w_quantize_global', lambda : quantize_global(w), info)
get_time('w_quantize_global_transpose', lambda : quantize_global_transpose(w), info)
time_standard = info['standard_fwd'] + info['standard_gx'] + info['standard_gw']
time_rowwise = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_colwise_transpose'] + info['w_quantize_rowwise'] + info['standard_gw'] + info['rowwise_fwd'] + info['rowwise_bwd']
time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd']
print('TOTAL STANDARD', time_standard)
print('TOTAL ROWWISE', time_rowwise)
print('TOTAL GLOBAL', time_global)
print('speedup', -100*(time_global - time_standard)/time_standard)
info['time_standard'] = time_standard
info['time_rowwise'] = time_rowwise
info['time_global'] = time_global
info_json = json.dumps(info)
# TODO: change this to what you want.
with open("speed_benchmark/info.jsonl", "a") as file:
file.write(info_json + "\n")
...@@ -3,13 +3,14 @@ ...@@ -3,13 +3,14 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from . import cuda_setup, utils from . import cuda_setup, utils, research
from .autograd._functions import ( from .autograd._functions import (
MatmulLtState, MatmulLtState,
bmm_cublas, bmm_cublas,
matmul, matmul,
matmul_cublas, matmul_cublas,
mm_cublas, mm_cublas,
matmul_4bit
) )
from .cextension import COMPILED_WITH_CUDA from .cextension import COMPILED_WITH_CUDA
from .nn import modules from .nn import modules
......
...@@ -2,7 +2,7 @@ import operator ...@@ -2,7 +2,7 @@ import operator
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from functools import reduce # Required in Python 3 from functools import reduce # Required in Python 3
from typing import Tuple, Optional from typing import Tuple, Optional, List
import torch import torch
...@@ -232,6 +232,19 @@ def supports_igemmlt(device: torch.device) -> bool: ...@@ -232,6 +232,19 @@ def supports_igemmlt(device: torch.device) -> bool:
return True return True
def _get_tile_size(format):
assert format in (
"col_turing",
"col_ampere",
), f"please find this assert and manually enter tile size for {format}"
return (8, 32) if format == "col_turing" else (32, 32)
def get_tile_inds(format, device):
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=format)[0].to(x.device)
with torch.no_grad():
return get_inverse_transform_indices(transform, _get_tile_size(format)).to(device)
@dataclass @dataclass
class MatmulLtState: class MatmulLtState:
_tile_indices: Optional[torch.Tensor] = None _tile_indices: Optional[torch.Tensor] = None
...@@ -267,20 +280,10 @@ class MatmulLtState: ...@@ -267,20 +280,10 @@ class MatmulLtState:
self.SBt = None self.SBt = None
self.CBt = None self.CBt = None
def get_tile_size(self):
assert self.formatB in (
"col_turing",
"col_ampere",
), f"please find this assert and manually enter tile size for {self.formatB}"
return (8, 32) if self.formatB == "col_turing" else (32, 32)
@property @property
def tile_indices(self): def tile_indices(self):
if self._tile_indices is None: if self._tile_indices is None:
device = self.CxB.device self._tile_indices = get_tile_inds(self.formatB, self.CxB.device)
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=self.formatB)[0].to(x.device)
with torch.no_grad():
self._tile_indices = get_inverse_transform_indices(transform, self.get_tile_size()).to(device)
return self._tile_indices return self._tile_indices
...@@ -424,10 +427,10 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -424,10 +427,10 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
if any(ctx.needs_input_grad[:2]): if any(ctx.needs_input_grad[:2]):
ctx.tensors = (CAt, subA) ctx.tensors = (CAt, subA, A)
ctx.tensor_states = (SCAt, state.idx) ctx.tensor_states = (SCAt, state.idx)
else: else:
ctx.tensors = [None, None] ctx.tensors = [None, None, A]
ctx.tensor_states = (None, None) ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None) ctx.save_for_backward(None, None)
...@@ -440,7 +443,7 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -440,7 +443,7 @@ class MatMul8bitLt(torch.autograd.Function):
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
CAt, subA = ctx.tensors CAt, subA, A = ctx.tensors
SCAt, idx = ctx.tensor_states SCAt, idx = ctx.tensor_states
formatB = ctx.formatB formatB = ctx.formatB
state = ctx.state state = ctx.state
...@@ -487,6 +490,64 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -487,6 +490,64 @@ class MatMul8bitLt(torch.autograd.Function):
return grad_A, grad_B, None, grad_bias, None return grad_A, grad_B, None, grad_bias, None
class MatMul4Bit(torch.autograd.Function):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@staticmethod
def forward(ctx, A, B, out=None, bias=None, state=None):
# default of pytorch behavior if inputs are empty
ctx.is_empty = False
if prod(A.shape) == 0:
ctx.is_empty = True
ctx.A = A
ctx.B = B
ctx.bias = bias
B_shape = state[1]
if A.shape[-1] == B_shape[0]:
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
else:
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
# 1. Dequantize
# 2. MatmulnN
output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype).t(), bias)
# 3. Save state
ctx.state = state
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
if any(ctx.needs_input_grad[:2]):
ctx.tensors = (A, B)
else:
ctx.tensors = (None, None)
return output
@staticmethod
def backward(ctx, grad_output):
if ctx.is_empty:
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, _, _, req_gradBias, _= ctx.needs_input_grad
A, B = ctx.tensors
state = ctx.state
grad_A, grad_B, grad_bias = None, None, None
if req_gradBias:
# compute grad_bias first before changing grad_output dtype
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)
# not supported by PyTorch. TODO: create work-around
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_fp4(B, ctx.state).to(grad_output.dtype).t())
return grad_A, grad_B, None, grad_bias, None
def matmul( def matmul(
A: tensor, A: tensor,
B: tensor, B: tensor,
...@@ -499,3 +560,8 @@ def matmul( ...@@ -499,3 +560,8 @@ def matmul(
if threshold > 0.0: if threshold > 0.0:
state.threshold = threshold state.threshold = threshold
return MatMul8bitLt.apply(A, B, out, bias, state) return MatMul8bitLt.apply(A, B, out, bias, state)
def matmul_4bit(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None):
assert quant_state is not None
return MatMul4Bit.apply(A, B, out, bias, quant_state)
...@@ -18,17 +18,24 @@ try: ...@@ -18,17 +18,24 @@ try:
CUDASetup.get_instance().generate_instructions() CUDASetup.get_instance().generate_instructions()
CUDASetup.get_instance().print_log_stack() CUDASetup.get_instance().print_log_stack()
raise RuntimeError(''' raise RuntimeError('''
CUDA Setup failed despite GPU being available. Inspect the CUDA SETUP outputs above to fix your environment! CUDA Setup failed despite GPU being available. Please run the following command to get more information:
If you cannot find any issues and suspect a bug, please open an issue with detals about your environment:
https://github.com/TimDettmers/bitsandbytes/issues''') python -m bitsandbytes
lib.cadam32bit_g32
Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them
to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes
and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues''')
lib.cadam32bit_grad_fp32 # runs on an error if the library could not be found -> COMPILED_WITH_CUDA=False
lib.get_context.restype = ct.c_void_p lib.get_context.restype = ct.c_void_p
lib.get_cusparse.restype = ct.c_void_p lib.get_cusparse.restype = ct.c_void_p
lib.cget_managed_ptr.restype = ct.c_void_p
COMPILED_WITH_CUDA = True COMPILED_WITH_CUDA = True
except AttributeError: except AttributeError as ex:
warn("The installed version of bitsandbytes was compiled without GPU support. " warn("The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers and GPU quantization are unavailable.") "8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.")
COMPILED_WITH_CUDA = False COMPILED_WITH_CUDA = False
print(str(ex))
# print the setup details after checking for errors so we do not print twice # print the setup details after checking for errors so we do not print twice
if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0': if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0':
......
...@@ -44,6 +44,9 @@ class CUDASetup: ...@@ -44,6 +44,9 @@ class CUDASetup:
raise RuntimeError("Call get_instance() instead") raise RuntimeError("Call get_instance() instead")
def generate_instructions(self): def generate_instructions(self):
if getattr(self, 'error', False): return
print(self.error)
self.error = True
if self.cuda is None: if self.cuda is None:
self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA library was not detected.') self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA library was not detected.')
self.add_log_entry('CUDA SETUP: Solution 1): Your paths are probably not up-to-date. You can update them via: sudo ldconfig.') self.add_log_entry('CUDA SETUP: Solution 1): Your paths are probably not up-to-date. You can update them via: sudo ldconfig.')
...@@ -93,6 +96,7 @@ class CUDASetup: ...@@ -93,6 +96,7 @@ class CUDASetup:
self.has_printed = False self.has_printed = False
self.lib = None self.lib = None
self.initialized = False self.initialized = False
self.error = False
def run_cuda_setup(self): def run_cuda_setup(self):
self.initialized = True self.initialized = True
......
This diff is collapsed.
...@@ -2,4 +2,5 @@ ...@@ -2,4 +2,5 @@
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from .modules import Int8Params, Linear8bitLt, StableEmbedding from .modules import Int8Params, Linear8bitLt, StableEmbedding, Linear4bit, LinearNF4, LinearFP4, Params4bit, OutlierAwareLinear, SwitchBackLinearBnb
from .triton_based_modules import SwitchBackLinear, SwitchBackLinearGlobal, SwitchBackLinearVectorwise, StandardLinear
...@@ -10,8 +10,9 @@ from torch import Tensor, device, dtype, nn ...@@ -10,8 +10,9 @@ from torch import Tensor, device, dtype, nn
import bitsandbytes as bnb import bitsandbytes as bnb
import bitsandbytes.functional import bitsandbytes.functional
from bitsandbytes.autograd._functions import get_inverse_transform_indices, undo_layout from bitsandbytes.autograd._functions import undo_layout, get_tile_inds
from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import OutlierTracer, find_outlier_dims
T = TypeVar("T", bound="torch.nn.Module") T = TypeVar("T", bound="torch.nn.Module")
...@@ -135,6 +136,101 @@ class Embedding(torch.nn.Embedding): ...@@ -135,6 +136,101 @@ class Embedding(torch.nn.Embedding):
return emb return emb
class Params4bit(torch.nn.Parameter):
def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, compress_statistics=True, quant_type='fp4'):
if data is None:
data = torch.empty(0)
self = torch.Tensor._make_subclass(cls, data, requires_grad)
self.blocksize = blocksize
self.compress_statistics = compress_statistics
self.quant_type = quant_type
self.quant_state = quant_state
self.data = data
return self
def cuda(self, device):
w = self.data.contiguous().half().cuda(device)
w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type)
self.data = w_4bit
self.quant_state = quant_state
return self
@overload
def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ...,) -> T:
...
@overload
def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T:
...
@overload
def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T:
...
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if (device is not None and device.type == "cuda" and self.data.device.type == "cpu"):
return self.cuda(device)
else:
s = self.quant_state
if s is not None:
# make sure the quantization state is on the right device
s[0] = s[0].to(device)
if self.compress_statistics:
# TODO: refactor this. This is a nightmare
# for 4-bit:
# state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type]
# state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type]
#s[-2][0] = s[-2][0].to(device) # offset
#s[-2][1][0] = s[-2][1][0].to(device) # nested absmax
# for 8-bit
s[-2][0] = s[-2][0].to(device) # offset
s[-2][1][0] = s[-2][1][0].to(device) # nested quantiation state statitics
s[-2][1][1] = s[-2][1][1].to(device) # nested quantiation codebook
new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking),
requires_grad=self.requires_grad, quant_state=self.quant_state,
blocksize=self.blocksize, compress_statistics=self.compress_statistics,
quant_type=self.quant_type)
return new_param
class Linear4bit(nn.Linear):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4'):
super().__init__(input_features, output_features, bias)
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type)
self.compute_dtype = compute_dtype
def forward(self, x: torch.Tensor):
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
if getattr(self.weight, 'quant_state', None) is None:
print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
inp_dtype = x.dtype
if self.compute_dtype is not None:
x = x.to(self.compute_dtype)
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
out = out.to(inp_dtype)
return out
class LinearFP4(Linear4bit):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4')
class LinearNF4(Linear4bit):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4')
class Int8Params(torch.nn.Parameter): class Int8Params(torch.nn.Parameter):
def __new__( def __new__(
...@@ -210,6 +306,18 @@ class Int8Params(torch.nn.Parameter): ...@@ -210,6 +306,18 @@ class Int8Params(torch.nn.Parameter):
return new_param return new_param
def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
weight = state_dict.get(f"{prefix}weight")
if weight is None:
# if the state dict has no weights for this layer (e.g., LoRA finetuning), do nothing
return
weight_format = state_dict.pop(f"{prefix}weight_format", "row")
if weight_format != "row":
tile_indices = get_tile_inds(weight_format, weight.device)
state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices)
class Linear8bitLt(nn.Linear): class Linear8bitLt(nn.Linear):
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
memory_efficient_backward=False, threshold=0.0, index=None): memory_efficient_backward=False, threshold=0.0, index=None):
...@@ -225,52 +333,55 @@ class Linear8bitLt(nn.Linear): ...@@ -225,52 +333,55 @@ class Linear8bitLt(nn.Linear):
self.state.use_pool = True self.state.use_pool = True
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights) self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
self._register_load_state_dict_pre_hook(maybe_rearrange_weight)
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_state_dict(self, destination, prefix, keep_vars):
if not self.state.has_fp16_weights and self.state.CB is None and self.state.CxB is not None:
# reorder weight layout back from ampere/turing to row
reorder_layout = True
weight_clone = self.weight.data.clone()
else:
reorder_layout = False
try:
if reorder_layout:
self.weight.data = undo_layout(self.state.CxB, self.state.tile_indices)
super()._save_to_state_dict(destination, prefix, keep_vars) super()._save_to_state_dict(destination, prefix, keep_vars)
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data # we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
weight_name = "SCB" scb_name = "SCB"
# case 1: .cuda was called, SCB is in self.weight # case 1: .cuda was called, SCB is in self.weight
param_from_weight = getattr(self.weight, weight_name) param_from_weight = getattr(self.weight, scb_name)
# case 2: self.init_8bit_state was called, SCB is in self.state # case 2: self.init_8bit_state was called, SCB is in self.state
param_from_state = getattr(self.state, weight_name) param_from_state = getattr(self.state, scb_name)
# case 3: SCB is in self.state, weight layout reordered after first forward()
layout_reordered = self.state.CxB is not None
key_name = prefix + f"{scb_name}"
format_name = prefix + "weight_format"
key_name = prefix + f"{weight_name}" if not self.state.has_fp16_weights:
if param_from_weight is not None: if param_from_weight is not None:
destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach() destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach()
elif not self.state.has_fp16_weights and param_from_state is not None: destination[format_name] = "row"
elif param_from_state is not None and not layout_reordered:
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
destination[format_name] = "row"
elif param_from_state is not None:
destination[key_name] = param_from_state if keep_vars else param_from_state.detach() destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
finally: destination[format_name] = self.state.formatB
if reorder_layout:
self.weight.data = weight_clone
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs): missing_keys, unexpected_keys, error_msgs):
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs) error_msgs)
for key in unexpected_keys: unexpected_copy = list(unexpected_keys)
for key in unexpected_copy:
input_name = key[len(prefix):] input_name = key[len(prefix):]
if input_name == "SCB": if input_name == "SCB":
if self.weight.SCB is None: if self.weight.SCB is None:
# buffers not yet initialized, can't call them directly without # buffers not yet initialized, can't access them directly without quantizing first
raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is " raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is "
"not supported. Please call module.cuda() before module.load_state_dict()") "not supported. Please call module.cuda() before module.load_state_dict()")
input_param = state_dict[key] input_param = state_dict[key]
self.weight.SCB.copy_(input_param) self.weight.SCB.copy_(input_param)
if self.state.SCB is not None:
self.state.SCB = self.weight.SCB
unexpected_keys.remove(key) unexpected_keys.remove(key)
def init_8bit_state(self): def init_8bit_state(self):
...@@ -289,6 +400,7 @@ class Linear8bitLt(nn.Linear): ...@@ -289,6 +400,7 @@ class Linear8bitLt(nn.Linear):
self.bias.data = self.bias.data.to(x.dtype) self.bias.data = self.bias.data.to(x.dtype)
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
if not self.state.has_fp16_weights: if not self.state.has_fp16_weights:
if self.state.CB is not None and self.state.CxB is not None: if self.state.CB is not None and self.state.CxB is not None:
# we converted 8-bit row major to turing/ampere format in the first inference pass # we converted 8-bit row major to turing/ampere format in the first inference pass
...@@ -296,3 +408,71 @@ class Linear8bitLt(nn.Linear): ...@@ -296,3 +408,71 @@ class Linear8bitLt(nn.Linear):
del self.state.CB del self.state.CB
self.weight.data = self.state.CxB self.weight.data = self.state.CxB
return out return out
class OutlierAwareLinear(nn.Linear):
def __init__(self, input_features, output_features, bias=True):
super().__init__(input_features, output_features, bias)
self.outlier_dim = None
self.is_quantized = False
def forward_with_outliers(self, x, outlier_idx):
raise NotImplementedError('Please override the `forward_with_outliers(self, x, outlier_idx)` function')
def quantize_weight(self, w, outlier_idx):
raise NotImplementedError('Please override the `quantize_weights(self, w, outlier_idx)` function')
def forward(self, x):
if self.outlier_dim is None:
tracer = OutlierTracer.get_instance()
if not tracer.is_initialized():
print('Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer')
outlier_idx = tracer.get_outliers(self.weight)
#print(outlier_idx, tracer.get_hvalue(self.weight))
self.outlier_dim = outlier_idx
if not self.is_quantized:
w = self.quantize_weight(self.weight, self.outlier_dim)
self.weight.data.copy_(w)
self.is_quantized = True
class SwitchBackLinearBnb(nn.Linear):
def __init__(
self,
input_features,
output_features,
bias=True,
has_fp16_weights=True,
memory_efficient_backward=False,
threshold=0.0,
index=None,
):
super().__init__(
input_features, output_features, bias
)
self.state = bnb.MatmulLtState()
self.index = index
self.state.threshold = threshold
self.state.has_fp16_weights = has_fp16_weights
self.state.memory_efficient_backward = memory_efficient_backward
if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True
self.weight = Int8Params(
self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights
)
def init_8bit_state(self):
self.state.CB = self.weight.CB
self.state.SCB = self.weight.SCB
self.weight.CB = None
self.weight.SCB = None
def forward(self, x):
self.state.is_training = self.training
if self.weight.CB is not None:
self.init_8bit_state()
out = bnb.matmul_mixed(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias
import torch
import torch.nn as nn
import time
from functools import partial
from bitsandbytes.triton.triton_utils import is_triton_available
from bitsandbytes.triton.dequantize_rowwise import dequantize_rowwise
from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose
from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze
class _switchback_global(torch.autograd.Function):
@staticmethod
def forward(ctx, X_3D, W, bias):
# reshape input to [N * L, D]
X = X_3D.view(-1, X_3D.size(-1))
# rowwise quantize for X, global quantize for W
X_int8, state_X = quantize_rowwise(X)
W_int8, state_W = quantize_global(W)
# save for backward.
ctx.save_for_backward = X, W
# matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized
return int8_matmul_mixed_dequanitze(
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D.size()[:-1], -1)
@staticmethod
def backward(ctx, G_3D):
# reshape input to [N_out * L, D]
G = G_3D.reshape(-1, G_3D.size(-1))
grad_X = grad_W = grad_bias = None
X, W = ctx.save_for_backward
if ctx.needs_input_grad[0]:
# rowwise quantize for G, global quantize for W
# for W, we also fuse the transpose operation because only A @ B^T is supported
# so we transpose once then call .t() in the matmul
G_int8, state_G = quantize_rowwise(G)
W_int8, state_W = quantize_global_transpose(W)
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D.size()[:-1], -1
)
if ctx.needs_input_grad[1]:
# backward pass uses standard weight grad
grad_W = torch.matmul(G.t(), X.to(G.dtype))
if ctx.needs_input_grad[2]:
grad_bias = G.sum(dim=0)
return grad_X, grad_W, grad_bias
class _switchback_vectorrize(torch.autograd.Function):
@staticmethod
def forward(ctx, X_3D, W, bias):
# reshape input to [N * L, D]
X = X_3D.view(-1, X_3D.size(-1))
ctx.save_for_backward = X, W
# rowwise quantize for X
# columnwise quantize for W (first rowwise, transpose later)
X_int8, state_X = quantize_rowwise(X)
W_int8, state_W = quantize_rowwise(W)
# matmult, fused dequant and add bias
# call kernel which expects rowwise quantized X and W
return int8_matmul_rowwise_dequantize(
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D.size()[:-1], -1)
@staticmethod
def backward(ctx, G_3D):
X, W = ctx.save_for_backward
G = G_3D.reshape(-1, G_3D.size(-1))
grad_X = grad_W = grad_bias = None
if ctx.needs_input_grad[0]:
# rowwise quantize for G, columnwise quantize for W and fused transpose
# we call .t() for weight later because only A @ B^T is supported
G_int8, state_G = quantize_rowwise(G)
W_int8, state_W = quantize_columnwise_and_transpose(W)
grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D.size()[:-1], -1
)
if ctx.needs_input_grad[1]:
# backward pass uses standard weight grad
grad_W = torch.matmul(G.t(), X.to(G.dtype))
if ctx.needs_input_grad[2]:
grad_bias = G.sum(dim=0)
return grad_X, grad_W, grad_bias
class _switchback_global_mem_efficient(torch.autograd.Function):
@staticmethod
def forward(ctx, X_3D, W, bias):
# reshape input to [N * L, D]
X = X_3D.view(-1, X_3D.size(-1))
X_3D_sz = X_3D.size()
# rowwise quantize for X, global quantize for W
X_int8, state_X = quantize_rowwise(X)
del X
W_int8, state_W = quantize_global(W)
# save for backward.
ctx.save_for_backward = X_int8, state_X, W_int8, state_W
# matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized
return int8_matmul_mixed_dequanitze(
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D_sz[:-1], -1)
@staticmethod
def backward(ctx, G_3D):
# reshape input to [N_out * L, D]
G = G_3D.reshape(-1, G_3D.size(-1))
G_3D_sz = G_3D.size()
grad_X = grad_W = grad_bias = None
X_int8, state_X, W_int8, state_W = ctx.save_for_backward
if ctx.needs_input_grad[1]:
real_X = dequantize_rowwise(X_int8, state_X)
del X_int8
grad_W = torch.matmul(G.t(), real_X.to(G.dtype))
del real_X
if ctx.needs_input_grad[2]:
grad_bias = G.sum(dim=0)
if ctx.needs_input_grad[0]:
G_int8, state_G = quantize_rowwise(G)
del G
W_int8 = W_int8.t().contiguous()
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D_sz[:-1], -1
)
return grad_X, grad_W, grad_bias
class SwitchBackLinear(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
vector_wise_quantization: bool = False,
mem_efficient : bool = False,
):
super().__init__(in_features, out_features, bias, device, dtype)
if not is_triton_available:
raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear.
Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''')
# By default, we use the global quantization.
self.vector_wise_quantization = vector_wise_quantization
if self.vector_wise_quantization:
self._fn = _switchback_vectorrize
if mem_efficient:
print('mem efficient is not supported for vector-wise quantization.')
exit(1)
else:
if mem_efficient:
self._fn = _switchback_global_mem_efficient
else:
self._fn = _switchback_global
def prepare_for_eval(self):
# If we just want to do eval, we can pre-quantize the weights instead of doing it on the forward pass.
# Note this is experimental and not tested thoroughly.
# Note this needs to be explicitly called with something like
# def cond_prepare(m):
# if hasattr(m, "prepare_for_eval"):
# m.prepare_for_eval()
# model.apply(cond_prepare)
print('=> preparing for eval.')
if self.vector_wise_quantization:
W_int8, state_W = quantize_rowwise(self.weight)
else:
W_int8, state_W = quantize_global(self.weight)
self.register_buffer("W_int8", W_int8)
self.register_buffer("state_W", state_W)
del self.weight
def forward(self, x):
if self.training:
return self._fn.apply(x, self.weight, self.bias)
else:
# If it hasn't been "prepared for eval", run the standard forward pass.
if not hasattr(self, "W_int8"):
return self._fn.apply(x, self.weight, self.bias)
# Otherwise, use pre-computed weights.
X = x.view(-1, x.size(-1))
X_int8, state_X = quantize_rowwise(X)
if self.vector_wise_quantization:
return int8_matmul_rowwise_dequantize(
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
).view(*x.size()[:-1], -1)
else:
return int8_matmul_mixed_dequanitze(
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
).view(*x.size()[:-1], -1)
SwitchBackLinearGlobal = partial(SwitchBackLinear, vector_wise_quantization=False)
SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vector_wise_quantization=False, mem_efficient=True)
SwitchBackLinearVectorwise = partial(SwitchBackLinear, vector_wise_quantization=True)
# This is just the standard linear function.
class StandardLinearFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias=None):
X = input.view(-1, input.size(-1))
ctx.save_for_backward(X, weight, bias)
output = input.matmul(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output.view(*input.size()[:-1], -1)
@staticmethod
def backward(ctx, grad_output_3D):
input, weight, bias = ctx.saved_tensors
grad_output = grad_output_3D.reshape(-1, grad_output_3D.size(-1))
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.matmul(weight.to(grad_output.dtype)).view(*grad_output_3D.size()[:-1], -1)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().matmul(input.to(grad_output.dtype))
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_bias
class StandardLinear(nn.Linear):
def forward(self, x):
return StandardLinearFunction.apply(x, self.weight, self.bias)
...@@ -6,11 +6,11 @@ ...@@ -6,11 +6,11 @@
from bitsandbytes.cextension import COMPILED_WITH_CUDA from bitsandbytes.cextension import COMPILED_WITH_CUDA
from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
from .adam import Adam, Adam8bit, Adam32bit from .adam import Adam, Adam8bit, Adam32bit, PagedAdam, PagedAdam8bit, PagedAdam32bit
from .adamw import AdamW, AdamW8bit, AdamW32bit from .adamw import AdamW, AdamW8bit, AdamW32bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit
from .lamb import LAMB, LAMB8bit, LAMB32bit from .lamb import LAMB, LAMB8bit, LAMB32bit
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
from .optimizer import GlobalOptimManager from .optimizer import GlobalOptimManager
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
from .lion import Lion, Lion8bit, Lion32bit from .lion import Lion, Lion8bit, Lion32bit, PagedLion, PagedLion8bit, PagedLion32bit
from .sgd import SGD, SGD8bit, SGD32bit from .sgd import SGD, SGD8bit, SGD32bit
...@@ -14,92 +14,34 @@ from bitsandbytes.optim.optimizer import Optimizer2State ...@@ -14,92 +14,34 @@ from bitsandbytes.optim.optimizer import Optimizer2State
class Adam(Optimizer2State): class Adam(Optimizer2State):
def __init__( def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
self, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
params, super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class Adam8bit(Optimizer2State): class Adam8bit(Optimizer2State):
def __init__( def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
self, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
params, super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class Adam32bit(Optimizer2State): class Adam32bit(Optimizer2State):
def __init__( def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
self, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
params, super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
lr=1e-3,
betas=(0.9, 0.999), class PagedAdam(Optimizer2State):
eps=1e-8, def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
amsgrad=False, super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
args=None,
min_8bit_size=4096, class PagedAdam8bit(Optimizer2State):
percentile_clipping=100, def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
block_wise=True, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
): super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
super().__init__(
"adam", class PagedAdam32bit(Optimizer2State):
params, def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
lr, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
betas, super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class AnalysisAdam(torch.optim.Optimizer): class AnalysisAdam(torch.optim.Optimizer):
"""Adam that performs 8-bit vs 32-bit error analysis. """Adam that performs 8-bit vs 32-bit error analysis.
......
...@@ -5,89 +5,35 @@ ...@@ -5,89 +5,35 @@
from bitsandbytes.optim.optimizer import Optimizer2State from bitsandbytes.optim.optimizer import Optimizer2State
class AdamW(Optimizer2State):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class AdamW(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged )
class AdamW8bit(Optimizer2State): class AdamW8bit(Optimizer2State):
def __init__( def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
self, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
params, super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged )
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class AdamW32bit(Optimizer2State): class AdamW32bit(Optimizer2State):
def __init__( def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
self, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
params, super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8, class PagedAdamW(Optimizer2State):
weight_decay=1e-2, def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
amsgrad=False, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
args=None, super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
min_8bit_size=4096,
percentile_clipping=100, class PagedAdamW8bit(Optimizer2State):
block_wise=True, def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
): args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
super().__init__( super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
"adam",
params, class PagedAdamW32bit(Optimizer2State):
lr, def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
betas, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
eps, super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
...@@ -4,84 +4,27 @@ ...@@ -4,84 +4,27 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer1State from bitsandbytes.optim.optimizer import Optimizer1State
class Lion(Optimizer1State): class Lion(Optimizer1State):
def __init__( def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
self, super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"lion",
params,
lr,
betas,
0.,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class Lion8bit(Optimizer1State): class Lion8bit(Optimizer1State):
def __init__( def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
self, super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"lion",
params,
lr,
betas,
0.,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class Lion32bit(Optimizer1State): class Lion32bit(Optimizer1State):
def __init__( def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
self, super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
params,
lr=1e-4,
betas=(0.9, 0.99), class PagedLion(Optimizer1State):
weight_decay=0, def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
args=None, super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
min_8bit_size=4096,
percentile_clipping=100, class PagedLion8bit(Optimizer1State):
block_wise=True, def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
): super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
super().__init__(
"lion", class PagedLion32bit(Optimizer1State):
params, def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
lr, super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
betas,
0.,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
...@@ -92,10 +92,12 @@ class GlobalOptimManager: ...@@ -92,10 +92,12 @@ class GlobalOptimManager:
class Optimizer8bit(torch.optim.Optimizer): class Optimizer8bit(torch.optim.Optimizer):
def __init__(self, params, defaults, optim_bits=32): def __init__(self, params, defaults, optim_bits=32, is_paged=False):
super().__init__(params, defaults) super().__init__(params, defaults)
self.initialized = False self.initialized = False
self.name2qmap = {} self.name2qmap = {}
self.is_paged = is_paged
self.page_mng = F.GlobalPageManager.get_instance()
self.mng = GlobalOptimManager.get_instance() self.mng = GlobalOptimManager.get_instance()
self.non_castable_tensor_keys = { self.non_castable_tensor_keys = {
...@@ -207,6 +209,8 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -207,6 +209,8 @@ class Optimizer8bit(torch.optim.Optimizer):
values = self.state[p] values = self.state[p]
for k, v in values.items(): for k, v in values.items():
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
is_paged = getattr(v, 'is_paged', False)
if not is_paged:
self.state[p][k] = v.to(p.device) self.state[p][k] = v.to(p.device)
def check_overrides(self): def check_overrides(self):
...@@ -252,6 +256,7 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -252,6 +256,7 @@ class Optimizer8bit(torch.optim.Optimizer):
self.to_gpu() # needed for fairseq pure fp16 training self.to_gpu() # needed for fairseq pure fp16 training
self.initialized = True self.initialized = True
#if self.is_paged: self.page_mng.prefetch_all()
for gindex, group in enumerate(self.param_groups): for gindex, group in enumerate(self.param_groups):
for pindex, p in enumerate(group["params"]): for pindex, p in enumerate(group["params"]):
if p.grad is None: if p.grad is None:
...@@ -260,7 +265,14 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -260,7 +265,14 @@ class Optimizer8bit(torch.optim.Optimizer):
if len(state) == 0: if len(state) == 0:
self.init_state(group, p, gindex, pindex) self.init_state(group, p, gindex, pindex)
self.prefetch_state(p)
self.update_step(group, p, gindex, pindex) self.update_step(group, p, gindex, pindex)
torch.cuda.synchronize()
if self.is_paged:
# all paged operation are asynchronous, we need
# to sync to make sure all tensors are in the right state
torch.cuda.synchronize()
return loss return loss
...@@ -289,6 +301,26 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -289,6 +301,26 @@ class Optimizer8bit(torch.optim.Optimizer):
"The update_step method needs to be overridden" "The update_step method needs to be overridden"
) )
def get_state_buffer(self, p, dtype=torch.float32):
if not self.is_paged or p.numel() < 1e5:
return torch.zeros_like(p, dtype=dtype, device=p.device)
else:
# > 1 MB
buff = F.get_paged(*p.shape, dtype=dtype, device=p.device)
F.fill(buff, 0)
self.page_mng.paged_tensors.append(buff)
return buff
def prefetch_state(self, p):
if self.is_paged:
state = self.state[p]
s1 = state['state1']
is_paged = getattr(s1, 'is_paged', False)
if is_paged:
F.prefetch_tensor(state['state1'])
if 'state2' in state:
F.prefetch_tensor(state['state2'])
class Optimizer2State(Optimizer8bit): class Optimizer2State(Optimizer8bit):
def __init__( def __init__(
...@@ -306,6 +338,7 @@ class Optimizer2State(Optimizer8bit): ...@@ -306,6 +338,7 @@ class Optimizer2State(Optimizer8bit):
block_wise=True, block_wise=True,
max_unorm=0.0, max_unorm=0.0,
skip_zeros=False, skip_zeros=False,
is_paged=False
): ):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}") raise ValueError(f"Invalid learning rate: {lr}")
...@@ -325,7 +358,7 @@ class Optimizer2State(Optimizer8bit): ...@@ -325,7 +358,7 @@ class Optimizer2State(Optimizer8bit):
f"Invalid weight_decay value: {weight_decay}" f"Invalid weight_decay value: {weight_decay}"
) )
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults, optim_bits) super().__init__(params, defaults, optim_bits, is_paged)
if args is None: if args is None:
args = {} args = {}
...@@ -365,18 +398,8 @@ class Optimizer2State(Optimizer8bit): ...@@ -365,18 +398,8 @@ class Optimizer2State(Optimizer8bit):
if dtype == torch.float32 or ( if dtype == torch.float32 or (
dtype == torch.uint8 and p.numel() < 4096 dtype == torch.uint8 and p.numel() < 4096
): ):
state["state1"] = torch.zeros_like( state["state1"] = self.get_state_buffer(p, dtype=torch.float32)
p, state["state2"] = self.get_state_buffer(p, dtype=torch.float32)
memory_format=torch.preserve_format,
dtype=torch.float32,
device=p.device,
)
state["state2"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
dtype=torch.float32,
device=p.device,
)
elif dtype == torch.uint8: elif dtype == torch.uint8:
if state["step"] == 0: if state["step"] == 0:
if "dynamic" not in self.name2qmap: if "dynamic" not in self.name2qmap:
...@@ -388,20 +411,10 @@ class Optimizer2State(Optimizer8bit): ...@@ -388,20 +411,10 @@ class Optimizer2State(Optimizer8bit):
p.device p.device
) )
state["state1"] = torch.zeros_like( state["state1"] = self.get_state_buffer(p, dtype=torch.uint8)
p,
memory_format=torch.preserve_format,
dtype=torch.uint8,
device=p.device,
)
state["qmap1"] = self.name2qmap["dynamic"] state["qmap1"] = self.name2qmap["dynamic"]
state["state2"] = torch.zeros_like( state["state2"] = self.get_state_buffer(p, dtype=torch.uint8)
p,
memory_format=torch.preserve_format,
dtype=torch.uint8,
device=p.device,
)
state["qmap2"] = self.name2qmap["udynamic"] state["qmap2"] = self.name2qmap["udynamic"]
if config["block_wise"]: if config["block_wise"]:
...@@ -538,6 +551,7 @@ class Optimizer1State(Optimizer8bit): ...@@ -538,6 +551,7 @@ class Optimizer1State(Optimizer8bit):
block_wise=True, block_wise=True,
max_unorm=0.0, max_unorm=0.0,
skip_zeros=False, skip_zeros=False,
is_paged=False
): ):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}") raise ValueError(f"Invalid learning rate: {lr}")
...@@ -553,7 +567,7 @@ class Optimizer1State(Optimizer8bit): ...@@ -553,7 +567,7 @@ class Optimizer1State(Optimizer8bit):
f"Invalid weight_decay value: {weight_decay}" f"Invalid weight_decay value: {weight_decay}"
) )
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults, optim_bits) super().__init__(params, defaults, optim_bits, is_paged)
if args is None: if args is None:
args = {} args = {}
...@@ -593,12 +607,7 @@ class Optimizer1State(Optimizer8bit): ...@@ -593,12 +607,7 @@ class Optimizer1State(Optimizer8bit):
if dtype == torch.float32 or ( if dtype == torch.float32 or (
dtype == torch.uint8 and p.numel() < 4096 dtype == torch.uint8 and p.numel() < 4096
): ):
state["state1"] = torch.zeros_like( state["state1"] = self.get_state_buffer(p, dtype=torch.float32)
p,
memory_format=torch.preserve_format,
dtype=torch.float32,
device=p.device,
)
elif dtype == torch.uint8: elif dtype == torch.uint8:
if state["step"] == 0: if state["step"] == 0:
if "dynamic" not in self.name2qmap: if "dynamic" not in self.name2qmap:
...@@ -607,12 +616,7 @@ class Optimizer1State(Optimizer8bit): ...@@ -607,12 +616,7 @@ class Optimizer1State(Optimizer8bit):
p.device p.device
) )
state["state1"] = torch.zeros_like( state["state1"] = self.get_state_buffer(p, dtype=torch.uint8)
p,
memory_format=torch.preserve_format,
dtype=torch.uint8,
device=p.device,
)
state["qmap1"] = self.name2qmap["dynamic"] state["qmap1"] = self.name2qmap["dynamic"]
if config["block_wise"]: if config["block_wise"]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment