Unverified Commit f0ec93d0 authored by Tim Dettmers's avatar Tim Dettmers Committed by GitHub
Browse files

Merge pull request #76 from tomaarsen/cleanup

Cleanup involving a handful of failures, some optimization and a lot of code quality improvements
parents c059bd28 c91f592a
......@@ -12,13 +12,13 @@ import torch
import bitsandbytes.functional as F
class MockArgs(object):
class MockArgs:
def __init__(self, initial_data):
for key in initial_data:
setattr(self, key, initial_data[key])
class GlobalOptimManager(object):
class GlobalOptimManager:
_instance = None
def __init__(self):
......@@ -56,9 +56,9 @@ class GlobalOptimManager(object):
"""
Overrides initial optimizer config for specific parameters.
The key-values of the optimizer config for the input parameters are overidden
The key-values of the optimizer config for the input parameters are overridden
This can be both, optimizer parameters like "betas", or "lr" or it can be
8-bit specific paramters like "optim_bits", "percentile_clipping".
8-bit specific parameters like "optim_bits", "percentile_clipping".
Parameters
----------
......@@ -93,13 +93,12 @@ class GlobalOptimManager(object):
class Optimizer8bit(torch.optim.Optimizer):
def __init__(self, params, defaults, optim_bits=32):
super(Optimizer8bit, self).__init__(params, defaults)
super().__init__(params, defaults)
self.initialized = False
self.name2qmap = {}
self.mng = GlobalOptimManager.get_instance()
self.non_castable_tensor_keys = set(
[
self.non_castable_tensor_keys = {
"qmap1",
"qmap2",
"max1",
......@@ -112,8 +111,7 @@ class Optimizer8bit(torch.optim.Optimizer):
"absmax1",
"absmax2",
"unorm_vec",
]
)
}
if optim_bits == 8:
self.fill_qmap()
......@@ -123,7 +121,7 @@ class Optimizer8bit(torch.optim.Optimizer):
self.name2qmap["udynamic"] = F.create_dynamic_map(signed=False)
def __setstate__(self, state):
super(Optimizer8bit, self).__setstate__(state)
super().__setstate__(state)
def load_state_dict(self, state_dict):
r"""Loads the optimizer state.
......@@ -155,8 +153,8 @@ class Optimizer8bit(torch.optim.Optimizer):
id_map = {
old_id: p
for old_id, p in zip(
chain.from_iterable((g["params"] for g in saved_groups)),
chain.from_iterable((g["params"] for g in groups)),
chain.from_iterable(g["params"] for g in saved_groups),
chain.from_iterable(g["params"] for g in groups),
)
}
......@@ -284,11 +282,11 @@ class Optimizer8bit(torch.optim.Optimizer):
return config
def init_state(self, group, p, gindex, pindex):
raise NotImplementedError(f"init_state method needs to be overidden")
raise NotImplementedError("init_state method needs to be overridden")
def update_step(self, group, p, gindex, pindex):
raise NotImplementedError(
f"The update_step method needs to be overidden"
"The update_step method needs to be overridden"
)
......@@ -310,9 +308,9 @@ class Optimizer2State(Optimizer8bit):
skip_zeros=False,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
raise ValueError(f"Invalid epsilon value: {eps}")
if isinstance(betas, str):
# format: '(beta1, beta2)'
betas = betas.replace("(", "").replace(")", "").strip().split(",")
......@@ -324,10 +322,10 @@ class Optimizer2State(Optimizer8bit):
)
if not 0.0 <= weight_decay:
raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay)
f"Invalid weight_decay value: {weight_decay}"
)
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(Optimizer2State, self).__init__(params, defaults, optim_bits)
super().__init__(params, defaults, optim_bits)
if args is None:
args = {}
......@@ -542,9 +540,9 @@ class Optimizer1State(Optimizer8bit):
skip_zeros=False,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
raise ValueError(f"Invalid epsilon value: {eps}")
for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0:
raise ValueError(
......@@ -552,10 +550,10 @@ class Optimizer1State(Optimizer8bit):
)
if not 0.0 <= weight_decay:
raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay)
f"Invalid weight_decay value: {weight_decay}"
)
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(Optimizer1State, self).__init__(params, defaults, optim_bits)
super().__init__(params, defaults, optim_bits)
if args is None:
args = {}
......
......@@ -23,11 +23,11 @@ class RMSprop(Optimizer1State):
):
if alpha == 0:
raise NotImplementedError(
f"RMSprop with alpha==0.0 is not supported!"
"RMSprop with alpha==0.0 is not supported!"
)
if centered:
raise NotImplementedError(f"Centered RMSprop is not supported!")
super(RMSprop, self).__init__(
raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__(
"rmsprop",
params,
lr,
......@@ -59,11 +59,11 @@ class RMSprop8bit(Optimizer1State):
):
if alpha == 0:
raise NotImplementedError(
f"RMSprop with alpha==0.0 is not supported!"
"RMSprop with alpha==0.0 is not supported!"
)
if centered:
raise NotImplementedError(f"Centered RMSprop is not supported!")
super(RMSprop8bit, self).__init__(
raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__(
"rmsprop",
params,
lr,
......@@ -96,11 +96,11 @@ class RMSprop32bit(Optimizer1State):
if alpha == 0:
raise NotImplementedError(
f"RMSprop with alpha==0.0 is not supported!"
"RMSprop with alpha==0.0 is not supported!"
)
if centered:
raise NotImplementedError(f"Centered RMSprop is not supported!")
super(RMSprop32bit, self).__init__(
raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__(
"rmsprop",
params,
lr,
......
......@@ -21,8 +21,8 @@ class SGD(Optimizer1State):
block_wise=True,
):
if momentum == 0:
raise NotImplementedError(f"SGD without momentum is not supported!")
super(SGD, self).__init__(
raise NotImplementedError("SGD without momentum is not supported!")
super().__init__(
"momentum",
params,
lr,
......@@ -52,8 +52,8 @@ class SGD8bit(Optimizer1State):
block_wise=True,
):
if momentum == 0:
raise NotImplementedError(f"SGD without momentum is not supported!")
super(SGD8bit, self).__init__(
raise NotImplementedError("SGD without momentum is not supported!")
super().__init__(
"momentum",
params,
lr,
......@@ -83,8 +83,8 @@ class SGD32bit(Optimizer1State):
block_wise=True,
):
if momentum == 0:
raise NotImplementedError(f"SGD without momentum is not supported!")
super(SGD32bit, self).__init__(
raise NotImplementedError("SGD without momentum is not supported!")
super().__init__(
"momentum",
params,
lr,
......
......@@ -121,5 +121,3 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
#endif
......@@ -290,4 +290,3 @@ extern "C"
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); }
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); }
}
......@@ -76,6 +76,3 @@ if [[ -n "$CUDA_VERSION" ]]; then
else
echo ""
fi
......@@ -26,9 +26,6 @@ setup(
keywords="gpu optimizers optimization 8-bit quantization compression",
url="https://github.com/TimDettmers/bitsandbytes",
packages=find_packages(),
entry_points={
"console_scripts": ["debug_cuda = bitsandbytes.debug_cli:cli"],
},
package_data={"": libs},
long_description=read("README.md"),
long_description_content_type="text/markdown",
......
from itertools import product, permutations
from itertools import permutations, product
import pytest
import torch
......@@ -27,7 +27,7 @@ str_values = list(
)
)
names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}".format(
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format(
*vals
)
for vals in str_values
......@@ -286,7 +286,7 @@ str_values = list(
has_bias
)
)
names = ["dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}_has_bias_{10}".format(*vals) for vals in str_values]
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_decomp_{}_has_fp16_weights_{}_has_bias_{}".format(*vals) for vals in str_values]
@pytest.mark.parametrize(
......
import os
import pytest
import bitsandbytes as bnb
from typing import List, NamedTuple
import pytest
import bitsandbytes as bnb
from bitsandbytes.cuda_setup import (
CUDA_RUNTIME_LIB,
evaluate_cuda_setup,
determine_cuda_runtime_lib_path,
evaluate_cuda_setup,
extract_candidate_paths,
)
......
......@@ -28,7 +28,7 @@ def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0):
class FFN(torch.nn.Module):
def __init__(self, input_features, hidden_size, bias=True):
super(FFN, self).__init__()
super().__init__()
self.fc1 = torch.nn.Linear(input_features, hidden_size, bias=bias)
self.fc2 = torch.nn.Linear(hidden_size, input_features, bias=bias)
......@@ -42,7 +42,7 @@ class FFN(torch.nn.Module):
return x
class Timer(object):
class Timer:
def __init__(self):
self.starts = {}
self.ends = {}
......@@ -69,7 +69,7 @@ class Timer(object):
self.ends.pop(name)
if print_ms and name in self.agg:
print("{0} took: {1:.5f}s".format(name, self.agg[name] / 1000.0))
print(f"{name} took: {self.agg[name] / 1000.0:.5f}s")
return self.agg[name]
......@@ -302,7 +302,7 @@ batched = [False, True]
values = list(product(dim1, dim2, methods, batched))
values_names = list(product(dim1, dim2, method_names, batched))
names = [
"dim1_{0}_dim2_{1}_quant_{2}_batched_{3}".format(*vals)
"dim1_{}_dim2_{}_quant_{}_batched_{}".format(*vals)
for vals in values_names
]
......@@ -360,7 +360,7 @@ seq_dim = torch.randint(16, 256, size=(n,)).tolist()
transpose = [(False, False), (False, True), (True, False), (True, True)]
values = list(product(hidden_dim, batch_dim, transpose, seq_dim))
names = [
"hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}".format(*vals)
"hidden_dim_{}_batch_dim_{},transpose_{}_seq_dim_{}".format(*vals)
for vals in values
]
......@@ -425,7 +425,7 @@ hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
batch_dim = torch.randint(2, 16, size=(n,)).tolist()
values = list(product(seq_dim, hidden_dim, batch_dim))
names = [
"seq_dim{0}_hidden_dim{1}_batch_dim{2}".format(*vals) for vals in values
"seq_dim{}_hidden_dim{}_batch_dim{}".format(*vals) for vals in values
]
......@@ -457,7 +457,7 @@ batch_dim = torch.randint(2, 16, size=(n,)).tolist()
transpose = [False, True]
values = list(product(seq_dim, hidden_dim, batch_dim, transpose))
names = [
"seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}".format(*vals)
"seq_dim={}_hidden_dim={}_batch_dim={}_transpose{}".format(*vals)
for vals in values
]
......@@ -542,7 +542,7 @@ dim4 = torch.randint(32, 256, size=(n,)).tolist()
transpose = [(False, False), (True, False), (False, True), (True, True)]
values = list(product(dim1, dim2, dim3, dim4, transpose))
names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}".format(*vals)
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_transpose_{}".format(*vals)
for vals in values
]
......@@ -580,7 +580,7 @@ dim1 = torch.randint(1, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 128, size=(n,)).tolist()
dim3 = torch.randint(32, 256, size=(n,)).tolist()
values = list(product(dim1, dim2, dim3))
names = ["dim1_{0}_dim2_{1}_dim3_{2}".format(*vals) for vals in values]
names = ["dim1_{}_dim2_{}_dim3_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names)
......@@ -609,7 +609,7 @@ transpose = [False]
dims = [2, 3]
values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose))
names = ["dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}".format(*vals)for vals in values]
names = ["dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_transpose_{}".format(*vals)for vals in values]
@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",values,ids=names)
......@@ -691,7 +691,7 @@ ldb = [0]
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims, ldb))
names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}".format(*vals)
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}_ldb_{}".format(*vals)
for vals in values
]
......@@ -739,7 +739,7 @@ dims = (2,)
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims))
names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}".format(*vals)
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}".format(*vals)
for vals in values
]
......@@ -797,7 +797,7 @@ values = [
# values = list(product(batch, seq, model, hidden))
names = [
"batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values
"batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values
]
......@@ -965,7 +965,7 @@ dims = (2,)
formatB = ["col_turing", "col_ampere"]
has_bias = [True, False]
values = list(product(dim1, dim4, dims, formatB, has_bias))
names = ["dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}_has_bias_{4}".format(*vals) for vals in values]
names = ["dim1_{}_dim4_{}_dims_{}_formatB_{}_has_bias_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, dims, formatB, has_bias", values, ids=names)
......@@ -1015,7 +1015,7 @@ dim2 = [1 * 1024]
dims = (2,)
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dims))
names = ["dim1_{0}_dim2_{1}_dims_{2}".format(*vals) for vals in values]
names = ["dim1_{}_dim2_{}_dims_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dims", values, ids=names)
......@@ -1071,7 +1071,7 @@ dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]
names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
......@@ -1118,7 +1118,7 @@ dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
values = list(zip(dim1, dim4, inner))
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]
names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
......@@ -1162,7 +1162,7 @@ dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
values = list(zip(dim1, dim4, inner))
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]
names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
......@@ -1237,7 +1237,7 @@ inner = [12288 * 4, 4096 * 4]
dim4 = [12288, 4096]
values = list(zip(dim1, dim4, inner))
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]
names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
......@@ -1303,7 +1303,7 @@ values = list(
product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)
)
names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}".format(
"dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_{}".format(
*vals
)
for vals in values
......@@ -1354,7 +1354,7 @@ a_order = ["col_turing"]
out_order = ["row"]
values = list(product(dim1, dim2, dtype, a_order, out_order))
names = [
"dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}".format(*vals)
"dim1_{}_dim2_{}_dtype_{}_orderA_{}_orderOut_{}".format(*vals)
for vals in values
]
......@@ -1380,7 +1380,7 @@ dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
# dim2 = [5]
values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]
names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
......@@ -1417,7 +1417,7 @@ dim2 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
# dim2 = [11]
transposed_B = [False, True]
values = list(product(dim1, dim2, transposed_B))
names = ["dim1_{0}_dim2_{1}_transposed_B_{2}".format(*vals) for vals in values]
names = ["dim1_{}_dim2_{}_transposed_B_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names)
......@@ -1498,7 +1498,7 @@ n = 2
dim1 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
dim2 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]
names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
......@@ -1563,7 +1563,7 @@ dtype = [torch.float16]
out_function = ["zeros", "ones"]
values = list(product(dim1, dim2, dtype, out_function))
names = [
"dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}".format(*vals) for vals in values
"dim1_{}_dim2_{}_dtype_{}_out_func_{}".format(*vals) for vals in values
]
......@@ -1680,7 +1680,7 @@ dim2 = [2048]
# dim2 = [2]
dtype = [torch.int8]
values = list(product(dim1, dim2, dtype))
names = ["dim1_{0}_dim2_{1}_dtype_{2}".format(*vals) for vals in values]
names = ["dim1_{}_dim2_{}_dtype_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names)
......@@ -1796,7 +1796,7 @@ values.append((batch_size, seqdim, 768, 4 * 768))
# values.append((batch_size, seqdim, 5140, 4*5140))
#values.append((batch_size, seqdim, 12288, 4*12288))
names = [
"batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values
"batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values
]
......
......@@ -7,7 +7,7 @@ from torch import nn
import bitsandbytes as bnb
class MockArgs(object):
class MockArgs:
def __init__(self, initial_data):
for key in initial_data:
setattr(self, key, initial_data[key])
......@@ -15,7 +15,7 @@ class MockArgs(object):
class MLP8bit(torch.nn.Module):
def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0):
super(MLP8bit, self).__init__()
super().__init__()
self.fc1 = bnb.nn.Linear8bitLt(
dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
threshold=threshold
......@@ -289,7 +289,7 @@ class LinearFunction(torch.autograd.Function):
class Linear8bit(nn.Module):
def __init__(self, input_features, output_features, bias=True, args=None):
super(Linear8bit, self).__init__()
super().__init__()
self.input_features = input_features
self.output_features = output_features
self.args = args
......@@ -312,7 +312,7 @@ class Linear8bit(nn.Module):
threshold = [0.0, 3.0]
values = threshold
names = ["threshold_{0}".format(vals) for vals in values]
names = [f"threshold_{vals}" for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names)
......@@ -378,7 +378,7 @@ def test_linear8bitlt_accumulated_gradient():
threshold = [0.0, 2.0]
values = threshold
names = ["threshold_{0}".format(vals) for vals in values]
names = [f"threshold_{vals}" for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names)
......
......@@ -18,7 +18,7 @@ k = 20
def get_temp_dir():
path = "/tmp/autoswap/{0}".format(str(uuid.uuid4()))
path = f"/tmp/autoswap/{str(uuid.uuid4())}"
os.makedirs(path, exist_ok=True)
return path
......@@ -116,7 +116,7 @@ gtype = [torch.float32, torch.float16]
optimizer_names = ["adam", "momentum", "rmsprop", "lars"]
values = list(product(dim1, dim2, gtype, optimizer_names))
names = [
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
]
......@@ -187,7 +187,7 @@ dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16]
values = list(product(dim1, dim2, gtype))
names = ["dim1_{0}_dim2_{1}_gtype_{2}".format(*vals) for vals in values]
names = ["dim1_{}_dim2_{}_gtype_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names)
......@@ -250,7 +250,7 @@ optimizer_names = [
]
values = list(product(dim1, dim2, gtype, optimizer_names))
names = [
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
]
......@@ -391,7 +391,7 @@ gtype = [torch.float32]
optim_bits = [32, 8]
values = list(product(dim1, dim2, gtype, optim_bits))
names = [
"dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}".format(*vals)
"dim1_{}_dim2_{}_gtype_{}_optim_bits_{}".format(*vals)
for vals in values
]
......@@ -495,7 +495,7 @@ gtype = [torch.float32, torch.float16]
optimizer_names = ["adam8bit_blockwise"]
values = list(product(dim1, dim2, gtype, optimizer_names))
names = [
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment