Unverified Commit f0ec93d0 authored by Tim Dettmers's avatar Tim Dettmers Committed by GitHub
Browse files

Merge pull request #76 from tomaarsen/cleanup

Cleanup involving a handful of failures, some optimization and a lot of code quality improvements
parents c059bd28 c91f592a
...@@ -12,13 +12,13 @@ import torch ...@@ -12,13 +12,13 @@ import torch
import bitsandbytes.functional as F import bitsandbytes.functional as F
class MockArgs(object): class MockArgs:
def __init__(self, initial_data): def __init__(self, initial_data):
for key in initial_data: for key in initial_data:
setattr(self, key, initial_data[key]) setattr(self, key, initial_data[key])
class GlobalOptimManager(object): class GlobalOptimManager:
_instance = None _instance = None
def __init__(self): def __init__(self):
...@@ -56,9 +56,9 @@ class GlobalOptimManager(object): ...@@ -56,9 +56,9 @@ class GlobalOptimManager(object):
""" """
Overrides initial optimizer config for specific parameters. Overrides initial optimizer config for specific parameters.
The key-values of the optimizer config for the input parameters are overidden The key-values of the optimizer config for the input parameters are overridden
This can be both, optimizer parameters like "betas", or "lr" or it can be This can be both, optimizer parameters like "betas", or "lr" or it can be
8-bit specific paramters like "optim_bits", "percentile_clipping". 8-bit specific parameters like "optim_bits", "percentile_clipping".
Parameters Parameters
---------- ----------
...@@ -93,13 +93,12 @@ class GlobalOptimManager(object): ...@@ -93,13 +93,12 @@ class GlobalOptimManager(object):
class Optimizer8bit(torch.optim.Optimizer): class Optimizer8bit(torch.optim.Optimizer):
def __init__(self, params, defaults, optim_bits=32): def __init__(self, params, defaults, optim_bits=32):
super(Optimizer8bit, self).__init__(params, defaults) super().__init__(params, defaults)
self.initialized = False self.initialized = False
self.name2qmap = {} self.name2qmap = {}
self.mng = GlobalOptimManager.get_instance() self.mng = GlobalOptimManager.get_instance()
self.non_castable_tensor_keys = set( self.non_castable_tensor_keys = {
[
"qmap1", "qmap1",
"qmap2", "qmap2",
"max1", "max1",
...@@ -112,8 +111,7 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -112,8 +111,7 @@ class Optimizer8bit(torch.optim.Optimizer):
"absmax1", "absmax1",
"absmax2", "absmax2",
"unorm_vec", "unorm_vec",
] }
)
if optim_bits == 8: if optim_bits == 8:
self.fill_qmap() self.fill_qmap()
...@@ -123,7 +121,7 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -123,7 +121,7 @@ class Optimizer8bit(torch.optim.Optimizer):
self.name2qmap["udynamic"] = F.create_dynamic_map(signed=False) self.name2qmap["udynamic"] = F.create_dynamic_map(signed=False)
def __setstate__(self, state): def __setstate__(self, state):
super(Optimizer8bit, self).__setstate__(state) super().__setstate__(state)
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
r"""Loads the optimizer state. r"""Loads the optimizer state.
...@@ -155,8 +153,8 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -155,8 +153,8 @@ class Optimizer8bit(torch.optim.Optimizer):
id_map = { id_map = {
old_id: p old_id: p
for old_id, p in zip( for old_id, p in zip(
chain.from_iterable((g["params"] for g in saved_groups)), chain.from_iterable(g["params"] for g in saved_groups),
chain.from_iterable((g["params"] for g in groups)), chain.from_iterable(g["params"] for g in groups),
) )
} }
...@@ -284,11 +282,11 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -284,11 +282,11 @@ class Optimizer8bit(torch.optim.Optimizer):
return config return config
def init_state(self, group, p, gindex, pindex): def init_state(self, group, p, gindex, pindex):
raise NotImplementedError(f"init_state method needs to be overidden") raise NotImplementedError("init_state method needs to be overridden")
def update_step(self, group, p, gindex, pindex): def update_step(self, group, p, gindex, pindex):
raise NotImplementedError( raise NotImplementedError(
f"The update_step method needs to be overidden" "The update_step method needs to be overridden"
) )
...@@ -310,9 +308,9 @@ class Optimizer2State(Optimizer8bit): ...@@ -310,9 +308,9 @@ class Optimizer2State(Optimizer8bit):
skip_zeros=False, skip_zeros=False,
): ):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps)) raise ValueError(f"Invalid epsilon value: {eps}")
if isinstance(betas, str): if isinstance(betas, str):
# format: '(beta1, beta2)' # format: '(beta1, beta2)'
betas = betas.replace("(", "").replace(")", "").strip().split(",") betas = betas.replace("(", "").replace(")", "").strip().split(",")
...@@ -324,10 +322,10 @@ class Optimizer2State(Optimizer8bit): ...@@ -324,10 +322,10 @@ class Optimizer2State(Optimizer8bit):
) )
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError( raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay) f"Invalid weight_decay value: {weight_decay}"
) )
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(Optimizer2State, self).__init__(params, defaults, optim_bits) super().__init__(params, defaults, optim_bits)
if args is None: if args is None:
args = {} args = {}
...@@ -542,9 +540,9 @@ class Optimizer1State(Optimizer8bit): ...@@ -542,9 +540,9 @@ class Optimizer1State(Optimizer8bit):
skip_zeros=False, skip_zeros=False,
): ):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps)) raise ValueError(f"Invalid epsilon value: {eps}")
for i in range(len(betas)): for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0: if not 0.0 <= betas[i] < 1.0:
raise ValueError( raise ValueError(
...@@ -552,10 +550,10 @@ class Optimizer1State(Optimizer8bit): ...@@ -552,10 +550,10 @@ class Optimizer1State(Optimizer8bit):
) )
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError( raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay) f"Invalid weight_decay value: {weight_decay}"
) )
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(Optimizer1State, self).__init__(params, defaults, optim_bits) super().__init__(params, defaults, optim_bits)
if args is None: if args is None:
args = {} args = {}
......
...@@ -23,11 +23,11 @@ class RMSprop(Optimizer1State): ...@@ -23,11 +23,11 @@ class RMSprop(Optimizer1State):
): ):
if alpha == 0: if alpha == 0:
raise NotImplementedError( raise NotImplementedError(
f"RMSprop with alpha==0.0 is not supported!" "RMSprop with alpha==0.0 is not supported!"
) )
if centered: if centered:
raise NotImplementedError(f"Centered RMSprop is not supported!") raise NotImplementedError("Centered RMSprop is not supported!")
super(RMSprop, self).__init__( super().__init__(
"rmsprop", "rmsprop",
params, params,
lr, lr,
...@@ -59,11 +59,11 @@ class RMSprop8bit(Optimizer1State): ...@@ -59,11 +59,11 @@ class RMSprop8bit(Optimizer1State):
): ):
if alpha == 0: if alpha == 0:
raise NotImplementedError( raise NotImplementedError(
f"RMSprop with alpha==0.0 is not supported!" "RMSprop with alpha==0.0 is not supported!"
) )
if centered: if centered:
raise NotImplementedError(f"Centered RMSprop is not supported!") raise NotImplementedError("Centered RMSprop is not supported!")
super(RMSprop8bit, self).__init__( super().__init__(
"rmsprop", "rmsprop",
params, params,
lr, lr,
...@@ -96,11 +96,11 @@ class RMSprop32bit(Optimizer1State): ...@@ -96,11 +96,11 @@ class RMSprop32bit(Optimizer1State):
if alpha == 0: if alpha == 0:
raise NotImplementedError( raise NotImplementedError(
f"RMSprop with alpha==0.0 is not supported!" "RMSprop with alpha==0.0 is not supported!"
) )
if centered: if centered:
raise NotImplementedError(f"Centered RMSprop is not supported!") raise NotImplementedError("Centered RMSprop is not supported!")
super(RMSprop32bit, self).__init__( super().__init__(
"rmsprop", "rmsprop",
params, params,
lr, lr,
......
...@@ -21,8 +21,8 @@ class SGD(Optimizer1State): ...@@ -21,8 +21,8 @@ class SGD(Optimizer1State):
block_wise=True, block_wise=True,
): ):
if momentum == 0: if momentum == 0:
raise NotImplementedError(f"SGD without momentum is not supported!") raise NotImplementedError("SGD without momentum is not supported!")
super(SGD, self).__init__( super().__init__(
"momentum", "momentum",
params, params,
lr, lr,
...@@ -52,8 +52,8 @@ class SGD8bit(Optimizer1State): ...@@ -52,8 +52,8 @@ class SGD8bit(Optimizer1State):
block_wise=True, block_wise=True,
): ):
if momentum == 0: if momentum == 0:
raise NotImplementedError(f"SGD without momentum is not supported!") raise NotImplementedError("SGD without momentum is not supported!")
super(SGD8bit, self).__init__( super().__init__(
"momentum", "momentum",
params, params,
lr, lr,
...@@ -83,8 +83,8 @@ class SGD32bit(Optimizer1State): ...@@ -83,8 +83,8 @@ class SGD32bit(Optimizer1State):
block_wise=True, block_wise=True,
): ):
if momentum == 0: if momentum == 0:
raise NotImplementedError(f"SGD without momentum is not supported!") raise NotImplementedError("SGD without momentum is not supported!")
super(SGD32bit, self).__init__( super().__init__(
"momentum", "momentum",
params, params,
lr, lr,
......
...@@ -121,5 +121,3 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T ...@@ -121,5 +121,3 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
#endif #endif
...@@ -290,4 +290,3 @@ extern "C" ...@@ -290,4 +290,3 @@ extern "C"
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); }
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); }
} }
...@@ -76,6 +76,3 @@ if [[ -n "$CUDA_VERSION" ]]; then ...@@ -76,6 +76,3 @@ if [[ -n "$CUDA_VERSION" ]]; then
else else
echo "" echo ""
fi fi
...@@ -26,9 +26,6 @@ setup( ...@@ -26,9 +26,6 @@ setup(
keywords="gpu optimizers optimization 8-bit quantization compression", keywords="gpu optimizers optimization 8-bit quantization compression",
url="https://github.com/TimDettmers/bitsandbytes", url="https://github.com/TimDettmers/bitsandbytes",
packages=find_packages(), packages=find_packages(),
entry_points={
"console_scripts": ["debug_cuda = bitsandbytes.debug_cli:cli"],
},
package_data={"": libs}, package_data={"": libs},
long_description=read("README.md"), long_description=read("README.md"),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
......
from itertools import product, permutations from itertools import permutations, product
import pytest import pytest
import torch import torch
...@@ -27,7 +27,7 @@ str_values = list( ...@@ -27,7 +27,7 @@ str_values = list(
) )
) )
names = [ names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}".format( "dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format(
*vals *vals
) )
for vals in str_values for vals in str_values
...@@ -286,7 +286,7 @@ str_values = list( ...@@ -286,7 +286,7 @@ str_values = list(
has_bias has_bias
) )
) )
names = ["dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}_has_bias_{10}".format(*vals) for vals in str_values] names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_decomp_{}_has_fp16_weights_{}_has_bias_{}".format(*vals) for vals in str_values]
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
import os import os
import pytest
import bitsandbytes as bnb
from typing import List, NamedTuple from typing import List, NamedTuple
import pytest
import bitsandbytes as bnb
from bitsandbytes.cuda_setup import ( from bitsandbytes.cuda_setup import (
CUDA_RUNTIME_LIB, CUDA_RUNTIME_LIB,
evaluate_cuda_setup,
determine_cuda_runtime_lib_path, determine_cuda_runtime_lib_path,
evaluate_cuda_setup,
extract_candidate_paths, extract_candidate_paths,
) )
......
...@@ -28,7 +28,7 @@ def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0): ...@@ -28,7 +28,7 @@ def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0):
class FFN(torch.nn.Module): class FFN(torch.nn.Module):
def __init__(self, input_features, hidden_size, bias=True): def __init__(self, input_features, hidden_size, bias=True):
super(FFN, self).__init__() super().__init__()
self.fc1 = torch.nn.Linear(input_features, hidden_size, bias=bias) self.fc1 = torch.nn.Linear(input_features, hidden_size, bias=bias)
self.fc2 = torch.nn.Linear(hidden_size, input_features, bias=bias) self.fc2 = torch.nn.Linear(hidden_size, input_features, bias=bias)
...@@ -42,7 +42,7 @@ class FFN(torch.nn.Module): ...@@ -42,7 +42,7 @@ class FFN(torch.nn.Module):
return x return x
class Timer(object): class Timer:
def __init__(self): def __init__(self):
self.starts = {} self.starts = {}
self.ends = {} self.ends = {}
...@@ -69,7 +69,7 @@ class Timer(object): ...@@ -69,7 +69,7 @@ class Timer(object):
self.ends.pop(name) self.ends.pop(name)
if print_ms and name in self.agg: if print_ms and name in self.agg:
print("{0} took: {1:.5f}s".format(name, self.agg[name] / 1000.0)) print(f"{name} took: {self.agg[name] / 1000.0:.5f}s")
return self.agg[name] return self.agg[name]
...@@ -302,7 +302,7 @@ batched = [False, True] ...@@ -302,7 +302,7 @@ batched = [False, True]
values = list(product(dim1, dim2, methods, batched)) values = list(product(dim1, dim2, methods, batched))
values_names = list(product(dim1, dim2, method_names, batched)) values_names = list(product(dim1, dim2, method_names, batched))
names = [ names = [
"dim1_{0}_dim2_{1}_quant_{2}_batched_{3}".format(*vals) "dim1_{}_dim2_{}_quant_{}_batched_{}".format(*vals)
for vals in values_names for vals in values_names
] ]
...@@ -360,7 +360,7 @@ seq_dim = torch.randint(16, 256, size=(n,)).tolist() ...@@ -360,7 +360,7 @@ seq_dim = torch.randint(16, 256, size=(n,)).tolist()
transpose = [(False, False), (False, True), (True, False), (True, True)] transpose = [(False, False), (False, True), (True, False), (True, True)]
values = list(product(hidden_dim, batch_dim, transpose, seq_dim)) values = list(product(hidden_dim, batch_dim, transpose, seq_dim))
names = [ names = [
"hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}".format(*vals) "hidden_dim_{}_batch_dim_{},transpose_{}_seq_dim_{}".format(*vals)
for vals in values for vals in values
] ]
...@@ -425,7 +425,7 @@ hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist() ...@@ -425,7 +425,7 @@ hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
batch_dim = torch.randint(2, 16, size=(n,)).tolist() batch_dim = torch.randint(2, 16, size=(n,)).tolist()
values = list(product(seq_dim, hidden_dim, batch_dim)) values = list(product(seq_dim, hidden_dim, batch_dim))
names = [ names = [
"seq_dim{0}_hidden_dim{1}_batch_dim{2}".format(*vals) for vals in values "seq_dim{}_hidden_dim{}_batch_dim{}".format(*vals) for vals in values
] ]
...@@ -457,7 +457,7 @@ batch_dim = torch.randint(2, 16, size=(n,)).tolist() ...@@ -457,7 +457,7 @@ batch_dim = torch.randint(2, 16, size=(n,)).tolist()
transpose = [False, True] transpose = [False, True]
values = list(product(seq_dim, hidden_dim, batch_dim, transpose)) values = list(product(seq_dim, hidden_dim, batch_dim, transpose))
names = [ names = [
"seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}".format(*vals) "seq_dim={}_hidden_dim={}_batch_dim={}_transpose{}".format(*vals)
for vals in values for vals in values
] ]
...@@ -542,7 +542,7 @@ dim4 = torch.randint(32, 256, size=(n,)).tolist() ...@@ -542,7 +542,7 @@ dim4 = torch.randint(32, 256, size=(n,)).tolist()
transpose = [(False, False), (True, False), (False, True), (True, True)] transpose = [(False, False), (True, False), (False, True), (True, True)]
values = list(product(dim1, dim2, dim3, dim4, transpose)) values = list(product(dim1, dim2, dim3, dim4, transpose))
names = [ names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}".format(*vals) "dim1_{}_dim2_{}_dim3_{}_dim4_{}_transpose_{}".format(*vals)
for vals in values for vals in values
] ]
...@@ -580,7 +580,7 @@ dim1 = torch.randint(1, 64, size=(n,)).tolist() ...@@ -580,7 +580,7 @@ dim1 = torch.randint(1, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 128, size=(n,)).tolist() dim2 = torch.randint(32, 128, size=(n,)).tolist()
dim3 = torch.randint(32, 256, size=(n,)).tolist() dim3 = torch.randint(32, 256, size=(n,)).tolist()
values = list(product(dim1, dim2, dim3)) values = list(product(dim1, dim2, dim3))
names = ["dim1_{0}_dim2_{1}_dim3_{2}".format(*vals) for vals in values] names = ["dim1_{}_dim2_{}_dim3_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names) @pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names)
...@@ -609,7 +609,7 @@ transpose = [False] ...@@ -609,7 +609,7 @@ transpose = [False]
dims = [2, 3] dims = [2, 3]
values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)) values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose))
names = ["dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}".format(*vals)for vals in values] names = ["dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_transpose_{}".format(*vals)for vals in values]
@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",values,ids=names) @pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",values,ids=names)
...@@ -691,7 +691,7 @@ ldb = [0] ...@@ -691,7 +691,7 @@ ldb = [0]
# ldb = list(range(256, 1*1024, 256)) # ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims, ldb)) values = list(product(dim1, dim2, dim3, dim4, dims, ldb))
names = [ names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}".format(*vals) "dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}_ldb_{}".format(*vals)
for vals in values for vals in values
] ]
...@@ -739,7 +739,7 @@ dims = (2,) ...@@ -739,7 +739,7 @@ dims = (2,)
# ldb = list(range(256, 1*1024, 256)) # ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims)) values = list(product(dim1, dim2, dim3, dim4, dims))
names = [ names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}".format(*vals) "dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}".format(*vals)
for vals in values for vals in values
] ]
...@@ -797,7 +797,7 @@ values = [ ...@@ -797,7 +797,7 @@ values = [
# values = list(product(batch, seq, model, hidden)) # values = list(product(batch, seq, model, hidden))
names = [ names = [
"batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values "batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values
] ]
...@@ -965,7 +965,7 @@ dims = (2,) ...@@ -965,7 +965,7 @@ dims = (2,)
formatB = ["col_turing", "col_ampere"] formatB = ["col_turing", "col_ampere"]
has_bias = [True, False] has_bias = [True, False]
values = list(product(dim1, dim4, dims, formatB, has_bias)) values = list(product(dim1, dim4, dims, formatB, has_bias))
names = ["dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}_has_bias_{4}".format(*vals) for vals in values] names = ["dim1_{}_dim4_{}_dims_{}_formatB_{}_has_bias_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, dims, formatB, has_bias", values, ids=names) @pytest.mark.parametrize("dim1, dim4, dims, formatB, has_bias", values, ids=names)
...@@ -1015,7 +1015,7 @@ dim2 = [1 * 1024] ...@@ -1015,7 +1015,7 @@ dim2 = [1 * 1024]
dims = (2,) dims = (2,)
# ldb = list(range(256, 1*1024, 256)) # ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dims)) values = list(product(dim1, dim2, dims))
names = ["dim1_{0}_dim2_{1}_dims_{2}".format(*vals) for vals in values] names = ["dim1_{}_dim2_{}_dims_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dims", values, ids=names) @pytest.mark.parametrize("dim1, dim2, dims", values, ids=names)
...@@ -1071,7 +1071,7 @@ dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist() ...@@ -1071,7 +1071,7 @@ dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist() dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
values = list(product(dim1, dim2)) values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values] names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names) @pytest.mark.parametrize("dim1, dim2", values, ids=names)
...@@ -1118,7 +1118,7 @@ dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist() ...@@ -1118,7 +1118,7 @@ dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist() inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
values = list(zip(dim1, dim4, inner)) values = list(zip(dim1, dim4, inner))
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values] names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) @pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
...@@ -1162,7 +1162,7 @@ dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist() ...@@ -1162,7 +1162,7 @@ dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist() inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
values = list(zip(dim1, dim4, inner)) values = list(zip(dim1, dim4, inner))
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values] names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) @pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
...@@ -1237,7 +1237,7 @@ inner = [12288 * 4, 4096 * 4] ...@@ -1237,7 +1237,7 @@ inner = [12288 * 4, 4096 * 4]
dim4 = [12288, 4096] dim4 = [12288, 4096]
values = list(zip(dim1, dim4, inner)) values = list(zip(dim1, dim4, inner))
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values] names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) @pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
...@@ -1303,7 +1303,7 @@ values = list( ...@@ -1303,7 +1303,7 @@ values = list(
product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose) product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)
) )
names = [ names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}".format( "dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_{}".format(
*vals *vals
) )
for vals in values for vals in values
...@@ -1354,7 +1354,7 @@ a_order = ["col_turing"] ...@@ -1354,7 +1354,7 @@ a_order = ["col_turing"]
out_order = ["row"] out_order = ["row"]
values = list(product(dim1, dim2, dtype, a_order, out_order)) values = list(product(dim1, dim2, dtype, a_order, out_order))
names = [ names = [
"dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}".format(*vals) "dim1_{}_dim2_{}_dtype_{}_orderA_{}_orderOut_{}".format(*vals)
for vals in values for vals in values
] ]
...@@ -1380,7 +1380,7 @@ dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist() ...@@ -1380,7 +1380,7 @@ dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
# dim2 = [5] # dim2 = [5]
values = list(product(dim1, dim2)) values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values] names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names) @pytest.mark.parametrize("dim1, dim2", values, ids=names)
...@@ -1417,7 +1417,7 @@ dim2 = torch.randint(1, 1 * 1024, size=(n,)).tolist() ...@@ -1417,7 +1417,7 @@ dim2 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
# dim2 = [11] # dim2 = [11]
transposed_B = [False, True] transposed_B = [False, True]
values = list(product(dim1, dim2, transposed_B)) values = list(product(dim1, dim2, transposed_B))
names = ["dim1_{0}_dim2_{1}_transposed_B_{2}".format(*vals) for vals in values] names = ["dim1_{}_dim2_{}_transposed_B_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names) @pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names)
...@@ -1498,7 +1498,7 @@ n = 2 ...@@ -1498,7 +1498,7 @@ n = 2
dim1 = torch.randint(256, 1 * 1024, size=(n,)).tolist() dim1 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
dim2 = torch.randint(256, 1 * 1024, size=(n,)).tolist() dim2 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
values = list(product(dim1, dim2)) values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values] names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names) @pytest.mark.parametrize("dim1, dim2", values, ids=names)
...@@ -1563,7 +1563,7 @@ dtype = [torch.float16] ...@@ -1563,7 +1563,7 @@ dtype = [torch.float16]
out_function = ["zeros", "ones"] out_function = ["zeros", "ones"]
values = list(product(dim1, dim2, dtype, out_function)) values = list(product(dim1, dim2, dtype, out_function))
names = [ names = [
"dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}".format(*vals) for vals in values "dim1_{}_dim2_{}_dtype_{}_out_func_{}".format(*vals) for vals in values
] ]
...@@ -1680,7 +1680,7 @@ dim2 = [2048] ...@@ -1680,7 +1680,7 @@ dim2 = [2048]
# dim2 = [2] # dim2 = [2]
dtype = [torch.int8] dtype = [torch.int8]
values = list(product(dim1, dim2, dtype)) values = list(product(dim1, dim2, dtype))
names = ["dim1_{0}_dim2_{1}_dtype_{2}".format(*vals) for vals in values] names = ["dim1_{}_dim2_{}_dtype_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names) @pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names)
...@@ -1796,7 +1796,7 @@ values.append((batch_size, seqdim, 768, 4 * 768)) ...@@ -1796,7 +1796,7 @@ values.append((batch_size, seqdim, 768, 4 * 768))
# values.append((batch_size, seqdim, 5140, 4*5140)) # values.append((batch_size, seqdim, 5140, 4*5140))
#values.append((batch_size, seqdim, 12288, 4*12288)) #values.append((batch_size, seqdim, 12288, 4*12288))
names = [ names = [
"batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values "batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values
] ]
......
...@@ -7,7 +7,7 @@ from torch import nn ...@@ -7,7 +7,7 @@ from torch import nn
import bitsandbytes as bnb import bitsandbytes as bnb
class MockArgs(object): class MockArgs:
def __init__(self, initial_data): def __init__(self, initial_data):
for key in initial_data: for key in initial_data:
setattr(self, key, initial_data[key]) setattr(self, key, initial_data[key])
...@@ -15,7 +15,7 @@ class MockArgs(object): ...@@ -15,7 +15,7 @@ class MockArgs(object):
class MLP8bit(torch.nn.Module): class MLP8bit(torch.nn.Module):
def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0): def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0):
super(MLP8bit, self).__init__() super().__init__()
self.fc1 = bnb.nn.Linear8bitLt( self.fc1 = bnb.nn.Linear8bitLt(
dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward, dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
threshold=threshold threshold=threshold
...@@ -289,7 +289,7 @@ class LinearFunction(torch.autograd.Function): ...@@ -289,7 +289,7 @@ class LinearFunction(torch.autograd.Function):
class Linear8bit(nn.Module): class Linear8bit(nn.Module):
def __init__(self, input_features, output_features, bias=True, args=None): def __init__(self, input_features, output_features, bias=True, args=None):
super(Linear8bit, self).__init__() super().__init__()
self.input_features = input_features self.input_features = input_features
self.output_features = output_features self.output_features = output_features
self.args = args self.args = args
...@@ -312,7 +312,7 @@ class Linear8bit(nn.Module): ...@@ -312,7 +312,7 @@ class Linear8bit(nn.Module):
threshold = [0.0, 3.0] threshold = [0.0, 3.0]
values = threshold values = threshold
names = ["threshold_{0}".format(vals) for vals in values] names = [f"threshold_{vals}" for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names) @pytest.mark.parametrize("threshold", values, ids=names)
...@@ -378,7 +378,7 @@ def test_linear8bitlt_accumulated_gradient(): ...@@ -378,7 +378,7 @@ def test_linear8bitlt_accumulated_gradient():
threshold = [0.0, 2.0] threshold = [0.0, 2.0]
values = threshold values = threshold
names = ["threshold_{0}".format(vals) for vals in values] names = [f"threshold_{vals}" for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names) @pytest.mark.parametrize("threshold", values, ids=names)
......
...@@ -18,7 +18,7 @@ k = 20 ...@@ -18,7 +18,7 @@ k = 20
def get_temp_dir(): def get_temp_dir():
path = "/tmp/autoswap/{0}".format(str(uuid.uuid4())) path = f"/tmp/autoswap/{str(uuid.uuid4())}"
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
return path return path
...@@ -116,7 +116,7 @@ gtype = [torch.float32, torch.float16] ...@@ -116,7 +116,7 @@ gtype = [torch.float32, torch.float16]
optimizer_names = ["adam", "momentum", "rmsprop", "lars"] optimizer_names = ["adam", "momentum", "rmsprop", "lars"]
values = list(product(dim1, dim2, gtype, optimizer_names)) values = list(product(dim1, dim2, gtype, optimizer_names))
names = [ names = [
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
] ]
...@@ -187,7 +187,7 @@ dim1 = [1024] ...@@ -187,7 +187,7 @@ dim1 = [1024]
dim2 = [32, 1024, 4097] dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16] gtype = [torch.float32, torch.float16]
values = list(product(dim1, dim2, gtype)) values = list(product(dim1, dim2, gtype))
names = ["dim1_{0}_dim2_{1}_gtype_{2}".format(*vals) for vals in values] names = ["dim1_{}_dim2_{}_gtype_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names) @pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names)
...@@ -250,7 +250,7 @@ optimizer_names = [ ...@@ -250,7 +250,7 @@ optimizer_names = [
] ]
values = list(product(dim1, dim2, gtype, optimizer_names)) values = list(product(dim1, dim2, gtype, optimizer_names))
names = [ names = [
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
] ]
...@@ -391,7 +391,7 @@ gtype = [torch.float32] ...@@ -391,7 +391,7 @@ gtype = [torch.float32]
optim_bits = [32, 8] optim_bits = [32, 8]
values = list(product(dim1, dim2, gtype, optim_bits)) values = list(product(dim1, dim2, gtype, optim_bits))
names = [ names = [
"dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}".format(*vals) "dim1_{}_dim2_{}_gtype_{}_optim_bits_{}".format(*vals)
for vals in values for vals in values
] ]
...@@ -495,7 +495,7 @@ gtype = [torch.float32, torch.float16] ...@@ -495,7 +495,7 @@ gtype = [torch.float32, torch.float16]
optimizer_names = ["adam8bit_blockwise"] optimizer_names = ["adam8bit_blockwise"]
values = list(product(dim1, dim2, gtype, optimizer_names)) values = list(product(dim1, dim2, gtype, optimizer_names))
names = [ names = [
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment