Commit bfa0e332 authored by Titus von Koeller's avatar Titus von Koeller
Browse files

ran black and isort for coherent code formatting

parent 597a8521
......@@ -3,14 +3,16 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .nn import modules
from .autograd._functions import mm_cublas, bmm_cublas, matmul_cublas, matmul, MatmulLtState
from .autograd._functions import (MatmulLtState, bmm_cublas, matmul,
matmul_cublas, mm_cublas)
from .cextension import COMPILED_WITH_CUDA
from .nn import modules
if COMPILED_WITH_CUDA:
from .optim import adam
__pdoc__ = {'libbitsandbytes': False,
'optim.optimizer.Optimizer8bit': False,
'optim.optimizer.MockArgs': False
}
__pdoc__ = {
"libbitsandbytes": False,
"optim.optimizer.Optimizer8bit": False,
"optim.optimizer.MockArgs": False,
}
from dataclasses import dataclass
import torch
import bitsandbytes as bnb
import bitsandbytes.functional as F
from dataclasses import dataclass
tensor = torch.Tensor
'''
"""
This class pools outlier dimensions across layers.
This is particularly important for small models where outlier features
are less systematic and occur with low frequency.
'''
"""
class GlobalOutlierPooler(object):
_instance = None
def __init__(self):
raise RuntimeError('Call get_instance() instead')
raise RuntimeError("Call get_instance() instead")
def initialize(self):
self.outliers = set()
......@@ -29,25 +32,29 @@ class GlobalOutlierPooler(object):
return cls._instance
def add_outliers(self, outlier_idx, feature_dim):
if self.model_dim is None: self.model_dim = feature_dim
if feature_dim != self.model_dim: return # we do not encode outliers for the 2nd FFN layer
if self.model_dim is None:
self.model_dim = feature_dim
if feature_dim != self.model_dim:
return # we do not encode outliers for the 2nd FFN layer
self.outliers.update(outlier_idx.tolist())
def get_current_outlier_idx(self):
return torch.Tensor(list(self.outliers)).to(torch.int64)
class MatMul8bit(torch.autograd.Function):
class MatMul8bit(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B, out=None, quant_type='vector', precision=[8, 8, 8]):
def forward(ctx, A, B, out=None, quant_type="vector", precision=[8, 8, 8]):
if precision[0] != 8:
with torch.no_grad():
output = torch.matmul(A, B)
else:
if len(B.shape) == 2: dim = 0
else: dim = 1
if len(B.shape) == 2:
dim = 0
else:
dim = 1
qA, SA = F.vectorwise_quant(A, dim=-1, quant_type=quant_type)
qB, SB = F.vectorwise_quant(B, dim=dim, quant_type=quant_type)
iout = F.igemm(qA, qB)
......@@ -84,21 +91,41 @@ class MatMul8bit(torch.autograd.Function):
else:
if len(B.shape) == 2 and len(A.shape) == 3:
grad_output = grad_output.contiguous()
if not grad_output.is_contiguous(): grad_output.contiguous()
qgrad_output, S1 = F.vectorwise_quant(grad_output.view(-1, grad_output.shape[2]), dim=0, quant_type=quant_type)
if not A.is_contiguous(): A = A.contiguous()
qA, S2 = F.vectorwise_quant(A.view(-1, A.shape[2]), dim=0, quant_type=quant_type)
if not grad_output.is_contiguous():
grad_output.contiguous()
qgrad_output, S1 = F.vectorwise_quant(
grad_output.view(-1, grad_output.shape[2]),
dim=0,
quant_type=quant_type,
)
if not A.is_contiguous():
A = A.contiguous()
qA, S2 = F.vectorwise_quant(
A.view(-1, A.shape[2]), dim=0, quant_type=quant_type
)
igrad_B = F.igemm(qA.t(), qgrad_output)
grad_B = F.vectorwise_mm_dequant(igrad_B, S2.t(), S1, grad_output.dtype, quant_type)
grad_B = F.vectorwise_mm_dequant(
igrad_B, S2.t(), S1, grad_output.dtype, quant_type
)
else:
qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
qgrad_output, S1 = F.vectorwise_quant(
grad_output, dim=dims, quant_type=quant_type
)
qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type)
igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
grad_B = F.vectorwise_mm_dequant(igrad_B, S2.permute(permute_dim), S1, grad_output.dtype, quant_type)
grad_B = F.vectorwise_mm_dequant(
igrad_B,
S2.permute(permute_dim),
S1,
grad_output.dtype,
quant_type,
)
if A.requires_grad:
if len(grad_output.shape) == 3: dims = [2]
else: dims = [1]
if len(grad_output.shape) == 3:
dims = [2]
else:
dims = [1]
if len(B.shape) == 3:
# bio -> boi
......@@ -113,10 +140,14 @@ class MatMul8bit(torch.autograd.Function):
with torch.no_grad():
grad_A = torch.matmul(grad_output, B.permute(permute_dim))
else:
qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
qgrad_output, S1 = F.vectorwise_quant(
grad_output, dim=dims, quant_type=quant_type
)
qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
grad_A = F.vectorwise_mm_dequant(igrad_A, S1, S3.permute(permute_dim), grad_output.dtype, quant_type)
grad_A = F.vectorwise_mm_dequant(
igrad_A, S1, S3.permute(permute_dim), grad_output.dtype, quant_type
)
return grad_A, grad_B, None, None, None
......@@ -125,6 +156,7 @@ mm_cublas = MatMul8bit.apply
bmm_cublas = MatMul8bit.apply
matmul_cublas = MatMul8bit.apply
@dataclass
class MatmulLtState:
CB = None
......@@ -159,7 +191,6 @@ class MatmulLtState:
class MatMul8bitLt(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B, out=None, state=MatmulLtState()):
# 1. Quantize A
......@@ -171,11 +202,15 @@ class MatMul8bitLt(torch.autograd.Function):
requires_gradB = B.requires_grad
formatB = state.formatB
input_shape = A.shape
if state.outlier_pool is None: state.outlier_pool = GlobalOutlierPooler.get_instance()
assert A.dtype == torch.float16, f'The input data type needs to be fp16 but {A.dtype} was found!'
if state.outlier_pool is None:
state.outlier_pool = GlobalOutlierPooler.get_instance()
assert (
A.dtype == torch.float16
), f"The input data type needs to be fp16 but {A.dtype} was found!"
# 1. Quantize A
if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous()
if len(A.shape) == 3:
A = A.view(-1, A.shape[-1]).contiguous()
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=state.threshold)
if state.threshold > 0.0 and coo_tensorA is not None:
......@@ -191,8 +226,8 @@ class MatMul8bitLt(torch.autograd.Function):
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
#state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
#if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
# state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
# if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
# # generate outlier index and subB
# outlier_idx = torch.unique(coo_tensorA.colidx).long()
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
......@@ -203,24 +238,24 @@ class MatMul8bitLt(torch.autograd.Function):
# state.idx = outlier_idx
# state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half()
#if state.idx is not None:
# if state.idx is not None:
# # extract outliers
# CA[:, state.idx] = 0
# CAt[:, state.idx] = 0
# subA = A[:, state.idx]
#else:
# else:
# subA = None
else:
if not state.has_fp16_weights and state.CxB is None:
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
subA = None
# 2. Quantize B
if state.has_fp16_weights:
has_grad = (True if (getattr(B, 'grad', None) is not None) else False)
has_grad = True if (getattr(B, "grad", None) is not None) else False
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
if is_transposed: B = B.contiguous()
if is_transposed:
B = B.contiguous()
if (state.is_training and not has_grad) or state.CxB is None:
state.reset_grads()
......@@ -234,14 +269,16 @@ class MatMul8bitLt(torch.autograd.Function):
outlier_idx = torch.unique(coo_tensorA.colidx)
state.idx = outlier_idx
#state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
#if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# # do not use pool for 2nd FFN layer
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
#else:
# else:
# state.idx = outlier_idx
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
state.subB = (outliers*state.SCB.view(-1, 1)/127.0).t().contiguous().half()
state.subB = (
(outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().half()
)
CA[:, state.idx.long()] = 0
CAt[:, state.idx.long()] = 0
subA = A[:, state.idx.long()]
......@@ -254,7 +291,7 @@ class MatMul8bitLt(torch.autograd.Function):
output_shape = (input_shape[0], shapeB[0])
# 3. Matmul
C32A, SA = F.transform(CA, 'col32')
C32A, SA = F.transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
output = F.mm_dequant(out32, Sout32, SCA, state.SCB)
......@@ -277,7 +314,7 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None)
#clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
# clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
clone_func = torch.clone
return clone_func(output.view(output_shape))
......@@ -288,7 +325,7 @@ class MatMul8bitLt(torch.autograd.Function):
SCAt, idx = ctx.tensor_states
formatB = ctx.formatB
state = ctx.state
assert state.has_fp16_weights, 'Backprop only supported for fp16 weights.'
assert state.has_fp16_weights, "Backprop only supported for fp16 weights."
if len(grad_output.shape) == 3:
grad_output = grad_output.view(-1, grad_output.shape[-1]).contiguous()
......@@ -298,18 +335,22 @@ class MatMul8bitLt(torch.autograd.Function):
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
if req_gradB:
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
C32grad, Sgrad = F.transform(Cgradt, 'col32', transpose=True)
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
if state.threshold > 0.0 and subA is not None:
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
if req_gradA:
C32grad, Sgrad = F.transform(Cgrad, 'col32')
C32grad, Sgrad = F.transform(Cgrad, "col32")
if state.CxBt is None:
state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
state.CxBt, state.SBt = F.transform(
state.CBt, to_order=formatB, transpose=True
)
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(
ctx.grad_shape
)
return grad_A, grad_B, None, None, None, None, None
......@@ -317,9 +358,10 @@ class MatMul8bitLt(torch.autograd.Function):
matmul = MatMul8bitLt.apply
def matmul(A : tensor, B : tensor, out : tensor=None, state : MatmulLtState = None, threshold=0.0):
def matmul(
A: tensor, B: tensor, out: tensor = None, state: MatmulLtState = None, threshold=0.0
):
state = state or MatmulLtState()
if threshold > 0.0:
state.threshold = threshold
return MatMul8bitLt.apply(A, B, out, state)
import ctypes as ct
import os
from warnings import warn
from bitsandbytes.cuda_setup import evaluate_cuda_setup
......@@ -8,17 +9,21 @@ class CUDALibrary_Singleton(object):
_instance = None
def __init__(self):
raise RuntimeError('Call get_instance() instead')
raise RuntimeError("Call get_instance() instead")
def initialize(self):
self.context = {}
binary_name = evaluate_cuda_setup()
if not os.path.exists(os.path.dirname(__file__) + f'/{binary_name}'):
print(f'TODO: compile library for specific version: {binary_name}')
print('defaulting to libbitsandbytes.so')
self.lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libbitsandbytes.so')
if not os.path.exists(os.path.dirname(__file__) + f"/{binary_name}"):
print(f"TODO: compile library for specific version: {binary_name}")
print("defaulting to libbitsandbytes.so")
self.lib = ct.cdll.LoadLibrary(
os.path.dirname(__file__) + "/libbitsandbytes.so"
)
else:
self.lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + f'/{binary_name}')
self.lib = ct.cdll.LoadLibrary(
os.path.dirname(__file__) + f"/{binary_name}"
)
@classmethod
def get_instance(cls):
......@@ -35,6 +40,8 @@ try:
lib.get_cusparse.restype = ct.c_void_p
COMPILED_WITH_CUDA = True
except AttributeError:
warn("The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers and GPU quantization are unavailable.")
warn(
"The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers and GPU quantization are unavailable."
)
COMPILED_WITH_CUDA = False
......@@ -18,31 +18,36 @@ evaluation:
- based on that set the default path
"""
import ctypes
import shlex
import subprocess
from os import environ as env
from pathlib import Path
from typing import Set, Union
from .utils import warn_of_missing_prerequisite, print_err
import ctypes
import shlex
import subprocess
from .utils import print_err, warn_of_missing_prerequisite
def execute_and_return(strCMD):
proc = subprocess.Popen(shlex.split(strCMD), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
proc = subprocess.Popen(
shlex.split(strCMD), stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
out, err = proc.communicate()
out, err = out.decode("UTF-8").strip(), err.decode("UTF-8").strip()
return out, err
def check_cuda_result(cuda, result_val):
if result_val != 0:
cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
print(f"Count not initialize CUDA - failure!")
raise Exception('CUDA exception!')
raise Exception("CUDA exception!")
return result_val
# taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
def get_compute_capability():
libnames = ('libcuda.so', 'libcuda.dylib', 'cuda.dll')
libnames = ("libcuda.so", "libcuda.dylib", "cuda.dll")
for libname in libnames:
try:
cuda = ctypes.CDLL(libname)
......@@ -51,8 +56,7 @@ def get_compute_capability():
else:
break
else:
raise OSError("could not load any of: " + ' '.join(libnames))
raise OSError("could not load any of: " + " ".join(libnames))
nGpus = ctypes.c_int()
cc_major = ctypes.c_int()
......@@ -69,39 +73,43 @@ def get_compute_capability():
ccs = []
for i in range(nGpus.value):
result = check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i))
result = check_cuda_result(cuda, cuda.cuDeviceComputeCapability(ctypes.byref(cc_major), ctypes.byref(cc_minor), device))
ccs.append(f'{cc_major.value}.{cc_minor.value}')
result = check_cuda_result(
cuda,
cuda.cuDeviceComputeCapability(
ctypes.byref(cc_major), ctypes.byref(cc_minor), device
),
)
ccs.append(f"{cc_major.value}.{cc_minor.value}")
#TODO: handle different compute capabilities; for now, take the max
# TODO: handle different compute capabilities; for now, take the max
ccs.sort()
return ccs[-1]
# return ccs[-1]
return ccs
CUDA_RUNTIME_LIB: str = "libcudart.so"
def tokenize_paths(paths: str) -> Set[Path]:
return {
Path(ld_path) for ld_path in paths.split(':')
if ld_path
}
return {Path(ld_path) for ld_path in paths.split(":") if ld_path}
def get_cuda_runtime_lib_path(
# TODO: replace this with logic for all paths in env vars
LD_LIBRARY_PATH: Union[str, None] = env.get("LD_LIBRARY_PATH")
) -> Union[Path, None]:
""" # TODO: add doc-string
"""
"""# TODO: add doc-string"""
if not LD_LIBRARY_PATH:
warn_of_missing_prerequisite(
'LD_LIBRARY_PATH is completely missing from environment!'
"LD_LIBRARY_PATH is completely missing from environment!"
)
return None
ld_library_paths: Set[Path] = tokenize_paths(LD_LIBRARY_PATH)
non_existent_directories: Set[Path] = {
path for path in ld_library_paths
if not path.exists()
path for path in ld_library_paths if not path.exists()
}
if non_existent_directories:
......@@ -111,7 +119,8 @@ def get_cuda_runtime_lib_path(
)
cuda_runtime_libs: Set[Path] = {
path / CUDA_RUNTIME_LIB for path in ld_library_paths
path / CUDA_RUNTIME_LIB
for path in ld_library_paths
if (path / CUDA_RUNTIME_LIB).is_file()
} - non_existent_directories
......@@ -126,26 +135,31 @@ def get_cuda_runtime_lib_path(
single_cuda_runtime_lib_dir = next(iter(cuda_runtime_libs))
return single_cuda_runtime_lib_dir
def evaluate_cuda_setup():
cuda_path = get_cuda_runtime_lib_path()
cc = get_compute_capability()
binary_name = 'libbitsandbytes_cpu.so'
binary_name = "libbitsandbytes_cpu.so"
if not (has_gpu := bool(cc)):
print('WARNING: No GPU detected! Check our CUDA paths. Processing to load CPU-only library...')
print(
"WARNING: No GPU detected! Check our CUDA paths. Processing to load CPU-only library..."
)
return binary_name
has_cublaslt = cc in ['7.5', '8.0', '8.6']
has_cublaslt = cc in ["7.5", "8.0", "8.6"]
# TODO:
# (1) Model missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)
# (2) Multiple CUDA versions installed
cuda_home = str(Path(cuda_path).parent.parent)
ls_output, err = execute_and_return(f'{cuda_home}/bin/nvcc --version')
cuda_version = ls_output.split('\n')[3].split(',')[-1].strip().lower().replace('v', '')
major, minor, revision = cuda_version.split('.')
cuda_version_string = f'{major}{minor}'
ls_output, err = execute_and_return(f"{cuda_home}/bin/nvcc --version")
cuda_version = (
ls_output.split("\n")[3].split(",")[-1].strip().lower().replace("v", "")
)
major, minor, revision = cuda_version.split(".")
cuda_version_string = f"{major}{minor}"
binary_name = f'libbitsandbytes_cuda{cuda_version_string}_{("cublaslt" if has_cublaslt else "")}.so'
......
import typer
cli = typer.Typer()
......
......@@ -9,47 +9,68 @@ from typing import Tuple
import torch
from torch import Tensor
from .cextension import lib, COMPILED_WITH_CUDA
from .cextension import COMPILED_WITH_CUDA, lib
name2qmap = {}
if COMPILED_WITH_CUDA:
''' C FUNCTIONS FOR OPTIMIZERS '''
"""C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer32bit = {}
str2optimizer32bit['adam'] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
str2optimizer32bit['momentum'] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16)
str2optimizer32bit['rmsprop'] = (lib.crmsprop32bit_g32, lib.crmsprop32bit_g16)
str2optimizer32bit['adagrad'] = (lib.cadagrad32bit_g32, lib.cadagrad32bit_g16)
str2optimizer32bit['lars'] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16)
str2optimizer32bit['lamb'] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
str2optimizer32bit["adam"] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
str2optimizer32bit["momentum"] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16)
str2optimizer32bit["rmsprop"] = (lib.crmsprop32bit_g32, lib.crmsprop32bit_g16)
str2optimizer32bit["adagrad"] = (lib.cadagrad32bit_g32, lib.cadagrad32bit_g16)
str2optimizer32bit["lars"] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16)
str2optimizer32bit["lamb"] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
str2optimizer8bit = {}
str2optimizer8bit['adam'] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16)
str2optimizer8bit['momentum'] = (lib.cmomentum_static_8bit_g32, lib.cmomentum_static_8bit_g16)
str2optimizer8bit['rmsprop'] = (lib.crmsprop_static_8bit_g32, lib.crmsprop_static_8bit_g16)
str2optimizer8bit['lamb'] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16)
str2optimizer8bit['lars'] = (lib.cmomentum_static_8bit_g32, lib.cmomentum_static_8bit_g16)
str2optimizer8bit["adam"] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16)
str2optimizer8bit["momentum"] = (
lib.cmomentum_static_8bit_g32,
lib.cmomentum_static_8bit_g16,
)
str2optimizer8bit["rmsprop"] = (
lib.crmsprop_static_8bit_g32,
lib.crmsprop_static_8bit_g16,
)
str2optimizer8bit["lamb"] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16)
str2optimizer8bit["lars"] = (
lib.cmomentum_static_8bit_g32,
lib.cmomentum_static_8bit_g16,
)
str2optimizer8bit_blockwise = {}
str2optimizer8bit_blockwise['adam'] = (lib.cadam_8bit_blockwise_fp32, lib.cadam_8bit_blockwise_fp16)
str2optimizer8bit_blockwise['momentum'] = (lib.cmomentum_8bit_blockwise_fp32, lib.cmomentum_8bit_blockwise_fp16)
str2optimizer8bit_blockwise['rmsprop'] = (lib.crmsprop_8bit_blockwise_fp32, lib.crmsprop_8bit_blockwise_fp16)
str2optimizer8bit_blockwise['adagrad'] = (lib.cadagrad_8bit_blockwise_fp32, lib.cadagrad_8bit_blockwise_fp16)
str2optimizer8bit_blockwise["adam"] = (
lib.cadam_8bit_blockwise_fp32,
lib.cadam_8bit_blockwise_fp16,
)
str2optimizer8bit_blockwise["momentum"] = (
lib.cmomentum_8bit_blockwise_fp32,
lib.cmomentum_8bit_blockwise_fp16,
)
str2optimizer8bit_blockwise["rmsprop"] = (
lib.crmsprop_8bit_blockwise_fp32,
lib.crmsprop_8bit_blockwise_fp16,
)
str2optimizer8bit_blockwise["adagrad"] = (
lib.cadagrad_8bit_blockwise_fp32,
lib.cadagrad_8bit_blockwise_fp16,
)
class CUBLAS_Context(object):
_instance = None
def __init__(self):
raise RuntimeError('Call get_instance() instead')
raise RuntimeError("Call get_instance() instead")
def initialize(self):
self.context = {}
#prev_device = torch.cuda.current_device()
#for i in range(torch.cuda.device_count()):
# prev_device = torch.cuda.current_device()
# for i in range(torch.cuda.device_count()):
# torch.cuda.set_device(torch.device('cuda', i))
# self.context.append(ct.c_void_p(lib.get_context()))
#torch.cuda.set_device(prev_device)
# torch.cuda.set_device(prev_device)
@classmethod
def get_instance(cls):
......@@ -66,11 +87,12 @@ class CUBLAS_Context(object):
torch.cuda.set_device(prev_device)
return self.context[device.index]
class Cusparse_Context(object):
_instance = None
def __init__(self):
raise RuntimeError('Call get_instance() instead')
raise RuntimeError("Call get_instance() instead")
def initialize(self):
self.context = ct.c_void_p(lib.get_cusparse())
......@@ -82,14 +104,16 @@ class Cusparse_Context(object):
cls._instance.initialize()
return cls._instance
def create_linear_map(signed=True):
if signed:
return torch.linspace(-1.0, 1.0, 256)
else:
return torch.linspace(0.0, 1.0, 256)
def create_dynamic_map(signed=True, n=7):
'''
"""
Creates the dynamic quantiztion map.
The dynamic data type is made up of a dynamic exponent and
......@@ -103,46 +127,54 @@ def create_dynamic_map(signed=True, n=7):
For more details see
(8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]
'''
"""
data = []
# these are additional items that come from the case
# where all the exponent bits are zero and no
# indicator bit is present
additional_items = 2**(7-n)-1
if not signed: additional_items = 2*additional_items
additional_items = 2 ** (7 - n) - 1
if not signed:
additional_items = 2 * additional_items
for i in range(n):
fraction_items = 2**(i+7-n)+1 if signed else 2**(i+7-n+1)+1
fraction_items = 2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1
boundaries = torch.linspace(0.1, 1, fraction_items)
means = (boundaries[:-1]+boundaries[1:])/2.0
data += ((10**(-(n-1)+i))*means).tolist()
means = (boundaries[:-1] + boundaries[1:]) / 2.0
data += ((10 ** (-(n - 1) + i)) * means).tolist()
if signed:
data += (-(10**(-(n-1)+i))*means).tolist()
data += (-(10 ** (-(n - 1) + i)) * means).tolist()
if additional_items > 0:
boundaries = torch.linspace(0.1, 1, additional_items+1)
means = (boundaries[:-1]+boundaries[1:])/2.0
data += ((10**(-(n-1)+i))*means).tolist()
boundaries = torch.linspace(0.1, 1, additional_items + 1)
means = (boundaries[:-1] + boundaries[1:]) / 2.0
data += ((10 ** (-(n - 1) + i)) * means).tolist()
if signed:
data += (-(10**(-(n-1)+i))*means).tolist()
data += (-(10 ** (-(n - 1) + i)) * means).tolist()
data.append(0)
data.append(1.0)
data.sort()
return Tensor(data)
def get_special_format_str():
major, minor = torch.cuda.get_device_capability()
if major < 7:
print(f'Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!')
print(
f"Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!"
)
assert major >= 7
if major == 7: return 'col_turing'
elif major == 8: return 'col_ampere'
else: return 'col_turing'
if major == 7:
return "col_turing"
elif major == 8:
return "col_ampere"
else:
return "col_turing"
def get_ptr(A: Tensor) -> ct.c_void_p:
'''
"""
Get the ctypes pointer from a PyTorch Tensor.
Parameters
......@@ -153,31 +185,39 @@ def get_ptr(A: Tensor) -> ct.c_void_p:
Returns
-------
ctypes.c_void_p
'''
if A is None: return None
else: return ct.c_void_p(A.data.storage().data_ptr())
"""
if A is None:
return None
else:
return ct.c_void_p(A.data.storage().data_ptr())
def pre_call(device):
prev_device = torch.cuda.current_device()
torch.cuda.set_device(device)
return prev_device
def post_call(prev_device):
torch.cuda.set_device(prev_device)
def get_transform_func(dtype, orderA, orderOut, transpose=False):
name = f'ctransform_{(8 if dtype == torch.int8 else 32)}_{orderA}_to_{orderOut}_{"t" if transpose else "n"}'
if not hasattr(lib, name):
print(name)
raise ValueError(f'Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}')
raise ValueError(
f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}"
)
else:
return getattr(lib, name)
class GlobalData(object):
_instance = None
def __init__(self):
raise RuntimeError('Call get_instance() instead')
raise RuntimeError("Call get_instance() instead")
def initialize(self):
self.data = {}
......@@ -190,15 +230,17 @@ class GlobalData(object):
return cls._instance
def get_transform_buffer(shape, dtype, device, to_order, from_order='row', transpose=False):
#init_func = torch.empty
def get_transform_buffer(
shape, dtype, device, to_order, from_order="row", transpose=False
):
# init_func = torch.empty
init_func = torch.zeros
dims = len(shape)
if dims == 2:
rows = shape[0]
elif dims == 3:
rows = shape[0]*shape[1]
rows = shape[0] * shape[1]
cols = shape[-1]
state = (shape, to_order)
......@@ -209,30 +251,39 @@ def get_transform_buffer(shape, dtype, device, to_order, from_order='row', trans
cols = tmp
state = (shape[::-1], to_order)
if to_order == 'row' or to_order == 'col':
if to_order == "row" or to_order == "col":
return init_func(shape, dtype=dtype, device=device), state
elif to_order == 'col32':
elif to_order == "col32":
# blocks of 32 columns (padded)
cols = 32*((cols+31)//32)
cols = 32 * ((cols + 31) // 32)
return init_func((rows, cols), dtype=dtype, device=device), state
elif to_order == 'col_turing':
elif to_order == "col_turing":
# blocks of 32 columns and 8 rows
cols = 32*((cols+31)//32)
rows = 8*((rows+7)//8)
cols = 32 * ((cols + 31) // 32)
rows = 8 * ((rows + 7) // 8)
return init_func((rows, cols), dtype=dtype, device=device), state
elif to_order == 'col_ampere':
elif to_order == "col_ampere":
# blocks of 32 columns and 32 rows
cols = 32*((cols+31)//32)
rows = 32*((rows+31)//32)
cols = 32 * ((cols + 31) // 32)
rows = 32 * ((rows + 31) // 32)
return init_func((rows, cols), dtype=dtype, device=device), state
else:
raise NotImplementedError(f'To_order not supported: {to_order}')
raise NotImplementedError(f"To_order not supported: {to_order}")
def nvidia_transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None):
if state is None: state = (A.shape, from_order)
else: from_order = state[1]
if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1])
else: new_state = (state[1], to_order)
def nvidia_transform(
A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None
):
if state is None:
state = (A.shape, from_order)
else:
from_order = state[1]
if out is None:
out, new_state = get_transform_buffer(
state[0], A.dtype, A.device, to_order, state[1]
)
else:
new_state = (state[1], to_order)
func = get_transform_func(A.dtype, from_order, to_order, transpose)
shape = state[0]
......@@ -242,10 +293,10 @@ def nvidia_transform(A, to_order, from_order='row', out=None, transpose=False, s
elif ld is not None:
n = math.prod(shape)
dim1 = math.prod([shape[i] for i in ld])
dim2 = ct.c_int32(n//dim1)
dim2 = ct.c_int32(n // dim1)
dim1 = ct.c_int32(dim1)
else:
dim1 = ct.c_int32(shape[0]*shape[1])
dim1 = ct.c_int32(shape[0] * shape[1])
dim2 = ct.c_int32(shape[2])
ptr = CUBLAS_Context.get_instance().get_context(A.device)
......@@ -253,11 +304,13 @@ def nvidia_transform(A, to_order, from_order='row', out=None, transpose=False, s
ptrOut = get_ptr(out)
func(ptr, get_ptr(A), get_ptr(out), dim1, dim2)
return out, new_state
def estimate_quantiles(A: Tensor, out: Tensor=None, offset: float=1/512) -> Tensor:
'''
def estimate_quantiles(
A: Tensor, out: Tensor = None, offset: float = 1 / 512
) -> Tensor:
"""
Estimates 256 equidistant quantiles on the input tensor eCDF.
Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles
......@@ -282,18 +335,26 @@ def estimate_quantiles(A: Tensor, out: Tensor=None, offset: float=1/512) -> Tens
-------
torch.Tensor:
The 256 quantiles in float32 datatype.
'''
if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device)
"""
if out is None:
out = torch.zeros((256,), dtype=torch.float32, device=A.device)
if A.dtype == torch.float32:
lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
lib.cestimate_quantiles_fp32(
get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
)
elif A.dtype == torch.float16:
lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
lib.cestimate_quantiles_fp16(
get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
)
else:
raise NotImplementedError(f'Not supported data type {A.dtype}')
raise NotImplementedError(f"Not supported data type {A.dtype}")
return out
def quantize_blockwise(A: Tensor, code: Tensor=None, absmax: Tensor=None, rand=None, out: Tensor=None) -> Tensor:
'''
def quantize_blockwise(
A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None
) -> Tensor:
"""
Quantize tensor A in blocks of size 4096 values.
Quantizes tensor A by dividing it into blocks of 4096 values.
......@@ -319,51 +380,96 @@ def quantize_blockwise(A: Tensor, code: Tensor=None, absmax: Tensor=None, rand=N
The 8-bit tensor.
tuple(torch.Tensor, torch.Tensor):
The quantization state to undo the quantization.
'''
"""
if code is None:
if 'dynamic' not in name2qmap: name2qmap['dynamic'] = create_dynamic_map().to(A.device)
code = name2qmap['dynamic']
if "dynamic" not in name2qmap:
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
code = name2qmap["dynamic"]
code = code.to(A.device)
if absmax is None:
n = A.numel()
num_blocks = 4096
blocks = n//num_blocks
blocks = n // num_blocks
blocks += 1 if n % num_blocks > 0 else 0
absmax = torch.zeros((blocks,), device=A.device)
if out is None: out = torch.zeros_like(A, dtype=torch.uint8)
if out is None:
out = torch.zeros_like(A, dtype=torch.uint8)
if A.device.type != 'cpu':
if A.device.type != "cpu":
if rand is not None:
assert rand.numel() >= 1024
rand_offset = random.randint(0, 1023)
if A.dtype == torch.float32:
lib.cquantize_blockwise_stochastic_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
lib.cquantize_blockwise_stochastic_fp32(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
get_ptr(rand),
ct.c_int32(rand_offset),
ct.c_int(A.numel()),
)
elif A.dtype == torch.float16:
lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
lib.cquantize_blockwise_stochastic_fp16(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
get_ptr(rand),
ct.c_int32(rand_offset),
ct.c_int(A.numel()),
)
else:
raise ValueError(f'Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}')
raise ValueError(
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
)
else:
if A.dtype == torch.float32:
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(A.numel()))
lib.cquantize_blockwise_fp32(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(A.numel()),
)
elif A.dtype == torch.float16:
lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(A.numel()))
lib.cquantize_blockwise_fp16(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(A.numel()),
)
else:
raise ValueError(f'Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}')
raise ValueError(
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
)
else:
# cpu
assert rand is None
lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(A.numel()))
lib.cquantize_blockwise_cpu_fp32(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(A.numel()),
)
return out, (absmax, code)
def dequantize_blockwise(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None,
absmax: Tensor=None, code: Tensor=None, out: Tensor=None,
blocksize: int=4096) -> Tensor:
'''
def dequantize_blockwise(
A: Tensor,
quant_state: Tuple[Tensor, Tensor] = None,
absmax: Tensor = None,
code: Tensor = None,
out: Tensor = None,
blocksize: int = 4096,
) -> Tensor:
"""
Dequantizes blockwise quantized values.
Dequantizes the tensor A with maximum absolute values absmax in
......@@ -387,57 +493,94 @@ def dequantize_blockwise(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None,
-------
torch.Tensor:
Dequantized tensor (default: float32)
'''
"""
assert quant_state is not None or absmax is not None
if code is None and quant_state is None:
if 'dynamic' not in name2qmap: name2qmap['dynamic'] = create_dynamic_map().to(A.device)
code = name2qmap['dynamic']
if "dynamic" not in name2qmap:
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
code = name2qmap["dynamic"]
code = code.to(A.device)
if out is None: out = torch.zeros_like(A, dtype=torch.float32)
if quant_state is None: quant_state = (absmax, code)
if out is None:
out = torch.zeros_like(A, dtype=torch.float32)
if quant_state is None:
quant_state = (absmax, code)
if blocksize not in [2048, 4096]:
raise ValueError(f'The blockwise of {blocksize} is not supported. Supported values: [2048 4096]')
raise ValueError(
f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]"
)
if A.device.type != 'cpu':
if A.device.type != "cpu":
if out.dtype == torch.float32:
lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
lib.cdequantize_blockwise_fp32(
get_ptr(quant_state[1]),
get_ptr(A),
get_ptr(quant_state[0]),
get_ptr(out),
ct.c_int(blocksize),
ct.c_int(A.numel()),
)
elif out.dtype == torch.float16:
lib.cdequantize_blockwise_fp16(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
lib.cdequantize_blockwise_fp16(
get_ptr(quant_state[1]),
get_ptr(A),
get_ptr(quant_state[0]),
get_ptr(out),
ct.c_int(blocksize),
ct.c_int(A.numel()),
)
else:
raise ValueError(f'Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}')
raise ValueError(
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
)
else:
lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(A.numel()))
lib.cdequantize_blockwise_cpu_fp32(
get_ptr(quant_state[1]),
get_ptr(A),
get_ptr(quant_state[0]),
get_ptr(out),
ct.c_int(A.numel()),
)
return out
def quantize(A: Tensor, code: Tensor=None, out: Tensor=None) -> Tensor:
def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor:
if code is None:
if 'dynamic' not in name2qmap: name2qmap['dynamic'] = create_dynamic_map().to(A.device)
code = name2qmap['dynamic']
if "dynamic" not in name2qmap:
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
code = name2qmap["dynamic"]
code = code.to(A.device)
absmax = torch.abs(A).max()
inp = A/absmax
inp = A / absmax
out = quantize_no_absmax(inp, code, out)
return out, (absmax, code)
def dequantize(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None, absmax: Tensor=None, code: Tensor=None, out: Tensor=None) -> Tensor:
def dequantize(
A: Tensor,
quant_state: Tuple[Tensor, Tensor] = None,
absmax: Tensor = None,
code: Tensor = None,
out: Tensor = None,
) -> Tensor:
assert quant_state is not None or absmax is not None
if code is None and quant_state is None:
if 'dynamic' not in name2qmap: name2qmap['dynamic'] = create_dynamic_map().to(A.device)
code = name2qmap['dynamic']
if "dynamic" not in name2qmap:
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
code = name2qmap["dynamic"]
code = code.to(A.device)
if quant_state is None: quant_state = (absmax, code)
if quant_state is None:
quant_state = (absmax, code)
out = dequantize_no_absmax(A, quant_state[1], out)
return out*quant_state[0]
return out * quant_state[0]
def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor:
'''
def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
"""
Quantizes input tensor to 8-bit.
Quantizes the 32-bit input tensor `A` to the 8-bit output tensor
......@@ -456,13 +599,15 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor:
-------
torch.Tensor:
Quantized 8-bit tensor.
'''
if out is None: out = torch.zeros_like(A, dtype=torch.uint8)
"""
if out is None:
out = torch.zeros_like(A, dtype=torch.uint8)
lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
return out
def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor:
'''
def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
"""
Dequantizes the 8-bit tensor to 32-bit.
Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via
......@@ -481,17 +626,31 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor:
-------
torch.Tensor:
32-bit output tensor.
'''
if out is None: out = torch.zeros_like(A, dtype=torch.float32)
"""
if out is None:
out = torch.zeros_like(A, dtype=torch.float32)
lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
return out
def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Tensor,
beta1: float, eps: float, step: int, lr: float,
state2: Tensor=None, beta2: float=0.0,
weight_decay: float=0.0, gnorm_scale: float=1.0,
unorm_vec: Tensor=None, max_unorm: float=0.0, skip_zeros=False) -> None:
'''
def optimizer_update_32bit(
optimizer_name: str,
g: Tensor,
p: Tensor,
state1: Tensor,
beta1: float,
eps: float,
step: int,
lr: float,
state2: Tensor = None,
beta2: float = 0.0,
weight_decay: float = 0.0,
gnorm_scale: float = 1.0,
unorm_vec: Tensor = None,
max_unorm: float = 0.0,
skip_zeros=False,
) -> None:
"""
Performs an inplace optimizer update with one or two optimizer states.
Universal optimizer update for 32-bit state and 32/16-bit gradients/weights.
......@@ -528,33 +687,84 @@ def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Ten
The maximum update norm relative to the weight norm.
skip_zeros : bool
Whether to skip zero-valued gradients or not (default: False).
'''
"""
param_norm = 0.0
if max_unorm > 0.0:
param_norm = torch.norm(p.data.float())
if optimizer_name not in str2optimizer32bit:
raise NotImplementedError(f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}')
raise NotImplementedError(
f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}'
)
if g.dtype == torch.float32 and state1.dtype == torch.float32:
str2optimizer32bit[optimizer_name][0](get_ptr(g), get_ptr(p), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), ct.c_float(max_unorm),
ct.c_float(param_norm), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_float(weight_decay),
ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), ct.c_int32(g.numel()))
str2optimizer32bit[optimizer_name][0](
get_ptr(g),
get_ptr(p),
get_ptr(state1),
get_ptr(state2),
get_ptr(unorm_vec),
ct.c_float(max_unorm),
ct.c_float(param_norm),
ct.c_float(beta1),
ct.c_float(beta2),
ct.c_float(eps),
ct.c_float(weight_decay),
ct.c_int32(step),
ct.c_float(lr),
ct.c_float(gnorm_scale),
ct.c_bool(skip_zeros),
ct.c_int32(g.numel()),
)
elif g.dtype == torch.float16 and state1.dtype == torch.float32:
str2optimizer32bit[optimizer_name][1](get_ptr(g), get_ptr(p), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), ct.c_float(max_unorm),
ct.c_float(param_norm), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_float(weight_decay),
ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), ct.c_int32(g.numel()))
str2optimizer32bit[optimizer_name][1](
get_ptr(g),
get_ptr(p),
get_ptr(state1),
get_ptr(state2),
get_ptr(unorm_vec),
ct.c_float(max_unorm),
ct.c_float(param_norm),
ct.c_float(beta1),
ct.c_float(beta2),
ct.c_float(eps),
ct.c_float(weight_decay),
ct.c_int32(step),
ct.c_float(lr),
ct.c_float(gnorm_scale),
ct.c_bool(skip_zeros),
ct.c_int32(g.numel()),
)
else:
raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}')
def optimizer_update_8bit(optimizer_name: str, g: Tensor, p: Tensor, state1: Tensor, state2: Tensor,
beta1: float, beta2: float, eps: float,
step: int, lr: float, qmap1: Tensor, qmap2: Tensor,
max1: Tensor, max2: Tensor, new_max1: Tensor, new_max2: Tensor,
weight_decay: float=0.0, gnorm_scale: float=1.0,
unorm_vec: Tensor=None, max_unorm: float=0.0) -> None:
'''
raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
)
def optimizer_update_8bit(
optimizer_name: str,
g: Tensor,
p: Tensor,
state1: Tensor,
state2: Tensor,
beta1: float,
beta2: float,
eps: float,
step: int,
lr: float,
qmap1: Tensor,
qmap2: Tensor,
max1: Tensor,
max2: Tensor,
new_max1: Tensor,
new_max2: Tensor,
weight_decay: float = 0.0,
gnorm_scale: float = 1.0,
unorm_vec: Tensor = None,
max_unorm: float = 0.0,
) -> None:
"""
Performs an inplace Adam update.
Universal Adam update for 32/8-bit state and 32/16-bit gradients/weights.
......@@ -602,56 +812,135 @@ def optimizer_update_8bit(optimizer_name: str, g: Tensor, p: Tensor, state1: Ten
The tensor for the update norm.
max_unorm : float
The maximum update norm relative to the weight norm.
'''
"""
param_norm = 0.0
if max_unorm > 0.0:
param_norm = torch.norm(p.data.float())
if g.dtype == torch.float32 and state1.dtype == torch.uint8:
str2optimizer8bit[optimizer_name][0](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2),
get_ptr(unorm_vec), ct.c_float(max_unorm), ct.c_float(param_norm),
ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps),
ct.c_int32(step), ct.c_float(lr),
get_ptr(qmap1), get_ptr(qmap2),
get_ptr(max1), get_ptr(max2), get_ptr(new_max1), get_ptr(new_max2),
ct.c_float(weight_decay),ct.c_float(gnorm_scale), ct.c_int32(g.numel()))
str2optimizer8bit[optimizer_name][0](
get_ptr(p),
get_ptr(g),
get_ptr(state1),
get_ptr(state2),
get_ptr(unorm_vec),
ct.c_float(max_unorm),
ct.c_float(param_norm),
ct.c_float(beta1),
ct.c_float(beta2),
ct.c_float(eps),
ct.c_int32(step),
ct.c_float(lr),
get_ptr(qmap1),
get_ptr(qmap2),
get_ptr(max1),
get_ptr(max2),
get_ptr(new_max1),
get_ptr(new_max2),
ct.c_float(weight_decay),
ct.c_float(gnorm_scale),
ct.c_int32(g.numel()),
)
elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
str2optimizer8bit[optimizer_name][1](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2),
get_ptr(unorm_vec), ct.c_float(max_unorm), ct.c_float(param_norm),
ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps),
ct.c_int32(step), ct.c_float(lr),
get_ptr(qmap1), get_ptr(qmap2),
get_ptr(max1), get_ptr(max2), get_ptr(new_max1), get_ptr(new_max2),
ct.c_float(weight_decay),ct.c_float(gnorm_scale), ct.c_int32(g.numel()))
str2optimizer8bit[optimizer_name][1](
get_ptr(p),
get_ptr(g),
get_ptr(state1),
get_ptr(state2),
get_ptr(unorm_vec),
ct.c_float(max_unorm),
ct.c_float(param_norm),
ct.c_float(beta1),
ct.c_float(beta2),
ct.c_float(eps),
ct.c_int32(step),
ct.c_float(lr),
get_ptr(qmap1),
get_ptr(qmap2),
get_ptr(max1),
get_ptr(max2),
get_ptr(new_max1),
get_ptr(new_max2),
ct.c_float(weight_decay),
ct.c_float(gnorm_scale),
ct.c_int32(g.numel()),
)
else:
raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}')
def optimizer_update_8bit_blockwise(optimizer_name: str, g: Tensor, p: Tensor, state1: Tensor, state2: Tensor,
beta1: float, beta2: float, eps: float,
step: int, lr: float, qmap1: Tensor, qmap2: Tensor,
absmax1: Tensor, absmax2: Tensor, weight_decay: float=0.0, gnorm_scale: float=1.0,
skip_zeros=False) -> None:
raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
)
def optimizer_update_8bit_blockwise(
optimizer_name: str,
g: Tensor,
p: Tensor,
state1: Tensor,
state2: Tensor,
beta1: float,
beta2: float,
eps: float,
step: int,
lr: float,
qmap1: Tensor,
qmap2: Tensor,
absmax1: Tensor,
absmax2: Tensor,
weight_decay: float = 0.0,
gnorm_scale: float = 1.0,
skip_zeros=False,
) -> None:
if g.dtype == torch.float32 and state1.dtype == torch.uint8:
str2optimizer8bit_blockwise[optimizer_name][0](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2),
ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps),
ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2),
get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale),
ct.c_bool(skip_zeros), ct.c_int32(g.numel()))
str2optimizer8bit_blockwise[optimizer_name][0](
get_ptr(p),
get_ptr(g),
get_ptr(state1),
get_ptr(state2),
ct.c_float(beta1),
ct.c_float(beta2),
ct.c_float(eps),
ct.c_int32(step),
ct.c_float(lr),
get_ptr(qmap1),
get_ptr(qmap2),
get_ptr(absmax1),
get_ptr(absmax2),
ct.c_float(weight_decay),
ct.c_float(gnorm_scale),
ct.c_bool(skip_zeros),
ct.c_int32(g.numel()),
)
elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
str2optimizer8bit_blockwise[optimizer_name][1](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2),
ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps),
ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2),
get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale),
ct.c_bool(skip_zeros), ct.c_int32(g.numel()))
str2optimizer8bit_blockwise[optimizer_name][1](
get_ptr(p),
get_ptr(g),
get_ptr(state1),
get_ptr(state2),
ct.c_float(beta1),
ct.c_float(beta2),
ct.c_float(eps),
ct.c_int32(step),
ct.c_float(lr),
get_ptr(qmap1),
get_ptr(qmap2),
get_ptr(absmax1),
get_ptr(absmax2),
ct.c_float(weight_decay),
ct.c_float(gnorm_scale),
ct.c_bool(skip_zeros),
ct.c_int32(g.numel()),
)
else:
raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}')
raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
)
def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int=5):
def percentile_clipping(
grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5
):
"""Applies percentile clipping
grad: torch.Tensor
......@@ -663,11 +952,21 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile:
"""
if grad.dtype == torch.float32:
lib.cpercentile_clipping_g32(get_ptr(grad), get_ptr(gnorm_vec), ct.c_int32(step), ct.c_int32(grad.numel()))
lib.cpercentile_clipping_g32(
get_ptr(grad),
get_ptr(gnorm_vec),
ct.c_int32(step),
ct.c_int32(grad.numel()),
)
elif grad.dtype == torch.float16:
lib.cpercentile_clipping_g16(get_ptr(grad), get_ptr(gnorm_vec), ct.c_int32(step), ct.c_int32(grad.numel()))
lib.cpercentile_clipping_g16(
get_ptr(grad),
get_ptr(gnorm_vec),
ct.c_int32(step),
ct.c_int32(grad.numel()),
)
else:
raise ValueError(f'Gradient type {grad.dtype} not supported!')
raise ValueError(f"Gradient type {grad.dtype} not supported!")
current_gnorm = torch.sqrt(gnorm_vec[step % 100])
vals, idx = torch.sort(gnorm_vec)
......@@ -675,31 +974,44 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile:
gnorm_scale = 1.0
if current_gnorm > clip_value:
gnorm_scale = clip_value/current_gnorm
gnorm_scale = clip_value / current_gnorm
return current_gnorm, clip_value, gnorm_scale
def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor):
def histogram_scatter_add_2d(
histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor
):
assert len(histogram.shape) == 2
assert histogram.dtype == torch.float32
assert source.dtype == torch.float32
assert index1.dtype == torch.int32
assert index2.dtype == torch.int32
assert histogram.device.type == 'cuda'
assert index1.device.type == 'cuda'
assert index2.device.type == 'cuda'
assert source.device.type == 'cuda'
assert histogram.device.type == "cuda"
assert index1.device.type == "cuda"
assert index2.device.type == "cuda"
assert source.device.type == "cuda"
maxdim1 = ct.c_int32(histogram.shape[0])
n = ct.c_int32(index1.numel())
lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n)
lib.chistogram_scatter_add_2d(
get_ptr(histogram),
get_ptr(index1),
get_ptr(index2),
get_ptr(source),
maxdim1,
n,
)
def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8):
if not torch.cuda.is_initialized(): torch.cuda.init()
if not torch.cuda.is_initialized():
torch.cuda.init()
if A.dtype != expected_type or B.dtype != expected_type:
raise TypeError(f'Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}')
raise TypeError(
f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}"
)
sA = A.shape
sB = B.shape
......@@ -709,64 +1021,101 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
correct = True
if len(sA) == 2 and len(sB) == 2:
if not tA and not tB and A.shape[1] != B.shape[0]: correct = False
elif tA and not tB and A.shape[0] != B.shape[0]: correct = False
elif tA and tB and A.shape[0] != B.shape[1]: correct = False
elif not tA and tB and A.shape[1] != B.shape[1]: correct = False
if not tA and not tB and A.shape[1] != B.shape[0]:
correct = False
elif tA and not tB and A.shape[0] != B.shape[0]:
correct = False
elif tA and tB and A.shape[0] != B.shape[1]:
correct = False
elif not tA and tB and A.shape[1] != B.shape[1]:
correct = False
elif len(sA) == 3 and len(sB) == 2:
if not tA and not tB and A.shape[2] != B.shape[0]: correct = False
elif tA and not tB and A.shape[1] != B.shape[0]: correct = False
elif tA and tB and A.shape[1] != B.shape[1]: correct = False
elif not tA and tB and A.shape[2] != B.shape[1]: correct = False
if not tA and not tB and A.shape[2] != B.shape[0]:
correct = False
elif tA and not tB and A.shape[1] != B.shape[0]:
correct = False
elif tA and tB and A.shape[1] != B.shape[1]:
correct = False
elif not tA and tB and A.shape[2] != B.shape[1]:
correct = False
elif len(sA) == 3 and len(sB) == 3:
if not tA and not tB and A.shape[2] != B.shape[1]: correct = False
elif tA and not tB and A.shape[1] != B.shape[1]: correct = False
elif tA and tB and A.shape[1] != B.shape[2]: correct = False
elif not tA and tB and A.shape[2] != B.shape[2]: correct = False
if not tA and not tB and A.shape[2] != B.shape[1]:
correct = False
elif tA and not tB and A.shape[1] != B.shape[1]:
correct = False
elif tA and tB and A.shape[1] != B.shape[2]:
correct = False
elif not tA and tB and A.shape[2] != B.shape[2]:
correct = False
if out is not None:
sout = out.shape
# special case common in backprop
if not correct and len(sA) == 3 and len(sB) == 3:
if (sout[0] == sA[2] and sout[1] == sB[2] and
sA[0] == sB[0] and sA[1] == sB[1]):
if (
sout[0] == sA[2]
and sout[1] == sB[2]
and sA[0] == sB[0]
and sA[1] == sB[1]
):
correct = True
else:
if len(sA) == 2 and len(sB) == 2:
if not tA and not tB: sout = (sA[0], sB[1])
elif tA and tB: sout = (sA[1], sB[0])
elif tA and not tB: sout = (sA[1], sB[1])
elif not tA and tB: sout = (sA[0], sB[0])
if not tA and not tB:
sout = (sA[0], sB[1])
elif tA and tB:
sout = (sA[1], sB[0])
elif tA and not tB:
sout = (sA[1], sB[1])
elif not tA and tB:
sout = (sA[0], sB[0])
elif len(sA) == 3 and len(sB) == 2:
if not tA and not tB: sout = (sA[0], sA[1], sB[1])
elif tA and tB: sout = (sA[0], sA[2], sB[0])
elif tA and not tB: sout = (sA[0], sA[2], sB[1])
elif not tA and tB: sout = (sA[0], sA[1], sB[0])
if not tA and not tB:
sout = (sA[0], sA[1], sB[1])
elif tA and tB:
sout = (sA[0], sA[2], sB[0])
elif tA and not tB:
sout = (sA[0], sA[2], sB[1])
elif not tA and tB:
sout = (sA[0], sA[1], sB[0])
elif len(sA) == 3 and len(sB) == 3:
if not tA and not tB: sout = (sA[0], sA[1], sB[2])
elif tA and tB: sout = (sA[0], sA[2], sB[1])
elif tA and not tB: sout = (sA[0], sA[2], sB[2])
elif not tA and tB: sout = (sA[0], sA[1], sB[1])
if not tA and not tB:
sout = (sA[0], sA[1], sB[2])
elif tA and tB:
sout = (sA[0], sA[2], sB[1])
elif tA and not tB:
sout = (sA[0], sA[2], sB[2])
elif not tA and tB:
sout = (sA[0], sA[1], sB[1])
if not correct:
raise ValueError(f'Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}.')
raise ValueError(
f"Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}."
)
return sout
def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed_B=False):
def igemm(
A: Tensor, B: Tensor, out: Tensor = None, transposed_A=False, transposed_B=False
):
sout = check_matmul(A, B, out, transposed_A, transposed_B)
if out is None: out = torch.zeros(size=sout, dtype=torch.int32, device=A.device)
if out is None:
out = torch.zeros(size=sout, dtype=torch.int32, device=A.device)
if len(A.shape) == 3 and len(B.shape) == 3:
if A.shape[0] == B.shape[0] and A.shape[2] == B.shape[1]:
return batched_igemm(A, B, out)
sA = A.shape
sB = B.shape
if transposed_A and len(sA) == 2: sA = (sA[1], sA[0])
elif transposed_A and len(sA) == 3: sA = (sA[0], sA[2], sA[0])
if transposed_B and len(sB) == 2: sB = (sB[1], sB[0])
elif transposed_B and len(sB) == 3: sB = (sB[0], sB[2], sB[0])
if transposed_A and len(sA) == 2:
sA = (sA[1], sA[0])
elif transposed_A and len(sA) == 3:
sA = (sA[0], sA[2], sA[0])
if transposed_B and len(sB) == 2:
sB = (sB[1], sB[0])
elif transposed_B and len(sB) == 3:
sB = (sB[0], sB[2], sB[0])
# this is a mess: cuBLAS expect column major, but PyTorch is row major.
# So to perform the matrix multiplication, we have to treat A, B, and C matrices
# (transpose of row major is column major)
......@@ -777,23 +1126,28 @@ def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed
# row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n]
# column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m]
if len(sB) == 2:
if B.stride()[0] == B.shape[1]: transposed_B = False
elif B.stride()[1] == B.shape[0]: transposed_B = True
if B.stride()[0] == B.shape[1]:
transposed_B = False
elif B.stride()[1] == B.shape[0]:
transposed_B = True
if len(A.shape) == 2:
if A.stride()[0] == A.shape[1]: transposed_A = False
elif A.stride()[1] == A.shape[0]: transposed_A = True
if A.stride()[0] == A.shape[1]:
transposed_A = False
elif A.stride()[1] == A.shape[0]:
transposed_A = True
else:
if A.stride()[1] == A.shape[2]: transposed_A = False
elif A.stride()[2] == A.shape[1]: transposed_A = True
if A.stride()[1] == A.shape[2]:
transposed_A = False
elif A.stride()[2] == A.shape[1]:
transposed_A = True
if len(sA) == 2:
n = sA[0]
ldb = A.stride()[1 if transposed_A else 0]
elif len(sA) == 3 and len(sB) == 2:
n = sA[0]*sA[1]
n = sA[0] * sA[1]
ldb = sA[2]
m = sB[1]
k = sB[0]
lda = B.stride()[(1 if transposed_B else 0)]
......@@ -802,34 +1156,52 @@ def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed
# special case
assert len(sA) == 3
if not (sA[0] == sB[0] and sA[1] == sB[1]):
raise ValueError(f'Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}')
raise ValueError(
f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}"
)
transposed_A = True
transposed_B = False
m = sB[2]
n = sA[2]
k = sB[0]*sB[1]
k = sB[0] * sB[1]
lda = m
ldb = sA[2]
ldc = m
ptr = CUBLAS_Context.get_instance().get_context(A.device)
# B^T @ A^T = C^T
# [km, nk -> mn]
lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc))
lib.cigemm(
ptr,
ct.c_bool(transposed_B),
ct.c_bool(transposed_A),
ct.c_int32(m),
ct.c_int32(n),
ct.c_int32(k),
get_ptr(B),
get_ptr(A),
get_ptr(out),
ct.c_int32(lda),
ct.c_int32(ldb),
ct.c_int32(ldc),
)
return out
def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed_B=False):
def batched_igemm(
A: Tensor, B: Tensor, out: Tensor = None, transposed_A=False, transposed_B=False
):
if not len(A.shape) == 3 or not len(B.shape) == 3:
raise ValueError(f'Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}')
raise ValueError(
f"Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}"
)
sout = check_matmul(A, B, out, transposed_A, transposed_B)
if out is None: out = torch.zeros(size=sout, dtype=torch.int32, device=A.device)
if out is None:
out = torch.zeros(size=sout, dtype=torch.int32, device=A.device)
if B.is_contiguous():
lda = B.stride()[1]
......@@ -886,17 +1258,33 @@ def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, tr
ldc = m
strideA = B.shape[1]*B.shape[2]
strideB = A.shape[1]*A.shape[2]
strideC = A.shape[1]*B.shape[2]
strideA = B.shape[1] * B.shape[2]
strideB = A.shape[1] * A.shape[2]
strideC = A.shape[1] * B.shape[2]
ptr = CUBLAS_Context.get_instance().get_context(A.device)
lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc),
ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch))
lib.cbatched_igemm(
ptr,
ct.c_bool(transposed_B),
ct.c_bool(transposed_A),
ct.c_int32(m),
ct.c_int32(n),
ct.c_int32(k),
get_ptr(B),
get_ptr(A),
get_ptr(out),
ct.c_int32(lda),
ct.c_int32(ldb),
ct.c_int32(ldc),
ct.c_long(strideA),
ct.c_long(strideB),
ct.c_long(strideC),
ct.c_uint32(num_batch),
)
return out
def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
shapeA = SA[0]
shapeB = SB[0]
......@@ -905,28 +1293,34 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
if dimsA == 2:
m = shapeA[0]
elif dimsA == 3:
m = shapeA[0]*shapeA[1]
m = shapeA[0] * shapeA[1]
if dimsB == 2:
rows = n = shapeB[0]
elif dimsB == 3:
rows = n = shapeB[0]*shapeB[1]
rows = n = shapeB[0] * shapeB[1]
if dimsA == 2 and out is None:
out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, 'col32', 'row')
out, Sout = get_transform_buffer(
(shapeA[0], shapeB[0]), dtype, A.device, "col32", "row"
)
elif dimsA == 3 and out is None:
out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, 'col32', 'row')
out, Sout = get_transform_buffer(
(shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row"
)
assert dimsB != 3, 'len(B.shape)==3 not supported'
assert A.device.type == 'cuda'
assert B.device.type == 'cuda'
assert dimsB != 3, "len(B.shape)==3 not supported"
assert A.device.type == "cuda"
assert B.device.type == "cuda"
assert A.dtype == torch.int8
assert B.dtype == torch.int8
assert out.dtype == dtype
assert SA[1] == 'col32'
assert SB[1] in ['col_turing', 'col_ampere']
assert Sout[1] == 'col32'
assert shapeA[-1] == shapeB[-1], f'Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}'
assert SA[1] == "col32"
assert SB[1] in ["col_turing", "col_ampere"]
assert Sout[1] == "col32"
assert (
shapeA[-1] == shapeB[-1]
), f"Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}"
formatB = SB[1]
prev_device = A.device
torch.cuda.set_device(A.device)
......@@ -937,53 +1331,76 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
ptrC = get_ptr(out)
k = shapeA[-1]
lda = ct.c_int32(m*32)
if formatB == 'col_turing':
lda = ct.c_int32(m * 32)
if formatB == "col_turing":
# turing: tiles with rows filled up to multiple of 8 rows by 32 columns
# n = rows
ldb = ct.c_int32(((rows+7)//8)*8*32)
ldb = ct.c_int32(((rows + 7) // 8) * 8 * 32)
else:
# ampere: tiles with rows filled up to multiple of 32 rows by 32 columns
# n = rows
ldb = ct.c_int32(((rows+31)//32)*32*32)
ldb = ct.c_int32(((rows + 31) // 32) * 32 * 32)
ldc = ct.c_int32(m*32)
ldc = ct.c_int32(m * 32)
m = ct.c_int32(m)
n = ct.c_int32(n)
k = ct.c_int32(k)
has_error = 0
ptrRowScale = get_ptr(None)
if formatB == 'col_turing':
if formatB == "col_turing":
if dtype == torch.int32:
has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
has_error = lib.cigemmlt_turing_32(
ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
)
else:
has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
elif formatB == 'col_ampere':
has_error = lib.cigemmlt_turing_8(
ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
)
elif formatB == "col_ampere":
if dtype == torch.int32:
has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
has_error = lib.cigemmlt_ampere_32(
ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
)
else:
has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
has_error = lib.cigemmlt_ampere_8(
ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
)
if has_error == 1:
raise Exception('cublasLt ran into an error!')
raise Exception("cublasLt ran into an error!")
torch.cuda.set_device(prev_device)
return out, Sout
def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None):
def mm_dequant(
A,
quant_state,
row_stats,
col_stats,
out=None,
new_row_stats=None,
new_col_stats=None,
):
assert A.dtype == torch.int32
out_shape = quant_state[0]
if len(out_shape) == 3: out_shape = (out_shape[0]*out_shape[1], out_shape[2])
if out is None: out = torch.empty(out_shape, dtype=torch.float16, device=A.device)
if new_row_stats is None: new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device)
if new_col_stats is None: new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device)
assert new_row_stats.shape[0] == row_stats.shape[0], f"{new_row_stats.shape} vs {row_stats.shape}"
assert new_col_stats.shape[0] == col_stats.shape[0], f"{new_col_stats.shape} vs {col_stats.shape}"
if len(out_shape) == 3:
out_shape = (out_shape[0] * out_shape[1], out_shape[2])
if out is None:
out = torch.empty(out_shape, dtype=torch.float16, device=A.device)
if new_row_stats is None:
new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device)
if new_col_stats is None:
new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device)
assert (
new_row_stats.shape[0] == row_stats.shape[0]
), f"{new_row_stats.shape} vs {row_stats.shape}"
assert (
new_col_stats.shape[0] == col_stats.shape[0]
), f"{new_col_stats.shape} vs {col_stats.shape}"
ptrA = get_ptr(A)
ptrOut = get_ptr(out)
......@@ -994,27 +1411,47 @@ def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=Non
numRows = ct.c_int32(out_shape[0])
numCols = ct.c_int32(out_shape[1])
lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, numRows, numCols)
lib.cdequant_mm_int32_fp16(
ptrA,
ptrRowStats,
ptrColStats,
ptrOut,
ptrNewRowStats,
ptrNewColStats,
numRows,
numCols,
)
return out
def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0):
def get_colrow_absmax(
A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0
):
assert A.dtype == torch.float16
device = A.device
cols = A.shape[-1]
if len(A.shape) == 3:
rows = A.shape[0]*A.shape[1]
rows = A.shape[0] * A.shape[1]
else:
rows = A.shape[0]
col_tiles = (cols+255)//256
tiled_rows = ((rows+15)//16)*16
if row_stats is None: row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_(-50000.0)
if col_stats is None: col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_(-50000.0)
if nnz_block_ptr is None and threshold > 0.0: nnz_block_ptr = torch.zeros(((tiled_rows*col_tiles)+1,), dtype=torch.int32, device=device)
col_tiles = (cols + 255) // 256
tiled_rows = ((rows + 15) // 16) * 16
if row_stats is None:
row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_(
-50000.0
)
if col_stats is None:
col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_(
-50000.0
)
if nnz_block_ptr is None and threshold > 0.0:
nnz_block_ptr = torch.zeros(
((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device
)
ptrA = get_ptr(A)
ptrRowStats = get_ptr(row_stats)
......@@ -1024,16 +1461,17 @@ def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, thr
cols = ct.c_int32(cols)
prev_device = pre_call(A.device)
lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols)
lib.cget_col_row_stats(
ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols
)
post_call(prev_device)
if threshold > 0.0:
nnz_block_ptr.cumsum_(0)
return row_stats, col_stats, nnz_block_ptr
class COOSparseTensor(object):
def __init__(self, rows, cols, nnz, rowidx, colidx, values):
assert rowidx.dtype == torch.int32
......@@ -1050,6 +1488,7 @@ class COOSparseTensor(object):
self.colidx = colidx
self.values = values
class CSRSparseTensor(object):
def __init__(self, rows, cols, nnz, rowptr, colidx, values):
assert rowptr.dtype == torch.int32
......@@ -1057,7 +1496,7 @@ class CSRSparseTensor(object):
assert values.dtype == torch.float16
assert values.numel() == nnz
assert colidx.numel() == nnz
assert rowptr.numel() == rows+1
assert rowptr.numel() == rows + 1
self.rows = rows
self.cols = cols
......@@ -1066,6 +1505,7 @@ class CSRSparseTensor(object):
self.colidx = colidx
self.values = values
class CSCSparseTensor(object):
def __init__(self, rows, cols, nnz, colptr, rowidx, values):
assert colptr.dtype == torch.int32
......@@ -1073,7 +1513,7 @@ class CSCSparseTensor(object):
assert values.dtype == torch.float16
assert values.numel() == nnz
assert rowidx.numel() == nnz
assert colptr.numel() == cols+1
assert colptr.numel() == cols + 1
self.rows = rows
self.cols = cols
......@@ -1082,13 +1522,17 @@ class CSCSparseTensor(object):
self.rowidx = rowidx
self.values = values
def coo2csr(cooA):
values, counts = torch.unique(cooA.rowidx, return_counts=True)
values.add_(1)
rowptr = torch.zeros((cooA.rows+1, ), dtype=torch.int32, device=cooA.rowidx.device)
rowptr = torch.zeros((cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device)
rowptr.scatter_(index=values.long(), src=counts.int(), dim=0)
rowptr.cumsum_(0)
return CSRSparseTensor(cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values)
return CSRSparseTensor(
cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values
)
def coo2csc(cooA):
val, col2rowidx = torch.sort(cooA.colidx)
......@@ -1096,11 +1540,12 @@ def coo2csc(cooA):
values = cooA.values[col2rowidx]
colvalues, counts = torch.unique(val, return_counts=True)
colvalues.add_(1)
colptr = torch.zeros((cooA.cols+1, ), dtype=torch.int32, device=cooA.colidx.device)
colptr = torch.zeros((cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device)
colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0)
colptr.cumsum_(0)
return CSCSparseTensor(cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values)
def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
rowidx = torch.zeros((nnz,), dtype=torch.int32, device=device)
colidx = torch.zeros((nnz,), dtype=torch.int32, device=device)
......@@ -1108,23 +1553,27 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values)
def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0):
def double_quant(
A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0
):
device = A.device
assert A.dtype == torch.half
assert device.type == 'cuda'
assert device.type == "cuda"
prev_device = pre_call(A.device)
cols = A.shape[-1]
if len(A.shape) == 3:
rows = A.shape[0]*A.shape[1]
rows = A.shape[0] * A.shape[1]
else:
rows = A.shape[0]
if row_stats is None or col_stats is None:
row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold)
if out_col is None: out_col = torch.zeros(A.shape, device=device, dtype=torch.int8)
if out_row is None: out_row = torch.zeros(A.shape, device=device, dtype=torch.int8)
if out_col is None:
out_col = torch.zeros(A.shape, device=device, dtype=torch.int8)
if out_row is None:
out_row = torch.zeros(A.shape, device=device, dtype=torch.int8)
coo_tensor = None
ptrA = get_ptr(A)
......@@ -1136,21 +1585,62 @@ def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None,
if threshold > 0.0:
nnz = nnz_row_ptr[-1].item()
if nnz > 0:
coo_tensor = coo_zeros(A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device)
coo_tensor = coo_zeros(
A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device
)
ptrRowIdx = get_ptr(coo_tensor.rowidx)
ptrColIdx = get_ptr(coo_tensor.colidx)
ptrVal = get_ptr(coo_tensor.values)
ptrRowPtr = get_ptr(nnz_row_ptr)
lib.cdouble_rowcol_quant(ptrA, ptrRowStats, ptrColStats, ptrOutCol, ptrOutRow, ptrRowIdx, ptrColIdx, ptrVal, ptrRowPtr, ct.c_float(threshold), ct.c_int32(rows), ct.c_int32(cols))
lib.cdouble_rowcol_quant(
ptrA,
ptrRowStats,
ptrColStats,
ptrOutCol,
ptrOutRow,
ptrRowIdx,
ptrColIdx,
ptrVal,
ptrRowPtr,
ct.c_float(threshold),
ct.c_int32(rows),
ct.c_int32(cols),
)
val, idx = torch.sort(coo_tensor.rowidx)
coo_tensor.rowidx = val
coo_tensor.colidx = coo_tensor.colidx[idx]
coo_tensor.values = coo_tensor.values[idx]
else:
lib.cdouble_rowcol_quant(ptrA, ptrRowStats, ptrColStats, ptrOutCol, ptrOutRow, None, None, None, None, ct.c_float(0.0), ct.c_int32(rows), ct.c_int32(cols))
lib.cdouble_rowcol_quant(
ptrA,
ptrRowStats,
ptrColStats,
ptrOutCol,
ptrOutRow,
None,
None,
None,
None,
ct.c_float(0.0),
ct.c_int32(rows),
ct.c_int32(cols),
)
else:
lib.cdouble_rowcol_quant(ptrA, ptrRowStats, ptrColStats, ptrOutCol, ptrOutRow, None, None, None, None, ct.c_float(threshold), ct.c_int32(rows), ct.c_int32(cols))
lib.cdouble_rowcol_quant(
ptrA,
ptrRowStats,
ptrColStats,
ptrOutCol,
ptrOutRow,
None,
None,
None,
None,
ct.c_float(threshold),
ct.c_int32(rows),
ct.c_int32(cols),
)
post_call(prev_device)
return out_row, out_col, row_stats, col_stats, coo_tensor
......@@ -1159,69 +1649,81 @@ def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None,
def get_special_format_str():
major, minor = torch.cuda.get_device_capability()
if major < 7:
print(f'Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!')
print(
f"Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!"
)
assert major >= 7
if major == 7: return 'col_turing'
elif major == 8: return 'col_ampere'
else: return 'col_turing'
if major == 7:
return "col_turing"
elif major == 8:
return "col_ampere"
else:
return "col_turing"
def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None):
if state is None: state = (A.shape, from_order)
else: from_order = state[1]
if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose)
else: new_state = (state[0], to_order) # (shape, order)
def transform(
A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None
):
if state is None:
state = (A.shape, from_order)
else:
from_order = state[1]
if out is None:
out, new_state = get_transform_buffer(
state[0], A.dtype, A.device, to_order, state[1], transpose
)
else:
new_state = (state[0], to_order) # (shape, order)
shape = state[0]
if len(shape) == 2:
dim1 = ct.c_int32(shape[0])
dim2 = ct.c_int32(shape[1])
else:
dim1 = ct.c_int32(shape[0]*shape[1])
dim1 = ct.c_int32(shape[0] * shape[1])
dim2 = ct.c_int32(shape[2])
ptrA = get_ptr(A)
ptrOut = get_ptr(out)
if to_order == 'col32':
if to_order == "col32":
if transpose:
lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2)
else:
lib.ctransform_row2col32(get_ptr(A), get_ptr(out), dim1, dim2)
elif to_order == 'col_turing':
elif to_order == "col_turing":
if transpose:
lib.ctransform_row2turingT(get_ptr(A), get_ptr(out), dim1, dim2)
else:
lib.ctransform_row2turing(get_ptr(A), get_ptr(out), dim1, dim2)
elif to_order == 'col_ampere':
elif to_order == "col_ampere":
if transpose:
lib.ctransform_row2ampereT(get_ptr(A), get_ptr(out), dim1, dim2)
else:
lib.ctransform_row2ampere(get_ptr(A), get_ptr(out), dim1, dim2)
elif to_order == 'row':
if from_order == 'col_turing':
elif to_order == "row":
if from_order == "col_turing":
lib.ctransform_turing2row(get_ptr(A), get_ptr(out), dim1, dim2)
elif from_order == 'col_ampere':
elif from_order == "col_ampere":
lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2)
else:
raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}')
raise NotImplementedError(
f"Transform function not implemented: From {from_order} to {to_order}"
)
return out, new_state
def spmm_coo(cooA, B, out=None):
if out is None: out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype)
if out is None:
out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype)
nnz = cooA.nnz
assert cooA.rowidx.numel() == nnz
assert cooA.colidx.numel() == nnz
assert cooA.values.numel() == nnz
assert cooA.cols == B.shape[0]
transposed_B = (False if B.is_contiguous() else True)
transposed_B = False if B.is_contiguous() else True
ldb = B.stride()[(1 if transposed_B else 0)]
ldc = B.shape[1]
......@@ -1240,19 +1742,37 @@ def spmm_coo(cooA, B, out=None):
cldb = ct.c_int32(ldb)
cldc = ct.c_int32(ldc)
lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B))
lib.cspmm_coo(
ptr,
ptrRowidx,
ptrColidx,
ptrValues,
cnnz,
crowsA,
ccolsA,
ccolsB,
cldb,
ptrB,
cldc,
ptrC,
ct.c_bool(transposed_B),
)
return out
def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
if out is None: out = torch.zeros((cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype)
if out is None:
out = torch.zeros(
(cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype
)
nnz = cooA.nnz
assert cooA.rowidx.numel() == nnz
assert cooA.colidx.numel() == nnz
assert cooA.values.numel() == nnz
assert cooA.cols == B.shape[0], f'{cooA.cols} vs {B.shape}'
assert cooA.cols == B.shape[0], f"{cooA.cols} vs {B.shape}"
transposed_B = (False if B.is_contiguous() else True)
transposed_B = False if B.is_contiguous() else True
ldb = B.stride()[(1 if transposed_B else 0)]
ldc = B.shape[1]
......@@ -1262,7 +1782,9 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
max_count, max_idx = torch.sort(counts, descending=True)
max_idx = max_idx.int()
max_count = max_count.int()
assert max_count[0] <= 32, f'Current max count per row is 8 but found {max_count[0]}.'
assert (
max_count[0] <= 32
), f"Current max count per row is 8 but found {max_count[0]}."
assert B.dtype in [torch.float16, torch.int8]
ptrOffset = get_ptr(offset)
ptrMaxCount = get_ptr(max_count)
......@@ -1282,134 +1804,183 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
ccolsB = ct.c_int32(B.shape[1])
cldb = ct.c_int32(ldb)
cldc = ct.c_int32(ldc)
#print(cooA.rowidx[:64])
#print(cooA.colidx[:64].sort()[0])
# print(cooA.rowidx[:64])
# print(cooA.colidx[:64].sort()[0])
if B.dtype == torch.float16:
lib.cspmm_coo_very_sparse_naive_fp16(ptrMaxCount, ptrMaxIdx, ptrOffset, ptrRowidx, ptrColidx, ptrValues, ptrB, ptrC, ptrDequantStats, cnnz_rows, cnnz, crowsA, crowsB, ccolsB)
lib.cspmm_coo_very_sparse_naive_fp16(
ptrMaxCount,
ptrMaxIdx,
ptrOffset,
ptrRowidx,
ptrColidx,
ptrValues,
ptrB,
ptrC,
ptrDequantStats,
cnnz_rows,
cnnz,
crowsA,
crowsB,
ccolsB,
)
elif B.dtype == torch.int8:
lib.cspmm_coo_very_sparse_naive_int8(ptrMaxCount, ptrMaxIdx, ptrOffset, ptrRowidx, ptrColidx, ptrValues, ptrB, ptrC, ptrDequantStats, cnnz_rows, cnnz, crowsA, crowsB, ccolsB)
#else: assertion error
lib.cspmm_coo_very_sparse_naive_int8(
ptrMaxCount,
ptrMaxIdx,
ptrOffset,
ptrRowidx,
ptrColidx,
ptrValues,
ptrB,
ptrC,
ptrDequantStats,
cnnz_rows,
cnnz,
crowsA,
crowsB,
ccolsB,
)
# else: assertion error
return out
C = 127.0
def vectorwise_quant(x, dim=1, quant_type='vector'):
if quant_type == 'linear':
def vectorwise_quant(x, dim=1, quant_type="vector"):
if quant_type == "linear":
max1 = torch.abs(x).max().float()
xq = torch.round(x/max1*127).to(torch.int8)
xq = torch.round(x / max1 * 127).to(torch.int8)
return xq, max1
elif quant_type in ['vector', 'row']:
elif quant_type in ["vector", "row"]:
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
xq = torch.round(x*(C/max1)).to(torch.int8)
xq = torch.round(x * (C / max1)).to(torch.int8)
return xq, max1
elif quant_type == 'zeropoint':
elif quant_type == "zeropoint":
dtype = x.dtype
x = x.float()
dyna = x.max() - x.min()
if dyna == 0: dyna = 1
qx = 255./dyna
if dyna == 0:
dyna = 1
qx = 255.0 / dyna
minx = x.min()
zpx = torch.round(minx* qx)
x = torch.round(qx*x - zpx) + zpx
zpx = torch.round(minx * qx)
x = torch.round(qx * x - zpx) + zpx
return x, qx
elif quant_type in ['vector-zeropoint', 'row-zeropoint']:
elif quant_type in ["vector-zeropoint", "row-zeropoint"]:
dtype = x.dtype
x = x.float()
dyna = (torch.amax(x, dim=dim, keepdim=True) - torch.amin(x, dim=dim, keepdim=True))
dyna[dyna==0] = 1
qx = 255./dyna
dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin(
x, dim=dim, keepdim=True
)
dyna[dyna == 0] = 1
qx = 255.0 / dyna
minx = torch.amin(x, dim=dim, keepdim=True)
zpx = torch.round(minx* qx)
x = torch.round(qx*x - zpx) + zpx
zpx = torch.round(minx * qx)
x = torch.round(qx * x - zpx) + zpx
return x, qx
elif quant_type == 'truncated-vector':
elif quant_type == "truncated-vector":
with torch.no_grad():
absx = torch.abs(x)
max1 = torch.amax(absx, dim=dim, keepdim=True)
max1 = max1*0.7
idx = (absx > max1.expand_as(absx))
max1 = max1 * 0.7
idx = absx > max1.expand_as(absx)
sign = torch.sign(x[idx])
x[idx] = max1.expand_as(absx)[idx]*sign
xq = torch.round(x/max1*C).to(torch.int8)
x[idx] = max1.expand_as(absx)[idx] * sign
xq = torch.round(x / max1 * C).to(torch.int8)
return xq, max1
else: return None
else:
return None
def vectorwise_dequant(xq, max1, quant_type='vector'):
if quant_type == 'vector':
x = (xq/C*max1).to(torch.float32)
def vectorwise_dequant(xq, max1, quant_type="vector"):
if quant_type == "vector":
x = (xq / C * max1).to(torch.float32)
return x
else: return None
else:
return None
def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type='vector'):
if quant_type == 'linear':
norm = S1*S2/(C*C)
def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"):
if quant_type == "linear":
norm = S1 * S2 / (C * C)
# double cast needed to prevent overflows
return (xq.float()*norm).to(dtype)
elif quant_type == 'zeropoint':
norm = 1.0/(S1*S2)
return (xq.float()*norm).to(dtype)
elif quant_type == 'row-zeropoint':
norm = 1.0/(S1*S2)
return (xq.float() * norm).to(dtype)
elif quant_type == "zeropoint":
norm = 1.0 / (S1 * S2)
return (xq.float() * norm).to(dtype)
elif quant_type == "row-zeropoint":
norm = 1.0 / (S1 * S2)
x = xq.float()
if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0)
if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0)
if len(S1.shape) == 3 and len(x.shape) == 2:
S1 = S1.squeeze(0)
if len(S2.shape) == 3 and len(x.shape) == 2:
S2 = S2.squeeze(0)
if len(S1.shape) == 2:
x *= norm
else:
x *= norm
return x.to(dtype)
elif quant_type == 'vector-zeropoint':
elif quant_type == "vector-zeropoint":
x = xq.float()
if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0)
if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0)
if len(S1.shape) == 3 and len(x.shape) == 2:
S1 = S1.squeeze(0)
if len(S2.shape) == 3 and len(x.shape) == 2:
S2 = S2.squeeze(0)
if len(S1.shape) == 2:
x *= 1.0/S1
x *= 1.0 / S1
else:
x *= 1.0/S1
x *= 1.0/S2.t()
x *= 1.0 / S1
x *= 1.0 / S2.t()
return x.to(dtype)
elif quant_type == 'row':
elif quant_type == "row":
x = xq.float()
if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0)
if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0)
if len(S1.shape) == 3 and len(x.shape) == 2:
S1 = S1.squeeze(0)
if len(S2.shape) == 3 and len(x.shape) == 2:
S2 = S2.squeeze(0)
if len(S1.shape) == 2:
x *= S1*S2/(C*C)
x *= S1 * S2 / (C * C)
else:
x *= S1*S2/(C*C)
x *= S1 * S2 / (C * C)
return x.to(dtype)
elif quant_type in ['truncated-vector', 'vector']:
elif quant_type in ["truncated-vector", "vector"]:
x = xq.float()
if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0)
if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0)
if len(S1.shape) == 3 and len(x.shape) == 2:
S1 = S1.squeeze(0)
if len(S2.shape) == 3 and len(x.shape) == 2:
S2 = S2.squeeze(0)
if len(S1.shape) == 2:
x *= S1/C
x *= S1 / C
else:
x *= S1/C
x *= S2/C
x *= S1 / C
x *= S2 / C
return x.to(dtype)
else: return None
else:
return None
def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half):
offset = B.float().t().sum(0)*(SA[0]+SA[1])
offset = B.float().t().sum(0) * (SA[0] + SA[1])
x = xq.float()
if len(xq.shape) == 2 and len(SB.shape) == 3: SB = SB.squeeze(0)
if len(xq.shape) == 2 and len(SB.shape) == 3:
SB = SB.squeeze(0)
if len(SB.shape) == 2:
x *= SB.t()/127
x *= SB.t() / 127
else:
x *= SB/127
x *= SA[1]/127
x +=offset
x *= SB / 127
x *= SA[1] / 127
x += offset
return x.to(dtype)
def extract_outliers(A, SA, idx):
shapeA = SA[0]
formatA = SA[1]
assert formatA in ['col_turing', 'col_ampere']
assert A.device.type == 'cuda'
assert formatA in ["col_turing", "col_ampere"]
assert A.device.type == "cuda"
out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device)
......@@ -1420,13 +1991,9 @@ def extract_outliers(A, SA, idx):
ptrIdx = get_ptr(idx)
ptrOut = get_ptr(out)
if formatA == 'col_turing':
if formatA == "col_turing":
lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
elif formatA == 'col_ampere':
elif formatA == "col_ampere":
lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
return out
......@@ -2,4 +2,4 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .modules import StableEmbedding, Linear8bit, Linear8bitLt, Int8Params
from .modules import Int8Params, Linear8bit, Linear8bitLt, StableEmbedding
......@@ -2,38 +2,58 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import bitsandbytes as bnb
from typing import (Any, Callable, Dict, Iterator, Mapping, Optional, Set,
Tuple, TypeVar, Union, overload)
from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict
from torch import Tensor, device, dtype
from torch import nn
from torch.nn.parameter import Parameter
import torch
import torch.nn.functional as F
from torch import Tensor, device, dtype, nn
from torch.nn.parameter import Parameter
import bitsandbytes as bnb
from bitsandbytes.optim import GlobalOptimManager
T = TypeVar('T', bound='torch.nn.Module')
T = TypeVar("T", bound="torch.nn.Module")
class StableEmbedding(torch.nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
sparse: bool = False, _weight: Optional[Tensor] = None) -> None:
super(StableEmbedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight)
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
_weight: Optional[Tensor] = None,
) -> None:
super(StableEmbedding, self).__init__(
num_embeddings,
embedding_dim,
padding_idx,
max_norm,
norm_type,
scale_grad_by_freq,
sparse,
_weight,
)
self.norm = torch.nn.LayerNorm(embedding_dim)
GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32})
GlobalOptimManager.get_instance().register_module_override(
self, "weight", {"optim_bits": 32}
)
def reset_parameters(self) -> None:
torch.nn.init.xavier_uniform_(self.weight)
self._fill_padding_idx_with_zero()
''' !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
""" !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
to make the Layer compatible with Pytorch < 1.9.
This means that if this changes in future PyTorch releases this need to change too
which is cumbersome. However, with this we can ensure compatibility with previous
PyTorch releases.
'''
"""
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
......@@ -41,29 +61,55 @@ class StableEmbedding(torch.nn.Embedding):
def forward(self, input: Tensor) -> Tensor:
emb = F.embedding(
input, self.weight, self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse)
input,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
return self.norm(emb)
class Embedding(torch.nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
sparse: bool = False, _weight: Optional[Tensor] = None) -> None:
super(Embedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight)
GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32})
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
_weight: Optional[Tensor] = None,
) -> None:
super(Embedding, self).__init__(
num_embeddings,
embedding_dim,
padding_idx,
max_norm,
norm_type,
scale_grad_by_freq,
sparse,
_weight,
)
GlobalOptimManager.get_instance().register_module_override(
self, "weight", {"optim_bits": 32}
)
def reset_parameters(self) -> None:
torch.nn.init.xavier_uniform_(self.weight)
self._fill_padding_idx_with_zero()
''' !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
""" !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
to make the Layer compatible with Pytorch < 1.9.
This means that if this changes in future PyTorch releases this need to change too
which is cumbersome. However, with this we can ensure compatibility with previous
PyTorch releases.
'''
"""
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
......@@ -71,13 +117,22 @@ class Embedding(torch.nn.Embedding):
def forward(self, input: Tensor) -> Tensor:
emb = F.embedding(
input, self.weight, self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse)
input,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
return emb
class Int8Params(torch.nn.Parameter):
def __new__(cls, data=None, requires_grad=True, has_fp16_weights=False, CB=None, SCB=None):
def __new__(
cls, data=None, requires_grad=True, has_fp16_weights=False, CB=None, SCB=None
):
cls.has_fp16_weights = has_fp16_weights
cls.CB = None
cls.SCB = None
......@@ -96,14 +151,18 @@ class Int8Params(torch.nn.Parameter):
del CBt
del SCBt
self.data = CB
setattr(self, 'CB', CB)
setattr(self, 'SCB', SCB)
setattr(self, "CB", CB)
setattr(self, "SCB", SCB)
return self
@overload
def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ...,
non_blocking: bool = ...) -> T:
def to(
self: T,
device: Optional[Union[int, device]] = ...,
dtype: Optional[Union[dtype, str]] = ...,
non_blocking: bool = ...,
) -> T:
...
@overload
......@@ -115,23 +174,41 @@ class Int8Params(torch.nn.Parameter):
...
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None and device.type == 'cuda' and self.data.device.type == 'cpu': return self.cuda(device)
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
*args, **kwargs
)
if (
device is not None
and device.type == "cuda"
and self.data.device.type == "cpu"
):
return self.cuda(device)
else:
new_param = Int8Params(super().to(device=device, dtype=dtype, non_blocking=non_blocking), requires_grad=self.requires_grad, has_fp16_weights=self.has_fp16_weights)
new_param = Int8Params(
super().to(device=device, dtype=dtype, non_blocking=non_blocking),
requires_grad=self.requires_grad,
has_fp16_weights=self.has_fp16_weights,
)
new_param.CB = self.CB
new_param.SCB = self.SCB
return new_param
class Linear8bitLt(nn.Linear):
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, threshold=0.0, index=None):
def __init__(
self,
input_features,
output_features,
bias=True,
has_fp16_weights=True,
threshold=0.0,
index=None,
):
super(Linear8bitLt, self).__init__(input_features, output_features, bias)
self.state = bnb.MatmulLtState()
self.index=index
self.index = index
self.state.threshold = threshold
self.state.has_fp16_weights = has_fp16_weights
......@@ -149,9 +226,10 @@ class Linear8bitLt(nn.Linear):
def forward(self, x):
self.state.is_training = self.training
if self.weight.CB is not None: self.init_8bit_state()
#assert not self.state.has_fp16_weights
#if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
if self.weight.CB is not None:
self.init_8bit_state()
# assert not self.state.has_fp16_weights
# if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
out = bnb.matmul(x, self.weight, state=self.state)
......@@ -166,8 +244,18 @@ class Linear8bitLt(nn.Linear):
return out
class Linear8bit(nn.Linear):
def __init__(self, input_features, output_features, bias=True, quant_type='vector', index=None, args=None, sparse_decomp=False):
def __init__(
self,
input_features,
output_features,
bias=True,
quant_type="vector",
index=None,
args=None,
sparse_decomp=False,
):
super(Linear8bit, self).__init__(input_features, output_features, bias)
self.quant_type = quant_type
self.index = index
......@@ -178,15 +266,24 @@ class Linear8bit(nn.Linear):
self.iter += 1
if self.iter % self.args.clip_freq == 0:
with torch.no_grad():
maxval, maxidx = torch.topk(torch.abs(self.weight.flatten()), k=self.args.clip_idx)
maxval, maxidx = torch.topk(
torch.abs(self.weight.flatten()), k=self.args.clip_idx
)
if not dist.is_initialized() or dist.get_rank() == 0:
print('clip', maxval[-1].item())
print("clip", maxval[-1].item())
self.weight.clip_(-maxval[-1], maxval[-1])
if self.args is not None:
out = bnb.nn.functional.sparse_decomposed_linear8bit(x, self.weight, self.bias, qval=self.args.sparse_decomp_val, quant_type=self.args.quant_type)
out = bnb.nn.functional.sparse_decomposed_linear8bit(
x,
self.weight,
self.bias,
qval=self.args.sparse_decomp_val,
quant_type=self.args.quant_type,
)
else:
out = bnb.nn.functional.linear8bit(x, self.weight, self.bias, quant_type=self.args.quant_type)
out = bnb.nn.functional.linear8bit(
x, self.weight, self.bias, quant_type=self.args.quant_type
)
return out
......@@ -4,9 +4,22 @@
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer1State
class Adagrad(Optimizer1State):
def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10,
optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
def __init__(
self,
params,
lr=1e-2,
lr_decay=0,
weight_decay=0,
initial_accumulator_value=0,
eps=1e-10,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= weight_decay:
......@@ -14,15 +27,39 @@ class Adagrad(Optimizer1State):
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if initial_accumulator_value != 0.0:
raise ValueError('Initial accumulator value != 0.0 not supported!')
raise ValueError("Initial accumulator value != 0.0 not supported!")
if lr_decay != 0.0:
raise ValueError('Lr Decay != 0.0 not supported!')
super(Adagrad, self).__init__('adagrad', params, lr, (0.0, 0.0), eps,
weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
raise ValueError("Lr Decay != 0.0 not supported!")
super(Adagrad, self).__init__(
"adagrad",
params,
lr,
(0.0, 0.0),
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class Adagrad8bit(Optimizer1State):
def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10,
optim_bits=8, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
def __init__(
self,
params,
lr=1e-2,
lr_decay=0,
weight_decay=0,
initial_accumulator_value=0,
eps=1e-10,
optim_bits=8,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= weight_decay:
......@@ -30,16 +67,40 @@ class Adagrad8bit(Optimizer1State):
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if initial_accumulator_value != 0.0:
raise ValueError('Initial accumulator value != 0.0 not supported!')
raise ValueError("Initial accumulator value != 0.0 not supported!")
if lr_decay != 0.0:
raise ValueError('Lr Decay != 0.0 not supported!')
raise ValueError("Lr Decay != 0.0 not supported!")
assert block_wise
super(Adagrad8bit, self).__init__('adagrad', params, lr, (0.0, 0.0), eps,
weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
super(Adagrad8bit, self).__init__(
"adagrad",
params,
lr,
(0.0, 0.0),
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class Adagrad32bit(Optimizer1State):
def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10,
optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
def __init__(
self,
params,
lr=1e-2,
lr_decay=0,
weight_decay=0,
initial_accumulator_value=0,
eps=1e-10,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= weight_decay:
......@@ -47,8 +108,19 @@ class Adagrad32bit(Optimizer1State):
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if initial_accumulator_value != 0.0:
raise ValueError('Initial accumulator value != 0.0 not supported!')
raise ValueError("Initial accumulator value != 0.0 not supported!")
if lr_decay != 0.0:
raise ValueError('Lr Decay != 0.0 not supported!')
super(Adagrad32bit, self).__init__('adagrad', params, lr, (0.0, 0.0), eps,
weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
raise ValueError("Lr Decay != 0.0 not supported!")
super(Adagrad32bit, self).__init__(
"adagrad",
params,
lr,
(0.0, 0.0),
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
......@@ -8,29 +8,97 @@ import os
import torch
import torch.distributed as dist
from bitsandbytes.optim.optimizer import Optimizer2State
import bitsandbytes.functional as F
from bitsandbytes.optim.optimizer import Optimizer2State
class Adam(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False, optim_bits=32, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
super(Adam, self).__init__('adam', params, lr, betas, eps,
weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super(Adam, self).__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class Adam8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
super(Adam8bit, self).__init__('adam', params, lr, betas, eps,
weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super(Adam8bit, self).__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class Adam32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
super(Adam32bit, self).__init__('adam', params, lr, betas, eps,
weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super(Adam32bit, self).__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class AnalysisAdam(torch.optim.Optimizer):
......@@ -68,8 +136,8 @@ class AnalysisAdam(torch.optim.Optimizer):
eps=1e-8,
weight_decay=0,
amsgrad=False,
bnb_analysis='dynamic-blockwise',
savedir=None
bnb_analysis="dynamic-blockwise",
savedir=None,
):
defaults = dict(
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad
......@@ -124,9 +192,13 @@ class AnalysisAdam(torch.optim.Optimizer):
state["exp_avg"] = torch.zeros_like(p_data_fp32)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
state['abserrors'] = torch.zeros((256, 256), device=p_data_fp32.device)
state['relerrors'] = torch.zeros((256, 256), device=p_data_fp32.device)
state['counts'] = torch.zeros((256, 256), device=p_data_fp32.device)
state["abserrors"] = torch.zeros(
(256, 256), device=p_data_fp32.device
)
state["relerrors"] = torch.zeros(
(256, 256), device=p_data_fp32.device
)
state["counts"] = torch.zeros((256, 256), device=p_data_fp32.device)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32)
......@@ -143,9 +215,9 @@ class AnalysisAdam(torch.optim.Optimizer):
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
e = state['abserrors']
rele = state['relerrors']
counts = state['counts']
e = state["abserrors"]
rele = state["relerrors"]
counts = state["counts"]
if group["weight_decay"] != 0:
p_data_fp32.add_(
......@@ -156,77 +228,84 @@ class AnalysisAdam(torch.optim.Optimizer):
if amsgrad:
max_exp_avg_sq = state["max_exp_avg_sq"]
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = exp_avg_sq.sqrt().add_(group["eps"])
update_fp32 = exp_avg/denom
update_fp32 = exp_avg / denom
if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000*1000:
if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000 * 1000:
# embedding layer or too small
p_data_fp32 += -step_size*update_fp32
p_data_fp32 += -step_size * update_fp32
else:
if self.analysis == 'dynamic-blockwise':
if self.analysis == "dynamic-blockwise":
code1 = F.create_dynamic_map(signed=True).to(p.device)
code2 = F.create_dynamic_map(signed=False).to(p.device)
C1, S1 = F.quantize_blockwise(exp_avg, code=code1)
state1 = F.dequantize_blockwise(C1, S1)
C2, S2 = F.quantize_blockwise(exp_avg_sq, code=code2)
state2 = F.dequantize_blockwise(C2, S2)
elif self.analysis == 'dynamic':
elif self.analysis == "dynamic":
code1 = F.create_dynamic_map(signed=True).to(p.device)
code2 = F.create_dynamic_map(signed=False).to(p.device)
C1, S1 = F.quantize(exp_avg, code=code1)
state1 = F.dequantize(C1, S1)
C2, S2 = F.quantize(exp_avg_sq, code=code2)
state2 = F.dequantize(C2, S2)
elif self.analysis == 'linear':
elif self.analysis == "linear":
code1 = F.create_linear_map(signed=True).to(p.device)
code2 = F.create_linear_map(signed=False).to(p.device)
C1, S1 = F.quantize(exp_avg, code=code1)
state1 = F.dequantize(C1, S1)
C2, S2 = F.quantize(exp_avg_sq, code=code2)
state2 = F.dequantize(C2, S2)
elif self.analysis == 'quantile':
elif self.analysis == "quantile":
code1 = F.estimate_quantiles(exp_avg)
code2 = F.estimate_quantiles(exp_avg_sq)
C1 = F.quantize_no_absmax(exp_avg, code=code1)
state1 = F.dequantize_no_absmax(C1, code1)
C2 = F.quantize_no_absmax(exp_avg_sq, code=code2)
state2 = F.dequantize_no_absmax(C2, code2)
elif self.analysis == 'my-quantization-routine':
elif self.analysis == "my-quantization-routine":
pass
# 1. get code
# 2. quantize
# 3. dequantize
# Error will be calculated automatically!
else:
raise ValueError(f'Invalid analysis value: {self.analysis}!')
raise ValueError(f"Invalid analysis value: {self.analysis}!")
denom = state2.sqrt().add_(group["eps"])
update_8bit = state1/denom
update_8bit = state1 / denom
abserr = torch.abs(update_8bit-update_fp32)
relerr = abserr/torch.abs(update_fp32+1e-6)
abserr = torch.abs(update_8bit - update_fp32)
relerr = abserr / torch.abs(update_fp32 + 1e-6)
C1, C2 = C1.int(), C2.int()
F.histogram_scatter_add_2d(e, C1.int(), C2.int(), abserr)
F.histogram_scatter_add_2d(rele, C1.int(), C2.int(), relerr)
F.histogram_scatter_add_2d(counts, C1.int(), C2.int(), torch.ones_like(abserr))
p_data_fp32 += -step_size*update_fp32
F.histogram_scatter_add_2d(
counts, C1.int(), C2.int(), torch.ones_like(abserr)
)
p_data_fp32 += -step_size * update_fp32
if not dist.is_initialized() or dist.get_rank() == 0:
if self.savedir != '' and state['step'] % 100 == 0:
if not os.path.exists(self.savedir): os.makedirs(self.savedir)
shapestr = '_'.join([str(dim) for dim in p_data_fp32.shape])
pathe = os.path.join(self.savedir, f'{p_id}_{shapestr}_abserr.pkl')
pathrele = os.path.join(self.savedir, f'{p_id}_{shapestr}_relerr.pkl')
pathcounts = os.path.join(self.savedir, f'{p_id}_{shapestr}_counts.pkl')
if self.savedir != "" and state["step"] % 100 == 0:
if not os.path.exists(self.savedir):
os.makedirs(self.savedir)
shapestr = "_".join([str(dim) for dim in p_data_fp32.shape])
pathe = os.path.join(
self.savedir, f"{p_id}_{shapestr}_abserr.pkl"
)
pathrele = os.path.join(
self.savedir, f"{p_id}_{shapestr}_relerr.pkl"
)
pathcounts = os.path.join(
self.savedir, f"{p_id}_{shapestr}_counts.pkl"
)
torch.save(e, pathe)
torch.save(rele, pathrele)
torch.save(counts, pathcounts)
......@@ -234,6 +313,4 @@ class AnalysisAdam(torch.optim.Optimizer):
if p.data.dtype in {torch.float16, torch.bfloat16}:
p.data.copy_(p_data_fp32)
return loss
......@@ -4,24 +4,90 @@
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer2State
class AdamW(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=1e-2, amsgrad=False, optim_bits=32, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
super(AdamW, self).__init__('adam', params, lr, betas, eps,
weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super(AdamW, self).__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class AdamW8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=1e-2, amsgrad=False, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
super(AdamW8bit, self).__init__('adam', params, lr, betas, eps,
weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super(AdamW8bit, self).__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class AdamW32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=1e-2, amsgrad=False, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
super(AdamW32bit, self).__init__('adam', params, lr, betas, eps,
weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
class AdamW32bit(Optimizer2State):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super(AdamW32bit, self).__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
......@@ -4,25 +4,102 @@
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer2State
class LAMB(Optimizer2State):
def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False, adam_w_mode=True, optim_bits=32, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0):
super(LAMB, self).__init__('lamb', params, lr, betas, eps,
weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0)
def __init__(
self,
params,
lr=1e-3,
bias_correction=True,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
adam_w_mode=True,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=False,
max_unorm=1.0,
):
super(LAMB, self).__init__(
"lamb",
params,
lr,
betas,
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
max_unorm=1.0,
)
class LAMB8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False, adam_w_mode=True, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0):
super(LAMB8bit, self).__init__('lamb', params, lr, betas, eps,
weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0)
class LAMB32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False, adam_w_mode=True, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0):
super(LAMB32bit, self).__init__('lamb', params, lr, betas, eps,
weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0)
class LAMB8bit(Optimizer2State):
def __init__(
self,
params,
lr=1e-3,
bias_correction=True,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
adam_w_mode=True,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=False,
max_unorm=1.0,
):
super(LAMB8bit, self).__init__(
"lamb",
params,
lr,
betas,
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
max_unorm=1.0,
)
class LAMB32bit(Optimizer2State):
def __init__(
self,
params,
lr=1e-3,
bias_correction=True,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
adam_w_mode=True,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=False,
max_unorm=1.0,
):
super(LAMB32bit, self).__init__(
"lamb",
params,
lr,
betas,
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
max_unorm=1.0,
)
......@@ -3,41 +3,119 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch.optim import Optimizer
from bitsandbytes.optim.optimizer import Optimizer1State
class LARS(Optimizer1State):
def __init__(self, params, lr, momentum=0, dampening=0,
weight_decay=0, nesterov=False, optim_bits=32, args=None,
min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
def __init__(
self,
params,
lr,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
max_unorm=0.02,
):
if momentum == 0:
raise NotImplementedError(f'LARS without momentum is not supported!')
super(LARS, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
raise NotImplementedError(f"LARS without momentum is not supported!")
super(LARS, self).__init__(
"lars",
params,
lr,
(momentum, dampening),
0.0,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
max_unorm=max_unorm,
block_wise=False,
)
class LARS8bit(Optimizer1State):
def __init__(self, params, lr, momentum=0, dampening=0,
weight_decay=0, nesterov=False, args=None,
min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
def __init__(
self,
params,
lr,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
max_unorm=0.02,
):
if momentum == 0:
raise NotImplementedError(f'LARS without momentum is not supported!')
super(LARS8bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
weight_decay, 8, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
raise NotImplementedError(f"LARS without momentum is not supported!")
super(LARS8bit, self).__init__(
"lars",
params,
lr,
(momentum, dampening),
0.0,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
max_unorm=max_unorm,
block_wise=False,
)
class LARS32bit(Optimizer1State):
def __init__(self, params, lr, momentum=0, dampening=0,
weight_decay=0, nesterov=False, args=None,
min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
def __init__(
self,
params,
lr,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
max_unorm=0.02,
):
if momentum == 0:
raise NotImplementedError(f'LARS without momentum is not supported!')
super(LARS32bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
weight_decay, 32, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
raise NotImplementedError(f"LARS without momentum is not supported!")
super(LARS32bit, self).__init__(
"lars",
params,
lr,
(momentum, dampening),
0.0,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
max_unorm=max_unorm,
block_wise=False,
)
class PytorchLARS(Optimizer):
def __init__(self, params, lr=0.01, momentum=0, dampening=0,
weight_decay=0, nesterov=False, max_unorm=0.02):
def __init__(
self,
params,
lr=0.01,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
max_unorm=0.02,
):
if lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
......@@ -45,8 +123,14 @@ class PytorchLARS(Optimizer):
if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
weight_decay=weight_decay, nesterov=nesterov, max_unorm=max_unorm)
defaults = dict(
lr=lr,
momentum=momentum,
dampening=dampening,
weight_decay=weight_decay,
nesterov=nesterov,
max_unorm=max_unorm,
)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super(PytorchLARS, self).__init__(params, defaults)
......@@ -54,7 +138,7 @@ class PytorchLARS(Optimizer):
def __setstate__(self, state):
super(PytorchLARS, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('nesterov', False)
group.setdefault("nesterov", False)
@torch.no_grad()
def step(self, closure=None):
......@@ -73,15 +157,16 @@ class PytorchLARS(Optimizer):
params_with_grad = []
d_p_list = []
momentum_buffer_list = []
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
max_unorm = group['max_unorm']
lr = group['lr']
weight_decay = group["weight_decay"]
momentum = group["momentum"]
dampening = group["dampening"]
nesterov = group["nesterov"]
max_unorm = group["max_unorm"]
lr = group["lr"]
for p in group['params']:
if p.grad is None: continue
for p in group["params"]:
if p.grad is None:
continue
state = self.state[p]
d_p = p.grad
......@@ -89,16 +174,16 @@ class PytorchLARS(Optimizer):
d_p = d_p.add(param, alpha=weight_decay)
if momentum != 0:
buf = state.get('momentum_buffer', None)
buf = state.get("momentum_buffer", None)
if buf is None:
buf = torch.clone(d_p).detach()
state['momentum_buffer']= buf
state["momentum_buffer"] = buf
else:
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
if nesterov:
update = d_p + buf*momentum
update = d_p + buf * momentum
else:
update = buf
......@@ -107,9 +192,9 @@ class PytorchLARS(Optimizer):
assert p.dtype == torch.float32
pnorm = torch.norm(p.detach())
unorm = torch.norm(update)
if unorm > max_unorm*pnorm:
update_scale = max_unorm*pnorm/unorm
if unorm > max_unorm * pnorm:
update_scale = max_unorm * pnorm / unorm
p.add_(update, alpha=-lr*update_scale)
p.add_(update, alpha=-lr * update_scale)
return loss
......@@ -2,12 +2,15 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import abc as container_abcs
from collections import defaultdict
from copy import deepcopy
from itertools import chain
import torch
import bitsandbytes.functional as F
from copy import deepcopy
from itertools import chain
from collections import defaultdict, abc as container_abcs
class MockArgs(object):
def __init__(self, initial_data):
......@@ -19,7 +22,7 @@ class GlobalOptimManager(object):
_instance = None
def __init__(self):
raise RuntimeError('Call get_instance() instead')
raise RuntimeError("Call get_instance() instead")
def initialize(self):
self.pid2config = {}
......@@ -38,15 +41,15 @@ class GlobalOptimManager(object):
def register_parameters(self, params):
param_groups = list(params)
if not isinstance(param_groups[0], dict):
param_groups = [{'params': param_groups}]
param_groups = [{"params": param_groups}]
for group_index, group in enumerate(param_groups):
for p_index, p in enumerate(group['params']):
for p_index, p in enumerate(group["params"]):
if id(p) in self.pid2config:
self.index2config[(group_index, p_index)] = self.pid2config[id(p)]
def override_config(self, parameters, key=None, value=None, key_value_dict=None):
'''
"""
Overrides initial optimizer config for specific parameters.
The key-values of the optimizer config for the input parameters are overidden
......@@ -63,7 +66,7 @@ class GlobalOptimManager(object):
The value for the hyperparamters.
key_value_dict : dict
A dictionary with multiple key-values to override.
'''
"""
self.uses_config_override = True
if isinstance(parameters, torch.nn.Parameter):
parameters = [parameters]
......@@ -75,16 +78,16 @@ class GlobalOptimManager(object):
if key_value_dict is not None:
for p in parameters:
if id(p) in self.pid2config:self.pid2config[id(p)].update(key_value_dict)
else: self.pid2config[id(p)] = key_value_dict
if id(p) in self.pid2config:
self.pid2config[id(p)].update(key_value_dict)
else:
self.pid2config[id(p)] = key_value_dict
def register_module_override(self, module, param_name, config):
self.module_weight_config_triple.append((module, param_name, config))
class Optimizer8bit(torch.optim.Optimizer):
def __init__(self, params, defaults, optim_bits=32):
super(Optimizer8bit, self).__init__(params, defaults)
self.initialized = False
......@@ -92,23 +95,32 @@ class Optimizer8bit(torch.optim.Optimizer):
self.mng = GlobalOptimManager.get_instance()
self.non_castable_tensor_keys = set(
['qmap1', 'qmap2',
'max1', 'max2',
'new_max1', 'new_max2',
'state1', 'state2',
'gnorm_vec', 'absmax1', 'absmax2',
'unorm_vec'])
if optim_bits == 8: self.fill_qmap()
[
"qmap1",
"qmap2",
"max1",
"max2",
"new_max1",
"new_max2",
"state1",
"state2",
"gnorm_vec",
"absmax1",
"absmax2",
"unorm_vec",
]
)
if optim_bits == 8:
self.fill_qmap()
def fill_qmap(self):
self.name2qmap['dynamic'] = F.create_dynamic_map(signed=True)
self.name2qmap['udynamic'] = F.create_dynamic_map(signed=False)
self.name2qmap["dynamic"] = F.create_dynamic_map(signed=True)
self.name2qmap["udynamic"] = F.create_dynamic_map(signed=False)
def __setstate__(self, state):
super(Optimizer8bit, self).__setstate__(state)
def load_state_dict(self, state_dict):
r"""Loads the optimizer state.
......@@ -120,21 +132,28 @@ class Optimizer8bit(torch.optim.Optimizer):
state_dict = deepcopy(state_dict)
# Validate the state_dict
groups = self.param_groups
saved_groups = state_dict['param_groups']
saved_groups = state_dict["param_groups"]
if len(groups) != len(saved_groups):
raise ValueError("loaded state dict has a different number of "
"parameter groups")
param_lens = (len(g['params']) for g in groups)
saved_lens = (len(g['params']) for g in saved_groups)
raise ValueError(
"loaded state dict has a different number of " "parameter groups"
)
param_lens = (len(g["params"]) for g in groups)
saved_lens = (len(g["params"]) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError("loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group")
raise ValueError(
"loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group"
)
# Update the state
id_map = {old_id: p for old_id, p in
zip(chain.from_iterable((g['params'] for g in saved_groups)),
chain.from_iterable((g['params'] for g in groups)))}
id_map = {
old_id: p
for old_id, p in zip(
chain.from_iterable((g["params"] for g in saved_groups)),
chain.from_iterable((g["params"] for g in groups)),
)
}
def cast(param, value):
r"""Make a deep copy of value, casting all tensors to device of param."""
......@@ -161,7 +180,7 @@ class Optimizer8bit(torch.optim.Optimizer):
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
state = defaultdict(dict)
for k, v in state_dict['state'].items():
for k, v in state_dict["state"].items():
if k in id_map:
param = id_map[k]
state[param] = cast(param, v)
......@@ -170,15 +189,15 @@ class Optimizer8bit(torch.optim.Optimizer):
# Update parameter groups, setting their 'params' value
def update_group(group, new_group):
new_group['params'] = group['params']
new_group["params"] = group["params"]
return new_group
param_groups = [
update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({'state': state, 'param_groups': param_groups})
param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({"state": state, "param_groups": param_groups})
def to_gpu(self):
for gindex, group in enumerate(self.param_groups):
for pindex, p in enumerate(group['params']):
for pindex, p in enumerate(group["params"]):
if p in self.state:
values = self.state[p]
for k, v in values.items():
......@@ -189,17 +208,23 @@ class Optimizer8bit(torch.optim.Optimizer):
for module, attr, config in self.mng.module_weight_config_triple:
pmodule = getattr(module, attr)
assert pmodule is not None
assert isinstance(pmodule, torch.Tensor) or isinstance(pmodule, torch.Parameter)
assert isinstance(pmodule, torch.Tensor) or isinstance(
pmodule, torch.Parameter
)
found = False
for gindex, group in enumerate(self.param_groups):
if found: break
for pindex, p in enumerate(group['params']):
if found: break
if found:
break
for pindex, p in enumerate(group["params"]):
if found:
break
if id(p) == id(pmodule):
# found the matching parameter
# init override
self.mng.pid2config[id(p)] = config
self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[id(p)]
self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[
id(p)
]
found = True
@torch.no_grad()
......@@ -223,7 +248,7 @@ class Optimizer8bit(torch.optim.Optimizer):
self.initialized = True
for gindex, group in enumerate(self.param_groups):
for pindex, p in enumerate(group['params']):
for pindex, p in enumerate(group["params"]):
if p.grad is None:
continue
state = self.state[p]
......@@ -236,58 +261,70 @@ class Optimizer8bit(torch.optim.Optimizer):
def get_config(self, gindex, pindex, group):
config = {}
config['betas'] = group['betas']
config['eps'] = group['eps']
config['weight_decay'] = group['weight_decay']
config['lr'] = group['lr']
config['optim_bits'] = self.args.optim_bits
config['min_8bit_size'] = self.args.min_8bit_size
config['percentile_clipping'] = self.args.percentile_clipping
config['block_wise'] = self.args.block_wise
config['max_unorm'] = self.args.max_unorm
config['skip_zeros'] = self.args.skip_zeros
config["betas"] = group["betas"]
config["eps"] = group["eps"]
config["weight_decay"] = group["weight_decay"]
config["lr"] = group["lr"]
config["optim_bits"] = self.args.optim_bits
config["min_8bit_size"] = self.args.min_8bit_size
config["percentile_clipping"] = self.args.percentile_clipping
config["block_wise"] = self.args.block_wise
config["max_unorm"] = self.args.max_unorm
config["skip_zeros"] = self.args.skip_zeros
if (gindex, pindex) in self.mng.index2config:
config.update(self.mng.index2config[(gindex, pindex)])
return config
def init_state(self, group, p, gindex, pindex):
raise NotImplementedError(f'init_state method needs to be overidden')
raise NotImplementedError(f"init_state method needs to be overidden")
def update_step(self, group, p, gindex, pindex):
raise NotImplementedError(f'The update_step method needs to be overidden')
raise NotImplementedError(f"The update_step method needs to be overidden")
class Optimizer2State(Optimizer8bit):
def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0.0, optim_bits=32, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0,
skip_zeros=False):
def __init__(
self,
optimizer_name,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0.0,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
max_unorm=0.0,
skip_zeros=False,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if isinstance(betas, str):
# format: '(beta1, beta2)'
betas = betas.replace('(', '').replace(')', '').strip().split(',')
betas = betas.replace("(", "").replace(")", "").strip().split(",")
betas = [float(b) for b in betas]
for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0:
raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay)
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(Optimizer2State, self).__init__(params, defaults, optim_bits)
if args is None:
args = {}
args['optim_bits'] = optim_bits
args['percentile_clipping'] = 100
args['min_8bit_size'] = min_8bit_size
args['percentile_clipping'] = percentile_clipping
args['block_wise'] = block_wise
args['max_unorm'] = max_unorm
args['skip_zeros'] = skip_zeros
args["optim_bits"] = optim_bits
args["percentile_clipping"] = 100
args["min_8bit_size"] = min_8bit_size
args["percentile_clipping"] = percentile_clipping
args["block_wise"] = block_wise
args["max_unorm"] = max_unorm
args["skip_zeros"] = skip_zeros
self.args = MockArgs(args)
else:
......@@ -299,50 +336,83 @@ class Optimizer2State(Optimizer8bit):
def init_state(self, group, p, gindex, pindex):
config = self.get_config(gindex, pindex, group)
if config['optim_bits'] == 32:
if config["optim_bits"] == 32:
dtype = torch.float32
elif config['optim_bits'] == 8:
elif config["optim_bits"] == 8:
dtype = torch.uint8
else: raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')
else:
raise NotImplementedError(
f'Amount of optimizer bits not supported: {config["optim_bits"]}'
)
if p.numel() < config['min_8bit_size']: dtype = torch.float32
if p.numel() < config["min_8bit_size"]:
dtype = torch.float32
state = self.state[p]
state['step'] = 0
state["step"] = 0
if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device)
state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device)
state["state1"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
dtype=torch.float32,
device=p.device,
)
state["state2"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
dtype=torch.float32,
device=p.device,
)
elif dtype == torch.uint8:
if state['step'] == 0:
if 'dynamic' not in self.name2qmap: self.fill_qmap()
self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device)
self.name2qmap['udynamic'] = self.name2qmap['udynamic'].to(p.device)
state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device)
state['qmap1'] = self.name2qmap['dynamic']
state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device)
state['qmap2'] = self.name2qmap['udynamic']
if config['block_wise']:
if state["step"] == 0:
if "dynamic" not in self.name2qmap:
self.fill_qmap()
self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device)
self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(p.device)
state["state1"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
dtype=torch.uint8,
device=p.device,
)
state["qmap1"] = self.name2qmap["dynamic"]
state["state2"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
dtype=torch.uint8,
device=p.device,
)
state["qmap2"] = self.name2qmap["udynamic"]
if config["block_wise"]:
n = p.numel()
blocks = n//2048
blocks = n // 2048
blocks += 1 if n % 2048 > 0 else 0
state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
state['absmax2'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
state["absmax1"] = torch.zeros(
(blocks,), dtype=torch.float32, device=p.device
)
state["absmax2"] = torch.zeros(
(blocks,), dtype=torch.float32, device=p.device
)
else:
state['max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
state['new_max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
state['max2'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
state['new_max2'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
state["new_max1"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
state["max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
state["new_max2"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
if config['percentile_clipping'] < 100:
state['gnorm_vec'] = torch.zeros((100,), device=p.device)
if config["percentile_clipping"] < 100:
state["gnorm_vec"] = torch.zeros((100,), device=p.device)
if config['max_unorm'] > 0.0:
state['unorm_vec'] = torch.zeros((1,), device=p.device)
if config["max_unorm"] > 0.0:
state["unorm_vec"] = torch.zeros((1,), device=p.device)
@torch.no_grad()
def update_step(self, group, p, gindex, pindex):
......@@ -351,41 +421,101 @@ class Optimizer2State(Optimizer8bit):
config = self.get_config(gindex, pindex, group)
state['step'] += 1
step = state['step']
state["step"] += 1
step = state["step"]
if config['percentile_clipping'] < 100:
current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(grad, state['gnorm_vec'], step, config['percentile_clipping'])
if config["percentile_clipping"] < 100:
current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(
grad, state["gnorm_vec"], step, config["percentile_clipping"]
)
else:
gnorm_scale = 1.0
if state['state1'].dtype == torch.float:
F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'],
state['state2'], config['betas'][1], config['weight_decay'], gnorm_scale,
state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'], skip_zeros=config['skip_zeros'])
elif state['state1'].dtype == torch.uint8 and not config['block_wise']:
F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1],
config['eps'], step, config['lr'],
state['qmap1'], state['qmap2'], state['max1'], state['max2'], state['new_max1'], state['new_max2'],
config['weight_decay'], gnorm_scale=gnorm_scale,
unorm_vec=state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'])
if state["state1"].dtype == torch.float:
F.optimizer_update_32bit(
self.optimizer_name,
grad,
p,
state["state1"],
config["betas"][0],
config["eps"],
step,
config["lr"],
state["state2"],
config["betas"][1],
config["weight_decay"],
gnorm_scale,
state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
max_unorm=config["max_unorm"],
skip_zeros=config["skip_zeros"],
)
elif state["state1"].dtype == torch.uint8 and not config["block_wise"]:
F.optimizer_update_8bit(
self.optimizer_name,
grad,
p,
state["state1"],
state["state2"],
config["betas"][0],
config["betas"][1],
config["eps"],
step,
config["lr"],
state["qmap1"],
state["qmap2"],
state["max1"],
state["max2"],
state["new_max1"],
state["new_max2"],
config["weight_decay"],
gnorm_scale=gnorm_scale,
unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
max_unorm=config["max_unorm"],
)
# swap maxes
state['max1'], state['new_max1'] = state['new_max1'], state['max1']
state['max2'], state['new_max2'] = state['new_max2'], state['max2']
elif state['state1'].dtype == torch.uint8 and config['block_wise']:
F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1],
config['eps'], step, config['lr'],
state['qmap1'], state['qmap2'], state['absmax1'], state['absmax2'],
config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=config['skip_zeros'])
state["max1"], state["new_max1"] = state["new_max1"], state["max1"]
state["max2"], state["new_max2"] = state["new_max2"], state["max2"]
elif state["state1"].dtype == torch.uint8 and config["block_wise"]:
F.optimizer_update_8bit_blockwise(
self.optimizer_name,
grad,
p,
state["state1"],
state["state2"],
config["betas"][0],
config["betas"][1],
config["eps"],
step,
config["lr"],
state["qmap1"],
state["qmap2"],
state["absmax1"],
state["absmax2"],
config["weight_decay"],
gnorm_scale=gnorm_scale,
skip_zeros=config["skip_zeros"],
)
class Optimizer1State(Optimizer8bit):
def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.0), eps=1e-8,
weight_decay=0.0, optim_bits=32, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0,
skip_zeros=False):
def __init__(
self,
optimizer_name,
params,
lr=1e-3,
betas=(0.9, 0.0),
eps=1e-8,
weight_decay=0.0,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
max_unorm=0.0,
skip_zeros=False,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
......@@ -395,19 +525,18 @@ class Optimizer1State(Optimizer8bit):
raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay)
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(Optimizer1State, self).__init__(params, defaults, optim_bits)
if args is None:
args = {}
args['optim_bits'] = optim_bits
args['percentile_clipping'] = 100
args['min_8bit_size'] = min_8bit_size
args['percentile_clipping'] = percentile_clipping
args['block_wise'] = block_wise
args['max_unorm'] = max_unorm
args['skip_zeros'] = skip_zeros
args["optim_bits"] = optim_bits
args["percentile_clipping"] = 100
args["min_8bit_size"] = min_8bit_size
args["percentile_clipping"] = percentile_clipping
args["block_wise"] = block_wise
args["max_unorm"] = max_unorm
args["skip_zeros"] = skip_zeros
self.args = MockArgs(args)
else:
......@@ -419,43 +548,61 @@ class Optimizer1State(Optimizer8bit):
def init_state(self, group, p, gindex, pindex):
config = self.get_config(gindex, pindex, group)
if config['optim_bits'] == 32:
if config["optim_bits"] == 32:
dtype = torch.float32
elif config['optim_bits'] == 8:
elif config["optim_bits"] == 8:
dtype = torch.uint8
else: raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')
else:
raise NotImplementedError(
f'Amount of optimizer bits not supported: {config["optim_bits"]}'
)
if p.numel() < config['min_8bit_size']: dtype = torch.float32
if p.numel() < config["min_8bit_size"]:
dtype = torch.float32
state = self.state[p]
state['step'] = 0
state["step"] = 0
if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device)
state["state1"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
dtype=torch.float32,
device=p.device,
)
elif dtype == torch.uint8:
if state['step'] == 0:
if 'dynamic' not in self.name2qmap: self.fill_qmap()
self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device)
state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device)
state['qmap1'] = self.name2qmap['dynamic']
if config['block_wise']:
if state["step"] == 0:
if "dynamic" not in self.name2qmap:
self.fill_qmap()
self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device)
state["state1"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
dtype=torch.uint8,
device=p.device,
)
state["qmap1"] = self.name2qmap["dynamic"]
if config["block_wise"]:
n = p.numel()
blocks = n//2048
blocks = n // 2048
blocks += 1 if n % 2048 > 0 else 0
state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
state["absmax1"] = torch.zeros(
(blocks,), dtype=torch.float32, device=p.device
)
else:
state['max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
state['new_max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
if config['percentile_clipping'] < 100:
state['gnorm_vec'] = torch.zeros((100,), device=p.device)
state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
state["new_max1"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
if config['max_unorm'] > 0.0:
state['unorm_vec'] = torch.zeros((1,), device=p.device)
if config["percentile_clipping"] < 100:
state["gnorm_vec"] = torch.zeros((100,), device=p.device)
if config["max_unorm"] > 0.0:
state["unorm_vec"] = torch.zeros((1,), device=p.device)
@torch.no_grad()
def update_step(self, group, p, gindex, pindex):
......@@ -464,29 +611,77 @@ class Optimizer1State(Optimizer8bit):
config = self.get_config(gindex, pindex, group)
state['step'] += 1
step = state['step']
state["step"] += 1
step = state["step"]
if config['percentile_clipping'] < 100:
current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(grad, state['gnorm_vec'], step, config['percentile_clipping'])
if config["percentile_clipping"] < 100:
current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(
grad, state["gnorm_vec"], step, config["percentile_clipping"]
)
else:
gnorm_scale = 1.0
if state['state1'].dtype == torch.float:
F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'],
None, 0.0, config['weight_decay'], gnorm_scale,
state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'],
skip_zeros=config['skip_zeros'])
elif state['state1'].dtype == torch.uint8 and not config['block_wise']:
F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
config['eps'], step, config['lr'], state['qmap1'], None, state['max1'], None, state['new_max1'], None,
config['weight_decay'], gnorm_scale,
state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'])
state['max1'], state['new_max1'] = state['new_max1'], state['max1']
elif state['state1'].dtype == torch.uint8 and config['block_wise']:
F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
config['eps'], step, config['lr'],
state['qmap1'], None, state['absmax1'], None,
config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=config['skip_zeros'])
if state["state1"].dtype == torch.float:
F.optimizer_update_32bit(
self.optimizer_name,
grad,
p,
state["state1"],
config["betas"][0],
config["eps"],
step,
config["lr"],
None,
0.0,
config["weight_decay"],
gnorm_scale,
state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
max_unorm=config["max_unorm"],
skip_zeros=config["skip_zeros"],
)
elif state["state1"].dtype == torch.uint8 and not config["block_wise"]:
F.optimizer_update_8bit(
self.optimizer_name,
grad,
p,
state["state1"],
None,
config["betas"][0],
config["betas"][1],
config["eps"],
step,
config["lr"],
state["qmap1"],
None,
state["max1"],
None,
state["new_max1"],
None,
config["weight_decay"],
gnorm_scale,
state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
max_unorm=config["max_unorm"],
)
state["max1"], state["new_max1"] = state["new_max1"], state["max1"]
elif state["state1"].dtype == torch.uint8 and config["block_wise"]:
F.optimizer_update_8bit_blockwise(
self.optimizer_name,
grad,
p,
state["state1"],
None,
config["betas"][0],
config["betas"][1],
config["eps"],
step,
config["lr"],
state["qmap1"],
None,
state["absmax1"],
None,
config["weight_decay"],
gnorm_scale=gnorm_scale,
skip_zeros=config["skip_zeros"],
)
......@@ -4,33 +4,106 @@
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer1State
class RMSprop(Optimizer1State):
def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, optim_bits=32, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
def __init__(
self,
params,
lr=1e-2,
alpha=0.99,
eps=1e-8,
weight_decay=0,
momentum=0,
centered=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
if alpha == 0:
raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!')
raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!")
if centered:
raise NotImplementedError(f'Centered RMSprop is not supported!')
super(RMSprop, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
raise NotImplementedError(f"Centered RMSprop is not supported!")
super(RMSprop, self).__init__(
"rmsprop",
params,
lr,
(alpha, momentum),
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class RMSprop8bit(Optimizer1State):
def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
def __init__(
self,
params,
lr=1e-2,
alpha=0.99,
eps=1e-8,
weight_decay=0,
momentum=0,
centered=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
if alpha == 0:
raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!')
raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!")
if centered:
raise NotImplementedError(f'Centered RMSprop is not supported!')
super(RMSprop8bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
raise NotImplementedError(f"Centered RMSprop is not supported!")
super(RMSprop8bit, self).__init__(
"rmsprop",
params,
lr,
(alpha, momentum),
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class RMSprop32bit(Optimizer1State):
def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
def __init__(
self,
params,
lr=1e-2,
alpha=0.99,
eps=1e-8,
weight_decay=0,
momentum=0,
centered=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
if alpha == 0:
raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!')
raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!")
if centered:
raise NotImplementedError(f'Centered RMSprop is not supported!')
super(RMSprop32bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
raise NotImplementedError(f"Centered RMSprop is not supported!")
super(RMSprop32bit, self).__init__(
"rmsprop",
params,
lr,
(alpha, momentum),
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
......@@ -4,29 +4,96 @@
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer1State
class SGD(Optimizer1State):
def __init__(self, params, lr, momentum=0, dampening=0,
weight_decay=0, nesterov=False, optim_bits=32, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
def __init__(
self,
params,
lr,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
if momentum == 0:
raise NotImplementedError(f'SGD without momentum is not supported!')
super(SGD, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
raise NotImplementedError(f"SGD without momentum is not supported!")
super(SGD, self).__init__(
"momentum",
params,
lr,
(momentum, dampening),
0.0,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class SGD8bit(Optimizer1State):
def __init__(self, params, lr, momentum=0, dampening=0,
weight_decay=0, nesterov=False, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
def __init__(
self,
params,
lr,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
if momentum == 0:
raise NotImplementedError(f'SGD without momentum is not supported!')
super(SGD8bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
raise NotImplementedError(f"SGD without momentum is not supported!")
super(SGD8bit, self).__init__(
"momentum",
params,
lr,
(momentum, dampening),
0.0,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class SGD32bit(Optimizer1State):
def __init__(self, params, lr, momentum=0, dampening=0,
weight_decay=0, nesterov=False, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
def __init__(
self,
params,
lr,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
if momentum == 0:
raise NotImplementedError(f'SGD without momentum is not supported!')
super(SGD32bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
raise NotImplementedError(f"SGD without momentum is not supported!")
super(SGD32bit, self).__init__(
"momentum",
params,
lr,
(momentum, dampening),
0.0,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
import sys
def print_err(s: str) -> None:
print(s, file=sys.stderr)
def warn_of_missing_prerequisite(s: str) -> None:
print_err('WARNING, missing pre-requisite: ' + s)
print_err("WARNING, missing pre-requisite: " + s)
from itertools import product
import torch
import bitsandbytes as bnb
import bitsandbytes.functional as F
from itertools import product
def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb):
k = 25
for i in range(k):
if dims == 2:
A = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8)
A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(
torch.int8
)
elif dims == 3:
A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
B = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(
torch.int8
)
B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8)
C1 = torch.matmul(A.float(), B.t().float())
A2, SA = F.transform(A, 'col32')
B2, SB = F.transform(B, 'colx')
A2, SA = F.transform(A, "col32")
B2, SB = F.transform(B, "colx")
if dims == 2:
C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
C2, SC = F.transform(
torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device="cuda"),
"col32",
)
else:
C2, SC = F.transform(torch.zeros(A.shape[0], A.shape[1], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
C2, SC = F.transform(
torch.zeros(
A.shape[0], A.shape[1], B.shape[0], dtype=torch.int32, device="cuda"
),
"col32",
)
F.igemmlt(A2, B2, C2, SA, SB, SC)
C3, S = F.transform(C2, 'row', state=SC)
#torch.testing.assert_allclose(C1, C3.float())
#print(C1)
#print(C2)
#print(C3)
C3, S = F.transform(C2, "row", state=SC)
# torch.testing.assert_allclose(C1, C3.float())
# print(C1)
# print(C2)
# print(C3)
allclose = torch.allclose(C1, C3.float())
if allclose:
print(C1)
......@@ -33,29 +47,29 @@ def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb):
print(C3)
## transposed
#A = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
#if dims == 2:
# A = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
# if dims == 2:
# B = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8)
# C1 = torch.matmul(A.float(), B.float().t())
#elif dims == 3:
# elif dims == 3:
# B = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
# C1 = torch.matmul(B.float(), A.t().float())
# C1 = C1.permute([2, 0, 1])
#A2, SA = F.transform(A, 'col32')
#B2, SB = F.transform(B, 'colx')
#if dims == 2:
# A2, SA = F.transform(A, 'col32')
# B2, SB = F.transform(B, 'colx')
# if dims == 2:
# C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
#else:
# else:
# C2 = torch.zeros(A.shape[0], B.shape[0], B.shape[1], dtype=torch.int32, device='cuda')
# state = (C2.shape, 'row', A.shape[0])
# C2, SC = F.transform(C2, 'col32', state=state)
#F.igemmlt(A2, B2, C2, SA, SB, SC)
#C3, S = F.transform(C2, 'row', state=SC, ld=[0])
#torch.testing.assert_allclose(C1, C3.float())
# F.igemmlt(A2, B2, C2, SA, SB, SC)
# C3, S = F.transform(C2, 'row', state=SC, ld=[0])
# torch.testing.assert_allclose(C1, C3.float())
## weight update
#if dims == 3:
# if dims == 3:
# A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
# B = torch.randint(-128, 127, size=(dim1, dim2, dim4), device='cuda').to(torch.int8)
# C1 = torch.matmul(B.view(-1, B.shape[-1]).t().float(), A.view(-1, A.shape[-1]).float())
......@@ -73,18 +87,18 @@ dims = (2, 3)
ldb = [0]
n = 2
dim1 = torch.randint(1,256, size=(n,)).tolist()
dim2 = torch.randint(32,512, size=(n,)).tolist()
dim3 = torch.randint(32,1024, size=(n,)).tolist()
dim4 = torch.randint(32,1024, size=(n,)).tolist()
values = list(product(dim1,dim2,dim3,dim4,dims, ldb))
dim1 = torch.randint(1, 256, size=(n,)).tolist()
dim2 = torch.randint(32, 512, size=(n,)).tolist()
dim3 = torch.randint(32, 1024, size=(n,)).tolist()
dim4 = torch.randint(32, 1024, size=(n,)).tolist()
values = list(product(dim1, dim2, dim3, dim4, dims, ldb))
for ldb in range(32, 4096, 32):
#for ldb in [None]:
# for ldb in [None]:
val = test_igemmlt(2, 2, 2, 2, 2, ldb)
if val:
print(val, ldb)
else:
print('nope', ldb)
#for val in values:
#test_igemmlt(*val)
print("nope", ldb)
# for val in values:
# test_igemmlt(*val)
......@@ -2,18 +2,20 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import glob
from setuptools import setup, find_packages
import os
from setuptools import find_packages, setup
libs = list(glob.glob('./bitsandbytes/libbitsandbytes*.so'))
libs = list(glob.glob("./bitsandbytes/libbitsandbytes*.so"))
libs = [os.path.basename(p) for p in libs]
print('libs:', libs)
print("libs:", libs)
def read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()
setup(
name=f"bitsandbytes",
version=f"0.31.0",
......@@ -27,11 +29,11 @@ setup(
entry_points={
"console_scripts": ["debug_cuda = bitsandbytes.debug_cli:cli"],
},
package_data={'': libs},
long_description=read('README.md'),
long_description_content_type='text/markdown',
package_data={"": libs},
long_description=read("README.md"),
long_description_content_type="text/markdown",
classifiers=[
"Development Status :: 4 - Beta",
'Topic :: Scientific/Engineering :: Artificial Intelligence'
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment