Commit 7140c014 authored by Tim Dettmers's avatar Tim Dettmers
Browse files

Merge branch 'main' into fp8_merge

parents dd562c24 32f8c892
...@@ -189,3 +189,35 @@ Improvements: ...@@ -189,3 +189,35 @@ Improvements:
- StableEmbedding layer now has device and dtype parameters to make it 1:1 replaceable with regular Embedding layers (@lostmsu) - StableEmbedding layer now has device and dtype parameters to make it 1:1 replaceable with regular Embedding layers (@lostmsu)
- runtime performance of block-wise quantization slightly improved - runtime performance of block-wise quantization slightly improved
- added error message for the case multiple libcudart.so are installed and bitsandbytes picks the wrong one - added error message for the case multiple libcudart.so are installed and bitsandbytes picks the wrong one
### 0.37.0
#### Int8 Matmul + backward support for all GPUs
Features:
- Int8 MatmulLt now supports backward through inversion of the ColTuring/ColAmpere format. Slow, but memory efficient. Big thanks to @borzunov
- Int8 now supported on all GPUs. On devices with compute capability < 7.5, the Int weights are cast to 16/32-bit for the matrix multiplication. Contributed by @borzunov
Improvements:
- Improved logging for the CUDA detection mechanism.
### 0.38.0
#### 8-bit Lion, Load/Store 8-bit Models directly from/to HF Hub
Features:
- Support for 32 and 8-bit Lion has been added. Thank you @lucidrains
- Support for serialization of Linear8bitLt layers (LLM.int8()). This allows to store and load 8-bit weights directly from the HuggingFace Hub. Thank you @myrab
- New bug report features `python -m bitsandbytes` now gives extensive debugging details to debug CUDA setup failures.
Bug fixes:
- Fixed a bug where some bitsandbytes methods failed in a model-parallel setup on multiple GPUs. Thank you @tonylins
- Fixed a bug where cudart.so libraries could not be found in newer PyTorch releases.
Improvements:
- Improved the CUDA Setup procedure by doing a more extensive search for CUDA libraries
Deprecated:
- Devices with compute capability 3.0 (GTX 700s, K10) and 3.2 (Tegra K1, Jetson TK1) are now deprecated and support will be removed in 0.39.0.
- Support for CUDA 10.0 and 10.2 will be removed in bitsandbytes 0.39.0
...@@ -60,8 +60,8 @@ CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90 ...@@ -60,8 +60,8 @@ CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90
all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) $(NVCC) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(NVCC) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)
cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
......
...@@ -11,10 +11,41 @@ Resources: ...@@ -11,10 +11,41 @@ Resources:
## TL;DR ## TL;DR
**Requirements** **Requirements**
Python >=3.8. Linux distribution (Ubuntu, MacOS, etc.) + CUDA > 10.0. LLM.int8() requires Turing or Ampere GPUs. Python >=3.8. Linux distribution (Ubuntu, MacOS, etc.) + CUDA > 10.0.
(Deprecated: CUDA 10.0 is deprecated and only CUDA >= 11.0) will be supported with release 0.39.0)
**Installation**: **Installation**:
``pip install bitsandbytes`` ``pip install bitsandbytes``
In some cases it can happen that you need to compile from source. If this happens please consider submitting a bug report with `python -m bitsandbytes` information. What now follows is some short instructions which might work out of the box if `nvcc` is installed. If these do not work see further below.
Compilation quickstart:
```bash
git clone https://github.com/timdettmers/bitsandbytes.git
cd bitsandbytes
# CUDA_VERSIONS in {110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 120}
# make argument in {cuda110, cuda11x, cuda12x}
# if you do not know what CUDA you have, try looking at the output of: python -m bitsandbytes
CUDA_VERSION=117 make cuda11x
python setup.py install
```
**Using Int8 inference with HuggingFace Transformers**
```python
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
'decapoda-research/llama-7b-hf,
device_map='auto',
load_in_8bit=True,
max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB')
```
A more detailed example, can be found in [examples/int8_inference_huggingface.py](examples/int8_inference_huggingface.py).
**Using 8-bit optimizer**: **Using 8-bit optimizer**:
1. Comment out optimizer: ``#torch.optim.Adam(....)`` 1. Comment out optimizer: ``#torch.optim.Adam(....)``
2. Add 8-bit optimizer of your choice ``bnb.optim.Adam8bit(....)`` (arguments stay the same) 2. Add 8-bit optimizer of your choice ``bnb.optim.Adam8bit(....)`` (arguments stay the same)
...@@ -39,7 +70,7 @@ out = linear(x.to(torch.float16)) ...@@ -39,7 +70,7 @@ out = linear(x.to(torch.float16))
## Features ## Features
- 8-bit Matrix multiplication with mixed precision decomposition - 8-bit Matrix multiplication with mixed precision decomposition
- LLM.int8() inference - LLM.int8() inference
- 8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB (saves 75% memory) - 8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB, Lion (saves 75% memory)
- Stable Embedding Layer: Improved stability through better initialization, and normalization - Stable Embedding Layer: Improved stability through better initialization, and normalization
- 8-bit quantization: Quantile, Linear, and Dynamic quantization - 8-bit quantization: Quantile, Linear, and Dynamic quantization
- Fast quantile estimation: Up to 100x faster than other algorithms - Fast quantile estimation: Up to 100x faster than other algorithms
...@@ -58,6 +89,10 @@ The bitsandbytes library is currently only supported on Linux distributions. Win ...@@ -58,6 +89,10 @@ The bitsandbytes library is currently only supported on Linux distributions. Win
The requirements can best be fulfilled by installing pytorch via anaconda. You can install PyTorch by following the ["Get Started"](https://pytorch.org/get-started/locally/) instructions on the official website. The requirements can best be fulfilled by installing pytorch via anaconda. You can install PyTorch by following the ["Get Started"](https://pytorch.org/get-started/locally/) instructions on the official website.
To install run:
``pip install bitsandbytes``
## Using bitsandbytes ## Using bitsandbytes
### Using Int8 Matrix Multiplication ### Using Int8 Matrix Multiplication
...@@ -108,8 +143,23 @@ For upcoming features and changes and full history see [Patch Notes](CHANGELOG.m ...@@ -108,8 +143,23 @@ For upcoming features and changes and full history see [Patch Notes](CHANGELOG.m
2. __fatbinwrap_.. [Solution](errors_and_solutions.md#fatbinwrap_) 2. __fatbinwrap_.. [Solution](errors_and_solutions.md#fatbinwrap_)
## Compile from source ## Compile from source
To compile from source, you need an installation of CUDA. If `nvcc` is not installed, you can install the CUDA Toolkit with nvcc through the following commands.
```bash
wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/cuda_install.sh
# Syntax cuda_install CUDA_VERSION INSTALL_PREFIX EXPORT_TO_BASH
# CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121}
# EXPORT_TO_BASH in {0, 1} with 0=False and 1=True
# For example, the following installs CUDA 11.8 to ~/local/cuda-11.8 and exports the path to your .bashrc
bash cuda install 118 ~/local 1
```
To use a specific CUDA version just for a single compile run, you can set the variable `CUDA_HOME`, for example the following command compiles `libbitsandbytes_cuda117.so` using compiler flags for cuda11x with the cuda version at `~/local/cuda-11.7`:
``CUDA_HOME=~/local/cuda-11.7 CUDA_VERSION=117 make cuda11x``
To compile from source, please follow the [compile_from_source.md](compile_from_source.md) instructions. For more detailed instruction, please follow the [compile_from_source.md](compile_from_source.md) instructions.
## License ## License
......
import os import os
import sys import sys
import shlex
import subprocess
from warnings import warn from warnings import warn
from typing import Tuple
from os.path import isdir
import torch import torch
HEADER_WIDTH = 60 HEADER_WIDTH = 60
def execute_and_return(command_string: str) -> Tuple[str, str]:
def _decode(subprocess_err_out_tuple):
return tuple(
to_decode.decode("UTF-8").strip()
for to_decode in subprocess_err_out_tuple
)
def execute_and_return_decoded_std_streams(command_string):
return _decode(
subprocess.Popen(
shlex.split(command_string),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
).communicate()
)
std_out, std_err = execute_and_return_decoded_std_streams(command_string)
return std_out, std_err
def find_file_recursive(folder, filename):
cmd = f'find {folder} -name {filename}'
out, err = execute_and_return(cmd)
if len(err) > 0:
raise RuntimeError('Something when wrong when trying to find file. Maybe you do not have a linux system?')
return out
def generate_bug_report_information():
print_header("")
print_header("BUG REPORT INFORMATION")
print_header("")
print('')
if 'CONDA_PREFIX' in os.environ:
paths = find_file_recursive(os.environ['CONDA_PREFIX'], '*cuda*so')
print_header("ANACONDA CUDA PATHS")
print(paths)
print('')
if isdir('/usr/local/'):
paths = find_file_recursive('/usr/local', '*cuda*so')
print_header("/usr/local CUDA PATHS")
print(paths)
print('')
if isdir(os.getcwd()):
paths = find_file_recursive(os.getcwd(), '*cuda*so')
print_header("WORKING DIRECTORY CUDA PATHS")
print(paths)
print('')
print_header("LD_LIBRARY CUDA PATHS")
lib_path = os.environ['LD_LIBRARY_PATH'].strip()
for path in set(lib_path.split(':')):
try:
if isdir(path):
print_header(f"{path} CUDA PATHS")
paths = find_file_recursive(path, '*cuda*so')
print(paths)
except:
print(f'Could not read LD_LIBRARY_PATH: {path}')
print('')
def print_header( def print_header(
txt: str, width: int = HEADER_WIDTH, filler: str = "+" txt: str, width: int = HEADER_WIDTH, filler: str = "+"
...@@ -21,28 +92,16 @@ def print_debug_info() -> None: ...@@ -21,28 +92,16 @@ def print_debug_info() -> None:
) )
print_header("") generate_bug_report_information()
print_header("DEBUG INFORMATION")
print_header("")
print()
from . import COMPILED_WITH_CUDA, PACKAGE_GITHUB_URL from . import COMPILED_WITH_CUDA, PACKAGE_GITHUB_URL
from .cuda_setup.env_vars import to_be_ignored from .cuda_setup.env_vars import to_be_ignored
from .cuda_setup.main import get_compute_capabilities, get_cuda_lib_handle from .cuda_setup.main import get_compute_capabilities, get_cuda_lib_handle
print_header("POTENTIALLY LIBRARY-PATH-LIKE ENV VARS")
for k, v in os.environ.items():
if "/" in v and not to_be_ignored(k, v):
print(f"'{k}': '{v}'")
print_header("")
print(
"\nWARNING: Please be sure to sanitize sensible info from any such env vars!\n"
)
print_header("OTHER") print_header("OTHER")
print(f"{COMPILED_WITH_CUDA = }") print(f"COMPILED_WITH_CUDA = {COMPILED_WITH_CUDA}")
cuda = get_cuda_lib_handle() cuda = get_cuda_lib_handle()
print(f"COMPUTE_CAPABILITIES_PER_GPU = {get_compute_capabilities(cuda)}") print(f"COMPUTE_CAPABILITIES_PER_GPU = {get_compute_capabilities(cuda)}")
print_header("") print_header("")
...@@ -55,6 +114,7 @@ Running a quick check that: ...@@ -55,6 +114,7 @@ Running a quick check that:
+ CUDA function is callable + CUDA function is callable
""" """
) )
print("\nWARNING: Please be sure to sanitize sensible info from any such env vars!\n")
try: try:
from bitsandbytes.optim import Adam from bitsandbytes.optim import Adam
...@@ -91,3 +151,4 @@ except Exception as e: ...@@ -91,3 +151,4 @@ except Exception as e:
print(e) print(e)
print_debug_info() print_debug_info()
sys.exit(1) sys.exit(1)
from ._functions import undo_layout, get_inverse_transform_indices
...@@ -2,6 +2,7 @@ import operator ...@@ -2,6 +2,7 @@ import operator
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from functools import reduce # Required in Python 3 from functools import reduce # Required in Python 3
from typing import Tuple, Optional
import torch import torch
...@@ -14,6 +15,12 @@ def prod(iterable): ...@@ -14,6 +15,12 @@ def prod(iterable):
tensor = torch.Tensor tensor = torch.Tensor
# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
""" """
This class pools outlier dimensions across layers. This class pools outlier dimensions across layers.
This is particularly important for small models where outlier features This is particularly important for small models where outlier features
...@@ -48,6 +55,51 @@ class GlobalOutlierPooler: ...@@ -48,6 +55,51 @@ class GlobalOutlierPooler:
return torch.Tensor(list(self.outliers)).to(torch.int64) return torch.Tensor(list(self.outliers)).to(torch.int64)
def get_inverse_transform_indices(transform_tile: callable, tile_size: Tuple[int, int]):
"""
Compute a permutation of indices that invert the specified (tiled) matrix transformation
:param transform_tile: a function that applies forward transform to a tensor of shape [dim1, dim2]
:param tile_size: higher-level tile dimensions, i.e. (8, 32) for Turing and (32, 32) for Ampere
:note: we assume that tile_transform applies to a cpu-based int8 tensor of shape tile_size
:example: transform_tile function for the turing layout (bitsandbytes.functional as F)
:returns: indices
"""
d1, d2 = tile_size
assert 0 < d1 * d2 < 2**64
tile_indices = torch.arange(d1 * d2, dtype=torch.int64).view(d1, d2)
# encode each position in tile as a tuple of <= 8 unique bytes
permuted_tile_indices = torch.zeros_like(tile_indices)
for i in range(8):
# select i-th byte, apply transformation and trace where each index ended up
ith_dim_indices = torch.div(tile_indices, 256**i, rounding_mode="trunc") % 256
sample_tile_i = (ith_dim_indices - 128).to(torch.int8).contiguous()
assert torch.all(sample_tile_i.int() + 128 == ith_dim_indices), "int overflow"
permuted_tile_i = transform_tile(sample_tile_i)
ith_permuted_indices = permuted_tile_i.to(tile_indices.dtype) + 128
permuted_tile_indices += ith_permuted_indices * (256**i)
if d1 * d2 < 256**i:
break # if all indices fit in i bytes, stop early
return permuted_tile_indices
def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor:
"""
Undo a tiled permutation such as turing or ampere layout
:param permuted_tensor: torch tensor in a permuted layout
:param tile_indices: reverse transformation indices, from get_inverse_transform_indices
:return: contiguous row-major tensor
"""
(rows, cols), (tile_rows, tile_cols) = permuted_tensor.shape, tile_indices.shape
assert rows % tile_rows == cols % tile_cols == 0, "tensor must contain a whole number of tiles"
tensor = permuted_tensor.reshape(-1, tile_indices.numel()).t()
outputs = torch.empty_like(tensor) # note: not using .index_copy because it was slower on cuda
outputs[tile_indices.flatten()] = tensor
outputs = outputs.reshape(tile_rows, tile_cols, cols // tile_cols, rows // tile_rows)
outputs = outputs.permute(3, 0, 2, 1) # (rows // tile_rows, tile_rows), (cols // tile_cols, tile_cols)
return outputs.reshape(rows, cols).contiguous()
class MatMul8bit(torch.autograd.Function): class MatMul8bit(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, A, B, out=None, quant_type="vector", precision=None): def forward(ctx, A, B, out=None, quant_type="vector", precision=None):
...@@ -169,8 +221,21 @@ bmm_cublas = MatMul8bit.apply ...@@ -169,8 +221,21 @@ bmm_cublas = MatMul8bit.apply
matmul_cublas = MatMul8bit.apply matmul_cublas = MatMul8bit.apply
def supports_igemmlt(device: torch.device) -> bool:
"""check if this device supports the optimized int8 kernel"""
if torch.cuda.get_device_capability(device=device) < (7, 5):
return False
device_name = torch.cuda.get_device_name(device=device)
nvidia16_models = ('GTX 1630', 'GTX 1650', 'GTX 1660') # https://en.wikipedia.org/wiki/GeForce_16_series
if any(model_name in device_name for model_name in nvidia16_models):
return False # these devices are technically cuda 7.5-capable, but they lack tensor cores
return True
@dataclass @dataclass
class MatmulLtState: class MatmulLtState:
_tile_indices: Optional[torch.Tensor] = None
force_no_igemmlt: bool = False
CB = None CB = None
CxB = None CxB = None
SB = None SB = None
...@@ -202,11 +267,31 @@ class MatmulLtState: ...@@ -202,11 +267,31 @@ class MatmulLtState:
self.SBt = None self.SBt = None
self.CBt = None self.CBt = None
def get_tile_size(self):
assert self.formatB in (
"col_turing",
"col_ampere",
), f"please find this assert and manually enter tile size for {self.formatB}"
return (8, 32) if self.formatB == "col_turing" else (32, 32)
@property
def tile_indices(self):
if self._tile_indices is None:
device = self.CxB.device
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=self.formatB)[0].to(x.device)
with torch.no_grad():
self._tile_indices = get_inverse_transform_indices(transform, self.get_tile_size()).to(device)
return self._tile_indices
class MatMul8bitLt(torch.autograd.Function): class MatMul8bitLt(torch.autograd.Function):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@staticmethod @staticmethod
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
# default to pytorch behavior if inputs are empty using_igemmlt = supports_igemmlt(A.device) and not state.force_no_igemmlt
# default of pytorch behavior if inputs are empty
ctx.is_empty = False ctx.is_empty = False
if prod(A.shape) == 0: if prod(A.shape) == 0:
ctx.is_empty = True ctx.is_empty = True
...@@ -214,9 +299,9 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -214,9 +299,9 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.B = B ctx.B = B
ctx.bias = bias ctx.bias = bias
if A.shape[-1] == B.shape[0]: if A.shape[-1] == B.shape[0]:
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device) return torch.empty(A.shape[:-1] + B.shape[1:], dtype=A.dtype, device=A.device)
else: else:
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device) return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=A.device)
# 1. Quantize A # 1. Quantize A
# 2. Quantize B # 2. Quantize B
...@@ -235,9 +320,7 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -235,9 +320,7 @@ class MatMul8bitLt(torch.autograd.Function):
# 1. Quantize A # 1. Quantize A
if len(A.shape) == 3: if len(A.shape) == 3:
A = A.view(-1, A.shape[-1]).contiguous() A = A.view(-1, A.shape[-1]).contiguous()
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant( CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold)
A.to(torch.float16), threshold=state.threshold
)
if state.threshold > 0.0 and coo_tensorA is not None: if state.threshold > 0.0 and coo_tensorA is not None:
if state.has_fp16_weights: if state.has_fp16_weights:
...@@ -248,12 +331,12 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -248,12 +331,12 @@ class MatMul8bitLt(torch.autograd.Function):
state.subB = B[:, idx].t().contiguous() state.subB = B[:, idx].t().contiguous()
state.idx = idx state.idx = idx
else: else:
if state.CxB is None: if state.CxB is None and using_igemmlt:
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format # we also need to convert it to the turing/ampere format
state.CxB, state.SB = F.transform(state.CB, to_order=formatB) state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
else: else:
if not state.has_fp16_weights and state.CxB is None: if not state.has_fp16_weights and state.CxB is None and using_igemmlt:
state.CxB, state.SB = F.transform(state.CB, to_order=formatB) state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
subA = None subA = None
...@@ -273,7 +356,10 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -273,7 +356,10 @@ class MatMul8bitLt(torch.autograd.Function):
state.SCBt, state.SCBt,
coo_tensorB, coo_tensorB,
) = F.double_quant(B.to(torch.float16)) ) = F.double_quant(B.to(torch.float16))
state.CxB, state.SB = F.transform(CB, to_order=formatB) if using_igemmlt:
state.CxB, state.SB = F.transform(CB, to_order=formatB)
else:
state.CB = CB
else: else:
has_grad = False has_grad = False
...@@ -288,18 +374,17 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -288,18 +374,17 @@ class MatMul8bitLt(torch.autograd.Function):
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
# else: # else:
# state.idx = outlier_idx # state.idx = outlier_idx
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) if state.CxB is not None:
state.subB = ( outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
(outliers * state.SCB.view(-1, 1) / 127.0) else:
.t() outliers = state.CB[:, state.idx.long()].clone()
.contiguous()
.to(A.dtype) state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype)
)
CA[:, state.idx.long()] = 0 CA[:, state.idx.long()] = 0
CAt[:, state.idx.long()] = 0 CAt[:, state.idx.long()] = 0
subA = A[:, state.idx.long()] subA = A[:, state.idx.long()]
shapeB = state.SB[0] shapeB = state.SB[0] if state.SB else B.shape
if len(input_shape) == 3: if len(input_shape) == 3:
output_shape = (input_shape[0], input_shape[1], shapeB[0]) output_shape = (input_shape[0], input_shape[1], shapeB[0])
...@@ -307,16 +392,25 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -307,16 +392,25 @@ class MatMul8bitLt(torch.autograd.Function):
output_shape = (input_shape[0], shapeB[0]) output_shape = (input_shape[0], shapeB[0])
# 3. Matmul # 3. Matmul
C32A, SA = F.transform(CA, "col32") if using_igemmlt:
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) C32A, SA = F.transform(CA, "col32")
# we apply the fused bias here out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
if bias is None or bias.dtype == torch.float16:
# we apply the fused bias here
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
output = output.to(A.dtype)
else: # apply bias separately
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
output = output.to(A.dtype).add_(bias)
if bias is None or bias.dtype == torch.float16: else:
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) A_wo_outliers = A.clone()
output = output.to(A.dtype) if state.idx is not None:
else: # apply bias separately A_wo_outliers[:, state.idx.long()] = 0
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) output = torch.nn.functional.linear(A_wo_outliers, state.CB.to(A.dtype))
output = output.to(A.dtype).add_(bias) output = output.mul_(state.SCB.unsqueeze(0).mul(1.0 / 127.0))
if bias is not None:
output = output.add_(bias)
# 4. Mixed-precision decomposition matmul # 4. Mixed-precision decomposition matmul
if coo_tensorA is not None and subA is not None: if coo_tensorA is not None and subA is not None:
...@@ -337,14 +431,13 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -337,14 +431,13 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.tensor_states = (None, None) ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None) ctx.save_for_backward(None, None)
clone_func = torch.clone if len(output_shape) == 3 else lambda x: x
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
return clone_func(output.view(output_shape)) return clone_func(output.view(output_shape))
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
if ctx.is_empty: if ctx.is_empty:
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias)) bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
CAt, subA = ctx.tensors CAt, subA = ctx.tensors
...@@ -359,9 +452,7 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -359,9 +452,7 @@ class MatMul8bitLt(torch.autograd.Function):
# Cast grad_output to fp16 # Cast grad_output to fp16
if len(grad_output.shape) == 3: if len(grad_output.shape) == 3:
grad_output = grad_output.reshape( grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
-1, grad_output.shape[-1]
).contiguous()
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
if req_gradB: if req_gradB:
...@@ -376,17 +467,22 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -376,17 +467,22 @@ class MatMul8bitLt(torch.autograd.Function):
if state.CBt is not None: if state.CBt is not None:
C32grad, Sgrad = F.transform(Cgrad, "col32") C32grad, Sgrad = F.transform(Cgrad, "col32")
if state.CxBt is None: if state.CxBt is None:
state.CxBt, state.SBt = F.transform( state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
state.CBt, to_order=formatB, transpose=True
)
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
elif state.CB is not None: elif state.CB is not None:
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0)) CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
elif state.CxB is not None:
CB = (
undo_layout(state.CxB, state.tile_indices)
.to(ctx.dtype_A)
.mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
)
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
else: else:
raise Exception('State must contain either CBt or CB matrix for backward') raise Exception("State must contain either CBt or CB or CxB matrix for backward")
return grad_A, grad_B, None, grad_bias, None return grad_A, grad_B, None, grad_bias, None
......
...@@ -11,8 +11,6 @@ from bitsandbytes.cuda_setup.main import CUDASetup ...@@ -11,8 +11,6 @@ from bitsandbytes.cuda_setup.main import CUDASetup
setup = CUDASetup.get_instance() setup = CUDASetup.get_instance()
if setup.initialized != True: if setup.initialized != True:
setup.run_cuda_setup() setup.run_cuda_setup()
if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0':
setup.print_log_stack()
lib = setup.lib lib = setup.lib
try: try:
...@@ -20,14 +18,22 @@ try: ...@@ -20,14 +18,22 @@ try:
CUDASetup.get_instance().generate_instructions() CUDASetup.get_instance().generate_instructions()
CUDASetup.get_instance().print_log_stack() CUDASetup.get_instance().print_log_stack()
raise RuntimeError(''' raise RuntimeError('''
CUDA Setup failed despite GPU being available. Inspect the CUDA SETUP outputs above to fix your environment! CUDA Setup failed despite GPU being available. Please run the following command to get more information:
If you cannot find any issues and suspect a bug, please open an issue with detals about your environment:
https://github.com/TimDettmers/bitsandbytes/issues''') python -m bitsandbytes
Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them
to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes
and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues''')
lib.cadam32bit_g32 lib.cadam32bit_g32
lib.get_context.restype = ct.c_void_p lib.get_context.restype = ct.c_void_p
lib.get_cusparse.restype = ct.c_void_p lib.get_cusparse.restype = ct.c_void_p
COMPILED_WITH_CUDA = True COMPILED_WITH_CUDA = True
except AttributeError: except AttributeError:
warn("The installed version of bitsandbytes was compiled without GPU support. " warn("The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers and GPU quantization are unavailable.") "8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.")
COMPILED_WITH_CUDA = False COMPILED_WITH_CUDA = False
# print the setup details after checking for errors so we do not print twice
if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0':
setup.print_log_stack()
...@@ -11,6 +11,7 @@ def to_be_ignored(env_var: str, value: str) -> bool: ...@@ -11,6 +11,7 @@ def to_be_ignored(env_var: str, value: str) -> bool:
"HOME", # Linux shell default "HOME", # Linux shell default
"TMUX", # Terminal Multiplexer "TMUX", # Terminal Multiplexer
"XDG_DATA_DIRS", # XDG: Desktop environment stuff "XDG_DATA_DIRS", # XDG: Desktop environment stuff
"XDG_GREETER_DATA_DIR", # XDG: Desktop environment stuff
"XDG_RUNTIME_DIR", "XDG_RUNTIME_DIR",
"MAIL", # something related to emails "MAIL", # something related to emails
"SHELL", # binary for currently invoked shell "SHELL", # binary for currently invoked shell
......
...@@ -21,12 +21,21 @@ import os ...@@ -21,12 +21,21 @@ import os
import errno import errno
import torch import torch
from warnings import warn from warnings import warn
from itertools import product
from pathlib import Path from pathlib import Path
from typing import Set, Union from typing import Set, Union
from .env_vars import get_potentially_lib_path_containing_env_vars from .env_vars import get_potentially_lib_path_containing_env_vars
CUDA_RUNTIME_LIB: str = "libcudart.so" # these are the most common libs names
# libcudart.so is missing by default for a conda install with PyTorch 2.0 and instead
# we have libcudart.so.11.0 which causes a lot of errors before
# not sure if libcudart.so.12.0 exists in pytorch installs, but it does not hurt
CUDA_RUNTIME_LIBS: list = ["libcudart.so", 'libcudart.so.11.0', 'libcudart.so.12.0']
# this is a order list of backup paths to search CUDA in, if it cannot be found in the main environmental paths
backup_paths = []
backup_paths.append('$CONDA_PREFIX/lib/libcudart.so.11.0')
class CUDASetup: class CUDASetup:
_instance = None _instance = None
...@@ -80,9 +89,10 @@ class CUDASetup: ...@@ -80,9 +89,10 @@ class CUDASetup:
self.add_log_entry('python setup.py install') self.add_log_entry('python setup.py install')
def initialize(self): def initialize(self):
self.has_printed = False if not getattr(self, 'initialized', False):
self.lib = None self.has_printed = False
self.initialized = False self.lib = None
self.initialized = False
def run_cuda_setup(self): def run_cuda_setup(self):
self.initialized = True self.initialized = True
...@@ -97,13 +107,15 @@ class CUDASetup: ...@@ -97,13 +107,15 @@ class CUDASetup:
package_dir = Path(__file__).parent.parent package_dir = Path(__file__).parent.parent
binary_path = package_dir / binary_name binary_path = package_dir / binary_name
print('bin', binary_path)
try: try:
if not binary_path.exists(): if not binary_path.exists():
self.add_log_entry(f"CUDA SETUP: Required library version not found: {binary_name}. Maybe you need to compile it from source?") self.add_log_entry(f"CUDA SETUP: Required library version not found: {binary_name}. Maybe you need to compile it from source?")
legacy_binary_name = "libbitsandbytes_cpu.so" legacy_binary_name = "libbitsandbytes_cpu.so"
self.add_log_entry(f"CUDA SETUP: Defaulting to {legacy_binary_name}...") self.add_log_entry(f"CUDA SETUP: Defaulting to {legacy_binary_name}...")
binary_path = package_dir / legacy_binary_name binary_path = package_dir / legacy_binary_name
if not binary_path.exists(): if not binary_path.exists() or torch.cuda.is_available():
self.add_log_entry('') self.add_log_entry('')
self.add_log_entry('='*48 + 'ERROR' + '='*37) self.add_log_entry('='*48 + 'ERROR' + '='*37)
self.add_log_entry('CUDA SETUP: CUDA detection failed! Possible reasons:') self.add_log_entry('CUDA SETUP: CUDA detection failed! Possible reasons:')
...@@ -112,10 +124,10 @@ class CUDASetup: ...@@ -112,10 +124,10 @@ class CUDASetup:
self.add_log_entry('3. You have multiple conflicting CUDA libraries') self.add_log_entry('3. You have multiple conflicting CUDA libraries')
self.add_log_entry('4. Required library not pre-compiled for this bitsandbytes release!') self.add_log_entry('4. Required library not pre-compiled for this bitsandbytes release!')
self.add_log_entry('CUDA SETUP: If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION` for example, `make CUDA_VERSION=113`.') self.add_log_entry('CUDA SETUP: If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION` for example, `make CUDA_VERSION=113`.')
self.add_log_entry('CUDA SETUP: The CUDA version for the compile might depend on your conda install. Inspect CUDA version via `conda list | grep cuda`.')
self.add_log_entry('='*80) self.add_log_entry('='*80)
self.add_log_entry('') self.add_log_entry('')
self.generate_instructions() self.generate_instructions()
self.print_log_stack()
raise Exception('CUDA SETUP: Setup Failed!') raise Exception('CUDA SETUP: Setup Failed!')
self.lib = ct.cdll.LoadLibrary(binary_path) self.lib = ct.cdll.LoadLibrary(binary_path)
else: else:
...@@ -123,7 +135,6 @@ class CUDASetup: ...@@ -123,7 +135,6 @@ class CUDASetup:
self.lib = ct.cdll.LoadLibrary(binary_path) self.lib = ct.cdll.LoadLibrary(binary_path)
except Exception as ex: except Exception as ex:
self.add_log_entry(str(ex)) self.add_log_entry(str(ex))
self.print_log_stack()
def add_log_entry(self, msg, is_warning=False): def add_log_entry(self, msg, is_warning=False):
self.cuda_setup_log.append((msg, is_warning)) self.cuda_setup_log.append((msg, is_warning))
...@@ -148,7 +159,7 @@ def is_cublasLt_compatible(cc): ...@@ -148,7 +159,7 @@ def is_cublasLt_compatible(cc):
if cc is not None: if cc is not None:
cc_major, cc_minor = cc.split('.') cc_major, cc_minor = cc.split('.')
if int(cc_major) < 7 or (int(cc_major) == 7 and int(cc_minor) < 5): if int(cc_major) < 7 or (int(cc_major) == 7 and int(cc_minor) < 5):
cuda_setup.add_log_entry("WARNING: Compute capability < 7.5 detected! Proceeding to load CPU-only library...", is_warning=True) CUDASetup.get_instance().add_log_entry("WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU!", is_warning=True)
else: else:
has_cublaslt = True has_cublaslt = True
return has_cublaslt return has_cublaslt
...@@ -176,11 +187,12 @@ def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]: ...@@ -176,11 +187,12 @@ def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]:
def get_cuda_runtime_lib_paths(candidate_paths: Set[Path]) -> Set[Path]: def get_cuda_runtime_lib_paths(candidate_paths: Set[Path]) -> Set[Path]:
return { paths = set()
path / CUDA_RUNTIME_LIB for libname in CUDA_RUNTIME_LIBS:
for path in candidate_paths for path in candidate_paths:
if (path / CUDA_RUNTIME_LIB).is_file() if (path / libname).is_file():
} paths.add(path / libname)
return paths
def resolve_paths_list(paths_list_candidate: str) -> Set[Path]: def resolve_paths_list(paths_list_candidate: str) -> Set[Path]:
...@@ -200,12 +212,12 @@ def find_cuda_lib_in(paths_list_candidate: str) -> Set[Path]: ...@@ -200,12 +212,12 @@ def find_cuda_lib_in(paths_list_candidate: str) -> Set[Path]:
def warn_in_case_of_duplicates(results_paths: Set[Path]) -> None: def warn_in_case_of_duplicates(results_paths: Set[Path]) -> None:
if len(results_paths) > 1: if len(results_paths) > 1:
warning_msg = ( warning_msg = (
f"Found duplicate {CUDA_RUNTIME_LIB} files: {results_paths}.. " f"Found duplicate {CUDA_RUNTIME_LIBS} files: {results_paths}.. "
"We'll flip a coin and try one of these, in order to fail forward.\n" "We'll flip a coin and try one of these, in order to fail forward.\n"
"Either way, this might cause trouble in the future:\n" "Either way, this might cause trouble in the future:\n"
"If you get `CUDA error: invalid device function` errors, the above " "If you get `CUDA error: invalid device function` errors, the above "
"might be the cause and the solution is to make sure only one " "might be the cause and the solution is to make sure only one "
f"{CUDA_RUNTIME_LIB} in the paths that we search based on your env.") f"{CUDA_RUNTIME_LIBS} in the paths that we search based on your env.")
CUDASetup.get_instance().add_log_entry(warning_msg, is_warning=True) CUDASetup.get_instance().add_log_entry(warning_msg, is_warning=True)
...@@ -233,7 +245,7 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]: ...@@ -233,7 +245,7 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]:
return next(iter(conda_cuda_libs)) return next(iter(conda_cuda_libs))
CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["CONDA_PREFIX"]} did not contain ' CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["CONDA_PREFIX"]} did not contain '
f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...', is_warning=True) f'{CUDA_RUNTIME_LIBS} as expected! Searching further paths...', is_warning=True)
if "LD_LIBRARY_PATH" in candidate_env_vars: if "LD_LIBRARY_PATH" in candidate_env_vars:
lib_ld_cuda_libs = find_cuda_lib_in(candidate_env_vars["LD_LIBRARY_PATH"]) lib_ld_cuda_libs = find_cuda_lib_in(candidate_env_vars["LD_LIBRARY_PATH"])
...@@ -243,7 +255,7 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]: ...@@ -243,7 +255,7 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]:
warn_in_case_of_duplicates(lib_ld_cuda_libs) warn_in_case_of_duplicates(lib_ld_cuda_libs)
CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["LD_LIBRARY_PATH"]} did not contain ' CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["LD_LIBRARY_PATH"]} did not contain '
f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...', is_warning=True) f'{CUDA_RUNTIME_LIBS} as expected! Searching further paths...', is_warning=True)
remaining_candidate_env_vars = { remaining_candidate_env_vars = {
env_var: value for env_var, value in candidate_env_vars.items() env_var: value for env_var, value in candidate_env_vars.items()
...@@ -255,7 +267,7 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]: ...@@ -255,7 +267,7 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]:
cuda_runtime_libs.update(find_cuda_lib_in(value)) cuda_runtime_libs.update(find_cuda_lib_in(value))
if len(cuda_runtime_libs) == 0: if len(cuda_runtime_libs) == 0:
CUDASetup.get_instance().add_log_entry('CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching /usr/local/cuda/lib64...') CUDASetup.get_instance().add_log_entry('CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching in backup paths...')
cuda_runtime_libs.update(find_cuda_lib_in('/usr/local/cuda/lib64')) cuda_runtime_libs.update(find_cuda_lib_in('/usr/local/cuda/lib64'))
warn_in_case_of_duplicates(cuda_runtime_libs) warn_in_case_of_duplicates(cuda_runtime_libs)
...@@ -361,10 +373,10 @@ def evaluate_cuda_setup(): ...@@ -361,10 +373,10 @@ def evaluate_cuda_setup():
if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0': if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0':
print('') print('')
print('='*35 + 'BUG REPORT' + '='*35) print('='*35 + 'BUG REPORT' + '='*35)
print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues') print(('Welcome to bitsandbytes. For bug reports, please run\n\npython -m bitsandbytes\n\n'),
print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link') ('and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues'))
print('='*80) print('='*80)
if not torch.cuda.is_available(): return 'libsbitsandbytes_cpu.so', None, None, None, None if not torch.cuda.is_available(): return 'libbitsandbytes_cpu.so', None, None, None, None
cuda_setup = CUDASetup.get_instance() cuda_setup = CUDASetup.get_instance()
cudart_path = determine_cuda_runtime_lib_path() cudart_path = determine_cuda_runtime_lib_path()
......
...@@ -35,6 +35,10 @@ if COMPILED_WITH_CUDA: ...@@ -35,6 +35,10 @@ if COMPILED_WITH_CUDA:
lib.crmsprop32bit_g32, lib.crmsprop32bit_g32,
lib.crmsprop32bit_g16, lib.crmsprop32bit_g16,
) )
str2optimizer32bit["lion"] = (
lib.clion32bit_g32,
lib.clion32bit_g16,
)
str2optimizer32bit["adagrad"] = ( str2optimizer32bit["adagrad"] = (
lib.cadagrad32bit_g32, lib.cadagrad32bit_g32,
lib.cadagrad32bit_g16, lib.cadagrad32bit_g16,
...@@ -58,6 +62,10 @@ if COMPILED_WITH_CUDA: ...@@ -58,6 +62,10 @@ if COMPILED_WITH_CUDA:
lib.crmsprop_static_8bit_g32, lib.crmsprop_static_8bit_g32,
lib.crmsprop_static_8bit_g16, lib.crmsprop_static_8bit_g16,
) )
str2optimizer8bit["lion"] = (
lib.clion_static_8bit_g32,
lib.clion_static_8bit_g16,
)
str2optimizer8bit["lamb"] = ( str2optimizer8bit["lamb"] = (
lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g32,
lib.cadam_static_8bit_g16, lib.cadam_static_8bit_g16,
...@@ -80,6 +88,10 @@ if COMPILED_WITH_CUDA: ...@@ -80,6 +88,10 @@ if COMPILED_WITH_CUDA:
lib.crmsprop_8bit_blockwise_fp32, lib.crmsprop_8bit_blockwise_fp32,
lib.crmsprop_8bit_blockwise_fp16, lib.crmsprop_8bit_blockwise_fp16,
) )
str2optimizer8bit_blockwise["lion"] = (
lib.clion_8bit_blockwise_fp32,
lib.clion_8bit_blockwise_fp16,
)
str2optimizer8bit_blockwise["adagrad"] = ( str2optimizer8bit_blockwise["adagrad"] = (
lib.cadagrad_8bit_blockwise_fp32, lib.cadagrad_8bit_blockwise_fp32,
lib.cadagrad_8bit_blockwise_fp16, lib.cadagrad_8bit_blockwise_fp16,
...@@ -655,9 +667,11 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: ...@@ -655,9 +667,11 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
torch.Tensor: torch.Tensor:
Quantized 8-bit tensor. Quantized 8-bit tensor.
''' '''
prev_device = pre_call(A.device)
if out is None: out = torch.zeros_like(A, dtype=torch.uint8) if out is None: out = torch.zeros_like(A, dtype=torch.uint8)
is_on_gpu([A, out]) is_on_gpu([A, out])
lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
post_call(prev_device)
return out return out
...@@ -682,9 +696,11 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: ...@@ -682,9 +696,11 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
torch.Tensor: torch.Tensor:
32-bit output tensor. 32-bit output tensor.
''' '''
prev_device = pre_call(A.device)
if out is None: out = torch.zeros_like(A, dtype=torch.float32) if out is None: out = torch.zeros_like(A, dtype=torch.float32)
is_on_gpu([code, A, out]) is_on_gpu([code, A, out])
lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
post_call(prev_device)
return out return out
...@@ -753,6 +769,8 @@ def optimizer_update_32bit( ...@@ -753,6 +769,8 @@ def optimizer_update_32bit(
f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}' f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}'
) )
prev_device = pre_call(g.device)
is_on_gpu([g, p, state1, state2, unorm_vec])
if g.dtype == torch.float32 and state1.dtype == torch.float32: if g.dtype == torch.float32 and state1.dtype == torch.float32:
str2optimizer32bit[optimizer_name][0]( str2optimizer32bit[optimizer_name][0](
get_ptr(g), get_ptr(g),
...@@ -795,6 +813,7 @@ def optimizer_update_32bit( ...@@ -795,6 +813,7 @@ def optimizer_update_32bit(
raise ValueError( raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
) )
post_call(prev_device)
def optimizer_update_8bit( def optimizer_update_8bit(
...@@ -873,6 +892,8 @@ def optimizer_update_8bit( ...@@ -873,6 +892,8 @@ def optimizer_update_8bit(
if max_unorm > 0.0: if max_unorm > 0.0:
param_norm = torch.norm(p.data.float()) param_norm = torch.norm(p.data.float())
prev_device = pre_call(g.device)
is_on_gpu([g, p, state1, state2, unorm_vec, qmap1, qmap2, max1, max2, new_max1, new_max2])
if g.dtype == torch.float32 and state1.dtype == torch.uint8: if g.dtype == torch.float32 and state1.dtype == torch.uint8:
str2optimizer8bit[optimizer_name][0]( str2optimizer8bit[optimizer_name][0](
get_ptr(p), get_ptr(p),
...@@ -925,6 +946,7 @@ def optimizer_update_8bit( ...@@ -925,6 +946,7 @@ def optimizer_update_8bit(
raise ValueError( raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
) )
post_call(prev_device)
def optimizer_update_8bit_blockwise( def optimizer_update_8bit_blockwise(
...@@ -947,6 +969,8 @@ def optimizer_update_8bit_blockwise( ...@@ -947,6 +969,8 @@ def optimizer_update_8bit_blockwise(
skip_zeros=False, skip_zeros=False,
) -> None: ) -> None:
prev_device = pre_call(g.device)
is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2])
if g.dtype == torch.float32 and state1.dtype == torch.uint8: if g.dtype == torch.float32 and state1.dtype == torch.uint8:
str2optimizer8bit_blockwise[optimizer_name][0]( str2optimizer8bit_blockwise[optimizer_name][0](
get_ptr(p), get_ptr(p),
...@@ -991,6 +1015,7 @@ def optimizer_update_8bit_blockwise( ...@@ -991,6 +1015,7 @@ def optimizer_update_8bit_blockwise(
raise ValueError( raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
) )
post_call(prev_device)
def percentile_clipping( def percentile_clipping(
...@@ -1006,6 +1031,7 @@ def percentile_clipping( ...@@ -1006,6 +1031,7 @@ def percentile_clipping(
The current optimiation steps (number of past gradient norms). The current optimiation steps (number of past gradient norms).
""" """
prev_device = pre_call(grad.device)
is_on_gpu([grad, gnorm_vec]) is_on_gpu([grad, gnorm_vec])
if grad.dtype == torch.float32: if grad.dtype == torch.float32:
lib.cpercentile_clipping_g32( lib.cpercentile_clipping_g32(
...@@ -1023,6 +1049,7 @@ def percentile_clipping( ...@@ -1023,6 +1049,7 @@ def percentile_clipping(
) )
else: else:
raise ValueError(f"Gradient type {grad.dtype} not supported!") raise ValueError(f"Gradient type {grad.dtype} not supported!")
post_call(prev_device)
current_gnorm = torch.sqrt(gnorm_vec[step % 100]) current_gnorm = torch.sqrt(gnorm_vec[step % 100])
vals, idx = torch.sort(gnorm_vec) vals, idx = torch.sort(gnorm_vec)
...@@ -1779,6 +1806,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): ...@@ -1779,6 +1806,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
(cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype (cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype
) )
nnz = cooA.nnz nnz = cooA.nnz
prev_device = pre_call(B.device)
assert cooA.rowidx.numel() == nnz assert cooA.rowidx.numel() == nnz
assert cooA.colidx.numel() == nnz assert cooA.colidx.numel() == nnz
assert cooA.values.numel() == nnz assert cooA.values.numel() == nnz
...@@ -1855,6 +1883,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): ...@@ -1855,6 +1883,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
ccolsB, ccolsB,
) )
# else: assertion error # else: assertion error
post_call(prev_device)
return out return out
......
...@@ -9,6 +9,8 @@ import torch.nn.functional as F ...@@ -9,6 +9,8 @@ import torch.nn.functional as F
from torch import Tensor, device, dtype, nn from torch import Tensor, device, dtype, nn
import bitsandbytes as bnb import bitsandbytes as bnb
import bitsandbytes.functional
from bitsandbytes.autograd._functions import get_inverse_transform_indices, undo_layout
from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import OutlierTracer, find_outlier_dims from bitsandbytes.utils import OutlierTracer, find_outlier_dims
...@@ -238,19 +240,10 @@ class Int8Params(torch.nn.Parameter): ...@@ -238,19 +240,10 @@ class Int8Params(torch.nn.Parameter):
class Linear8bitLt(nn.Linear): class Linear8bitLt(nn.Linear):
def __init__( def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
self, memory_efficient_backward=False, threshold=0.0, index=None):
input_features, super().__init__(input_features, output_features, bias)
output_features, assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
bias=True,
has_fp16_weights=True,
memory_efficient_backward=False,
threshold=0.0,
index=None,
):
super().__init__(
input_features, output_features, bias
)
self.state = bnb.MatmulLtState() self.state = bnb.MatmulLtState()
self.index = index self.index = index
...@@ -260,9 +253,54 @@ class Linear8bitLt(nn.Linear): ...@@ -260,9 +253,54 @@ class Linear8bitLt(nn.Linear):
if threshold > 0.0 and not has_fp16_weights: if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True self.state.use_pool = True
self.weight = Int8Params( self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights
) def _save_to_state_dict(self, destination, prefix, keep_vars):
if not self.state.has_fp16_weights and self.state.CB is None and self.state.CxB is not None:
# reorder weight layout back from ampere/turing to row
reorder_layout = True
weight_clone = self.weight.data.clone()
else:
reorder_layout = False
try:
if reorder_layout:
self.weight.data = undo_layout(self.state.CxB, self.state.tile_indices)
super()._save_to_state_dict(destination, prefix, keep_vars)
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
weight_name = "SCB"
# case 1: .cuda was called, SCB is in self.weight
param_from_weight = getattr(self.weight, weight_name)
# case 2: self.init_8bit_state was called, SCB is in self.state
param_from_state = getattr(self.state, weight_name)
key_name = prefix + f"{weight_name}"
if param_from_weight is not None:
destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach()
elif not self.state.has_fp16_weights and param_from_state is not None:
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
finally:
if reorder_layout:
self.weight.data = weight_clone
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs)
for key in unexpected_keys:
input_name = key[len(prefix):]
if input_name == "SCB":
if self.weight.SCB is None:
# buffers not yet initialized, can't call them directly without
raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is "
"not supported. Please call module.cuda() before module.load_state_dict()")
input_param = state_dict[key]
self.weight.SCB.copy_(input_param)
unexpected_keys.remove(key)
def init_8bit_state(self): def init_8bit_state(self):
self.state.CB = self.weight.CB self.state.CB = self.weight.CB
...@@ -270,30 +308,23 @@ class Linear8bitLt(nn.Linear): ...@@ -270,30 +308,23 @@ class Linear8bitLt(nn.Linear):
self.weight.CB = None self.weight.CB = None
self.weight.SCB = None self.weight.SCB = None
def forward(self, x): def forward(self, x: torch.Tensor):
self.state.is_training = self.training self.state.is_training = self.training
if self.weight.CB is not None: if self.weight.CB is not None:
self.init_8bit_state() self.init_8bit_state()
# weights are cast automatically as Int8Params, but the bias has to be cast manually # weights are cast automatically as Int8Params, but the bias has to be cast manually
# if self.bias is not None and self.bias.dtype != torch.float16: if self.bias is not None and self.bias.dtype != x.dtype:
# self.bias.data = self.bias.data.half() self.bias.data = self.bias.data.to(x.dtype)
#out = bnb.matmul(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
out = bnb.matmul(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias
if not self.state.has_fp16_weights: if not self.state.has_fp16_weights:
if not self.state.memory_efficient_backward and self.state.CB is not None: if self.state.CB is not None and self.state.CxB is not None:
# we converted 8-bit row major to turing/ampere format in the first inference pass # we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight # we no longer need the row-major weight
del self.state.CB del self.state.CB
self.weight.data = self.state.CxB self.weight.data = self.state.CxB
elif self.state.memory_efficient_backward and self.state.CxB is not None:
# For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass.
# Thus, we delete CxB from the state.
del self.state.CxB
return out return out
...@@ -336,22 +367,4 @@ class SwitchBackLinearBnb(nn.Linear): ...@@ -336,22 +367,4 @@ class SwitchBackLinearBnb(nn.Linear):
if self.weight.CB is not None: if self.weight.CB is not None:
self.init_8bit_state() self.init_8bit_state()
# weights are cast automatically as Int8Params, but the bias has to be cast manually
# if self.bias is not None and self.bias.dtype != torch.float16:
# self.bias.data = self.bias.data.half()
#out = bnb.matmul(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias
out = bnb.matmul_mixed(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias out = bnb.matmul_mixed(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias
if not self.state.has_fp16_weights:
if not self.state.memory_efficient_backward and self.state.CB is not None:
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
del self.state.CB
self.weight.data = self.state.CxB
elif self.state.memory_efficient_backward and self.state.CxB is not None:
# For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass.
# Thus, we delete CxB from the state.
del self.state.CxB
return out
...@@ -12,4 +12,5 @@ from .lamb import LAMB, LAMB8bit, LAMB32bit ...@@ -12,4 +12,5 @@ from .lamb import LAMB, LAMB8bit, LAMB32bit
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
from .optimizer import GlobalOptimManager from .optimizer import GlobalOptimManager
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
from .lion import Lion, Lion8bit, Lion32bit
from .sgd import SGD, SGD8bit, SGD32bit from .sgd import SGD, SGD8bit, SGD32bit
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer1State
class Lion(Optimizer1State):
def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"lion",
params,
lr,
betas,
0.,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class Lion8bit(Optimizer1State):
def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"lion",
params,
lr,
betas,
0.,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class Lion32bit(Optimizer1State):
def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"lion",
params,
lr,
betas,
0.,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
...@@ -665,7 +665,7 @@ class Optimizer1State(Optimizer8bit): ...@@ -665,7 +665,7 @@ class Optimizer1State(Optimizer8bit):
step, step,
config["lr"], config["lr"],
None, None,
0.0, config['betas'][1],
config["weight_decay"], config["weight_decay"],
gnorm_scale, gnorm_scale,
state["unorm_vec"] if config["max_unorm"] > 0.0 else None, state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
......
# Compiling from source # Compiling from source
Basic steps. Basic steps.
1. `make [target]` where `[target]` is among `cuda92, cuda10x, cuda110, cuda11x, cpuonly` 1. `CUDA_VERSION=XXX make [target]` where `[target]` is among `cuda92, cuda10x, cuda110, cuda11x, cuda12x, cpuonly`
2. `CUDA_VERSION=XXX python setup.py install` 2. `python setup.py install`
To run these steps you will need to have the nvcc compiler installed that comes with a CUDA installation. If you use anaconda (recommended) then you can figure out which version of CUDA you are using with PyTorch via the command `conda list | grep cudatoolkit`. Then you can install the nvcc compiler by downloading and installing the same CUDA version from the [CUDA toolkit archive](https://developer.nvidia.com/cuda-toolkit-archive). To run these steps you will need to have the nvcc compiler installed that comes with a CUDA installation. If you use anaconda (recommended) then you can figure out which version of CUDA you are using with PyTorch via the command `conda list | grep cudatoolkit`. Then you can install the nvcc compiler by downloading and installing the same CUDA version from the [CUDA toolkit archive](https://developer.nvidia.com/cuda-toolkit-archive).
For your convenience, there is an installation script in the root directory that installs CUDA 11.1 locally and configures it automatically. After installing you should add the `bin` sub-directory to the `$PATH` variable to make the compiler visible to your system. To do this you can add this to your `.bashrc` by executing these commands: You can install CUDA locally without sudo by following the following steps:
```bash ```bash
echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/lib64/" >> ~/.bashrc wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/cuda_install.sh
echo "export PATH=$PATH:/usr/local/cuda/bin/" >> ~/.bashrc # Syntax cuda_install CUDA_VERSION INSTALL_PREFIX EXPORT_TO_BASH
source ~/.bashrc # CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121}
# EXPORT_TO_BASH in {0, 1} with 0=False and 1=True
# For example, the following installs CUDA 11.7 to ~/local/cuda-11.7 and exports the path to your .bashrc
bash cuda install 117 ~/local 1
``` ```
By default, the Makefile will look at your `CUDA_HOME` environmental variable to find your CUDA version for compiling the library. If this path is not set it is inferred from the path of your `nvcc` compiler. By default, the Makefile will look at your `CUDA_HOME` environmental variable to find your CUDA version for compiling the library. If this path is not set it is inferred from the path of your `nvcc` compiler.
Either `nvcc` needs to be in path for the `CUDA_HOME` variable needs to be set to the CUDA directory root (e.g. `/usr/local/cuda`) in order for compilation to succeed Either `nvcc` needs to be in path for the `CUDA_HOME` variable needs to be set to the CUDA directory root (e.g. `/usr/local/cuda`) in order for compilation to succeed
If you type `nvcc` and it cannot be found, you might need to add to your path or set the CUDA_HOME variable. You can run `python -m bitsandbytes` to find the path to CUDA. For example if `python -m bitsandbytes` shows you the following:
```
++++++++++++++++++ /usr/local CUDA PATHS +++++++++++++++++++
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudart.so
```
You can set `CUDA_HOME` to `/usr/local/cuda-11.7`. For example, you might be able to compile like this.
``CUDA_HOME=~/local/cuda-11.7 CUDA_VERSION=117 make cuda11x``
If you have problems compiling the library with these instructions from source, please open an issue. If you have problems compiling the library with these instructions from source, please open an issue.
...@@ -43,6 +43,14 @@ __device__ float atomicMin(float* address, float val) { ...@@ -43,6 +43,14 @@ __device__ float atomicMin(float* address, float val) {
return __int_as_float(old); return __int_as_float(old);
} }
// sign function for lion
// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA
template <typename T>
__device__ int sgn(T val) {
return (T(0) < val) - (val < T(0));
}
template <int STOCHASTIC> template <int STOCHASTIC>
__device__ unsigned char dQuantize(float* smem_code, const float rand, float x) __device__ unsigned char dQuantize(float* smem_code, const float rand, float x)
{ {
...@@ -745,7 +753,7 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS> ...@@ -745,7 +753,7 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1) __launch_bounds__(BLOCK_SIZE/NUM_VALS, 1)
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p, __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
float* state1, float *unorm, float* state1, float *unorm,
const float beta1, const float eps, const float weight_decay, const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n) const int step, const float lr, const float gnorm_scale, const int n)
{ {
...@@ -792,6 +800,9 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, ...@@ -792,6 +800,9 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
break; break;
case LION:
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*(float)g_vals[j]); // state update
break;
case RMSPROP: case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update
s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value
...@@ -823,7 +834,7 @@ template<typename T, int OPTIMIZER> ...@@ -823,7 +834,7 @@ template<typename T, int OPTIMIZER>
__launch_bounds__(TH, 1) __launch_bounds__(TH, 1)
__global__ void kOptimizer32bit1State(T *g, T *p, __global__ void kOptimizer32bit1State(T *g, T *p,
float *state1, float *unorm, const float max_unorm, const float param_norm, float *state1, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float eps, const float weight_decay, const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
{ {
...@@ -892,6 +903,10 @@ __global__ void kOptimizer32bit1State(T *g, T *p, ...@@ -892,6 +903,10 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j])); p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j]));
break; break;
case LION:
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j]))));
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*((float)g_vals[j]));
break;
case RMSPROP: case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j]));
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps)); p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps));
...@@ -1160,7 +1175,7 @@ __global__ void ...@@ -1160,7 +1175,7 @@ __global__ void
__launch_bounds__(NUM_THREADS, 2) __launch_bounds__(NUM_THREADS, 2)
kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1,
float *unorm, float *unorm,
const float beta1, const float beta1, const float beta2,
const float eps, const int step, const float eps, const int step,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles1,
float* max1, float* new_max1, float* max1, float* new_max1,
...@@ -1221,6 +1236,9 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c ...@@ -1221,6 +1236,9 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
if(unorm != NULL) if(unorm != NULL)
local_unorm += s1_vals[j]*s1_vals[j]; local_unorm += s1_vals[j]*s1_vals[j];
break; break;
case LION:
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
break;
case RMSPROP: case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
break; break;
...@@ -1244,9 +1262,10 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c ...@@ -1244,9 +1262,10 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
template<typename T, int OPTIMIZER> template<typename T, int OPTIMIZER>
__global__ void __global__ void
__launch_bounds__(1024, 1)
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
const float *unorm, const float max_unorm, const float param_norm, const float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta1, const float beta2,
const float eps, const int step, const float lr, const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles1,
float* max1, float* new_max1, float* max1, float* new_max1,
...@@ -1309,8 +1328,19 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, ...@@ -1309,8 +1328,19 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
{ {
g_val = float(g_vals[j]); g_val = float(g_vals[j]);
g_val *= gnorm_scale; g_val *= gnorm_scale;
if(weight_decay > 0.0f)
g_val += ((float)p_vals[j])*weight_decay; if(weight_decay > 0.0f) {
switch(OPTIMIZER) {
case MOMENTUM:
case RMSPROP:
g_val += ((float)p_vals[j])*weight_decay;
break;
case LION:
p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay);
break;
}
}
s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0]; s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0];
switch(OPTIMIZER) switch(OPTIMIZER)
...@@ -1323,6 +1353,10 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, ...@@ -1323,6 +1353,10 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j])); p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j]));
break; break;
case LION:
p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val))));
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
break;
case RMSPROP: case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps)); p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps));
...@@ -1651,10 +1685,20 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char ...@@ -1651,10 +1685,20 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
{ {
g_val = float(g_vals[j]); g_val = float(g_vals[j]);
g_val *= gnorm_scale; g_val *= gnorm_scale;
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
{ {
if(weight_decay > 0.0f) if(weight_decay > 0.0f) {
g_val += ((float)p_vals[j])*weight_decay; switch(OPTIMIZER) {
case MOMENTUM:
case ADAGRAD:
case RMSPROP:
g_val += ((float)p_vals[j])*weight_decay;
break;
case LION:
p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay);
break;
}
}
s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];
...@@ -1666,6 +1710,11 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char ...@@ -1666,6 +1710,11 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
else else
s1_vals[j] = (s1_vals[j]*beta1) + g_val; s1_vals[j] = (s1_vals[j]*beta1) + g_val;
break; break;
case LION:
// here, using gvals[j] to store the gradient smoothed by beta1 for the following parameter update, before the momentum is updated by beta2
g_vals[j] = lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*g_val));
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
break;
case RMSPROP: case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
break; break;
...@@ -1703,6 +1752,9 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char ...@@ -1703,6 +1752,9 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
case MOMENTUM: case MOMENTUM:
p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]); p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]);
break; break;
case LION:
p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]);
break;
case RMSPROP: case RMSPROP:
g_val = g_vals[j]; g_val = g_vals[j];
p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
...@@ -2694,24 +2746,28 @@ template __global__ void kEstimateQuantiles(half *__restrict__ const A, float *c ...@@ -2694,24 +2746,28 @@ template __global__ void kEstimateQuantiles(half *__restrict__ const A, float *c
#define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \ #define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \
template __global__ void kPreconditionOptimizer32bit1State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \ template __global__ void kPreconditionOptimizer32bit1State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
float* state1, float *unorm, \ float* state1, float *unorm, \
const float beta1, const float eps, const float weight_decay, \ const float beta1, const float beta2, const float eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, const int n); \ const int step, const float lr, const float gnorm_scale, const int n); \
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half) MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half)
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float) MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float)
MAKE_PreconditionOptimizer32bit1State(RMSPROP, half) MAKE_PreconditionOptimizer32bit1State(RMSPROP, half)
MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) MAKE_PreconditionOptimizer32bit1State(RMSPROP, float)
MAKE_PreconditionOptimizer32bit1State(LION, half)
MAKE_PreconditionOptimizer32bit1State(LION, float)
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half)
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float)
#define MAKE_Optimizer32bit1State(oname, gtype) \ #define MAKE_Optimizer32bit1State(oname, gtype) \
template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \
const float beta1, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \ const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \
MAKE_Optimizer32bit1State(MOMENTUM, half) MAKE_Optimizer32bit1State(MOMENTUM, half)
MAKE_Optimizer32bit1State(MOMENTUM, float) MAKE_Optimizer32bit1State(MOMENTUM, float)
MAKE_Optimizer32bit1State(RMSPROP, half) MAKE_Optimizer32bit1State(RMSPROP, half)
MAKE_Optimizer32bit1State(RMSPROP, float) MAKE_Optimizer32bit1State(RMSPROP, float)
MAKE_Optimizer32bit1State(LION, half)
MAKE_Optimizer32bit1State(LION, float)
MAKE_Optimizer32bit1State(ADAGRAD, half) MAKE_Optimizer32bit1State(ADAGRAD, half)
MAKE_Optimizer32bit1State(ADAGRAD, float) MAKE_Optimizer32bit1State(ADAGRAD, float)
...@@ -2733,6 +2789,7 @@ template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p, ...@@ -2733,6 +2789,7 @@ template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p,
template __global__ void kPreconditionOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \ template __global__ void kPreconditionOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \
float *unorm, \ float *unorm, \
const float beta1, \ const float beta1, \
const float beta2, \
const float eps, const int step, \ const float eps, const int step, \
float* __restrict__ const quantiles1, \ float* __restrict__ const quantiles1, \
float* max1, float* new_max1, \ float* max1, float* new_max1, \
...@@ -2744,11 +2801,14 @@ MAKE_PreconditionStatic8bit1State(MOMENTUM, half) ...@@ -2744,11 +2801,14 @@ MAKE_PreconditionStatic8bit1State(MOMENTUM, half)
MAKE_PreconditionStatic8bit1State(MOMENTUM, float) MAKE_PreconditionStatic8bit1State(MOMENTUM, float)
MAKE_PreconditionStatic8bit1State(RMSPROP, half) MAKE_PreconditionStatic8bit1State(RMSPROP, half)
MAKE_PreconditionStatic8bit1State(RMSPROP, float) MAKE_PreconditionStatic8bit1State(RMSPROP, float)
MAKE_PreconditionStatic8bit1State(LION, half)
MAKE_PreconditionStatic8bit1State(LION, float)
#define MAKE_optimizerStatic8bit1State(oname, gtype) \ #define MAKE_optimizerStatic8bit1State(oname, gtype) \
template __global__ void kOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* const g, unsigned char* state1, \ template __global__ void kOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* const g, unsigned char* state1, \
const float *unorm, const float max_unorm, const float param_norm, \ const float *unorm, const float max_unorm, const float param_norm, \
const float beta1, \ const float beta1, \
const float beta2, \
const float eps, const int step, const float lr, \ const float eps, const int step, const float lr, \
float* __restrict__ const quantiles1, \ float* __restrict__ const quantiles1, \
float* max1, float* new_max1, \ float* max1, float* new_max1, \
...@@ -2760,6 +2820,8 @@ MAKE_optimizerStatic8bit1State(MOMENTUM, half) ...@@ -2760,6 +2820,8 @@ MAKE_optimizerStatic8bit1State(MOMENTUM, half)
MAKE_optimizerStatic8bit1State(MOMENTUM, float) MAKE_optimizerStatic8bit1State(MOMENTUM, float)
MAKE_optimizerStatic8bit1State(RMSPROP, half) MAKE_optimizerStatic8bit1State(RMSPROP, half)
MAKE_optimizerStatic8bit1State(RMSPROP, float) MAKE_optimizerStatic8bit1State(RMSPROP, float)
MAKE_optimizerStatic8bit1State(LION, half)
MAKE_optimizerStatic8bit1State(LION, float)
#define MAKE_PreconditionStatic8bit2State(oname, gtype) \ #define MAKE_PreconditionStatic8bit2State(oname, gtype) \
template __global__ void kPreconditionOptimizerStatic8bit2State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \ template __global__ void kPreconditionOptimizerStatic8bit2State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \
...@@ -2863,5 +2925,7 @@ MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8) ...@@ -2863,5 +2925,7 @@ MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8)
...@@ -32,20 +32,20 @@ __global__ void kOptimizer32bit2State(T* g, T* p, ...@@ -32,20 +32,20 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS> template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p, __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
float* state1, float *unorm, float* state1, float *unorm,
const float beta1, const float eps, const float weight_decay, const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n); const int step, const float lr, const float gnorm_scale, const int n);
template<typename T, int OPTIMIZER> template<typename T, int OPTIMIZER>
__global__ void kOptimizer32bit1State(T* g, T* p, __global__ void kOptimizer32bit1State(T* g, T* p,
float* state1, float *unorm, const float max_unorm, const float param_norm, float* state1, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float eps, const float weight_decay, const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template<typename T, int OPTIMIZER> template<typename T, int OPTIMIZER>
__global__ void __global__ void
kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1,
float *unorm, float *unorm,
const float beta1, const float beta1, const float beta2,
const float eps, const int step, const float eps, const int step,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles1,
float* max1, float* new_max1, float* max1, float* new_max1,
...@@ -57,7 +57,7 @@ template<typename T, int OPTIMIZER> ...@@ -57,7 +57,7 @@ template<typename T, int OPTIMIZER>
__global__ void __global__ void
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
const float *unorm, const float max_unorm, const float param_norm, const float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta1, const float beta2,
const float eps, const int step, const float lr, const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles1,
float* max1, float* new_max1, float* max1, float* new_max1,
......
...@@ -118,17 +118,28 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, ...@@ -118,17 +118,28 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
case MOMENTUM: case MOMENTUM:
case RMSPROP: case RMSPROP:
case ADAGRAD: case ADAGRAD:
if(max_unorm > 0.0f) if(max_unorm > 0.0f)
{ {
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n); kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
case LION:
// in lion, the momentum update after the parameter update
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
if(max_unorm > 0.0f)
{
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
break;
} }
} }
...@@ -162,12 +173,22 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, ...@@ -162,12 +173,22 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
case RMSPROP: case RMSPROP:
case ADAGRAD: case ADAGRAD:
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr, kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
case LION:
// in lion, the momentum update happens after the parameter update
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
default: default:
break; break;
} }
...@@ -196,6 +217,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g ...@@ -196,6 +217,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
case MOMENTUM: case MOMENTUM:
case RMSPROP: case RMSPROP:
case ADAGRAD: case ADAGRAD:
case LION:
num_blocks = n/BLOCKSIZE_1STATE; num_blocks = n/BLOCKSIZE_1STATE;
num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1; num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1;
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<num_blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr, kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<num_blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
...@@ -705,6 +727,8 @@ MAKE_optimizer32bit(MOMENTUM, half) ...@@ -705,6 +727,8 @@ MAKE_optimizer32bit(MOMENTUM, half)
MAKE_optimizer32bit(MOMENTUM, float) MAKE_optimizer32bit(MOMENTUM, float)
MAKE_optimizer32bit(RMSPROP, half) MAKE_optimizer32bit(RMSPROP, half)
MAKE_optimizer32bit(RMSPROP, float) MAKE_optimizer32bit(RMSPROP, float)
MAKE_optimizer32bit(LION, half)
MAKE_optimizer32bit(LION, float)
MAKE_optimizer32bit(ADAGRAD, half) MAKE_optimizer32bit(ADAGRAD, half)
MAKE_optimizer32bit(ADAGRAD, float) MAKE_optimizer32bit(ADAGRAD, float)
...@@ -724,6 +748,8 @@ MAKE_optimizerStatic8bit(MOMENTUM, half) ...@@ -724,6 +748,8 @@ MAKE_optimizerStatic8bit(MOMENTUM, half)
MAKE_optimizerStatic8bit(MOMENTUM, float) MAKE_optimizerStatic8bit(MOMENTUM, float)
MAKE_optimizerStatic8bit(RMSPROP, half) MAKE_optimizerStatic8bit(RMSPROP, half)
MAKE_optimizerStatic8bit(RMSPROP, float) MAKE_optimizerStatic8bit(RMSPROP, float)
MAKE_optimizerStatic8bit(LION, half)
MAKE_optimizerStatic8bit(LION, float)
#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \ #define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \
template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g, \ template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g, \
...@@ -736,6 +762,8 @@ MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM); ...@@ -736,6 +762,8 @@ MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM);
MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM); MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM);
MAKE_optimizerStatic8bitBlockwise(half, RMSPROP); MAKE_optimizerStatic8bitBlockwise(half, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); MAKE_optimizerStatic8bitBlockwise(float, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(half, LION);
MAKE_optimizerStatic8bitBlockwise(float, LION);
MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD); MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD);
MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD);
......
...@@ -70,6 +70,7 @@ typedef enum Optimizer_t ...@@ -70,6 +70,7 @@ typedef enum Optimizer_t
RMSPROP = 2, RMSPROP = 2,
LARS = 3, LARS = 3,
ADAGRAD = 4, ADAGRAD = 4,
LION = 5,
} Optimizer_t; } Optimizer_t;
typedef enum Transform_t typedef enum Transform_t
......
...@@ -33,6 +33,8 @@ MAKE_FUNC32(adam, ADAM, float, 32) ...@@ -33,6 +33,8 @@ MAKE_FUNC32(adam, ADAM, float, 32)
MAKE_FUNC32(adam, ADAM, half, 16) MAKE_FUNC32(adam, ADAM, half, 16)
MAKE_FUNC32(rmsprop, RMSPROP, float, 32) MAKE_FUNC32(rmsprop, RMSPROP, float, 32)
MAKE_FUNC32(rmsprop, RMSPROP, half, 16) MAKE_FUNC32(rmsprop, RMSPROP, half, 16)
MAKE_FUNC32(lion, LION, float, 32)
MAKE_FUNC32(lion, LION, half, 16)
MAKE_FUNC32(adagrad, ADAGRAD, float, 32) MAKE_FUNC32(adagrad, ADAGRAD, float, 32)
MAKE_FUNC32(adagrad, ADAGRAD, half, 16) MAKE_FUNC32(adagrad, ADAGRAD, half, 16)
...@@ -55,6 +57,8 @@ MAKE_FUNC8(momentum, MOMENTUM, float, 32) ...@@ -55,6 +57,8 @@ MAKE_FUNC8(momentum, MOMENTUM, float, 32)
MAKE_FUNC8(momentum, MOMENTUM, half, 16) MAKE_FUNC8(momentum, MOMENTUM, half, 16)
MAKE_FUNC8(rmsprop, RMSPROP, float, 32) MAKE_FUNC8(rmsprop, RMSPROP, float, 32)
MAKE_FUNC8(rmsprop, RMSPROP, half, 16) MAKE_FUNC8(rmsprop, RMSPROP, half, 16)
MAKE_FUNC8(lion, LION, float, 32)
MAKE_FUNC8(lion, LION, half, 16)
#define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \ #define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \
void fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \ void fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
...@@ -68,6 +72,8 @@ MAKE_BLOCKWISE8(momentum, MOMENTUM, half, 16) ...@@ -68,6 +72,8 @@ MAKE_BLOCKWISE8(momentum, MOMENTUM, half, 16)
MAKE_BLOCKWISE8(momentum, MOMENTUM, float, 32) MAKE_BLOCKWISE8(momentum, MOMENTUM, float, 32)
MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, 16) MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, 16)
MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, 32) MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, 32)
MAKE_BLOCKWISE8(lion, LION, half, 16)
MAKE_BLOCKWISE8(lion, LION, float, 32)
MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, 16) MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, 16)
MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32) MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32)
...@@ -161,6 +167,8 @@ extern "C" ...@@ -161,6 +167,8 @@ extern "C"
MAKE_CFUNC32(momentum, half, 16) MAKE_CFUNC32(momentum, half, 16)
MAKE_CFUNC32(rmsprop, float, 32) MAKE_CFUNC32(rmsprop, float, 32)
MAKE_CFUNC32(rmsprop, half, 16) MAKE_CFUNC32(rmsprop, half, 16)
MAKE_CFUNC32(lion, float, 32)
MAKE_CFUNC32(lion, half, 16)
MAKE_CFUNC32(adagrad, float, 32) MAKE_CFUNC32(adagrad, float, 32)
MAKE_CFUNC32(adagrad, half, 16) MAKE_CFUNC32(adagrad, half, 16)
...@@ -183,6 +191,8 @@ extern "C" ...@@ -183,6 +191,8 @@ extern "C"
MAKE_CFUNC8(momentum, half, 16) MAKE_CFUNC8(momentum, half, 16)
MAKE_CFUNC8(rmsprop, float, 32) MAKE_CFUNC8(rmsprop, float, 32)
MAKE_CFUNC8(rmsprop, half, 16) MAKE_CFUNC8(rmsprop, half, 16)
MAKE_CFUNC8(lion, float, 32)
MAKE_CFUNC8(lion, half, 16)
#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \ #define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \
void c##fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \ void c##fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
...@@ -196,6 +206,8 @@ extern "C" ...@@ -196,6 +206,8 @@ extern "C"
MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, 32) MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, 32)
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, 16) MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, 16)
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, 32) MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, 32)
MAKE_CBLOCKWISE8(lion, LION, half, 16)
MAKE_CBLOCKWISE8(lion, LION, float, 32)
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, 16) MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, 16)
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, 32) MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, 32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment