Unverified Commit 8b6fe9ee authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

Test cleanup (#1576)

* Testing cleanup

* More test cleanup

* Additional deprecations/removals.

* Skip benchmark, deprecated, slow tests by default
parent 677ff400
"""
Extracted from tests/test_optim.py
Usage: pytest benchmarking/optimizer_benchmark.py
"""
import time
import pytest
from tests.helpers import describe_dtype, id_formatter
import torch
import bitsandbytes as bnb
str2optimizers = {"paged_adamw": (torch.optim.AdamW, bnb.optim.PagedAdamW)}
@pytest.mark.parametrize("dim1", [2 * 1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("gtype", [torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("optim_name", ["paged_adamw"], ids=id_formatter("optim_name"))
@pytest.mark.parametrize("mode", ["bnb"], ids=id_formatter("mode"))
@pytest.mark.benchmark
def test_stream_optimizer_bench(dim1, gtype, optim_name, mode):
layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)]))
layers1 = layers1.to(gtype)
layers1 = layers1.cuda()
large_tensor = None
if mode == "torch":
optim = str2optimizers[optim_name][0](layers1.parameters())
else:
optim = str2optimizers[optim_name][1](layers1.parameters())
# 12 GB
large_tensor = torch.empty((int(4.5e9),), device="cuda")
torch.cuda.synchronize()
time.sleep(5)
num_batches = 5
batches = torch.randn(num_batches, 128, dim1, device="cuda").to(gtype)
lbls = torch.randint(0, 10, size=(num_batches, 128)).cuda()
for i in range(num_batches):
print(i)
b = batches[i]
if i == 2:
torch.cuda.synchronize()
t0 = time.time()
out1 = layers1(b)
loss1 = torch.nn.functional.cross_entropy(out1, lbls[i]).mean()
loss1.backward()
optim.step()
torch.cuda.synchronize()
print(mode, time.time() - t0)
...@@ -252,12 +252,6 @@ def fill(A, value, device=None, prefetch=True): ...@@ -252,12 +252,6 @@ def fill(A, value, device=None, prefetch=True):
elementwise_func("fill", A, None, value) elementwise_func("fill", A, None, value)
@deprecated("Function will be removed in a future release.", category=FutureWarning)
def arange(A, device=None):
elementwise_func("arange", A, None, 0)
@deprecated("Function will be removed in a future release.", category=FutureWarning)
def _mul(A, B, device=None): def _mul(A, B, device=None):
elementwise_func("_mul", A, B, 0) elementwise_func("_mul", A, B, 0)
...@@ -408,6 +402,7 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): ...@@ -408,6 +402,7 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
return torch.tensor(data, dtype=torch.float32) return torch.tensor(data, dtype=torch.float32)
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def create_quantile_map(A, total_bits=8): def create_quantile_map(A, total_bits=8):
q = estimate_quantiles(A, num_quantiles=2**total_bits - 1) q = estimate_quantiles(A, num_quantiles=2**total_bits - 1)
q = q.tolist() q = q.tolist()
...@@ -481,17 +476,6 @@ def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]: ...@@ -481,17 +476,6 @@ def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]:
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) @deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def pre_call(device):
prev_device = torch.cuda.current_device()
torch.cuda.set_device(device)
return prev_device
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def post_call(prev_device):
torch.cuda.set_device(prev_device)
def estimate_quantiles( def estimate_quantiles(
A: Tensor, A: Tensor,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
...@@ -540,15 +524,16 @@ def estimate_quantiles( ...@@ -540,15 +524,16 @@ def estimate_quantiles(
if out is None: if out is None:
out = torch.zeros((256,), dtype=torch.float32, device=A.device) out = torch.zeros((256,), dtype=torch.float32, device=A.device)
with _cuda_device_of(A):
is_on_gpu([A, out]) is_on_gpu([A, out])
device = pre_call(A.device)
if A.dtype == torch.float32: if A.dtype == torch.float32:
lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
elif A.dtype == torch.float16: elif A.dtype == torch.float16:
lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
else: else:
raise NotImplementedError(f"Not supported data type {A.dtype}") raise NotImplementedError(f"Not supported data type {A.dtype}")
post_call(device)
if num_quantiles < 256: if num_quantiles < 256:
step = round(256 / num_quantiles) step = round(256 / num_quantiles)
...@@ -1220,12 +1205,12 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = No ...@@ -1220,12 +1205,12 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = No
torch.Tensor: torch.Tensor:
Quantized 8-bit tensor. Quantized 8-bit tensor.
""" """
prev_device = pre_call(A.device) with _cuda_device_of(A):
if out is None: if out is None:
out = torch.zeros_like(A, dtype=torch.uint8) out = torch.zeros_like(A, dtype=torch.uint8)
is_on_gpu([A, out]) is_on_gpu([A, out])
lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
post_call(prev_device)
return out return out
...@@ -1251,13 +1236,13 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = ...@@ -1251,13 +1236,13 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] =
torch.Tensor: torch.Tensor:
32-bit output tensor. 32-bit output tensor.
""" """
prev_device = pre_call(A.device) with _cuda_device_of(A):
if out is None: if out is None:
out = torch.zeros_like(A, dtype=torch.float32) out = torch.zeros_like(A, dtype=torch.float32)
is_on_gpu([code, A, out]) is_on_gpu([code, A, out])
stream = _get_tensor_stream(A) stream = _get_tensor_stream(A)
lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()), stream) lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()), stream)
post_call(prev_device)
return out return out
...@@ -1445,7 +1430,7 @@ def optimizer_update_8bit( ...@@ -1445,7 +1430,7 @@ def optimizer_update_8bit(
if max_unorm > 0.0: if max_unorm > 0.0:
param_norm = torch.norm(p.data.float()) param_norm = torch.norm(p.data.float())
prev_device = pre_call(g.device) with _cuda_device_of(g):
is_on_gpu([g, p, state1, state2, unorm_vec, qmap1, qmap2, max1, max2, new_max1, new_max2]) is_on_gpu([g, p, state1, state2, unorm_vec, qmap1, qmap2, max1, max2, new_max1, new_max2])
if g.dtype == torch.float32 and state1.dtype == torch.uint8: if g.dtype == torch.float32 and state1.dtype == torch.uint8:
str2optimizer8bit[optimizer_name][0]( str2optimizer8bit[optimizer_name][0](
...@@ -1499,7 +1484,6 @@ def optimizer_update_8bit( ...@@ -1499,7 +1484,6 @@ def optimizer_update_8bit(
raise ValueError( raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}",
) )
post_call(prev_device)
def optimizer_update_8bit_blockwise( def optimizer_update_8bit_blockwise(
...@@ -1578,7 +1562,7 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: ...@@ -1578,7 +1562,7 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile:
The current optimization steps (number of past gradient norms). The current optimization steps (number of past gradient norms).
""" """
prev_device = pre_call(grad.device) with _cuda_device_of(grad):
is_on_gpu([grad, gnorm_vec]) is_on_gpu([grad, gnorm_vec])
if grad.dtype == torch.float32: if grad.dtype == torch.float32:
lib.cpercentile_clipping_g32( lib.cpercentile_clipping_g32(
...@@ -1596,7 +1580,6 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: ...@@ -1596,7 +1580,6 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile:
) )
else: else:
raise ValueError(f"Gradient type {grad.dtype} not supported!") raise ValueError(f"Gradient type {grad.dtype} not supported!")
post_call(prev_device)
current_gnorm = torch.sqrt(gnorm_vec[step % 100]) current_gnorm = torch.sqrt(gnorm_vec[step % 100])
vals, idx = torch.sort(gnorm_vec) vals, idx = torch.sort(gnorm_vec)
...@@ -2334,7 +2317,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): ...@@ -2334,7 +2317,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
if out is None: if out is None:
out = torch.zeros((cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype) out = torch.zeros((cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype)
nnz = cooA.nnz nnz = cooA.nnz
prev_device = pre_call(B.device)
assert cooA.rowidx.numel() == nnz assert cooA.rowidx.numel() == nnz
assert cooA.colidx.numel() == nnz assert cooA.colidx.numel() == nnz
assert cooA.values.numel() == nnz assert cooA.values.numel() == nnz
...@@ -2371,6 +2354,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): ...@@ -2371,6 +2354,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
cldb = ct.c_int32(ldb) cldb = ct.c_int32(ldb)
cldc = ct.c_int32(ldc) cldc = ct.c_int32(ldc)
with _cuda_device_of(B):
is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats]) is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats])
if B.dtype == torch.float16: if B.dtype == torch.float16:
lib.cspmm_coo_very_sparse_naive_fp16( lib.cspmm_coo_very_sparse_naive_fp16(
...@@ -2407,7 +2391,6 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): ...@@ -2407,7 +2391,6 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
ccolsB, ccolsB,
) )
# else: assertion error # else: assertion error
post_call(prev_device)
return out return out
...@@ -2464,18 +2447,6 @@ def vectorwise_quant(x, dim=1, quant_type="vector"): ...@@ -2464,18 +2447,6 @@ def vectorwise_quant(x, dim=1, quant_type="vector"):
return None return None
@deprecated(
"This function is deprecated and will be removed in a future release.",
category=FutureWarning,
)
def vectorwise_dequant(xq, max1, quant_type="vector"):
if quant_type == "vector":
x = (xq / C * max1).to(torch.float32)
return x
else:
return None
@deprecated( @deprecated(
"This function is deprecated and will be removed in a future release.", "This function is deprecated and will be removed in a future release.",
category=FutureWarning, category=FutureWarning,
......
...@@ -3,13 +3,7 @@ ...@@ -3,13 +3,7 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import math
import os
import torch
import torch.distributed as dist
import bitsandbytes.functional as F
from bitsandbytes.optim.optimizer import Optimizer2State from bitsandbytes.optim.optimizer import Optimizer2State
...@@ -377,204 +371,3 @@ class PagedAdam32bit(Optimizer2State): ...@@ -377,204 +371,3 @@ class PagedAdam32bit(Optimizer2State):
block_wise, block_wise,
is_paged=True, is_paged=True,
) )
class AnalysisAdam(torch.optim.Optimizer):
"""Adam that performs 8-bit vs 32-bit error analysis.
This implementation is modified from torch.optim.Adam based on:
`Fixed Weight Decay Regularization in Adam`
(see https://arxiv.org/abs/1711.05101)
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
.. _Adam: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
bnb_analysis="dynamic-blockwise",
savedir=None,
):
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad,
)
super().__init__(params, defaults)
self.analysis = bnb_analysis
self.savedir = savedir
@property
def supports_memory_efficient_fp16(self):
return True
@property
def supports_flat_params(self):
return True
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p_id, p in enumerate(group["params"]):
if p.grad is None:
continue
grad = p.grad.data
if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float()
if grad.is_sparse:
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
amsgrad = group.get("amsgrad", False)
assert not amsgrad
p_data_fp32 = p.data
if p.data.dtype in {torch.float16, torch.bfloat16}:
p_data_fp32 = p_data_fp32.float()
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p_data_fp32)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
state["abserrors"] = torch.zeros((256, 256), device=p_data_fp32.device)
state["relerrors"] = torch.zeros((256, 256), device=p_data_fp32.device)
state["counts"] = torch.zeros((256, 256), device=p_data_fp32.device)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32)
else:
state["exp_avg"] = state["exp_avg"].to(p_data_fp32)
state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32)
if amsgrad:
state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to(p_data_fp32)
state["step"] += 1
beta1, beta2 = group["betas"]
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
e = state["abserrors"]
rele = state["relerrors"]
counts = state["counts"]
if group["weight_decay"] != 0:
p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"])
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
if amsgrad:
max_exp_avg_sq = state["max_exp_avg_sq"]
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = exp_avg_sq.sqrt().add_(group["eps"])
update_fp32 = exp_avg / denom
if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000 * 1000:
# embedding layer or too small
p_data_fp32 += -step_size * update_fp32
else:
if self.analysis == "dynamic-blockwise":
code1 = F.create_dynamic_map(signed=True).to(p.device)
code2 = F.create_dynamic_map(signed=False).to(p.device)
C1, S1 = F.quantize_blockwise(exp_avg, code=code1)
state1 = F.dequantize_blockwise(C1, S1)
C2, S2 = F.quantize_blockwise(exp_avg_sq, code=code2)
state2 = F.dequantize_blockwise(C2, S2)
elif self.analysis == "dynamic":
code1 = F.create_dynamic_map(signed=True).to(p.device)
code2 = F.create_dynamic_map(signed=False).to(p.device)
C1, S1 = F.quantize(exp_avg, code=code1)
state1 = F.dequantize(C1, S1)
C2, S2 = F.quantize(exp_avg_sq, code=code2)
state2 = F.dequantize(C2, S2)
elif self.analysis == "linear":
code1 = F.create_linear_map(signed=True).to(p.device)
code2 = F.create_linear_map(signed=False).to(p.device)
C1, S1 = F.quantize(exp_avg, code=code1)
state1 = F.dequantize(C1, S1)
C2, S2 = F.quantize(exp_avg_sq, code=code2)
state2 = F.dequantize(C2, S2)
elif self.analysis == "quantile":
code1 = F.estimate_quantiles(exp_avg)
code2 = F.estimate_quantiles(exp_avg_sq)
C1 = F.quantize_no_absmax(exp_avg, code=code1)
state1 = F.dequantize_no_absmax(C1, code1)
C2 = F.quantize_no_absmax(exp_avg_sq, code=code2)
state2 = F.dequantize_no_absmax(C2, code2)
elif self.analysis == "my-quantization-routine":
pass
# 1. get code
# 2. quantize
# 3. dequantize
# Error will be calculated automatically!
else:
raise ValueError(f"Invalid analysis value: {self.analysis}!")
denom = state2.sqrt().add_(group["eps"])
update_8bit = state1 / denom
abserr = torch.abs(update_8bit - update_fp32)
relerr = abserr / torch.abs(update_fp32 + 1e-6)
C1, C2 = C1.int(), C2.int()
F.histogram_scatter_add_2d(e, C1.int(), C2.int(), abserr)
F.histogram_scatter_add_2d(rele, C1.int(), C2.int(), relerr)
F.histogram_scatter_add_2d(counts, C1.int(), C2.int(), torch.ones_like(abserr))
p_data_fp32 += -step_size * update_fp32
if not dist.is_initialized() or dist.get_rank() == 0:
if self.savedir != "" and state["step"] % 100 == 0:
if not os.path.exists(self.savedir):
os.makedirs(self.savedir)
shapestr = "_".join([str(dim) for dim in p_data_fp32.shape])
pathe = os.path.join(self.savedir, f"{p_id}_{shapestr}_abserr.pkl")
pathrele = os.path.join(self.savedir, f"{p_id}_{shapestr}_relerr.pkl")
pathcounts = os.path.join(self.savedir, f"{p_id}_{shapestr}_counts.pkl")
torch.save(e, pathe)
torch.save(rele, pathrele)
torch.save(counts, pathcounts)
if p.data.dtype in {torch.float16, torch.bfloat16}:
p.data.copy_(p_data_fp32)
return loss
...@@ -80,7 +80,7 @@ include = ["bitsandbytes*"] ...@@ -80,7 +80,7 @@ include = ["bitsandbytes*"]
version = {attr = "bitsandbytes.__version__"} version = {attr = "bitsandbytes.__version__"}
[tool.pytest.ini_options] [tool.pytest.ini_options]
addopts = "-rP" addopts = "-rP -m 'not slow and not benchmark and not deprecated'"
# ; --cov=bitsandbytes # ; --cov=bitsandbytes
# ; # contexts: record which test ran which line; can be seen in html coverage report # ; # contexts: record which test ran which line; can be seen in html coverage report
# ; --cov-context=test # ; --cov-context=test
......
import gc import gc
import random
import numpy as np
import pytest import pytest
import torch import torch
def _set_seed():
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.mps.manual_seed(0)
np.random.seed(0)
random.seed(0)
def pytest_runtest_call(item): def pytest_runtest_call(item):
try: try:
_set_seed()
item.runtest() item.runtest()
except AssertionError as ae: except AssertionError as ae:
if str(ae) == "Torch not compiled with CUDA enabled": if str(ae) == "Torch not compiled with CUDA enabled":
......
...@@ -6,7 +6,6 @@ from tests.helpers import ( ...@@ -6,7 +6,6 @@ from tests.helpers import (
BOOLEAN_TRIPLES, BOOLEAN_TRIPLES,
TRUE_FALSE, TRUE_FALSE,
describe_dtype, describe_dtype,
get_test_dims,
id_formatter, id_formatter,
) )
...@@ -136,10 +135,10 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec ...@@ -136,10 +135,10 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
torch.testing.assert_close(gradBias1, gradBias2) torch.testing.assert_close(gradBias1, gradBias2)
@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [48], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) @pytest.mark.parametrize("dim3", [64], ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4")) @pytest.mark.parametrize("dim4", [96], ids=id_formatter("dim4"))
@pytest.mark.parametrize("funcs", [(torch.matmul, bnb.matmul_4bit)], ids=["func=matmul"]) @pytest.mark.parametrize("funcs", [(torch.matmul, bnb.matmul_4bit)], ids=["func=matmul"])
@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad")) @pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose")) @pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
...@@ -231,85 +230,3 @@ def test_matmul_4bit( ...@@ -231,85 +230,3 @@ def test_matmul_4bit(
if req_grad[2]: if req_grad[2]:
torch.testing.assert_close(gradBias1, gradBias2) torch.testing.assert_close(gradBias1, gradBias2)
@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize(
"funcs",
[(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)],
ids=["matmul_fp8_mixed", "matmul_fp8_global"],
)
def test_matmul_fp8(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
req_grad = list(req_grad)
req_grad[2] = False
for i in range(3):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype)
torch.nn.init.xavier_uniform_(B)
fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(A.device)
bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(A.device)
if not transpose[0] and transpose[1]:
out_torch = funcs[0](A, B.t())
out_bnb = funcs[1](A, B.t(), fw_code, bw_code)
elif not transpose[0] and not transpose[1]:
out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B, fw_code, bw_code)
assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}"
n = out_bnb.numel()
err = torch.abs(out_bnb - out_torch).float().mean().item()
if n > 0:
assert err < 0.115
# assert err < 0.20
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None
if req_grad[0]:
torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[1]:
n = gradB1.numel()
if dim2 > 0:
assert torch.abs(gradB1).sum() > 0.0
assert torch.abs(gradB2).sum() > 0.0
else:
assert torch.abs(gradB1).sum() == 0.0
assert torch.abs(gradB2).sum() == 0.0
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() <= n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() <= n * 0.02
grad_err = (gradB1 - gradB2).abs().mean()
assert grad_err.item() < 0.003
torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3)
...@@ -3,7 +3,10 @@ import pytest ...@@ -3,7 +3,10 @@ import pytest
from scipy.stats import norm from scipy.stats import norm
import torch import torch
import bitsandbytes as bnb
from bitsandbytes import functional as F from bitsandbytes import functional as F
from tests.helpers import BOOLEAN_TRIPLES, describe_dtype, get_test_dims, id_formatter
from tests.test_autograd import TRANSPOSE_VALS
@pytest.mark.deprecated @pytest.mark.deprecated
...@@ -121,3 +124,87 @@ def test_percentile_clipping(gtype): ...@@ -121,3 +124,87 @@ def test_percentile_clipping(gtype):
torch.testing.assert_close(gnorm_vec1, torch.sqrt(gnorm_vec2)) torch.testing.assert_close(gnorm_vec1, torch.sqrt(gnorm_vec2))
torch.testing.assert_close(clip1, clip2) torch.testing.assert_close(clip1, clip2)
torch.testing.assert_close(gnorm1, gnorm2) torch.testing.assert_close(gnorm1, gnorm2)
@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize(
"funcs",
[(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)],
ids=["matmul_fp8_mixed", "matmul_fp8_global"],
)
@pytest.mark.deprecated
@pytest.mark.skip("Deprecated functionality, to be removed.")
def test_matmul_fp8(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
req_grad = list(req_grad)
req_grad[2] = False
for i in range(3):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype)
torch.nn.init.xavier_uniform_(B)
fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(A.device)
bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(A.device)
if not transpose[0] and transpose[1]:
out_torch = funcs[0](A, B.t())
out_bnb = funcs[1](A, B.t(), fw_code, bw_code)
elif not transpose[0] and not transpose[1]:
out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B, fw_code, bw_code)
assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}"
n = out_bnb.numel()
err = torch.abs(out_bnb - out_torch).float().mean().item()
if n > 0:
assert err < 0.115
# assert err < 0.20
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None
if req_grad[0]:
torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[1]:
n = gradB1.numel()
if dim2 > 0:
assert torch.abs(gradB1).sum() > 0.0
assert torch.abs(gradB2).sum() > 0.0
else:
assert torch.abs(gradB1).sum() == 0.0
assert torch.abs(gradB2).sum() == 0.0
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() <= n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() <= n * 0.02
grad_err = (gradB1 - gradB2).abs().mean()
assert grad_err.item() < 0.003
torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3)
...@@ -369,9 +369,9 @@ class TestIGEMMFunctional: ...@@ -369,9 +369,9 @@ class TestIGEMMFunctional:
# print(mean(errors)) # print(mean(errors))
# print(mean(relerrors)) # print(mean(relerrors))
@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 256, n=2), ids=id_formatter("hidden_dim")) @pytest.mark.parametrize("hidden_dim", [32, 256], ids=id_formatter("hidden_dim"))
@pytest.mark.parametrize("batch_dim", get_test_dims(16, 256, n=2), ids=id_formatter("batch_dim")) @pytest.mark.parametrize("batch_dim", [16, 256], ids=id_formatter("batch_dim"))
@pytest.mark.parametrize("seq_dim", get_test_dims(16, 256, n=2), ids=id_formatter("seq_dim")) @pytest.mark.parametrize("seq_dim", [16, 256], ids=id_formatter("seq_dim"))
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose")) @pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
def test_igemm(self, hidden_dim, batch_dim, transpose, seq_dim): def test_igemm(self, hidden_dim, batch_dim, transpose, seq_dim):
hidden_dim = hidden_dim - (hidden_dim % 32) hidden_dim = hidden_dim - (hidden_dim % 32)
...@@ -415,9 +415,9 @@ class TestIGEMMFunctional: ...@@ -415,9 +415,9 @@ class TestIGEMMFunctional:
torch.testing.assert_close(out.float(), out2) torch.testing.assert_close(out.float(), out2)
@pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=3), ids=id_formatter("seq_dim")) @pytest.mark.parametrize("seq_dim", [32, 256, 512], ids=id_formatter("seq_dim"))
@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=3), ids=id_formatter("hidden_dim")) @pytest.mark.parametrize("hidden_dim", [64, 1024, 4096], ids=id_formatter("hidden_dim"))
@pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=3), ids=id_formatter("batch_dim")) @pytest.mark.parametrize("batch_dim", [2, 8, 16], ids=id_formatter("batch_dim"))
def test_dim3_igemm(self, seq_dim, hidden_dim, batch_dim): def test_dim3_igemm(self, seq_dim, hidden_dim, batch_dim):
seq_dim = seq_dim - (seq_dim % 32) seq_dim = seq_dim - (seq_dim % 32)
hidden_dim = hidden_dim - (hidden_dim % 32) hidden_dim = hidden_dim - (hidden_dim % 32)
...@@ -431,9 +431,9 @@ class TestIGEMMFunctional: ...@@ -431,9 +431,9 @@ class TestIGEMMFunctional:
torch.testing.assert_close(out.float(), out2) torch.testing.assert_close(out.float(), out2)
@pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=2), ids=id_formatter("seq_dim")) @pytest.mark.parametrize("seq_dim", [32, 512], ids=id_formatter("seq_dim"))
@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=2), ids=id_formatter("hidden_dim")) @pytest.mark.parametrize("hidden_dim", [32, 1024 * 4], ids=id_formatter("hidden_dim"))
@pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=2), ids=id_formatter("batch_dim")) @pytest.mark.parametrize("batch_dim", [2, 16], ids=id_formatter("batch_dim"))
@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose")) @pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose"))
def test_minmax_igemm(self, seq_dim, hidden_dim, batch_dim, transpose): def test_minmax_igemm(self, seq_dim, hidden_dim, batch_dim, transpose):
def min_max(x): def min_max(x):
...@@ -501,10 +501,10 @@ class TestIGEMMFunctional: ...@@ -501,10 +501,10 @@ class TestIGEMMFunctional:
assert mean(errs) < 0.015 assert mean(errs) < 0.015
assert mean(relerrs) < 0.3 assert mean(relerrs) < 0.3
@pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=2), ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [1, 64], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=2), ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [32, 128], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=2), ids=id_formatter("dim3")) @pytest.mark.parametrize("dim3", [32, 256], ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 256, n=2), ids=id_formatter("dim4")) @pytest.mark.parametrize("dim4", [32, 256], ids=id_formatter("dim4"))
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose")) @pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
def test_ibmm(self, dim1, dim2, dim3, dim4, transpose): def test_ibmm(self, dim1, dim2, dim3, dim4, transpose):
dim2 = dim2 - (dim2 % 16) dim2 = dim2 - (dim2 % 16)
...@@ -760,8 +760,8 @@ class TestLLMInt8Functional: ...@@ -760,8 +760,8 @@ class TestLLMInt8Functional:
class TestSpMMFunctional: class TestSpMMFunctional:
@pytest.mark.parametrize("dim1", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [128, 512], ids=id_formatter("dim2"))
@pytest.mark.parametrize("transposed_B", TRUE_FALSE, ids=id_formatter("transposed_B")) @pytest.mark.parametrize("transposed_B", TRUE_FALSE, ids=id_formatter("transposed_B"))
def test_spmm_coo(self, dim1, dim2, transposed_B): def test_spmm_coo(self, dim1, dim2, transposed_B):
threshold = 1.5 threshold = 1.5
...@@ -893,7 +893,7 @@ class TestSpMMFunctional: ...@@ -893,7 +893,7 @@ class TestSpMMFunctional:
@pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [256, 1024], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [256, 1024], ids=id_formatter("dim2"))
@pytest.skip("No longer supported") @pytest.mark.skip("No longer supported")
def test_integrated_sparse_decomp(self, dim1, dim2): def test_integrated_sparse_decomp(self, dim1, dim2):
threshold = 3.0 threshold = 3.0
for _ in range(k): for _ in range(k):
...@@ -1096,8 +1096,8 @@ class TestQuantize4BitFunctional: ...@@ -1096,8 +1096,8 @@ class TestQuantize4BitFunctional:
assert err.item() < math.log2(blocksize) * 8e-2 assert err.item() < math.log2(blocksize) * 8e-2
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
def test_4bit_compressed_stats(self, quant_type): @pytest.mark.parametrize("blocksize", [64, 128], ids=id_formatter("blocksize"))
for blocksize in [128, 64]: def test_4bit_compressed_stats(self, quant_type, blocksize):
errs1 = [] errs1 = []
errs2 = [] errs2 = []
for i in range(10): for i in range(10):
...@@ -1125,9 +1125,6 @@ class TestQuantize4BitFunctional: ...@@ -1125,9 +1125,6 @@ class TestQuantize4BitFunctional:
assert err.item() < 0.11 assert err.item() < 0.11
assert relerr.item() < 0.28 assert relerr.item() < 0.28
# print(sum(errs1)/len(errs1), blocksize, quant_type)
# print(sum(errs2)/len(errs2), blocksize, quant_type)
# @pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) # @pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
@pytest.mark.parametrize("quant_type", ["nf4"]) @pytest.mark.parametrize("quant_type", ["nf4"])
@pytest.mark.benchmark @pytest.mark.benchmark
...@@ -1169,10 +1166,8 @@ class TestQuantize4BitFunctional: ...@@ -1169,10 +1166,8 @@ class TestQuantize4BitFunctional:
[torch.uint8, torch.float16, torch.bfloat16, torch.float32], [torch.uint8, torch.float16, torch.bfloat16, torch.float32],
ids=describe_dtype, ids=describe_dtype,
) )
def test_gemv_4bit(self, dtype, storage_type, quant_storage, double_quant, kind): @pytest.mark.parametrize("dim", [128, 256, 512, 1024], ids=id_formatter("dim"))
for dim in [128, 256, 512, 1024]: def test_gemv_4bit(self, dim, dtype, storage_type, quant_storage, double_quant, kind):
# for dim in [4*1024]:
# for dim in [1*16]:
errs1 = [] errs1 = []
errs2 = [] errs2 = []
errs3 = [] errs3 = []
...@@ -1361,18 +1356,3 @@ def test_managed(): ...@@ -1361,18 +1356,3 @@ def test_managed():
F._mul(A, B2) F._mul(A, B2)
F._mul(A, B2) F._mul(A, B2)
assert (A == 17 * (2**3)).sum().item() == n * n assert (A == 17 * (2**3)).sum().item() == n * n
@pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=1), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=1), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=1), ids=id_formatter("dim3"))
@pytest.mark.deprecated
def test_vector_quant(dim1, dim2, dim3):
dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16)
for i in range(k):
A = torch.randn(size=(dim2, dim3), device="cuda")
qA, SA = F.vectorwise_quant(A, dim=0)
A1 = F.vectorwise_dequant(qA, SA)
n = A1.numel()
assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n * 0.002))
...@@ -60,7 +60,7 @@ def generate(model, tokenizer, text, generation_config, prompt_func=get_prompt_f ...@@ -60,7 +60,7 @@ def generate(model, tokenizer, text, generation_config, prompt_func=get_prompt_f
return tokenizer.decode(outputs[0], skip_special_tokens=True) return tokenizer.decode(outputs[0], skip_special_tokens=True)
models = ["huggyllama/llama-7b", "bigscience/bloom-1b7"] models = ["bigscience/bloom-1b7"]
dtypes = ["nf4", "fp4"] dtypes = ["nf4", "fp4"]
......
...@@ -149,6 +149,9 @@ class Test4bitBlockwiseQuantOps: ...@@ -149,6 +149,9 @@ class Test4bitBlockwiseQuantOps:
if device == "cpu" and quant_type != "nf4": if device == "cpu" and quant_type != "nf4":
pytest.skip("CPU implementation is only available for nf4") pytest.skip("CPU implementation is only available for nf4")
if storage_dtype != torch.uint8:
pytest.xfail("Known issue with storage_dtype != uint8")
A = torch.randn(1024, 1024, dtype=dtype, device=device) A = torch.randn(1024, 1024, dtype=dtype, device=device)
out, absmax = torch.ops.bitsandbytes.quantize_4bit(A, blocksize, quant_type, storage_dtype) out, absmax = torch.ops.bitsandbytes.quantize_4bit(A, blocksize, quant_type, storage_dtype)
......
...@@ -604,44 +604,3 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): ...@@ -604,44 +604,3 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
params = (total_steps - total_steps // 5) * dim1 * dim2 params = (total_steps - total_steps // 5) * dim1 * dim2
print(optim_name, gtype, s, params, s / params) print(optim_name, gtype, s, params, s / params)
# assert s < 3.9 # assert s < 3.9
@pytest.mark.parametrize("dim1", [2 * 1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("gtype", [torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("optim_name", ["paged_adamw"], ids=id_formatter("optim_name"))
@pytest.mark.parametrize("mode", ["bnb"], ids=id_formatter("mode"))
@pytest.mark.benchmark
def test_stream_optimizer_bench(dim1, gtype, optim_name, mode):
layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)]))
layers1 = layers1.to(gtype)
layers1 = layers1.cuda()
large_tensor = None
if mode == "torch":
optim = str2optimizers[optim_name][0](layers1.parameters())
else:
optim = str2optimizers[optim_name][1](layers1.parameters())
# 12 GB
large_tensor = torch.empty((int(4.5e9),), device="cuda")
torch.cuda.synchronize()
time.sleep(5)
num_batches = 5
batches = torch.randn(num_batches, 128, dim1, device="cuda").to(gtype)
lbls = torch.randint(0, 10, size=(num_batches, 128)).cuda()
for i in range(num_batches):
print(i)
b = batches[i]
if i == 2:
torch.cuda.synchronize()
t0 = time.time()
out1 = layers1(b)
loss1 = torch.nn.functional.cross_entropy(out1, lbls[i]).mean()
loss1.backward()
optim.step()
torch.cuda.synchronize()
print(mode, time.time() - t0)
...@@ -11,6 +11,7 @@ from tests.helpers import TRUE_FALSE ...@@ -11,6 +11,7 @@ from tests.helpers import TRUE_FALSE
not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8, not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8,
reason="This test requires triton and a GPU with compute capability 8.0 or higher.", reason="This test requires triton and a GPU with compute capability 8.0 or higher.",
) )
@pytest.mark.skip("No longer supported.")
@pytest.mark.parametrize("vector_wise_quantization", TRUE_FALSE) @pytest.mark.parametrize("vector_wise_quantization", TRUE_FALSE)
def test_switchback(vector_wise_quantization): def test_switchback(vector_wise_quantization):
for dim in [83]: for dim in [83]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment