Unverified Commit 3418cd39 authored by Tim Dettmers's avatar Tim Dettmers Committed by GitHub
Browse files

Merge pull request #2 from TimDettmers/fix_imports

Remove unused imports, fix NotImplementedError
parents 4e60e7dc 33efe4a0
......@@ -2,13 +2,13 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import ctypes as ct
import os
import random
import math
import ctypes as ct
from typing import Tuple
import torch
from torch import Tensor
from typing import Tuple
lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libbitsandbytes.so')
name2qmap = {}
......
......@@ -7,7 +7,6 @@ import torch
from typing import Optional
from torch import Tensor
from torch.nn.parameter import Parameter
import torch.nn.functional as F
from bitsandbytes.optim import GlobalOptimManager
......
......@@ -2,11 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from bitsandbytes.optim.optimizer import Optimizer1State
torch.optim.Adagrad
class Adagrad(Optimizer1State):
def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10,
optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
......
......@@ -2,9 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from bitsandbytes.optim.optimizer import Optimizer2State
import bitsandbytes.functional as F
class AdamW(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
......
......@@ -12,7 +12,7 @@ class LARS(Optimizer1State):
weight_decay=0, nesterov=False, optim_bits=32, args=None,
min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
if momentum == 0:
raise NotImplementError(f'LARS without momentum is not supported!')
raise NotImplementedError(f'LARS without momentum is not supported!')
super(LARS, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
......@@ -21,7 +21,7 @@ class LARS8bit(Optimizer1State):
weight_decay=0, nesterov=False, args=None,
min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
if momentum == 0:
raise NotImplementError(f'LARS without momentum is not supported!')
raise NotImplementedError(f'LARS without momentum is not supported!')
super(LARS8bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
weight_decay, 8, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
......@@ -30,7 +30,7 @@ class LARS32bit(Optimizer1State):
weight_decay=0, nesterov=False, args=None,
min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
if momentum == 0:
raise NotImplementError(f'LARS without momentum is not supported!')
raise NotImplementedError(f'LARS without momentum is not supported!')
super(LARS32bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
weight_decay, 32, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
......
......@@ -2,16 +2,15 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from bitsandbytes.optim.optimizer import Optimizer1State
class RMSprop(Optimizer1State):
def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, optim_bits=32, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
if alpha == 0:
raise NotImplementError(f'RMSprop with alpha==0.0 is not supported!')
raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!')
if centered:
raise NotImplementError(f'Centered RMSprop is not supported!')
raise NotImplementedError(f'Centered RMSprop is not supported!')
super(RMSprop, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
......@@ -19,9 +18,9 @@ class RMSprop8bit(Optimizer1State):
def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
if alpha == 0:
raise NotImplementError(f'RMSprop with alpha==0.0 is not supported!')
raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!')
if centered:
raise NotImplementError(f'Centered RMSprop is not supported!')
raise NotImplementedError(f'Centered RMSprop is not supported!')
super(RMSprop8bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
......@@ -30,7 +29,7 @@ class RMSprop32bit(Optimizer1State):
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
if alpha == 0:
raise NotImplementError(f'RMSprop with alpha==0.0 is not supported!')
raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!')
if centered:
raise NotImplementError(f'Centered RMSprop is not supported!')
super(RMSprop32bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
......
......@@ -9,7 +9,7 @@ class SGD(Optimizer1State):
weight_decay=0, nesterov=False, optim_bits=32, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
if momentum == 0:
raise NotImplementError(f'SGD without momentum is not supported!')
raise NotImplementedError(f'SGD without momentum is not supported!')
super(SGD, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
......@@ -18,7 +18,7 @@ class SGD8bit(Optimizer1State):
weight_decay=0, nesterov=False, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
if momentum == 0:
raise NotImplementError(f'SGD without momentum is not supported!')
raise NotImplementedError(f'SGD without momentum is not supported!')
super(SGD8bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
......@@ -27,6 +27,6 @@ class SGD32bit(Optimizer1State):
weight_decay=0, nesterov=False, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
if momentum == 0:
raise NotImplementError(f'SGD without momentum is not supported!')
raise NotImplementedError(f'SGD without momentum is not supported!')
super(SGD32bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
......@@ -6,10 +6,6 @@ import pytest
import torch
import bitsandbytes as bnb
from itertools import product
from bitsandbytes import functional as F
@pytest.mark.parametrize("embcls", [bnb.nn.Embedding, bnb.nn.StableEmbedding], ids=['Embedding', 'StableEmbedding'])
def test_embeddings(embcls):
......
......@@ -7,7 +7,6 @@ import time
import shutil
import uuid
import pytest
import ctypes
import torch
import bitsandbytes as bnb
import bitsandbytes.functional as F
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment