Commit 675baa79 authored by Tim Dettmers's avatar Tim Dettmers
Browse files

Merge remote-tracking branch 'origin/main' into merge

parents f64cfe65 9e7cdc9e
...@@ -201,3 +201,30 @@ Features: ...@@ -201,3 +201,30 @@ Features:
Improvements: Improvements:
- Improved logging for the CUDA detection mechanism. - Improved logging for the CUDA detection mechanism.
### 0.38.0
#### 8-bit Lion, Load/Store 8-bit Models directly from/to HF Hub
Features:
- Support for 32 and 8-bit Lion has been added. Thank you @lucidrains
- Support for serialization of Linear8bitLt layers (LLM.int8()). This allows to store and load 8-bit weights directly from the HuggingFace Hub. Thank you @myrab
- New bug report features `python -m bitsandbytes` now gives extensive debugging details to debug CUDA setup failures.
Bug fixes:
- Fixed a bug where some bitsandbytes methods failed in a model-parallel setup on multiple GPUs. Thank you @tonylins
- Fixed a bug where cudart.so libraries could not be found in newer PyTorch releases.
Improvements:
- Improved the CUDA Setup procedure by doing a more extensive search for CUDA libraries
Deprecated:
- Devices with compute capability 3.0 (GTX 700s, K10) and 3.2 (Tegra K1, Jetson TK1) are now deprecated and support will be removed in 0.39.0.
- Support for CUDA 10.0 and 10.2 will be removed in bitsandbytes 0.39.0
### 0.38.1
Features:
- Added Int8 SwitchBack layers
- Added Fake FP8 layers for research purposes (available under `bnb.research.nn. ...`)
...@@ -11,11 +11,41 @@ Resources: ...@@ -11,11 +11,41 @@ Resources:
## TL;DR ## TL;DR
**Requirements** **Requirements**
Python >=3.8. Linux distribution (Ubuntu, MacOS, etc.) + CUDA > 10.0. LLM.int8() requires Turing or Ampere GPUs. Python >=3.8. Linux distribution (Ubuntu, MacOS, etc.) + CUDA > 10.0.
(Deprecated: CUDA 10.0 is deprecated and only CUDA >= 11.0) will be supported with release 0.39.0)
**Installation**: **Installation**:
``pip install bitsandbytes`` ``pip install bitsandbytes``
In some cases it can happen that you need to compile from source. If this happens please consider submitting a bug report with `python -m bitsandbytes` information. What now follows is some short instructions which might work out of the box if `nvcc` is installed. If these do not work see further below.
Compilation quickstart:
```bash
git clone https://github.com/timdettmers/bitsandbytes.git
cd bitsandbytes
# CUDA_VERSIONS in {110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 120}
# make argument in {cuda110, cuda11x, cuda12x}
# if you do not know what CUDA you have, try looking at the output of: python -m bitsandbytes
CUDA_VERSION=117 make cuda11x
python setup.py install
```
**Using Int8 inference with HuggingFace Transformers**
```python
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
'decapoda-research/llama-7b-hf,
device_map='auto',
load_in_8bit=True,
max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB')
```
A more detailed example, can be found in [examples/int8_inference_huggingface.py](examples/int8_inference_huggingface.py).
**Using 8-bit optimizer**: **Using 8-bit optimizer**:
1. Comment out optimizer: ``#torch.optim.Adam(....)`` 1. Comment out optimizer: ``#torch.optim.Adam(....)``
2. Add 8-bit optimizer of your choice ``bnb.optim.Adam8bit(....)`` (arguments stay the same) 2. Add 8-bit optimizer of your choice ``bnb.optim.Adam8bit(....)`` (arguments stay the same)
...@@ -40,7 +70,7 @@ out = linear(x.to(torch.float16)) ...@@ -40,7 +70,7 @@ out = linear(x.to(torch.float16))
## Features ## Features
- 8-bit Matrix multiplication with mixed precision decomposition - 8-bit Matrix multiplication with mixed precision decomposition
- LLM.int8() inference - LLM.int8() inference
- 8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB (saves 75% memory) - 8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB, Lion (saves 75% memory)
- Stable Embedding Layer: Improved stability through better initialization, and normalization - Stable Embedding Layer: Improved stability through better initialization, and normalization
- 8-bit quantization: Quantile, Linear, and Dynamic quantization - 8-bit quantization: Quantile, Linear, and Dynamic quantization
- Fast quantile estimation: Up to 100x faster than other algorithms - Fast quantile estimation: Up to 100x faster than other algorithms
...@@ -113,8 +143,23 @@ For upcoming features and changes and full history see [Patch Notes](CHANGELOG.m ...@@ -113,8 +143,23 @@ For upcoming features and changes and full history see [Patch Notes](CHANGELOG.m
2. __fatbinwrap_.. [Solution](errors_and_solutions.md#fatbinwrap_) 2. __fatbinwrap_.. [Solution](errors_and_solutions.md#fatbinwrap_)
## Compile from source ## Compile from source
To compile from source, you need an installation of CUDA. If `nvcc` is not installed, you can install the CUDA Toolkit with nvcc through the following commands.
```bash
wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/cuda_install.sh
# Syntax cuda_install CUDA_VERSION INSTALL_PREFIX EXPORT_TO_BASH
# CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121}
# EXPORT_TO_BASH in {0, 1} with 0=False and 1=True
# For example, the following installs CUDA 11.8 to ~/local/cuda-11.8 and exports the path to your .bashrc
bash cuda install 118 ~/local 1
```
To use a specific CUDA version just for a single compile run, you can set the variable `CUDA_HOME`, for example the following command compiles `libbitsandbytes_cuda117.so` using compiler flags for cuda11x with the cuda version at `~/local/cuda-11.7`:
``CUDA_HOME=~/local/cuda-11.7 CUDA_VERSION=117 make cuda11x``
To compile from source, please follow the [compile_from_source.md](compile_from_source.md) instructions. For more detailed instruction, please follow the [compile_from_source.md](compile_from_source.md) instructions.
## License ## License
......
Steps:
1. Run `python speed_benchmark/speed_benchmark.py` which times operations and writes their time to `speed_benchmark/info_a100_py2.jsonl` (change the name of the jsonl to a different name for your profiling).
2. Run `python speed_benchmark/make_plot_with_jsonl.py`, which produces the `speed_benchmark/plot_with_info.pdf`. Again make sure you change the jsonl which is being processed.
\ No newline at end of file
This diff is collapsed.
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
import matplotlib.gridspec as gridspec
cmap=plt.get_cmap('cool')
if __name__ == '__main__':
fig = plt.figure(tight_layout=True, figsize=(12,3.5))
gs = gridspec.GridSpec(1, 2)
dims_to_consider = [1024, 1280, 1408, 1664, 2048, 4096]
batch_size_for_plot1 = 32768
batch_sizes_for_plot2 = [2**14, 2**15, 2**16, 2**17]
dims_to_xtick = [1024, 2048, 4096]
logscale_plot1 = True
ax = fig.add_subplot(gs[0, 0])
# TODO: change this to what you want.
rdf = pd.read_json('speed_benchmark/info_a100_py2.jsonl', lines=True)
df = rdf[rdf.batch_size == batch_size_for_plot1]
# first plot the time occupied by different operations
for k, marker, ls, color, name in [
('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (sum of parts)'),
('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (sum of parts)'),
('standard_fwd', '^', '--', 'C2', 'Matmul XW (standard)'),
('standard_gw', '^', '-.', 'C2', 'Matmul GW (standard)'),
('standard_gx', '^', ':', 'gray', 'Matmul GX (both)'),
('global_fwd', '^', '--', 'C4', 'Int8 Matmul XW (switchback)'),
('global_bwd', '^', '-.', 'C4', 'Int8 Matmul GW (switchback)'),
('x_quantize_rowwise', 'P', '--', 'C4', 'Quantize rowwise X (switchback)'),
('g_quantize_rowwise', 'P', '-.', 'C4', 'Quantize rowwise G (switchback)'),
('w_quantize_global', '.', '--', 'C4', 'Quatnize global W (switchback)'),
('w_quantize_global_transpose', '.', '-.', 'C4', 'Quantize gloabl and\ntranspose W (switchback)'),
]:
xs = []
ys = []
for embed_dim in dims_to_consider:
# average over dim -> 4*dim and 4*dim -> dim
df_ = df[df.dim_in == embed_dim]
df_ = df_[df_.dim_out == embed_dim * 4]
xs.append(embed_dim)
y_ = 0
for k_ in k.split('+'):
y_ += df_[k_].values[0]
df_ = df[df.dim_in == embed_dim * 4]
df_ = df_[df_.dim_out == embed_dim]
for k_ in k.split('+'):
y_ += df_[k_].values[0]
ys.append(y_ * 0.5)
ax.plot(xs, ys, color=color, label=name, marker=marker, markersize=5 if marker=='s' else 5, linestyle=ls, linewidth=2 if '+' in k else 1.)
ax.set_xlabel('dim', fontsize=13)
ax.set_ylabel('time (ms)', fontsize=13)
ax.grid()
ax.set_xscale('log')
if logscale_plot1:
ax.set_yscale('log')
ax.tick_params(axis='x', labelsize=11)
ax.tick_params(axis='y', labelsize=11)
ax.set_xticks(dims_to_xtick)
ax.set_xticklabels(dims_to_xtick)
ax.set_xticks([], minor=True)
leg = ax.legend(loc='upper center', bbox_to_anchor=(-0.64, 1.), ncol=1, fontsize=10)
leg.get_texts()[0].set_fontweight('bold')
leg.get_texts()[1].set_fontweight('bold')
plt.subplots_adjust(left=0.1)
ax.set_title(' Linear layer, batch * sequence length = 32k', fontsize=10, loc='left', y=1.05, pad=-20)
ax = fig.add_subplot(gs[0, 1])
# now plot the % speedup for different batch sizes
for j, batch_size in enumerate(batch_sizes_for_plot2):
all_xs, all_ys = [], []
for k, marker, ls, color, name in [
('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (total time)'),
('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (total time)'),
]:
xs, ys = [], []
df = rdf[rdf.batch_size == batch_size]
for embed_dim in dims_to_consider:
df_ = df[df.dim_in == embed_dim]
df_ = df_[df_.dim_out == embed_dim * 4]
xs.append(embed_dim)
y_ = 0
for k_ in k.split('+'):
y_ += df_[k_].values[0]
df_ = df[df.dim_in == embed_dim * 4]
df_ = df_[df_.dim_out == embed_dim]
for k_ in k.split('+'):
y_ += df_[k_].values[0]
ys.append(y_ * 0.5)
all_xs.append(xs)
all_ys.append(ys)
color = cmap(j * 0.25)
real_ys = [-((all_ys[1][i] - all_ys[0][i]) / all_ys[0][i]) * 100 for i in range(len(all_ys[0]))]
markers = ['^', 'v', 'P', 'o']
ax.plot(all_xs[0], real_ys, color=color, label=f'batch * sequence length = {batch_size}', marker=markers[j], markersize=5 if marker=='s' else 5)
ax.legend()
ax.set_xlabel('dim', fontsize=13)
ax.set_xscale('log')
ax.grid()
ax.set_ylabel(r'% speedup', fontsize=13)
ax.tick_params(axis='x', labelsize=11)
ax.tick_params(axis='y', labelsize=11)
ax.set_xticks(dims_to_xtick)
ax.set_xticklabels(dims_to_xtick)
ax.set_xticks([], minor=True)
ax.set_title(' Linear layer summary, varying dimensions', fontsize=10, loc='left', y=1.05, pad=-20)
plt.savefig('speed_benchmark/plot_with_info.pdf', bbox_inches='tight')
import json
import time
import torch
import torch.nn as nn
from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose
from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze
# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large.
def get_time(k, fn, info_dict):
for _ in range(repeat // 2):
fn()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
fn()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info_dict[k] = ms
if __name__ == '__main__':
torch.manual_seed(0)
wm = 4
for dim in [1024, 1280, 1408, 1664, 2048, 4096]:
# note "batch_size" is actually "batch_size * embed_dim", which is why it's large
for batch_size in [256*32, 256*64, 256*128, 256*256, 256*512]:
# switch switches dim_in and dim_out
for switch in [False, True]:
# hparams
repeat = 64
batch_size = batch_size
dim_out = dim * wm
dim_in = dim
if switch:
dim_out = dim
dim_in = wm * dim
dim_in = round(dim_in)
dim_out = round(dim_out)
# simulate forward pass
x = torch.randn(batch_size, dim_in, dtype=torch.float16).cuda()
g = torch.randn(batch_size, dim_out, dtype=torch.float16).cuda()
w = torch.randn(dim_out, dim_in, dtype=torch.float16).cuda()
x_int8 = x.clone().to(torch.int8)
g_int8 = g.clone().to(torch.int8)
w_int8 = w.clone().to(torch.int8)
wt_int8 = w.t().contiguous().clone().to(torch.int8)
state_x_rowwise = x.max(dim=1)[0]
state_g_rowwise = g.max(dim=1)[0]
state_w_columnwise = w.max(dim=0)[0]
state_w_rowwise = w.max(dim=1)[0]
state_w_global = w.max()
info = {'repeat' : repeat, 'batch_size' : batch_size, 'dim_out' : dim_out, 'dim_in' : dim_in, 'wm' : wm, 'switch' : switch}
get_time('standard_fwd', lambda : x.matmul(w.t()), info)
get_time('standard_gw', lambda : g.t().matmul(x), info)
get_time('standard_gx', lambda : g.matmul(w), info)
get_time('rowwise_fwd', lambda : int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise, None), info)
get_time('rowwise_bwd', lambda : int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise, None), info)
get_time('global_fwd', lambda : int8_matmul_mixed_dequanitze(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('global_bwd', lambda : int8_matmul_mixed_dequanitze(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('x_quantize_rowwise', lambda : quantize_rowwise(x), info)
get_time('g_quantize_rowwise', lambda : quantize_rowwise(g), info)
get_time('w_quantize_rowwise', lambda : quantize_rowwise(w), info)
get_time('w_quantize_colwise_transpose', lambda : quantize_columnwise_and_transpose(w), info)
get_time('w_quantize_global', lambda : quantize_global(w), info)
get_time('w_quantize_global_transpose', lambda : quantize_global_transpose(w), info)
time_standard = info['standard_fwd'] + info['standard_gx'] + info['standard_gw']
time_rowwise = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_colwise_transpose'] + info['w_quantize_rowwise'] + info['standard_gw'] + info['rowwise_fwd'] + info['rowwise_bwd']
time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd']
print('TOTAL STANDARD', time_standard)
print('TOTAL ROWWISE', time_rowwise)
print('TOTAL GLOBAL', time_global)
print('speedup', -100*(time_global - time_standard)/time_standard)
info['time_standard'] = time_standard
info['time_rowwise'] = time_rowwise
info['time_global'] = time_global
info_json = json.dumps(info)
# TODO: change this to what you want.
with open("speed_benchmark/info.jsonl", "a") as file:
file.write(info_json + "\n")
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from . import cuda_setup, utils from . import cuda_setup, utils, research
from .autograd._functions import ( from .autograd._functions import (
MatmulLtState, MatmulLtState,
bmm_cublas, bmm_cublas,
......
import os import os
import sys import sys
import shlex
import subprocess
from warnings import warn from warnings import warn
from typing import Tuple
from os.path import isdir
import torch import torch
HEADER_WIDTH = 60 HEADER_WIDTH = 60
def execute_and_return(command_string: str) -> Tuple[str, str]:
def _decode(subprocess_err_out_tuple):
return tuple(
to_decode.decode("UTF-8").strip()
for to_decode in subprocess_err_out_tuple
)
def execute_and_return_decoded_std_streams(command_string):
return _decode(
subprocess.Popen(
shlex.split(command_string),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
).communicate()
)
std_out, std_err = execute_and_return_decoded_std_streams(command_string)
return std_out, std_err
def find_file_recursive(folder, filename):
cmd = f'find {folder} -name {filename}'
out, err = execute_and_return(cmd)
if len(err) > 0:
raise RuntimeError('Something when wrong when trying to find file. Maybe you do not have a linux system?')
return out
def generate_bug_report_information():
print_header("")
print_header("BUG REPORT INFORMATION")
print_header("")
print('')
if 'CONDA_PREFIX' in os.environ:
paths = find_file_recursive(os.environ['CONDA_PREFIX'], '*cuda*so')
print_header("ANACONDA CUDA PATHS")
print(paths)
print('')
if isdir('/usr/local/'):
paths = find_file_recursive('/usr/local', '*cuda*so')
print_header("/usr/local CUDA PATHS")
print(paths)
print('')
if isdir(os.getcwd()):
paths = find_file_recursive(os.getcwd(), '*cuda*so')
print_header("WORKING DIRECTORY CUDA PATHS")
print(paths)
print('')
print_header("LD_LIBRARY CUDA PATHS")
lib_path = os.environ['LD_LIBRARY_PATH'].strip()
for path in set(lib_path.split(':')):
try:
if isdir(path):
print_header(f"{path} CUDA PATHS")
paths = find_file_recursive(path, '*cuda*so')
print(paths)
except:
print(f'Could not read LD_LIBRARY_PATH: {path}')
print('')
def print_header( def print_header(
txt: str, width: int = HEADER_WIDTH, filler: str = "+" txt: str, width: int = HEADER_WIDTH, filler: str = "+"
...@@ -21,28 +92,16 @@ def print_debug_info() -> None: ...@@ -21,28 +92,16 @@ def print_debug_info() -> None:
) )
print_header("") generate_bug_report_information()
print_header("DEBUG INFORMATION")
print_header("")
print()
from . import COMPILED_WITH_CUDA, PACKAGE_GITHUB_URL from . import COMPILED_WITH_CUDA, PACKAGE_GITHUB_URL
from .cuda_setup.env_vars import to_be_ignored from .cuda_setup.env_vars import to_be_ignored
from .cuda_setup.main import get_compute_capabilities, get_cuda_lib_handle from .cuda_setup.main import get_compute_capabilities, get_cuda_lib_handle
print_header("POTENTIALLY LIBRARY-PATH-LIKE ENV VARS")
for k, v in os.environ.items():
if "/" in v and not to_be_ignored(k, v):
print(f"'{k}': '{v}'")
print_header("")
print(
"\nWARNING: Please be sure to sanitize sensible info from any such env vars!\n"
)
print_header("OTHER") print_header("OTHER")
print(f"{COMPILED_WITH_CUDA = }") print(f"COMPILED_WITH_CUDA = {COMPILED_WITH_CUDA}")
cuda = get_cuda_lib_handle() cuda = get_cuda_lib_handle()
print(f"COMPUTE_CAPABILITIES_PER_GPU = {get_compute_capabilities(cuda)}") print(f"COMPUTE_CAPABILITIES_PER_GPU = {get_compute_capabilities(cuda)}")
print_header("") print_header("")
...@@ -55,6 +114,7 @@ Running a quick check that: ...@@ -55,6 +114,7 @@ Running a quick check that:
+ CUDA function is callable + CUDA function is callable
""" """
) )
print("\nWARNING: Please be sure to sanitize sensible info from any such env vars!\n")
try: try:
from bitsandbytes.optim import Adam from bitsandbytes.optim import Adam
...@@ -91,3 +151,4 @@ except Exception as e: ...@@ -91,3 +151,4 @@ except Exception as e:
print(e) print(e)
print_debug_info() print_debug_info()
sys.exit(1) sys.exit(1)
...@@ -221,9 +221,20 @@ bmm_cublas = MatMul8bit.apply ...@@ -221,9 +221,20 @@ bmm_cublas = MatMul8bit.apply
matmul_cublas = MatMul8bit.apply matmul_cublas = MatMul8bit.apply
def supports_igemmlt(device: torch.device) -> bool:
"""check if this device supports the optimized int8 kernel"""
if torch.cuda.get_device_capability(device=device) < (7, 5):
return False
device_name = torch.cuda.get_device_name(device=device)
nvidia16_models = ('GTX 1630', 'GTX 1650', 'GTX 1660') # https://en.wikipedia.org/wiki/GeForce_16_series
if any(model_name in device_name for model_name in nvidia16_models):
return False # these devices are technically cuda 7.5-capable, but they lack tensor cores
return True
@dataclass @dataclass
class MatmulLtState: class MatmulLtState:
tile_indices: Optional[torch.Tensor] = None _tile_indices: Optional[torch.Tensor] = None
force_no_igemmlt: bool = False force_no_igemmlt: bool = False
CB = None CB = None
CxB = None CxB = None
...@@ -263,6 +274,15 @@ class MatmulLtState: ...@@ -263,6 +274,15 @@ class MatmulLtState:
), f"please find this assert and manually enter tile size for {self.formatB}" ), f"please find this assert and manually enter tile size for {self.formatB}"
return (8, 32) if self.formatB == "col_turing" else (32, 32) return (8, 32) if self.formatB == "col_turing" else (32, 32)
@property
def tile_indices(self):
if self._tile_indices is None:
device = self.CxB.device
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=self.formatB)[0].to(x.device)
with torch.no_grad():
self._tile_indices = get_inverse_transform_indices(transform, self.get_tile_size()).to(device)
return self._tile_indices
class MatMul8bitLt(torch.autograd.Function): class MatMul8bitLt(torch.autograd.Function):
# forward is the same, but we added the fallback for pre-turing GPUs # forward is the same, but we added the fallback for pre-turing GPUs
...@@ -270,7 +290,7 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -270,7 +290,7 @@ class MatMul8bitLt(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
using_igemmlt = torch.cuda.get_device_capability(device=A.device) >= (7, 5) and not state.force_no_igemmlt using_igemmlt = supports_igemmlt(A.device) and not state.force_no_igemmlt
# default of pytorch behavior if inputs are empty # default of pytorch behavior if inputs are empty
ctx.is_empty = False ctx.is_empty = False
if prod(A.shape) == 0: if prod(A.shape) == 0:
...@@ -456,13 +476,6 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -456,13 +476,6 @@ class MatMul8bitLt(torch.autograd.Function):
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
elif state.CxB is not None: elif state.CxB is not None:
if state.tile_indices is None:
order, tile_size = state.formatB, state.get_tile_size()
transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device)
with torch.no_grad():
state.tile_indices = get_inverse_transform_indices(transform, tile_size).to(state.CxB.device)
CB = ( CB = (
undo_layout(state.CxB, state.tile_indices) undo_layout(state.CxB, state.tile_indices)
.to(ctx.dtype_A) .to(ctx.dtype_A)
......
...@@ -9,10 +9,8 @@ from bitsandbytes.cuda_setup.main import CUDASetup ...@@ -9,10 +9,8 @@ from bitsandbytes.cuda_setup.main import CUDASetup
setup = CUDASetup.get_instance() setup = CUDASetup.get_instance()
if not setup.initialized: if setup.initialized != True:
setup.run_cuda_setup() setup.run_cuda_setup()
if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0':
setup.print_log_stack()
lib = setup.lib lib = setup.lib
try: try:
...@@ -20,15 +18,25 @@ try: ...@@ -20,15 +18,25 @@ try:
CUDASetup.get_instance().generate_instructions() CUDASetup.get_instance().generate_instructions()
CUDASetup.get_instance().print_log_stack() CUDASetup.get_instance().print_log_stack()
raise RuntimeError(''' raise RuntimeError('''
CUDA Setup failed despite GPU being available. Inspect the CUDA SETUP outputs above to fix your environment! CUDA Setup failed despite GPU being available. Please run the following command to get more information:
If you cannot find any issues and suspect a bug, please open an issue with detals about your environment:
https://github.com/TimDettmers/bitsandbytes/issues''') python -m bitsandbytes
lib.cadam_8bit_blockwise_fp32
Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them
to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes
and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues''')
lib.cadam32bit_grad_fp32 # runs on an error if the library could not be found -> COMPILED_WITH_CUDA=False
lib.get_context.restype = ct.c_void_p lib.get_context.restype = ct.c_void_p
lib.get_cusparse.restype = ct.c_void_p lib.get_cusparse.restype = ct.c_void_p
lib.cget_managed_ptr.restype = ct.c_void_p lib.cget_managed_ptr.restype = ct.c_void_p
COMPILED_WITH_CUDA = True COMPILED_WITH_CUDA = True
except AttributeError: except AttributeError as ex:
warn("The installed version of bitsandbytes was compiled without GPU support. " warn("The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers and GPU quantization are unavailable.") "8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.")
COMPILED_WITH_CUDA = False COMPILED_WITH_CUDA = False
print(str(ex))
# print the setup details after checking for errors so we do not print twice
if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0':
setup.print_log_stack()
...@@ -11,6 +11,7 @@ def to_be_ignored(env_var: str, value: str) -> bool: ...@@ -11,6 +11,7 @@ def to_be_ignored(env_var: str, value: str) -> bool:
"HOME", # Linux shell default "HOME", # Linux shell default
"TMUX", # Terminal Multiplexer "TMUX", # Terminal Multiplexer
"XDG_DATA_DIRS", # XDG: Desktop environment stuff "XDG_DATA_DIRS", # XDG: Desktop environment stuff
"XDG_GREETER_DATA_DIR", # XDG: Desktop environment stuff
"XDG_RUNTIME_DIR", "XDG_RUNTIME_DIR",
"MAIL", # something related to emails "MAIL", # something related to emails
"SHELL", # binary for currently invoked shell "SHELL", # binary for currently invoked shell
......
...@@ -21,12 +21,21 @@ import os ...@@ -21,12 +21,21 @@ import os
import errno import errno
import torch import torch
from warnings import warn from warnings import warn
from itertools import product
from pathlib import Path from pathlib import Path
from typing import Set, Union from typing import Set, Union
from .env_vars import get_potentially_lib_path_containing_env_vars from .env_vars import get_potentially_lib_path_containing_env_vars
CUDA_RUNTIME_LIB: str = "libcudart.so" # these are the most common libs names
# libcudart.so is missing by default for a conda install with PyTorch 2.0 and instead
# we have libcudart.so.11.0 which causes a lot of errors before
# not sure if libcudart.so.12.0 exists in pytorch installs, but it does not hurt
CUDA_RUNTIME_LIBS: list = ["libcudart.so", 'libcudart.so.11.0', 'libcudart.so.12.0']
# this is a order list of backup paths to search CUDA in, if it cannot be found in the main environmental paths
backup_paths = []
backup_paths.append('$CONDA_PREFIX/lib/libcudart.so.11.0')
class CUDASetup: class CUDASetup:
_instance = None _instance = None
...@@ -102,6 +111,8 @@ class CUDASetup: ...@@ -102,6 +111,8 @@ class CUDASetup:
package_dir = Path(__file__).parent.parent package_dir = Path(__file__).parent.parent
binary_path = package_dir / binary_name binary_path = package_dir / binary_name
print('bin', binary_path)
try: try:
if not binary_path.exists(): if not binary_path.exists():
self.add_log_entry(f"CUDA SETUP: Required library version not found: {binary_name}. Maybe you need to compile it from source?") self.add_log_entry(f"CUDA SETUP: Required library version not found: {binary_name}. Maybe you need to compile it from source?")
...@@ -121,7 +132,6 @@ class CUDASetup: ...@@ -121,7 +132,6 @@ class CUDASetup:
self.add_log_entry('='*80) self.add_log_entry('='*80)
self.add_log_entry('') self.add_log_entry('')
self.generate_instructions() self.generate_instructions()
self.print_log_stack()
raise Exception('CUDA SETUP: Setup Failed!') raise Exception('CUDA SETUP: Setup Failed!')
self.lib = ct.cdll.LoadLibrary(binary_path) self.lib = ct.cdll.LoadLibrary(binary_path)
else: else:
...@@ -129,7 +139,6 @@ class CUDASetup: ...@@ -129,7 +139,6 @@ class CUDASetup:
self.lib = ct.cdll.LoadLibrary(binary_path) self.lib = ct.cdll.LoadLibrary(binary_path)
except Exception as ex: except Exception as ex:
self.add_log_entry(str(ex)) self.add_log_entry(str(ex))
self.print_log_stack()
def add_log_entry(self, msg, is_warning=False): def add_log_entry(self, msg, is_warning=False):
self.cuda_setup_log.append((msg, is_warning)) self.cuda_setup_log.append((msg, is_warning))
...@@ -154,7 +163,7 @@ def is_cublasLt_compatible(cc): ...@@ -154,7 +163,7 @@ def is_cublasLt_compatible(cc):
if cc is not None: if cc is not None:
cc_major, cc_minor = cc.split('.') cc_major, cc_minor = cc.split('.')
if int(cc_major) < 7 or (int(cc_major) == 7 and int(cc_minor) < 5): if int(cc_major) < 7 or (int(cc_major) == 7 and int(cc_minor) < 5):
cuda_setup.add_log_entry("WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU!", is_warning=True) CUDASetup.get_instance().add_log_entry("WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU!", is_warning=True)
else: else:
has_cublaslt = True has_cublaslt = True
return has_cublaslt return has_cublaslt
...@@ -182,11 +191,12 @@ def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]: ...@@ -182,11 +191,12 @@ def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]:
def get_cuda_runtime_lib_paths(candidate_paths: Set[Path]) -> Set[Path]: def get_cuda_runtime_lib_paths(candidate_paths: Set[Path]) -> Set[Path]:
return { paths = set()
path / CUDA_RUNTIME_LIB for libname in CUDA_RUNTIME_LIBS:
for path in candidate_paths for path in candidate_paths:
if (path / CUDA_RUNTIME_LIB).is_file() if (path / libname).is_file():
} paths.add(path / libname)
return paths
def resolve_paths_list(paths_list_candidate: str) -> Set[Path]: def resolve_paths_list(paths_list_candidate: str) -> Set[Path]:
...@@ -206,12 +216,12 @@ def find_cuda_lib_in(paths_list_candidate: str) -> Set[Path]: ...@@ -206,12 +216,12 @@ def find_cuda_lib_in(paths_list_candidate: str) -> Set[Path]:
def warn_in_case_of_duplicates(results_paths: Set[Path]) -> None: def warn_in_case_of_duplicates(results_paths: Set[Path]) -> None:
if len(results_paths) > 1: if len(results_paths) > 1:
warning_msg = ( warning_msg = (
f"Found duplicate {CUDA_RUNTIME_LIB} files: {results_paths}.. " f"Found duplicate {CUDA_RUNTIME_LIBS} files: {results_paths}.. "
"We'll flip a coin and try one of these, in order to fail forward.\n" "We'll flip a coin and try one of these, in order to fail forward.\n"
"Either way, this might cause trouble in the future:\n" "Either way, this might cause trouble in the future:\n"
"If you get `CUDA error: invalid device function` errors, the above " "If you get `CUDA error: invalid device function` errors, the above "
"might be the cause and the solution is to make sure only one " "might be the cause and the solution is to make sure only one "
f"{CUDA_RUNTIME_LIB} in the paths that we search based on your env.") f"{CUDA_RUNTIME_LIBS} in the paths that we search based on your env.")
CUDASetup.get_instance().add_log_entry(warning_msg, is_warning=True) CUDASetup.get_instance().add_log_entry(warning_msg, is_warning=True)
...@@ -239,7 +249,7 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]: ...@@ -239,7 +249,7 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]:
return next(iter(conda_cuda_libs)) return next(iter(conda_cuda_libs))
CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["CONDA_PREFIX"]} did not contain ' CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["CONDA_PREFIX"]} did not contain '
f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...', is_warning=True) f'{CUDA_RUNTIME_LIBS} as expected! Searching further paths...', is_warning=True)
if "LD_LIBRARY_PATH" in candidate_env_vars: if "LD_LIBRARY_PATH" in candidate_env_vars:
lib_ld_cuda_libs = find_cuda_lib_in(candidate_env_vars["LD_LIBRARY_PATH"]) lib_ld_cuda_libs = find_cuda_lib_in(candidate_env_vars["LD_LIBRARY_PATH"])
...@@ -249,7 +259,7 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]: ...@@ -249,7 +259,7 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]:
warn_in_case_of_duplicates(lib_ld_cuda_libs) warn_in_case_of_duplicates(lib_ld_cuda_libs)
CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["LD_LIBRARY_PATH"]} did not contain ' CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["LD_LIBRARY_PATH"]} did not contain '
f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...', is_warning=True) f'{CUDA_RUNTIME_LIBS} as expected! Searching further paths...', is_warning=True)
remaining_candidate_env_vars = { remaining_candidate_env_vars = {
env_var: value for env_var, value in candidate_env_vars.items() env_var: value for env_var, value in candidate_env_vars.items()
...@@ -261,7 +271,7 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]: ...@@ -261,7 +271,7 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]:
cuda_runtime_libs.update(find_cuda_lib_in(value)) cuda_runtime_libs.update(find_cuda_lib_in(value))
if len(cuda_runtime_libs) == 0: if len(cuda_runtime_libs) == 0:
CUDASetup.get_instance().add_log_entry('CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching /usr/local/cuda/lib64...') CUDASetup.get_instance().add_log_entry('CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching in backup paths...')
cuda_runtime_libs.update(find_cuda_lib_in('/usr/local/cuda/lib64')) cuda_runtime_libs.update(find_cuda_lib_in('/usr/local/cuda/lib64'))
warn_in_case_of_duplicates(cuda_runtime_libs) warn_in_case_of_duplicates(cuda_runtime_libs)
...@@ -367,9 +377,10 @@ def evaluate_cuda_setup(): ...@@ -367,9 +377,10 @@ def evaluate_cuda_setup():
if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0': if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0':
print('') print('')
print('='*35 + 'BUG REPORT' + '='*35) print('='*35 + 'BUG REPORT' + '='*35)
print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues') print(('Welcome to bitsandbytes. For bug reports, please run\n\npython -m bitsandbytes\n\n'),
('and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues'))
print('='*80) print('='*80)
if not torch.cuda.is_available(): return 'libsbitsandbytes_cpu.so', None, None, None, None if not torch.cuda.is_available(): return 'libbitsandbytes_cpu.so', None, None, None, None
cuda_setup = CUDASetup.get_instance() cuda_setup = CUDASetup.get_instance()
cudart_path = determine_cuda_runtime_lib_path() cudart_path = determine_cuda_runtime_lib_path()
......
...@@ -28,59 +28,71 @@ name2qmap = {} ...@@ -28,59 +28,71 @@ name2qmap = {}
if COMPILED_WITH_CUDA: if COMPILED_WITH_CUDA:
"""C FUNCTIONS FOR OPTIMIZERS""" """C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer32bit = {} str2optimizer32bit = {}
str2optimizer32bit["adam"] = (lib.cadam32bit_gfp32, lib.cadam32bit_gfp16, lib.cadam32bit_gbf16) str2optimizer32bit["adam"] = (lib.cadam32bit_grad_fp32, lib.cadam32bit_grad_fp16, lib.cadam32bit_grad_bf16)
str2optimizer32bit["momentum"] = ( str2optimizer32bit["momentum"] = (
lib.cmomentum32bit_g32, lib.cmomentum32bit_grad_32,
lib.cmomentum32bit_g16, lib.cmomentum32bit_grad_16,
) )
str2optimizer32bit["rmsprop"] = ( str2optimizer32bit["rmsprop"] = (
lib.crmsprop32bit_g32, lib.crmsprop32bit_grad_32,
lib.crmsprop32bit_g16, lib.crmsprop32bit_grad_16,
)
str2optimizer32bit["lion"] = (
lib.clion32bit_grad_32,
lib.clion32bit_grad_16,
) )
str2optimizer32bit["adagrad"] = ( str2optimizer32bit["adagrad"] = (
lib.cadagrad32bit_g32, lib.cadagrad32bit_grad_32,
lib.cadagrad32bit_g16, lib.cadagrad32bit_grad_16,
) )
str2optimizer8bit = {} str2optimizer8bit = {}
str2optimizer8bit["adam"] = ( str2optimizer8bit["adam"] = (
lib.cadam_static_8bit_g32, lib.cadam_static_8bit_grad_32,
lib.cadam_static_8bit_g16, lib.cadam_static_8bit_grad_16,
) )
str2optimizer8bit["momentum"] = ( str2optimizer8bit["momentum"] = (
lib.cmomentum_static_8bit_g32, lib.cmomentum_static_8bit_grad_32,
lib.cmomentum_static_8bit_g16, lib.cmomentum_static_8bit_grad_16,
) )
str2optimizer8bit["rmsprop"] = ( str2optimizer8bit["rmsprop"] = (
lib.crmsprop_static_8bit_g32, lib.crmsprop_static_8bit_grad_32,
lib.crmsprop_static_8bit_g16, lib.crmsprop_static_8bit_grad_16,
)
str2optimizer8bit["lion"] = (
lib.clion_static_8bit_grad_32,
lib.clion_static_8bit_grad_16,
) )
str2optimizer8bit["lamb"] = ( str2optimizer8bit["lamb"] = (
lib.cadam_static_8bit_g32, lib.cadam_static_8bit_grad_32,
lib.cadam_static_8bit_g16, lib.cadam_static_8bit_grad_16,
) )
str2optimizer8bit["lars"] = ( str2optimizer8bit["lars"] = (
lib.cmomentum_static_8bit_g32, lib.cmomentum_static_8bit_grad_32,
lib.cmomentum_static_8bit_g16, lib.cmomentum_static_8bit_grad_16,
) )
str2optimizer8bit_blockwise = {} str2optimizer8bit_blockwise = {}
str2optimizer8bit_blockwise["adam"] = ( str2optimizer8bit_blockwise["adam"] = (
lib.cadam_8bit_blockwise_fp32, lib.cadam_8bit_blockwise_grad_fp32,
lib.cadam_8bit_blockwise_fp16, lib.cadam_8bit_blockwise_grad_fp16,
lib.cadam_8bit_blockwise_bf16, lib.cadam_8bit_blockwise_grad_bf16,
) )
str2optimizer8bit_blockwise["momentum"] = ( str2optimizer8bit_blockwise["momentum"] = (
lib.cmomentum_8bit_blockwise_fp32, lib.cmomentum_8bit_blockwise_grad_fp32,
lib.cmomentum_8bit_blockwise_fp16, lib.cmomentum_8bit_blockwise_grad_fp16,
) )
str2optimizer8bit_blockwise["rmsprop"] = ( str2optimizer8bit_blockwise["rmsprop"] = (
lib.crmsprop_8bit_blockwise_fp32, lib.crmsprop_8bit_blockwise_grad_fp32,
lib.crmsprop_8bit_blockwise_fp16, lib.crmsprop_8bit_blockwise_grad_fp16,
)
str2optimizer8bit_blockwise["lion"] = (
lib.clion_8bit_blockwise_grad_fp32,
lib.clion_8bit_blockwise_grad_fp16,
) )
str2optimizer8bit_blockwise["adagrad"] = ( str2optimizer8bit_blockwise["adagrad"] = (
lib.cadagrad_8bit_blockwise_fp32, lib.cadagrad_8bit_blockwise_grad_fp32,
lib.cadagrad_8bit_blockwise_fp16, lib.cadagrad_8bit_blockwise_grad_fp16,
) )
class GlobalPageManager: class GlobalPageManager:
...@@ -327,7 +339,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) ...@@ -327,7 +339,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
values = [] values = []
lst = list(itertools.product([0, 1], repeat=precision_bits)) lst = list(itertools.product([0, 1], repeat=precision_bits))
#for ev in evalues: #for ev in evalues:
bias = 2**(exponent_bits-1)-1 bias = 2**(exponent_bits-1)
for evalue in range(2**(exponent_bits)): for evalue in range(2**(exponent_bits)):
for bit_pattern in lst: for bit_pattern in lst:
value = (1 if evalue != 0 else 0) value = (1 if evalue != 0 else 0)
...@@ -335,10 +347,10 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) ...@@ -335,10 +347,10 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
value += pval*(2**-(i+1)) value += pval*(2**-(i+1))
if evalue == 0: if evalue == 0:
# subnormals # subnormals
value = value*2**-(bias-1) value = value*2**-(bias)
else: else:
# normals # normals
value = value*2**-(evalue-bias-2) value = value*2**-(evalue-bias-1)
values.append(value) values.append(value)
if signed: if signed:
values.append(-value) values.append(-value)
...@@ -624,7 +636,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n ...@@ -624,7 +636,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
return out return out
def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor: def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor:
""" """
Quantize tensor A in blocks of size 4096 values. Quantize tensor A in blocks of size 4096 values.
...@@ -640,8 +652,6 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra ...@@ -640,8 +652,6 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
The quantization map. The quantization map.
absmax : torch.Tensor absmax : torch.Tensor
The absmax values. The absmax values.
rand : torch.Tensor
The tensor for stochastic rounding.
out : torch.Tensor out : torch.Tensor
The output tensor (8-bit). The output tensor (8-bit).
...@@ -673,30 +683,17 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra ...@@ -673,30 +683,17 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
cblocksize = ct.c_int32(blocksize) cblocksize = ct.c_int32(blocksize)
prev_device = pre_call(A.device) prev_device = pre_call(A.device)
code = code.to(A.device) code = code.to(A.device)
if rand is not None: is_on_gpu([code, A, out, absmax])
is_on_gpu([code, A, out, absmax, rand]) if A.dtype == torch.float32:
assert blocksize==4096 lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
assert rand.numel() >= 1024 elif A.dtype == torch.float16:
rand_offset = random.randint(0, 1023) lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
if A.dtype == torch.float32:
lib.cquantize_blockwise_stochastic_fp32(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
elif A.dtype == torch.float16:
lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
else: else:
is_on_gpu([code, A, out, absmax]) raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
if A.dtype == torch.float32:
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
elif A.dtype == torch.float16:
lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
post_call(A.device) post_call(A.device)
else: else:
# cpu # cpu
code = code.cpu() code = code.cpu()
assert rand is None
lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
if nested: if nested:
...@@ -754,13 +751,16 @@ def dequantize_blockwise( ...@@ -754,13 +751,16 @@ def dequantize_blockwise(
if out is None: if out is None:
out = torch.zeros_like(A, dtype=torch.float32) out = torch.zeros_like(A, dtype=torch.float32)
if quant_state is None: if quant_state is None:
quant_state = (absmax, code, blocksize) quant_state = (absmax, code, blocksize)
assert absmax is not None and out is not None
else: else:
absmax, code, blocksize, nested, offset, state2 = quant_state absmax, code, blocksize, nested, offset, state2 = quant_state
if nested: if nested:
absmax = dequantize_blockwise(absmax, state2) absmax = dequantize_blockwise(absmax, state2)
absmax += offset absmax += offset
if A.device.type != 'cpu': if A.device.type != 'cpu':
device = pre_call(A.device) device = pre_call(A.device)
...@@ -994,9 +994,11 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: ...@@ -994,9 +994,11 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
torch.Tensor: torch.Tensor:
Quantized 8-bit tensor. Quantized 8-bit tensor.
''' '''
prev_device = pre_call(A.device)
if out is None: out = torch.zeros_like(A, dtype=torch.uint8) if out is None: out = torch.zeros_like(A, dtype=torch.uint8)
is_on_gpu([A, out]) is_on_gpu([A, out])
lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
post_call(prev_device)
return out return out
...@@ -1021,9 +1023,11 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: ...@@ -1021,9 +1023,11 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
torch.Tensor: torch.Tensor:
32-bit output tensor. 32-bit output tensor.
''' '''
prev_device = pre_call(A.device)
if out is None: out = torch.zeros_like(A, dtype=torch.float32) if out is None: out = torch.zeros_like(A, dtype=torch.float32)
is_on_gpu([code, A, out]) is_on_gpu([code, A, out])
lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
post_call(prev_device)
return out return out
...@@ -1196,6 +1200,8 @@ def optimizer_update_8bit( ...@@ -1196,6 +1200,8 @@ def optimizer_update_8bit(
if max_unorm > 0.0: if max_unorm > 0.0:
param_norm = torch.norm(p.data.float()) param_norm = torch.norm(p.data.float())
prev_device = pre_call(g.device)
is_on_gpu([g, p, state1, state2, unorm_vec, qmap1, qmap2, max1, max2, new_max1, new_max2])
if g.dtype == torch.float32 and state1.dtype == torch.uint8: if g.dtype == torch.float32 and state1.dtype == torch.uint8:
str2optimizer8bit[optimizer_name][0]( str2optimizer8bit[optimizer_name][0](
get_ptr(p), get_ptr(p),
...@@ -1248,6 +1254,7 @@ def optimizer_update_8bit( ...@@ -1248,6 +1254,7 @@ def optimizer_update_8bit(
raise ValueError( raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
) )
post_call(prev_device)
def optimizer_update_8bit_blockwise( def optimizer_update_8bit_blockwise(
...@@ -1271,6 +1278,8 @@ def optimizer_update_8bit_blockwise( ...@@ -1271,6 +1278,8 @@ def optimizer_update_8bit_blockwise(
) -> None: ) -> None:
optim_func = None optim_func = None
prev_device = pre_call(g.device)
is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2])
if g.dtype == torch.float32 and state1.dtype == torch.uint8: if g.dtype == torch.float32 and state1.dtype == torch.uint8:
optim_func = str2optimizer8bit_blockwise[optimizer_name][0] optim_func = str2optimizer8bit_blockwise[optimizer_name][0]
elif g.dtype == torch.float16 and state1.dtype == torch.uint8: elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
...@@ -1282,6 +1291,7 @@ def optimizer_update_8bit_blockwise( ...@@ -1282,6 +1291,7 @@ def optimizer_update_8bit_blockwise(
raise ValueError( raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
) )
post_call(prev_device)
is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2]) is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2])
...@@ -1320,6 +1330,7 @@ def percentile_clipping( ...@@ -1320,6 +1330,7 @@ def percentile_clipping(
The current optimiation steps (number of past gradient norms). The current optimiation steps (number of past gradient norms).
""" """
prev_device = pre_call(grad.device)
is_on_gpu([grad, gnorm_vec]) is_on_gpu([grad, gnorm_vec])
if grad.dtype == torch.float32: if grad.dtype == torch.float32:
lib.cpercentile_clipping_g32( lib.cpercentile_clipping_g32(
...@@ -1337,6 +1348,7 @@ def percentile_clipping( ...@@ -1337,6 +1348,7 @@ def percentile_clipping(
) )
else: else:
raise ValueError(f"Gradient type {grad.dtype} not supported!") raise ValueError(f"Gradient type {grad.dtype} not supported!")
post_call(prev_device)
current_gnorm = torch.sqrt(gnorm_vec[step % 100]) current_gnorm = torch.sqrt(gnorm_vec[step % 100])
vals, idx = torch.sort(gnorm_vec) vals, idx = torch.sort(gnorm_vec)
...@@ -2210,6 +2222,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): ...@@ -2210,6 +2222,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
(cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype (cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype
) )
nnz = cooA.nnz nnz = cooA.nnz
prev_device = pre_call(B.device)
assert cooA.rowidx.numel() == nnz assert cooA.rowidx.numel() == nnz
assert cooA.colidx.numel() == nnz assert cooA.colidx.numel() == nnz
assert cooA.values.numel() == nnz assert cooA.values.numel() == nnz
...@@ -2284,6 +2297,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): ...@@ -2284,6 +2297,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
ccolsB, ccolsB,
) )
# else: assertion error # else: assertion error
post_call(prev_device)
return out return out
......
...@@ -2,4 +2,5 @@ ...@@ -2,4 +2,5 @@
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from .modules import Int8Params, Linear8bitLt, StableEmbedding, Linear4bit, LinearNF4, LinearFP4, Params4bit from .modules import Int8Params, Linear8bitLt, StableEmbedding, Linear4bit, LinearNF4, LinearFP4, Params4bit, OutlierAwareLinear, SwitchBackLinearBnb
from .triton_based_modules import SwitchBackLinear, SwitchBackLinearGlobal, SwitchBackLinearVectorwise, StandardLinear
...@@ -9,7 +9,10 @@ import torch.nn.functional as F ...@@ -9,7 +9,10 @@ import torch.nn.functional as F
from torch import Tensor, device, dtype, nn from torch import Tensor, device, dtype, nn
import bitsandbytes as bnb import bitsandbytes as bnb
import bitsandbytes.functional
from bitsandbytes.autograd._functions import get_inverse_transform_indices, undo_layout
from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import OutlierTracer, find_outlier_dims
T = TypeVar("T", bound="torch.nn.Module") T = TypeVar("T", bound="torch.nn.Module")
...@@ -320,6 +323,53 @@ class Linear8bitLt(nn.Linear): ...@@ -320,6 +323,53 @@ class Linear8bitLt(nn.Linear):
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights) self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
def _save_to_state_dict(self, destination, prefix, keep_vars):
if not self.state.has_fp16_weights and self.state.CB is None and self.state.CxB is not None:
# reorder weight layout back from ampere/turing to row
reorder_layout = True
weight_clone = self.weight.data.clone()
else:
reorder_layout = False
try:
if reorder_layout:
self.weight.data = undo_layout(self.state.CxB, self.state.tile_indices)
super()._save_to_state_dict(destination, prefix, keep_vars)
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
weight_name = "SCB"
# case 1: .cuda was called, SCB is in self.weight
param_from_weight = getattr(self.weight, weight_name)
# case 2: self.init_8bit_state was called, SCB is in self.state
param_from_state = getattr(self.state, weight_name)
key_name = prefix + f"{weight_name}"
if param_from_weight is not None:
destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach()
elif not self.state.has_fp16_weights and param_from_state is not None:
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
finally:
if reorder_layout:
self.weight.data = weight_clone
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs)
for key in unexpected_keys:
input_name = key[len(prefix):]
if input_name == "SCB":
if self.weight.SCB is None:
# buffers not yet initialized, can't call them directly without
raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is "
"not supported. Please call module.cuda() before module.load_state_dict()")
input_param = state_dict[key]
self.weight.SCB.copy_(input_param)
unexpected_keys.remove(key)
def init_8bit_state(self): def init_8bit_state(self):
self.state.CB = self.weight.CB self.state.CB = self.weight.CB
self.state.SCB = self.weight.SCB self.state.SCB = self.weight.SCB
...@@ -336,6 +386,7 @@ class Linear8bitLt(nn.Linear): ...@@ -336,6 +386,7 @@ class Linear8bitLt(nn.Linear):
self.bias.data = self.bias.data.to(x.dtype) self.bias.data = self.bias.data.to(x.dtype)
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
if not self.state.has_fp16_weights: if not self.state.has_fp16_weights:
if self.state.CB is not None and self.state.CxB is not None: if self.state.CB is not None and self.state.CxB is not None:
# we converted 8-bit row major to turing/ampere format in the first inference pass # we converted 8-bit row major to turing/ampere format in the first inference pass
...@@ -343,3 +394,71 @@ class Linear8bitLt(nn.Linear): ...@@ -343,3 +394,71 @@ class Linear8bitLt(nn.Linear):
del self.state.CB del self.state.CB
self.weight.data = self.state.CxB self.weight.data = self.state.CxB
return out return out
class OutlierAwareLinear(nn.Linear):
def __init__(self, input_features, output_features, bias=True):
super().__init__(input_features, output_features, bias)
self.outlier_dim = None
self.is_quantized = False
def forward_with_outliers(self, x, outlier_idx):
raise NotImplementedError('Please override the `forward_with_outliers(self, x, outlier_idx)` function')
def quantize_weight(self, w, outlier_idx):
raise NotImplementedError('Please override the `quantize_weights(self, w, outlier_idx)` function')
def forward(self, x):
if self.outlier_dim is None:
tracer = OutlierTracer.get_instance()
if not tracer.is_initialized():
print('Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer')
outlier_idx = tracer.get_outliers(self.weight)
#print(outlier_idx, tracer.get_hvalue(self.weight))
self.outlier_dim = outlier_idx
if not self.is_quantized:
w = self.quantize_weight(self.weight, self.outlier_dim)
self.weight.data.copy_(w)
self.is_quantized = True
class SwitchBackLinearBnb(nn.Linear):
def __init__(
self,
input_features,
output_features,
bias=True,
has_fp16_weights=True,
memory_efficient_backward=False,
threshold=0.0,
index=None,
):
super().__init__(
input_features, output_features, bias
)
self.state = bnb.MatmulLtState()
self.index = index
self.state.threshold = threshold
self.state.has_fp16_weights = has_fp16_weights
self.state.memory_efficient_backward = memory_efficient_backward
if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True
self.weight = Int8Params(
self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights
)
def init_8bit_state(self):
self.state.CB = self.weight.CB
self.state.SCB = self.weight.SCB
self.weight.CB = None
self.weight.SCB = None
def forward(self, x):
self.state.is_training = self.training
if self.weight.CB is not None:
self.init_8bit_state()
out = bnb.matmul_mixed(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias
import torch
import torch.nn as nn
import time
from functools import partial
from bitsandbytes.triton.triton_utils import is_triton_available
from bitsandbytes.triton.dequantize_rowwise import dequantize_rowwise
from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose
from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze
class _switchback_global(torch.autograd.Function):
@staticmethod
def forward(ctx, X_3D, W, bias):
# reshape input to [N * L, D]
X = X_3D.view(-1, X_3D.size(-1))
# rowwise quantize for X, global quantize for W
X_int8, state_X = quantize_rowwise(X)
W_int8, state_W = quantize_global(W)
# save for backward.
ctx.save_for_backward = X, W
# matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized
return int8_matmul_mixed_dequanitze(
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D.size()[:-1], -1)
@staticmethod
def backward(ctx, G_3D):
# reshape input to [N_out * L, D]
G = G_3D.reshape(-1, G_3D.size(-1))
grad_X = grad_W = grad_bias = None
X, W = ctx.save_for_backward
if ctx.needs_input_grad[0]:
# rowwise quantize for G, global quantize for W
# for W, we also fuse the transpose operation because only A @ B^T is supported
# so we transpose once then call .t() in the matmul
G_int8, state_G = quantize_rowwise(G)
W_int8, state_W = quantize_global_transpose(W)
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D.size()[:-1], -1
)
if ctx.needs_input_grad[1]:
# backward pass uses standard weight grad
grad_W = torch.matmul(G.t(), X.to(G.dtype))
if ctx.needs_input_grad[2]:
grad_bias = G.sum(dim=0)
return grad_X, grad_W, grad_bias
class _switchback_vectorrize(torch.autograd.Function):
@staticmethod
def forward(ctx, X_3D, W, bias):
# reshape input to [N * L, D]
X = X_3D.view(-1, X_3D.size(-1))
ctx.save_for_backward = X, W
# rowwise quantize for X
# columnwise quantize for W (first rowwise, transpose later)
X_int8, state_X = quantize_rowwise(X)
W_int8, state_W = quantize_rowwise(W)
# matmult, fused dequant and add bias
# call kernel which expects rowwise quantized X and W
return int8_matmul_rowwise_dequantize(
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D.size()[:-1], -1)
@staticmethod
def backward(ctx, G_3D):
X, W = ctx.save_for_backward
G = G_3D.reshape(-1, G_3D.size(-1))
grad_X = grad_W = grad_bias = None
if ctx.needs_input_grad[0]:
# rowwise quantize for G, columnwise quantize for W and fused transpose
# we call .t() for weight later because only A @ B^T is supported
G_int8, state_G = quantize_rowwise(G)
W_int8, state_W = quantize_columnwise_and_transpose(W)
grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D.size()[:-1], -1
)
if ctx.needs_input_grad[1]:
# backward pass uses standard weight grad
grad_W = torch.matmul(G.t(), X.to(G.dtype))
if ctx.needs_input_grad[2]:
grad_bias = G.sum(dim=0)
return grad_X, grad_W, grad_bias
class _switchback_global_mem_efficient(torch.autograd.Function):
@staticmethod
def forward(ctx, X_3D, W, bias):
# reshape input to [N * L, D]
X = X_3D.view(-1, X_3D.size(-1))
X_3D_sz = X_3D.size()
# rowwise quantize for X, global quantize for W
X_int8, state_X = quantize_rowwise(X)
del X
W_int8, state_W = quantize_global(W)
# save for backward.
ctx.save_for_backward = X_int8, state_X, W_int8, state_W
# matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized
return int8_matmul_mixed_dequanitze(
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D_sz[:-1], -1)
@staticmethod
def backward(ctx, G_3D):
# reshape input to [N_out * L, D]
G = G_3D.reshape(-1, G_3D.size(-1))
G_3D_sz = G_3D.size()
grad_X = grad_W = grad_bias = None
X_int8, state_X, W_int8, state_W = ctx.save_for_backward
if ctx.needs_input_grad[1]:
real_X = dequantize_rowwise(X_int8, state_X)
del X_int8
grad_W = torch.matmul(G.t(), real_X.to(G.dtype))
del real_X
if ctx.needs_input_grad[2]:
grad_bias = G.sum(dim=0)
if ctx.needs_input_grad[0]:
G_int8, state_G = quantize_rowwise(G)
del G
W_int8 = W_int8.t().contiguous()
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D_sz[:-1], -1
)
return grad_X, grad_W, grad_bias
class SwitchBackLinear(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
vector_wise_quantization: bool = False,
mem_efficient : bool = False,
):
super().__init__(in_features, out_features, bias, device, dtype)
if not is_triton_available:
raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear.
Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''')
# By default, we use the global quantization.
self.vector_wise_quantization = vector_wise_quantization
if self.vector_wise_quantization:
self._fn = _switchback_vectorrize
if mem_efficient:
print('mem efficient is not supported for vector-wise quantization.')
exit(1)
else:
if mem_efficient:
self._fn = _switchback_global_mem_efficient
else:
self._fn = _switchback_global
def prepare_for_eval(self):
# If we just want to do eval, we can pre-quantize the weights instead of doing it on the forward pass.
# Note this is experimental and not tested thoroughly.
# Note this needs to be explicitly called with something like
# def cond_prepare(m):
# if hasattr(m, "prepare_for_eval"):
# m.prepare_for_eval()
# model.apply(cond_prepare)
print('=> preparing for eval.')
if self.vector_wise_quantization:
W_int8, state_W = quantize_rowwise(self.weight)
else:
W_int8, state_W = quantize_global(self.weight)
self.register_buffer("W_int8", W_int8)
self.register_buffer("state_W", state_W)
del self.weight
def forward(self, x):
if self.training:
return self._fn.apply(x, self.weight, self.bias)
else:
# If it hasn't been "prepared for eval", run the standard forward pass.
if not hasattr(self, "W_int8"):
return self._fn.apply(x, self.weight, self.bias)
# Otherwise, use pre-computed weights.
X = x.view(-1, x.size(-1))
X_int8, state_X = quantize_rowwise(X)
if self.vector_wise_quantization:
return int8_matmul_rowwise_dequantize(
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
).view(*x.size()[:-1], -1)
else:
return int8_matmul_mixed_dequanitze(
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
).view(*x.size()[:-1], -1)
SwitchBackLinearGlobal = partial(SwitchBackLinear, vector_wise_quantization=False)
SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vector_wise_quantization=False, mem_efficient=True)
SwitchBackLinearVectorwise = partial(SwitchBackLinear, vector_wise_quantization=True)
# This is just the standard linear function.
class StandardLinearFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias=None):
X = input.view(-1, input.size(-1))
ctx.save_for_backward(X, weight, bias)
output = input.matmul(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output.view(*input.size()[:-1], -1)
@staticmethod
def backward(ctx, grad_output_3D):
input, weight, bias = ctx.saved_tensors
grad_output = grad_output_3D.reshape(-1, grad_output_3D.size(-1))
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.matmul(weight.to(grad_output.dtype)).view(*grad_output_3D.size()[:-1], -1)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().matmul(input.to(grad_output.dtype))
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_bias
class StandardLinear(nn.Linear):
def forward(self, x):
return StandardLinearFunction.apply(x, self.weight, self.bias)
...@@ -12,4 +12,5 @@ from .lamb import LAMB, LAMB8bit, LAMB32bit ...@@ -12,4 +12,5 @@ from .lamb import LAMB, LAMB8bit, LAMB32bit
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
from .optimizer import GlobalOptimManager from .optimizer import GlobalOptimManager
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
from .lion import Lion, Lion8bit, Lion32bit
from .sgd import SGD, SGD8bit, SGD32bit from .sgd import SGD, SGD8bit, SGD32bit
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer1State
class Lion(Optimizer1State):
def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"lion",
params,
lr,
betas,
0.,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class Lion8bit(Optimizer1State):
def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"lion",
params,
lr,
betas,
0.,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class Lion32bit(Optimizer1State):
def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"lion",
params,
lr,
betas,
0.,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
...@@ -669,7 +669,7 @@ class Optimizer1State(Optimizer8bit): ...@@ -669,7 +669,7 @@ class Optimizer1State(Optimizer8bit):
step, step,
config["lr"], config["lr"],
None, None,
0.0, config['betas'][1],
config["weight_decay"], config["weight_decay"],
gnorm_scale, gnorm_scale,
state["unorm_vec"] if config["max_unorm"] > 0.0 else None, state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment