Unverified Commit 3418cd39 authored by Tim Dettmers's avatar Tim Dettmers Committed by GitHub
Browse files

Merge pull request #2 from TimDettmers/fix_imports

Remove unused imports, fix NotImplementedError
parents 4e60e7dc 33efe4a0
...@@ -2,13 +2,13 @@ ...@@ -2,13 +2,13 @@
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import ctypes as ct
import os import os
import random import random
import math from typing import Tuple
import ctypes as ct
import torch import torch
from torch import Tensor from torch import Tensor
from typing import Tuple
lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libbitsandbytes.so') lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libbitsandbytes.so')
name2qmap = {} name2qmap = {}
......
...@@ -7,7 +7,6 @@ import torch ...@@ -7,7 +7,6 @@ import torch
from typing import Optional from typing import Optional
from torch import Tensor from torch import Tensor
from torch.nn.parameter import Parameter
import torch.nn.functional as F import torch.nn.functional as F
from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.optim import GlobalOptimManager
......
...@@ -2,11 +2,8 @@ ...@@ -2,11 +2,8 @@
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch
from bitsandbytes.optim.optimizer import Optimizer1State from bitsandbytes.optim.optimizer import Optimizer1State
torch.optim.Adagrad
class Adagrad(Optimizer1State): class Adagrad(Optimizer1State):
def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10, def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10,
optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
......
...@@ -2,9 +2,7 @@ ...@@ -2,9 +2,7 @@
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch
from bitsandbytes.optim.optimizer import Optimizer2State from bitsandbytes.optim.optimizer import Optimizer2State
import bitsandbytes.functional as F
class AdamW(Optimizer2State): class AdamW(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
......
...@@ -12,7 +12,7 @@ class LARS(Optimizer1State): ...@@ -12,7 +12,7 @@ class LARS(Optimizer1State):
weight_decay=0, nesterov=False, optim_bits=32, args=None, weight_decay=0, nesterov=False, optim_bits=32, args=None,
min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02): min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
if momentum == 0: if momentum == 0:
raise NotImplementError(f'LARS without momentum is not supported!') raise NotImplementedError(f'LARS without momentum is not supported!')
super(LARS, self).__init__('lars', params, lr, (momentum, dampening), 0.0, super(LARS, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False) weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
...@@ -21,7 +21,7 @@ class LARS8bit(Optimizer1State): ...@@ -21,7 +21,7 @@ class LARS8bit(Optimizer1State):
weight_decay=0, nesterov=False, args=None, weight_decay=0, nesterov=False, args=None,
min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02): min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
if momentum == 0: if momentum == 0:
raise NotImplementError(f'LARS without momentum is not supported!') raise NotImplementedError(f'LARS without momentum is not supported!')
super(LARS8bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0, super(LARS8bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
weight_decay, 8, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False) weight_decay, 8, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
...@@ -30,7 +30,7 @@ class LARS32bit(Optimizer1State): ...@@ -30,7 +30,7 @@ class LARS32bit(Optimizer1State):
weight_decay=0, nesterov=False, args=None, weight_decay=0, nesterov=False, args=None,
min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02): min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
if momentum == 0: if momentum == 0:
raise NotImplementError(f'LARS without momentum is not supported!') raise NotImplementedError(f'LARS without momentum is not supported!')
super(LARS32bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0, super(LARS32bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
weight_decay, 32, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False) weight_decay, 32, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
......
...@@ -2,16 +2,15 @@ ...@@ -2,16 +2,15 @@
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch
from bitsandbytes.optim.optimizer import Optimizer1State from bitsandbytes.optim.optimizer import Optimizer1State
class RMSprop(Optimizer1State): class RMSprop(Optimizer1State):
def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, optim_bits=32, args=None, def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, optim_bits=32, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=True): min_8bit_size=4096, percentile_clipping=100, block_wise=True):
if alpha == 0: if alpha == 0:
raise NotImplementError(f'RMSprop with alpha==0.0 is not supported!') raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!')
if centered: if centered:
raise NotImplementError(f'Centered RMSprop is not supported!') raise NotImplementedError(f'Centered RMSprop is not supported!')
super(RMSprop, self).__init__('rmsprop', params, lr, (alpha, momentum), eps, super(RMSprop, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise) weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
...@@ -19,9 +18,9 @@ class RMSprop8bit(Optimizer1State): ...@@ -19,9 +18,9 @@ class RMSprop8bit(Optimizer1State):
def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, args=None, def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=True): min_8bit_size=4096, percentile_clipping=100, block_wise=True):
if alpha == 0: if alpha == 0:
raise NotImplementError(f'RMSprop with alpha==0.0 is not supported!') raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!')
if centered: if centered:
raise NotImplementError(f'Centered RMSprop is not supported!') raise NotImplementedError(f'Centered RMSprop is not supported!')
super(RMSprop8bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps, super(RMSprop8bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise) weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
...@@ -30,7 +29,7 @@ class RMSprop32bit(Optimizer1State): ...@@ -30,7 +29,7 @@ class RMSprop32bit(Optimizer1State):
min_8bit_size=4096, percentile_clipping=100, block_wise=True): min_8bit_size=4096, percentile_clipping=100, block_wise=True):
if alpha == 0: if alpha == 0:
raise NotImplementError(f'RMSprop with alpha==0.0 is not supported!') raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!')
if centered: if centered:
raise NotImplementError(f'Centered RMSprop is not supported!') raise NotImplementError(f'Centered RMSprop is not supported!')
super(RMSprop32bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps, super(RMSprop32bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
......
...@@ -9,7 +9,7 @@ class SGD(Optimizer1State): ...@@ -9,7 +9,7 @@ class SGD(Optimizer1State):
weight_decay=0, nesterov=False, optim_bits=32, args=None, weight_decay=0, nesterov=False, optim_bits=32, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=True): min_8bit_size=4096, percentile_clipping=100, block_wise=True):
if momentum == 0: if momentum == 0:
raise NotImplementError(f'SGD without momentum is not supported!') raise NotImplementedError(f'SGD without momentum is not supported!')
super(SGD, self).__init__('momentum', params, lr, (momentum, dampening), 0.0, super(SGD, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise) weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
...@@ -18,7 +18,7 @@ class SGD8bit(Optimizer1State): ...@@ -18,7 +18,7 @@ class SGD8bit(Optimizer1State):
weight_decay=0, nesterov=False, args=None, weight_decay=0, nesterov=False, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=True): min_8bit_size=4096, percentile_clipping=100, block_wise=True):
if momentum == 0: if momentum == 0:
raise NotImplementError(f'SGD without momentum is not supported!') raise NotImplementedError(f'SGD without momentum is not supported!')
super(SGD8bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0, super(SGD8bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise) weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
...@@ -27,6 +27,6 @@ class SGD32bit(Optimizer1State): ...@@ -27,6 +27,6 @@ class SGD32bit(Optimizer1State):
weight_decay=0, nesterov=False, args=None, weight_decay=0, nesterov=False, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=True): min_8bit_size=4096, percentile_clipping=100, block_wise=True):
if momentum == 0: if momentum == 0:
raise NotImplementError(f'SGD without momentum is not supported!') raise NotImplementedError(f'SGD without momentum is not supported!')
super(SGD32bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0, super(SGD32bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
...@@ -6,10 +6,6 @@ import pytest ...@@ -6,10 +6,6 @@ import pytest
import torch import torch
import bitsandbytes as bnb import bitsandbytes as bnb
from itertools import product
from bitsandbytes import functional as F
@pytest.mark.parametrize("embcls", [bnb.nn.Embedding, bnb.nn.StableEmbedding], ids=['Embedding', 'StableEmbedding']) @pytest.mark.parametrize("embcls", [bnb.nn.Embedding, bnb.nn.StableEmbedding], ids=['Embedding', 'StableEmbedding'])
def test_embeddings(embcls): def test_embeddings(embcls):
......
...@@ -7,7 +7,6 @@ import time ...@@ -7,7 +7,6 @@ import time
import shutil import shutil
import uuid import uuid
import pytest import pytest
import ctypes
import torch import torch
import bitsandbytes as bnb import bitsandbytes as bnb
import bitsandbytes.functional as F import bitsandbytes.functional as F
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment