"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "a8f5134c113da402a93580ef7a021557e816c98d"
Unverified Commit be5cecb8 authored by Tim Dettmers's avatar Tim Dettmers Committed by GitHub
Browse files

Merge branch 'main' into main

parents 8724c990 f0ec93d0
...@@ -149,3 +149,9 @@ Bug fixes: ...@@ -149,3 +149,9 @@ Bug fixes:
Bug fixes: Bug fixes:
- Fixed a bug in the CUDA Setup which led to an incomprehensible error if no GPU was detected. - Fixed a bug in the CUDA Setup which led to an incomprehensible error if no GPU was detected.
### 0.35.4
Bug fixes:
- Fixed a bug in the CUDA Setup failed with the cuda runtime was found, but not the cuda library.
- Fixed a bug where not finding the cuda runtime led to an incomprehensible error.
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from . import cuda_setup, utils
from .autograd._functions import ( from .autograd._functions import (
MatmulLtState, MatmulLtState,
bmm_cublas, bmm_cublas,
...@@ -12,7 +13,6 @@ from .autograd._functions import ( ...@@ -12,7 +13,6 @@ from .autograd._functions import (
) )
from .cextension import COMPILED_WITH_CUDA from .cextension import COMPILED_WITH_CUDA
from .nn import modules from .nn import modules
from . import cuda_setup, utils
if COMPILED_WITH_CUDA: if COMPILED_WITH_CUDA:
from .optim import adam from .optim import adam
......
# from bitsandbytes.debug_cli import cli
# cli()
import os import os
import sys import sys
from warnings import warn from warnings import warn
...@@ -31,8 +28,8 @@ print() ...@@ -31,8 +28,8 @@ print()
from . import COMPILED_WITH_CUDA, PACKAGE_GITHUB_URL from . import COMPILED_WITH_CUDA, PACKAGE_GITHUB_URL
from .cuda_setup.main import get_compute_capabilities, get_cuda_lib_handle
from .cuda_setup.env_vars import to_be_ignored from .cuda_setup.env_vars import to_be_ignored
from .cuda_setup.main import get_compute_capabilities, get_cuda_lib_handle
print_header("POTENTIALLY LIBRARY-PATH-LIKE ENV VARS") print_header("POTENTIALLY LIBRARY-PATH-LIKE ENV VARS")
for k, v in os.environ.items(): for k, v in os.environ.items():
......
import operator import operator
import warnings import warnings
from dataclasses import dataclass
from functools import reduce # Required in Python 3
import torch import torch
import bitsandbytes.functional as F import bitsandbytes.functional as F
from dataclasses import dataclass
from functools import reduce # Required in Python 3
# math.prod not compatible with python < 3.8 # math.prod not compatible with python < 3.8
def prod(iterable): def prod(iterable):
...@@ -18,7 +19,7 @@ tensor = torch.Tensor ...@@ -18,7 +19,7 @@ tensor = torch.Tensor
This is particularly important for small models where outlier features This is particularly important for small models where outlier features
are less systematic and occur with low frequency. are less systematic and occur with low frequency.
""" """
class GlobalOutlierPooler(object): class GlobalOutlierPooler:
_instance = None _instance = None
def __init__(self): def __init__(self):
...@@ -49,8 +50,9 @@ class GlobalOutlierPooler(object): ...@@ -49,8 +50,9 @@ class GlobalOutlierPooler(object):
class MatMul8bit(torch.autograd.Function): class MatMul8bit(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, A, B, out=None, quant_type="vector", precision=[8, 8, 8]): def forward(ctx, A, B, out=None, quant_type="vector", precision=None):
if precision is None:
precision = [8, 8, 8]
if precision[0] != 8: if precision[0] != 8:
with torch.no_grad(): with torch.no_grad():
output = torch.matmul(A, B) output = torch.matmul(A, B)
......
import ctypes as ct import ctypes as ct
import torch
from pathlib import Path from pathlib import Path
from warnings import warn from warnings import warn
import torch
class CUDASetup(object): class CUDASetup:
_instance = None _instance = None
def __init__(self): def __init__(self):
...@@ -52,8 +52,13 @@ class CUDASetup(object): ...@@ -52,8 +52,13 @@ class CUDASetup(object):
self.add_log_entry('python setup.py install') self.add_log_entry('python setup.py install')
def initialize(self): def initialize(self):
self.cuda_setup_log = [] self.has_printed = False
self.lib = None self.lib = None
self.run_cuda_setup()
def run_cuda_setup(self):
self.initialized = True
self.cuda_setup_log = []
from .cuda_setup.main import evaluate_cuda_setup from .cuda_setup.main import evaluate_cuda_setup
binary_name, cudart_path, cuda, cc, cuda_version_string = evaluate_cuda_setup() binary_name, cudart_path, cuda, cc, cuda_version_string = evaluate_cuda_setup()
...@@ -89,7 +94,8 @@ class CUDASetup(object): ...@@ -89,7 +94,8 @@ class CUDASetup(object):
else: else:
self.add_log_entry(f"CUDA SETUP: Loading binary {binary_path}...") self.add_log_entry(f"CUDA SETUP: Loading binary {binary_path}...")
self.lib = ct.cdll.LoadLibrary(binary_path) self.lib = ct.cdll.LoadLibrary(binary_path)
except: except Exception as ex:
self.add_log_entry(str(ex))
self.print_log_stack() self.print_log_stack()
def add_log_entry(self, msg, is_warning=False): def add_log_entry(self, msg, is_warning=False):
...@@ -116,7 +122,7 @@ try: ...@@ -116,7 +122,7 @@ try:
CUDASetup.get_instance().generate_instructions() CUDASetup.get_instance().generate_instructions()
CUDASetup.get_instance().print_log_stack() CUDASetup.get_instance().print_log_stack()
raise RuntimeError(''' raise RuntimeError('''
CUDA Setup failed despite GPU being available. Inspect the CUDA SETUP outputs to fix your environment! CUDA Setup failed despite GPU being available. Inspect the CUDA SETUP outputs above to fix your environment!
If you cannot find any issues and suspect a bug, please open an issue with detals about your environment: If you cannot find any issues and suspect a bug, please open an issue with detals about your environment:
https://github.com/TimDettmers/bitsandbytes/issues''') https://github.com/TimDettmers/bitsandbytes/issues''')
lib.cadam32bit_g32 lib.cadam32bit_g32
...@@ -124,8 +130,6 @@ try: ...@@ -124,8 +130,6 @@ try:
lib.get_cusparse.restype = ct.c_void_p lib.get_cusparse.restype = ct.c_void_p
COMPILED_WITH_CUDA = True COMPILED_WITH_CUDA = True
except AttributeError: except AttributeError:
warn( warn("The installed version of bitsandbytes was compiled without GPU support. "
"The installed version of bitsandbytes was compiled without GPU support. " "8-bit optimizers and GPU quantization are unavailable.")
"8-bit optimizers and GPU quantization are unavailable."
)
COMPILED_WITH_CUDA = False COMPILED_WITH_CUDA = False
from .paths import CUDA_RUNTIME_LIB, extract_candidate_paths, determine_cuda_runtime_lib_path
from .main import evaluate_cuda_setup from .main import evaluate_cuda_setup
from .paths import (
CUDA_RUNTIME_LIB,
determine_cuda_runtime_lib_path,
extract_candidate_paths,
)
...@@ -19,9 +19,12 @@ evaluation: ...@@ -19,9 +19,12 @@ evaluation:
import ctypes import ctypes
import os import os
from .paths import determine_cuda_runtime_lib_path import torch
from bitsandbytes.cextension import CUDASetup from bitsandbytes.cextension import CUDASetup
from .paths import determine_cuda_runtime_lib_path
def check_cuda_result(cuda, result_val): def check_cuda_result(cuda, result_val):
# 3. Check for CUDA errors # 3. Check for CUDA errors
...@@ -30,8 +33,11 @@ def check_cuda_result(cuda, result_val): ...@@ -30,8 +33,11 @@ def check_cuda_result(cuda, result_val):
cuda.cuGetErrorString(result_val, ctypes.byref(error_str)) cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
CUDASetup.get_instance().add_log_entry(f"CUDA exception! Error code: {error_str.value.decode()}") CUDASetup.get_instance().add_log_entry(f"CUDA exception! Error code: {error_str.value.decode()}")
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION
def get_cuda_version(cuda, cudart_path): def get_cuda_version(cuda, cudart_path):
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION if cuda is None: return None
try: try:
cudart = ctypes.CDLL(cudart_path) cudart = ctypes.CDLL(cudart_path)
except OSError: except OSError:
...@@ -45,7 +51,7 @@ def get_cuda_version(cuda, cudart_path): ...@@ -45,7 +51,7 @@ def get_cuda_version(cuda, cudart_path):
minor = (version-(major*1000))//10 minor = (version-(major*1000))//10
if major < 11: if major < 11:
CUDASetup.get_instance().add_log_entry('CUDA SETUP: CUDA version lower than 11 are currenlty not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!') CUDASetup.get_instance().add_log_entry('CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!')
return f'{major}{minor}' return f'{major}{minor}'
...@@ -73,7 +79,6 @@ def get_compute_capabilities(cuda): ...@@ -73,7 +79,6 @@ def get_compute_capabilities(cuda):
# bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549 # bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
""" """
nGpus = ctypes.c_int() nGpus = ctypes.c_int()
cc_major = ctypes.c_int() cc_major = ctypes.c_int()
cc_minor = ctypes.c_int() cc_minor = ctypes.c_int()
...@@ -100,44 +105,45 @@ def get_compute_capability(cuda): ...@@ -100,44 +105,45 @@ def get_compute_capability(cuda):
capabilities are downwards compatible. If no GPUs are detected, it returns capabilities are downwards compatible. If no GPUs are detected, it returns
None. None.
""" """
ccs = get_compute_capabilities(cuda) if cuda is None: return None
if ccs:
# TODO: handle different compute capabilities; for now, take the max # TODO: handle different compute capabilities; for now, take the max
return ccs[-1] ccs = get_compute_capabilities(cuda)
return None if ccs: return ccs[-1]
def evaluate_cuda_setup(): def evaluate_cuda_setup():
if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0': if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0':
print('') print('')
print('=' * 35 + 'BUG REPORT' + '=' * 35) print('='*35 + 'BUG REPORT' + '='*35)
print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues') print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')
print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link') print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
print('To hide this message, set the BITSANDBYTES_NOWELCOME variable like so: export BITSANDBYTES_NOWELCOME=1') print('='*80)
print('=' * 80) if not torch.cuda.is_available(): return 'libsbitsandbytes_cpu.so', None, None, None, None
# if not torch.cuda.is_available():
# print('No GPU detected. Loading CPU library...')
# return binary_name
binary_name = "libbitsandbytes_cpu.so"
cuda_setup = CUDASetup.get_instance() cuda_setup = CUDASetup.get_instance()
cudart_path = determine_cuda_runtime_lib_path() cudart_path = determine_cuda_runtime_lib_path()
if cudart_path is None:
cuda_setup.add_log_entry("WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!", is_warning=True)
return binary_name
cuda_setup.add_log_entry((f"CUDA SETUP: CUDA runtime path found: {cudart_path}"))
cuda = get_cuda_lib_handle() cuda = get_cuda_lib_handle()
cc = get_compute_capability(cuda) cc = get_compute_capability(cuda)
cuda_setup.add_log_entry(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}")
cuda_version_string = get_cuda_version(cuda, cudart_path) cuda_version_string = get_cuda_version(cuda, cudart_path)
failure = False
if cudart_path is None:
failure = True
cuda_setup.add_log_entry("WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!", is_warning=True)
else:
cuda_setup.add_log_entry(f"CUDA SETUP: CUDA runtime path found: {cudart_path}")
if cc == '' or cc is None: if cc == '' or cc is None:
cuda_setup.add_log_entry("WARNING: No GPU detected! Check your CUDA paths. Processing to load CPU-only library...", is_warning=True) failure = True
return binary_name, cudart_path, cuda, cc, cuda_version_string cuda_setup.add_log_entry("WARNING: No GPU detected! Check your CUDA paths. Proceeding to load CPU-only library...", is_warning=True)
else:
cuda_setup.add_log_entry(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}")
if cuda is None:
failure = True
else:
cuda_setup.add_log_entry(f'CUDA SETUP: Detected CUDA version {cuda_version_string}')
# 7.5 is the minimum CC vor cublaslt # 7.5 is the minimum CC vor cublaslt
has_cublaslt = cc in ["7.5", "8.0", "8.6"] has_cublaslt = cc in ["7.5", "8.0", "8.6"]
...@@ -148,16 +154,13 @@ def evaluate_cuda_setup(): ...@@ -148,16 +154,13 @@ def evaluate_cuda_setup():
# we use ls -l instead of nvcc to determine the cuda version # we use ls -l instead of nvcc to determine the cuda version
# since most installations will have the libcudart.so installed, but not the compiler # since most installations will have the libcudart.so installed, but not the compiler
cuda_setup.add_log_entry(f'CUDA SETUP: Detected CUDA version {cuda_version_string}')
def get_binary_name(): if failure:
"if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so" binary_name = "libbitsandbytes_cpu.so"
bin_base_name = "libbitsandbytes_cuda" elif has_cublaslt:
if has_cublaslt: binary_name = f"libbitsandbytes_cuda{cuda_version_string}.so"
return f"{bin_base_name}{cuda_version_string}.so"
else: else:
return f"{bin_base_name}{cuda_version_string}_nocublaslt.so" "if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so"
binary_name = f"libbitsandbytes_cuda{cuda_version_string}_nocublaslt.so"
binary_name = get_binary_name()
return binary_name, cudart_path, cuda, cc, cuda_version_string return binary_name, cudart_path, cuda, cc, cuda_version_string
import errno import errno
from pathlib import Path from pathlib import Path
from typing import Set, Union from typing import Set, Union
from bitsandbytes.cextension import CUDASetup from bitsandbytes.cextension import CUDASetup
from .env_vars import get_potentially_lib_path_containing_env_vars from .env_vars import get_potentially_lib_path_containing_env_vars
......
import typer
cli = typer.Typer()
@cli.callback()
def callback():
"""
Awesome Portal Gun
"""
@cli.command()
def shoot():
"""
Shoot the portal gun
"""
typer.echo("Shooting portal gun")
@cli.command()
def load():
"""
Load the portal gun
"""
typer.echo("Loading portal gun")
...@@ -3,15 +3,19 @@ ...@@ -3,15 +3,19 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import ctypes as ct import ctypes as ct
import itertools
import operator import operator
import random import random
import torch import torch
import itertools
import math
from functools import reduce # Required in Python 3
from typing import Tuple from typing import Tuple
from torch import Tensor from torch import Tensor
from .cextension import COMPILED_WITH_CUDA, lib from .cextension import COMPILED_WITH_CUDA, lib
from functools import reduce # Required in Python 3
# math.prod not compatible with python < 3.8 # math.prod not compatible with python < 3.8
def prod(iterable): def prod(iterable):
...@@ -82,7 +86,7 @@ if COMPILED_WITH_CUDA: ...@@ -82,7 +86,7 @@ if COMPILED_WITH_CUDA:
) )
class CUBLAS_Context(object): class CUBLAS_Context:
_instance = None _instance = None
def __init__(self): def __init__(self):
...@@ -112,7 +116,7 @@ class CUBLAS_Context(object): ...@@ -112,7 +116,7 @@ class CUBLAS_Context(object):
return self.context[device.index] return self.context[device.index]
class Cusparse_Context(object): class Cusparse_Context:
_instance = None _instance = None
def __init__(self): def __init__(self):
...@@ -129,14 +133,73 @@ class Cusparse_Context(object): ...@@ -129,14 +133,73 @@ class Cusparse_Context(object):
return cls._instance return cls._instance
def create_linear_map(signed=True): def create_linear_map(signed=True, total_bits=8, add_zero=True):
if signed: sign = (-1.0 if signed else 0.0)
return torch.linspace(-1.0, 1.0, 256) total_values = 2**total_bits
if add_zero or total_bits < 8:
# add a zero
# since we simulate less bits by having zeros in the data type, we
# we need to center the quantization around zero and as such lose
# a single value
total_values = (2**total_bits if not signed else 2**total_bits-1)
values = torch.linspace(sign, 1.0, total_values)
gap = 256 - values.numel()
if gap == 0:
return values
else:
l = values.numel()//2
#return torch.Tensor(values[:l].tolist() + [-1e-6]*((gap//2)-1) + [0]*2 + [1e-6]*((gap//2)-1) + values[l:].tolist())
return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist())
def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8):
e = exponent_bits
p = precision_bits
has_sign = 1 if signed else 0
assert e+p == total_bits-has_sign
# the exponent is biased to 2^(e-1) -1 == 0
evalues = []
pvalues = []
for i, val in enumerate(range(-((2**(exponent_bits-has_sign))), 2**(exponent_bits-has_sign), 1)):
evalues.append(2**val)
values = []
lst = list(itertools.product([0, 1], repeat=precision_bits))
#for ev in evalues:
bias = 2**(exponent_bits-1)-1
for evalue in range(2**(exponent_bits)):
for bit_pattern in lst:
value = (1 if evalue != 0 else 0)
for i, pval in enumerate(list(bit_pattern)):
value += pval*(2**-(i+1))
if evalue == 0:
# subnormals
value = value*2**-(bias-1)
else: else:
return torch.linspace(0.0, 1.0, 256) # normals
value = value*2**-(evalue-bias-2)
values.append(value)
if signed:
values.append(-value)
def create_dynamic_map(signed=True, n=7): assert len(values) == 2**total_bits
values.sort()
if total_bits < 8:
gap = 256 - len(values)
for i in range(gap):
values.append(0)
values.sort()
code = torch.Tensor(values)
code /= code.max()
return code
def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
""" """
Creates the dynamic quantiztion map. Creates the dynamic quantiztion map.
...@@ -157,40 +220,57 @@ def create_dynamic_map(signed=True, n=7): ...@@ -157,40 +220,57 @@ def create_dynamic_map(signed=True, n=7):
# these are additional items that come from the case # these are additional items that come from the case
# where all the exponent bits are zero and no # where all the exponent bits are zero and no
# indicator bit is present # indicator bit is present
additional_items = 2 ** (7 - n) - 1 non_sign_bits = total_bits - (1 if signed else 0)
additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1
if not signed: if not signed:
additional_items = 2 * additional_items additional_items = 2 * additional_items
for i in range(n): for i in range(max_exponent_bits):
fraction_items = ( fraction_items = int((2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1))
2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1
)
boundaries = torch.linspace(0.1, 1, fraction_items) boundaries = torch.linspace(0.1, 1, fraction_items)
means = (boundaries[:-1] + boundaries[1:]) / 2.0 means = (boundaries[:-1] + boundaries[1:]) / 2.0
data += ((10 ** (-(n - 1) + i)) * means).tolist() data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
if signed: if signed:
data += (-(10 ** (-(n - 1) + i)) * means).tolist() data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
if additional_items > 0: if additional_items > 0:
boundaries = torch.linspace(0.1, 1, additional_items + 1) boundaries = torch.linspace(0.1, 1, additional_items + 1)
means = (boundaries[:-1] + boundaries[1:]) / 2.0 means = (boundaries[:-1] + boundaries[1:]) / 2.0
data += ((10 ** (-(n - 1) + i)) * means).tolist() data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
if signed: if signed:
data += (-(10 ** (-(n - 1) + i)) * means).tolist() data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
data.append(0) data.append(0)
data.append(1.0) data.append(1.0)
gap = 256 - len(data)
for i in range(gap):
data.append(0)
data.sort() data.sort()
return Tensor(data) return Tensor(data)
def create_quantile_map(A, total_bits=8):
q = estimate_quantiles(A, num_quantiles=2**total_bits-1)
q = q.tolist()
q.append(0)
gap = 256 - len(q)
for i in range(gap):
q.append(0)
q.sort()
q = Tensor(q)
q = q/q.abs().max()
return q
def get_special_format_str(): def get_special_format_str():
if not torch.cuda.is_available(): return 'col_turing' if not torch.cuda.is_available(): return 'col_turing'
major, minor = torch.cuda.get_device_capability() major, _minor = torch.cuda.get_device_capability()
if major <= 7: if major <= 7:
return "col_turing" return "col_turing"
elif major == 8: if major == 8:
return "col_ampere" return "col_ampere"
else:
return "col_turing" return "col_turing"
...@@ -318,16 +398,12 @@ def nvidia_transform( ...@@ -318,16 +398,12 @@ def nvidia_transform(
dim2 = ct.c_int32(shape[2]) dim2 = ct.c_int32(shape[2])
ptr = CUBLAS_Context.get_instance().get_context(A.device) ptr = CUBLAS_Context.get_instance().get_context(A.device)
ptrA = get_ptr(A)
ptrOut = get_ptr(out)
func(ptr, get_ptr(A), get_ptr(out), dim1, dim2) func(ptr, get_ptr(A), get_ptr(out), dim1, dim2)
return out, new_state return out, new_state
def estimate_quantiles( def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor:
A: Tensor, out: Tensor = None, offset: float = 1 / 512
) -> Tensor:
''' '''
Estimates 256 equidistant quantiles on the input tensor eCDF. Estimates 256 equidistant quantiles on the input tensor eCDF.
...@@ -347,25 +423,37 @@ def estimate_quantiles( ...@@ -347,25 +423,37 @@ def estimate_quantiles(
out : torch.Tensor out : torch.Tensor
Tensor with the 256 estimated quantiles. Tensor with the 256 estimated quantiles.
offset : float offset : float
The offset for the first and last quantile from 0 and 1. Default: 1/512 The offset for the first and last quantile from 0 and 1. Default: 1/(2*num_quantiles)
num_quantiles : int
The number of equally spaced quantiles.
Returns Returns
------- -------
torch.Tensor: torch.Tensor:
The 256 quantiles in float32 datatype. The 256 quantiles in float32 datatype.
''' '''
if A.numel() < 256: raise NotImplementedError(f'Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.')
if num_quantiles > 256: raise NotImplementedError(f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}")
if num_quantiles < 256 and offset == 1/(512):
# override default arguments
offset = 1/(2*num_quantiles)
if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device) if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device)
is_on_gpu([A, out]) is_on_gpu([A, out])
device = pre_call(A.device)
if A.dtype == torch.float32: if A.dtype == torch.float32:
lib.cestimate_quantiles_fp32( lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
)
elif A.dtype == torch.float16: elif A.dtype == torch.float16:
lib.cestimate_quantiles_fp16( lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
)
else: else:
raise NotImplementedError(f"Not supported data type {A.dtype}") raise NotImplementedError(f"Not supported data type {A.dtype}")
post_call(device)
if num_quantiles < 256:
step = round(256/num_quantiles)
idx = torch.linspace(0, 255, num_quantiles).long().to(A.device)
out = out[idx]
return out return out
...@@ -398,15 +486,14 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra ...@@ -398,15 +486,14 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
The quantization state to undo the quantization. The quantization state to undo the quantization.
""" """
if code is None: if code is None:
if "dynamic" not in name2qmap: if "dynamic" not in name2qmap:
name2qmap["dynamic"] = create_dynamic_map().to(A.device) name2qmap["dynamic"] = create_dynamic_map().to(A.device)
code = name2qmap["dynamic"] code = name2qmap["dynamic"]
code = code.to(A.device)
if absmax is None: if absmax is None:
n = A.numel() n = A.numel()
blocksize = (blocksize if A.device.type == 'cpu' else 4096)
blocks = n // blocksize blocks = n // blocksize
blocks += 1 if n % blocksize > 0 else 0 blocks += 1 if n % blocksize > 0 else 0
absmax = torch.zeros((blocks,), device=A.device) absmax = torch.zeros((blocks,), device=A.device)
...@@ -415,8 +502,13 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra ...@@ -415,8 +502,13 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
out = torch.zeros_like(A, dtype=torch.uint8) out = torch.zeros_like(A, dtype=torch.uint8)
if A.device.type != 'cpu': if A.device.type != 'cpu':
is_on_gpu([code, A, absmax, out, rand]) assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
cblocksize = ct.c_int32(blocksize)
prev_device = pre_call(A.device)
code = code.to(A.device)
if rand is not None: if rand is not None:
is_on_gpu([code, A, out, absmax, rand])
assert blocksize==4096
assert rand.numel() >= 1024 assert rand.numel() >= 1024
rand_offset = random.randint(0, 1023) rand_offset = random.randint(0, 1023)
if A.dtype == torch.float32: if A.dtype == torch.float32:
...@@ -424,20 +516,19 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra ...@@ -424,20 +516,19 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
elif A.dtype == torch.float16: elif A.dtype == torch.float16:
lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel())) lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
else: else:
raise ValueError( raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
)
else: else:
is_on_gpu([code, A, out, absmax])
if A.dtype == torch.float32: if A.dtype == torch.float32:
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel())) lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
elif A.dtype == torch.float16: elif A.dtype == torch.float16:
lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel())) lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
else: else:
raise ValueError( raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" post_call(A.device)
)
else: else:
# cpu # cpu
code = code.cpu()
assert rand is None assert rand is None
lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
...@@ -482,27 +573,30 @@ def dequantize_blockwise( ...@@ -482,27 +573,30 @@ def dequantize_blockwise(
if "dynamic" not in name2qmap: if "dynamic" not in name2qmap:
name2qmap["dynamic"] = create_dynamic_map().to(A.device) name2qmap["dynamic"] = create_dynamic_map().to(A.device)
code = name2qmap["dynamic"] code = name2qmap["dynamic"]
code = code.to(A.device)
if out is None: if out is None:
out = torch.zeros_like(A, dtype=torch.float32) out = torch.zeros_like(A, dtype=torch.float32)
if quant_state is None: if quant_state is None:
quant_state = (absmax, code) quant_state = (absmax, code)
else:
absmax, code = quant_state
if A.device.type != 'cpu': if A.device.type != 'cpu':
if blocksize not in [2048, 4096]: device = pre_call(A.device)
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]") code = code.to(A.device)
if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]")
is_on_gpu([A, out]) is_on_gpu([A, out])
if out.dtype == torch.float32: if out.dtype == torch.float32:
lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
elif out.dtype == torch.float16: elif out.dtype == torch.float16:
lib.cdequantize_blockwise_fp16(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) lib.cdequantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
else: else:
raise ValueError( raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" post_call(A.device)
)
else: else:
code = code.cpu()
lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
return out return out
...@@ -958,7 +1052,7 @@ def histogram_scatter_add_2d( ...@@ -958,7 +1052,7 @@ def histogram_scatter_add_2d(
maxdim1 = ct.c_int32(histogram.shape[0]) maxdim1 = ct.c_int32(histogram.shape[0])
n = ct.c_int32(index1.numel()) n = ct.c_int32(index1.numel())
is_on_gpu([histogram, index1, index2d, source]) is_on_gpu([histogram, index1, index2, source])
lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n) lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n)
def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8): def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8):
...@@ -1417,7 +1511,7 @@ def get_colrow_absmax( ...@@ -1417,7 +1511,7 @@ def get_colrow_absmax(
return row_stats, col_stats, nnz_block_ptr return row_stats, col_stats, nnz_block_ptr
class COOSparseTensor(object): class COOSparseTensor:
def __init__(self, rows, cols, nnz, rowidx, colidx, values): def __init__(self, rows, cols, nnz, rowidx, colidx, values):
assert rowidx.dtype == torch.int32 assert rowidx.dtype == torch.int32
assert colidx.dtype == torch.int32 assert colidx.dtype == torch.int32
...@@ -1434,7 +1528,7 @@ class COOSparseTensor(object): ...@@ -1434,7 +1528,7 @@ class COOSparseTensor(object):
self.values = values self.values = values
class CSRSparseTensor(object): class CSRSparseTensor:
def __init__(self, rows, cols, nnz, rowptr, colidx, values): def __init__(self, rows, cols, nnz, rowptr, colidx, values):
assert rowptr.dtype == torch.int32 assert rowptr.dtype == torch.int32
assert colidx.dtype == torch.int32 assert colidx.dtype == torch.int32
...@@ -1451,7 +1545,7 @@ class CSRSparseTensor(object): ...@@ -1451,7 +1545,7 @@ class CSRSparseTensor(object):
self.values = values self.values = values
class CSCSparseTensor(object): class CSCSparseTensor:
def __init__(self, rows, cols, nnz, colptr, rowidx, values): def __init__(self, rows, cols, nnz, colptr, rowidx, values):
assert colptr.dtype == torch.int32 assert colptr.dtype == torch.int32
assert rowidx.dtype == torch.int32 assert rowidx.dtype == torch.int32
...@@ -1615,8 +1709,6 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No ...@@ -1615,8 +1709,6 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
dim1 = ct.c_int32(shape[0] * shape[1]) dim1 = ct.c_int32(shape[0] * shape[1])
dim2 = ct.c_int32(shape[2]) dim2 = ct.c_int32(shape[2])
ptrA = get_ptr(A)
ptrOut = get_ptr(out)
is_on_gpu([A, out]) is_on_gpu([A, out])
if to_order == 'col32': if to_order == 'col32':
if transpose: if transpose:
......
...@@ -2,24 +2,11 @@ ...@@ -2,24 +2,11 @@
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import ( from typing import Optional, TypeVar, Union, overload
Any,
Callable,
Dict,
Iterator,
Mapping,
Optional,
Set,
Tuple,
TypeVar,
Union,
overload,
)
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor, device, dtype, nn from torch import Tensor, device, dtype, nn
from torch.nn.parameter import Parameter
import bitsandbytes as bnb import bitsandbytes as bnb
from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.optim import GlobalOptimManager
...@@ -39,7 +26,7 @@ class StableEmbedding(torch.nn.Embedding): ...@@ -39,7 +26,7 @@ class StableEmbedding(torch.nn.Embedding):
sparse: bool = False, sparse: bool = False,
_weight: Optional[Tensor] = None, _weight: Optional[Tensor] = None,
) -> None: ) -> None:
super(StableEmbedding, self).__init__( super().__init__(
num_embeddings, num_embeddings,
embedding_dim, embedding_dim,
padding_idx, padding_idx,
...@@ -96,7 +83,7 @@ class Embedding(torch.nn.Embedding): ...@@ -96,7 +83,7 @@ class Embedding(torch.nn.Embedding):
sparse: bool = False, sparse: bool = False,
_weight: Optional[Tensor] = None, _weight: Optional[Tensor] = None,
) -> None: ) -> None:
super(Embedding, self).__init__( super().__init__(
num_embeddings, num_embeddings,
embedding_dim, embedding_dim,
padding_idx, padding_idx,
...@@ -225,7 +212,7 @@ class Linear8bitLt(nn.Linear): ...@@ -225,7 +212,7 @@ class Linear8bitLt(nn.Linear):
threshold=0.0, threshold=0.0,
index=None, index=None,
): ):
super(Linear8bitLt, self).__init__( super().__init__(
input_features, output_features, bias input_features, output_features, bias
) )
self.state = bnb.MatmulLtState() self.state = bnb.MatmulLtState()
......
...@@ -5,12 +5,11 @@ ...@@ -5,12 +5,11 @@
from bitsandbytes.cextension import COMPILED_WITH_CUDA from bitsandbytes.cextension import COMPILED_WITH_CUDA
from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
from .adam import Adam, Adam8bit, Adam32bit from .adam import Adam, Adam8bit, Adam32bit
from .adamw import AdamW, AdamW8bit, AdamW32bit from .adamw import AdamW, AdamW8bit, AdamW32bit
from .sgd import SGD, SGD8bit, SGD32bit
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
from .lamb import LAMB, LAMB8bit, LAMB32bit from .lamb import LAMB, LAMB8bit, LAMB32bit
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
from .optimizer import GlobalOptimManager from .optimizer import GlobalOptimManager
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
from .sgd import SGD, SGD8bit, SGD32bit
...@@ -21,18 +21,18 @@ class Adagrad(Optimizer1State): ...@@ -21,18 +21,18 @@ class Adagrad(Optimizer1State):
block_wise=True, block_wise=True,
): ):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError( raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay) f"Invalid weight_decay value: {weight_decay}"
) )
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps)) raise ValueError(f"Invalid epsilon value: {eps}")
if initial_accumulator_value != 0.0: if initial_accumulator_value != 0.0:
raise ValueError("Initial accumulator value != 0.0 not supported!") raise ValueError("Initial accumulator value != 0.0 not supported!")
if lr_decay != 0.0: if lr_decay != 0.0:
raise ValueError("Lr Decay != 0.0 not supported!") raise ValueError("Lr Decay != 0.0 not supported!")
super(Adagrad, self).__init__( super().__init__(
"adagrad", "adagrad",
params, params,
lr, lr,
...@@ -63,19 +63,19 @@ class Adagrad8bit(Optimizer1State): ...@@ -63,19 +63,19 @@ class Adagrad8bit(Optimizer1State):
block_wise=True, block_wise=True,
): ):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError( raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay) f"Invalid weight_decay value: {weight_decay}"
) )
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps)) raise ValueError(f"Invalid epsilon value: {eps}")
if initial_accumulator_value != 0.0: if initial_accumulator_value != 0.0:
raise ValueError("Initial accumulator value != 0.0 not supported!") raise ValueError("Initial accumulator value != 0.0 not supported!")
if lr_decay != 0.0: if lr_decay != 0.0:
raise ValueError("Lr Decay != 0.0 not supported!") raise ValueError("Lr Decay != 0.0 not supported!")
assert block_wise assert block_wise
super(Adagrad8bit, self).__init__( super().__init__(
"adagrad", "adagrad",
params, params,
lr, lr,
...@@ -106,18 +106,18 @@ class Adagrad32bit(Optimizer1State): ...@@ -106,18 +106,18 @@ class Adagrad32bit(Optimizer1State):
block_wise=True, block_wise=True,
): ):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError( raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay) f"Invalid weight_decay value: {weight_decay}"
) )
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps)) raise ValueError(f"Invalid epsilon value: {eps}")
if initial_accumulator_value != 0.0: if initial_accumulator_value != 0.0:
raise ValueError("Initial accumulator value != 0.0 not supported!") raise ValueError("Initial accumulator value != 0.0 not supported!")
if lr_decay != 0.0: if lr_decay != 0.0:
raise ValueError("Lr Decay != 0.0 not supported!") raise ValueError("Lr Decay != 0.0 not supported!")
super(Adagrad32bit, self).__init__( super().__init__(
"adagrad", "adagrad",
params, params,
lr, lr,
......
...@@ -28,7 +28,7 @@ class Adam(Optimizer2State): ...@@ -28,7 +28,7 @@ class Adam(Optimizer2State):
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
super(Adam, self).__init__( super().__init__(
"adam", "adam",
params, params,
lr, lr,
...@@ -57,7 +57,7 @@ class Adam8bit(Optimizer2State): ...@@ -57,7 +57,7 @@ class Adam8bit(Optimizer2State):
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
super(Adam8bit, self).__init__( super().__init__(
"adam", "adam",
params, params,
lr, lr,
...@@ -86,7 +86,7 @@ class Adam32bit(Optimizer2State): ...@@ -86,7 +86,7 @@ class Adam32bit(Optimizer2State):
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
super(Adam32bit, self).__init__( super().__init__(
"adam", "adam",
params, params,
lr, lr,
...@@ -146,7 +146,7 @@ class AnalysisAdam(torch.optim.Optimizer): ...@@ -146,7 +146,7 @@ class AnalysisAdam(torch.optim.Optimizer):
weight_decay=weight_decay, weight_decay=weight_decay,
amsgrad=amsgrad, amsgrad=amsgrad,
) )
super(AnalysisAdam, self).__init__(params, defaults) super().__init__(params, defaults)
self.analysis = bnb_analysis self.analysis = bnb_analysis
self.savedir = savedir self.savedir = savedir
......
...@@ -20,7 +20,7 @@ class AdamW(Optimizer2State): ...@@ -20,7 +20,7 @@ class AdamW(Optimizer2State):
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
super(AdamW, self).__init__( super().__init__(
"adam", "adam",
params, params,
lr, lr,
...@@ -49,7 +49,7 @@ class AdamW8bit(Optimizer2State): ...@@ -49,7 +49,7 @@ class AdamW8bit(Optimizer2State):
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
super(AdamW8bit, self).__init__( super().__init__(
"adam", "adam",
params, params,
lr, lr,
...@@ -78,7 +78,7 @@ class AdamW32bit(Optimizer2State): ...@@ -78,7 +78,7 @@ class AdamW32bit(Optimizer2State):
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
super(AdamW32bit, self).__init__( super().__init__(
"adam", "adam",
params, params,
lr, lr,
......
...@@ -23,7 +23,7 @@ class LAMB(Optimizer2State): ...@@ -23,7 +23,7 @@ class LAMB(Optimizer2State):
block_wise=False, block_wise=False,
max_unorm=1.0, max_unorm=1.0,
): ):
super(LAMB, self).__init__( super().__init__(
"lamb", "lamb",
params, params,
lr, lr,
...@@ -56,7 +56,7 @@ class LAMB8bit(Optimizer2State): ...@@ -56,7 +56,7 @@ class LAMB8bit(Optimizer2State):
block_wise=False, block_wise=False,
max_unorm=1.0, max_unorm=1.0,
): ):
super(LAMB8bit, self).__init__( super().__init__(
"lamb", "lamb",
params, params,
lr, lr,
...@@ -89,7 +89,7 @@ class LAMB32bit(Optimizer2State): ...@@ -89,7 +89,7 @@ class LAMB32bit(Optimizer2State):
block_wise=False, block_wise=False,
max_unorm=1.0, max_unorm=1.0,
): ):
super(LAMB32bit, self).__init__( super().__init__(
"lamb", "lamb",
params, params,
lr, lr,
......
...@@ -25,9 +25,9 @@ class LARS(Optimizer1State): ...@@ -25,9 +25,9 @@ class LARS(Optimizer1State):
): ):
if momentum == 0: if momentum == 0:
raise NotImplementedError( raise NotImplementedError(
f"LARS without momentum is not supported!" "LARS without momentum is not supported!"
) )
super(LARS, self).__init__( super().__init__(
"lars", "lars",
params, params,
lr, lr,
...@@ -59,9 +59,9 @@ class LARS8bit(Optimizer1State): ...@@ -59,9 +59,9 @@ class LARS8bit(Optimizer1State):
): ):
if momentum == 0: if momentum == 0:
raise NotImplementedError( raise NotImplementedError(
f"LARS without momentum is not supported!" "LARS without momentum is not supported!"
) )
super(LARS8bit, self).__init__( super().__init__(
"lars", "lars",
params, params,
lr, lr,
...@@ -93,9 +93,9 @@ class LARS32bit(Optimizer1State): ...@@ -93,9 +93,9 @@ class LARS32bit(Optimizer1State):
): ):
if momentum == 0: if momentum == 0:
raise NotImplementedError( raise NotImplementedError(
f"LARS without momentum is not supported!" "LARS without momentum is not supported!"
) )
super(LARS32bit, self).__init__( super().__init__(
"lars", "lars",
params, params,
lr, lr,
...@@ -123,12 +123,12 @@ class PytorchLARS(Optimizer): ...@@ -123,12 +123,12 @@ class PytorchLARS(Optimizer):
max_unorm=0.02, max_unorm=0.02,
): ):
if lr < 0.0: if lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError(f"Invalid learning rate: {lr}")
if momentum < 0.0: if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum)) raise ValueError(f"Invalid momentum value: {momentum}")
if weight_decay < 0.0: if weight_decay < 0.0:
raise ValueError( raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay) f"Invalid weight_decay value: {weight_decay}"
) )
defaults = dict( defaults = dict(
...@@ -143,10 +143,10 @@ class PytorchLARS(Optimizer): ...@@ -143,10 +143,10 @@ class PytorchLARS(Optimizer):
raise ValueError( raise ValueError(
"Nesterov momentum requires a momentum and zero dampening" "Nesterov momentum requires a momentum and zero dampening"
) )
super(PytorchLARS, self).__init__(params, defaults) super().__init__(params, defaults)
def __setstate__(self, state): def __setstate__(self, state):
super(PytorchLARS, self).__setstate__(state) super().__setstate__(state)
for group in self.param_groups: for group in self.param_groups:
group.setdefault("nesterov", False) group.setdefault("nesterov", False)
...@@ -181,7 +181,7 @@ class PytorchLARS(Optimizer): ...@@ -181,7 +181,7 @@ class PytorchLARS(Optimizer):
state = self.state[p] state = self.state[p]
d_p = p.grad d_p = p.grad
if weight_decay != 0: if weight_decay != 0:
d_p = d_p.add(param, alpha=weight_decay) d_p = d_p.add(p, alpha=weight_decay)
if momentum != 0: if momentum != 0:
buf = state.get("momentum_buffer", None) buf = state.get("momentum_buffer", None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment