Unverified Commit e82f72b3 authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

PyTorch Custom Operator Integration (#1544)



* Sketch out first custom op registration

* Add note

* Initial int8 op registration

* Cleanup some deprecated functions.

* Int8 ops updates; tests

* Implement 4bit quant/dequant ops

* Fix nested quant

* cleanup

* Test improvements

* Clean up and improve tests

* Add higher level custom op for int8 matmul + dequant + bias

* Add gemv 4bit custom op

* Cleanup

* Implement out kwarg overloads for custom ops

* Update PyTorch minimum to 2.1

* Deprecation updates

* Deprecation updates

* Cleanup; rename int8_linear_dequant -> int8_scaled_mm

* Bump min pytorch to 2.2

* cleanup

* Test reorganization

* Remove deprecated supports_igemmlt

* More cleanup

* Cleanup obsolete C++/CUDA code

* Cleanup

* Create 'default' backend for fallback op implementations; initial CPU nf4 work

* Stub out for multi-platform

* Fix serialization tests for torch>=2.6.0

* Add example for torch.compile e2e inference

* Test update

---------
Co-authored-by: default avatarTitus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com>
parent f0735f95
...@@ -3,18 +3,35 @@ ...@@ -3,18 +3,35 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from . import research, utils
import torch
from . import _ops, research, utils
from .autograd._functions import ( from .autograd._functions import (
MatmulLtState, MatmulLtState,
bmm_cublas,
matmul, matmul,
matmul_4bit, matmul_4bit,
matmul_cublas,
mm_cublas,
) )
from .backends.cpu import ops as cpu_ops
from .backends.default import ops as default_ops
from .nn import modules from .nn import modules
from .optim import adam from .optim import adam
# This is a signal for integrations with transformers/diffusers.
# Eventually, we will remove this and check based on release version.
features = {"multi-backend"}
supported_torch_devices = {
"cuda",
"cpu",
# "mps",
# "xpu",
# "hpu",
# "npu",
}
if torch.cuda.is_available():
from .backends.cuda import ops as cuda_ops
__pdoc__ = { __pdoc__ = {
"libbitsandbytes": False, "libbitsandbytes": False,
"optim.optimizer.Optimizer8bit": False, "optim.optimizer.Optimizer8bit": False,
......
from math import prod
from typing import Optional, Sequence, Tuple
import torch
_IS_TORCH_GTE_24 = False
if hasattr(torch.library, "register_fake"):
_IS_TORCH_GTE_24 = True
register_fake = torch.library.register_fake
register_kernel = torch.library.register_kernel
else:
# PyTorch <= 2.3
register_fake = torch.library.impl_abstract
register_kernel = torch.library.impl
# Higher level op: int8 matmul + dequant + bias
torch.library.define(
"bitsandbytes::int8_scaled_mm",
"(Tensor A, Tensor B, Tensor row_stats, Tensor col_stats, Tensor? bias=None, ScalarType dtype=float16) -> Tensor",
)
@register_fake("bitsandbytes::int8_scaled_mm")
def _(
A: torch.Tensor,
B: torch.Tensor,
row_stats: torch.Tensor,
col_stats: torch.Tensor,
bias: Optional[torch.Tensor] = None,
dtype=torch.float16,
) -> torch.Tensor:
shapeC = (*A.shape[:-1], B.shape[0])
return torch.empty(shapeC, device=A.device, dtype=dtype)
torch.library.define(
"bitsandbytes::int8_linear_matmul",
"(Tensor A, Tensor B) -> Tensor",
)
@register_fake("bitsandbytes::int8_linear_matmul")
def _(A: torch.Tensor, B: torch.Tensor):
torch._check(A.dtype == torch.int8, lambda: "A must be int8")
torch._check(B.dtype == torch.int8, lambda: "B must be int8")
shapeC = (*A.shape[:-1], B.shape[0])
return torch.empty(shapeC, device=A.device, dtype=torch.int32)
# More info on `out` overloads:
# https://github.com/pytorch/pytorch/issues/125044
torch.library.define(
"bitsandbytes::int8_linear_matmul.out",
"(Tensor A, Tensor B, Tensor! out) -> ()",
)
@register_fake("bitsandbytes::int8_linear_matmul.out")
def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
shapeC = (*A.shape[:-1], B.shape[0])
torch._check(A.dtype == torch.int8, lambda: "A must be int8")
torch._check(B.dtype == torch.int8, lambda: "B must be int8")
torch._check(out.shape == shapeC, lambda: f"Expected out.shape == {shapeC}, got {out.shape}")
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == torch.int32, lambda: f"Expected out.dtype == int32, got {out.dtype}")
torch.library.define(
"bitsandbytes::int8_vectorwise_quant",
"(Tensor A, float threshold=0.0) -> (Tensor, Tensor, Tensor?)",
)
@register_fake("bitsandbytes::int8_vectorwise_quant")
def _(A: torch.Tensor, threshold=0.0):
out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8)
row_stats = torch.empty(prod(A.shape[:-1]), device=A.device, dtype=torch.float32)
if threshold == 0.0:
return out_row, row_stats, None
outlier_cols = torch.library.get_ctx().new_dynamic_size()
return out_row, row_stats, A.new_empty(outlier_cols, dtype=torch.int64)
torch.library.define("bitsandbytes::int8_vectorwise_dequant", "(Tensor A, Tensor stats) -> Tensor")
@register_fake("bitsandbytes::int8_vectorwise_dequant")
def _(A: torch.Tensor, stats: torch.Tensor) -> torch.Tensor:
torch._check(A.dtype == torch.int8, lambda: "A must be int8")
return torch.empty_like(A, dtype=torch.float32)
# Default PyTorch-native implementation
@register_kernel("bitsandbytes::int8_vectorwise_dequant", None)
def _(A: torch.Tensor, stats: torch.Tensor):
# To dequantize we divide by 127, or multiply by the reciprocal.
return A * stats.view(-1, 1) * 7.874015718698502e-3
torch.library.define(
"bitsandbytes::int8_mm_dequant",
"(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType dtype=float16, Tensor? bias=None) -> Tensor",
)
@register_fake("bitsandbytes::int8_mm_dequant")
def _(
A: torch.Tensor,
row_stats: torch.Tensor,
col_stats: torch.Tensor,
dtype=torch.float16,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
torch._check(A.dtype == torch.int32, lambda: "A must be int32")
return torch.empty_like(A, dtype=dtype)
torch.library.define(
"bitsandbytes::int8_double_quant",
"(Tensor A, float threshold=0.0) -> (Tensor, Tensor, Tensor, Tensor, Tensor?)",
)
@register_fake("bitsandbytes::int8_double_quant")
def _(
A: torch.Tensor,
threshold=0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
out_row = torch.empty_like(A, dtype=torch.int8)
out_col = torch.empty_like(A, dtype=torch.int8)
row_stats = torch.empty(prod(A.shape[:-1]), device=A.device, dtype=torch.float32)
col_stats = torch.empty(A.shape[-1], device=A.device, dtype=torch.float32)
outlier_n = torch.library.get_ctx().new_dynamic_size()
outlier_cols = A.new_empty(outlier_n, dtype=torch.int64)
return out_row, out_col, row_stats, col_stats, outlier_cols
torch.library.define(
"bitsandbytes::dequantize_4bit",
"(Tensor A, Tensor absmax, int blocksize, str quant_type, int[] shape, ScalarType dtype) -> Tensor",
)
@register_fake("bitsandbytes::dequantize_4bit")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
torch._check_is_size(blocksize)
return torch.empty(shape, dtype=dtype, device=A.device)
torch.library.define(
"bitsandbytes::dequantize_4bit.out",
"(Tensor A, Tensor absmax, int blocksize, str quant_type, int[] shape, ScalarType dtype, Tensor! out) -> ()",
)
@register_fake("bitsandbytes::dequantize_4bit.out")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
torch._check_is_size(blocksize)
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
torch.library.define(
"bitsandbytes::quantize_4bit",
"(Tensor A, int blocksize, str quant_type, ScalarType quant_storage) -> (Tensor, Tensor)",
)
@register_fake("bitsandbytes::quantize_4bit")
def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
n = A.numel()
blocks = -(n // -blocksize)
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage)
return out, absmax
torch.library.define(
"bitsandbytes::dequantize_blockwise",
"(Tensor A, Tensor absmax, Tensor code, int blocksize, ScalarType dtype) -> Tensor",
)
@register_fake("bitsandbytes::dequantize_blockwise")
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
return torch.empty_like(A, dtype=dtype)
torch.library.define(
"bitsandbytes::dequantize_blockwise.out",
"(Tensor A, Tensor absmax, Tensor code, int blocksize, ScalarType dtype, Tensor! out) -> ()",
)
@register_fake("bitsandbytes::dequantize_blockwise.out")
def _(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
):
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}")
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
torch.library.define("bitsandbytes::quantize_blockwise", "(Tensor A, Tensor code, int blocksize) -> (Tensor, Tensor)")
@register_fake("bitsandbytes::quantize_blockwise")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
n = A.numel()
blocks = -(n // -blocksize)
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
out = torch.empty_like(A, dtype=torch.uint8)
return out, absmax
torch.library.define(
"bitsandbytes::gemv_4bit",
"(Tensor A, Tensor B, int[] shapeB, Tensor absmax, Tensor code, int blocksize) -> Tensor",
)
@register_fake("bitsandbytes::gemv_4bit")
def _(
A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int
) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}")
torch._check(
A.dtype in [torch.float16, torch.bfloat16, torch.float32],
lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
)
torch._check(
B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32],
lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}",
)
shape = (*A.shape[:-1], shapeB[0])
return torch.empty(shape, device=A.device, dtype=A.dtype)
torch.library.define(
"bitsandbytes::gemv_4bit.out",
"(Tensor A, Tensor B, int[] shapeB, Tensor absmax, Tensor code, int blocksize, Tensor! out) -> ()",
)
@register_fake("bitsandbytes::gemv_4bit.out")
def _(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
out: torch.Tensor,
) -> None:
torch._check_is_size(blocksize)
torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}")
torch._check(
A.dtype in [torch.float16, torch.bfloat16, torch.float32],
lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
)
torch._check(
B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32],
lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}",
)
torch._check(
out.shape == (*A.shape[:-1], shapeB[0]),
lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}",
)
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")
...@@ -49,6 +49,10 @@ class GlobalOutlierPooler: ...@@ -49,6 +49,10 @@ class GlobalOutlierPooler:
return torch.Tensor(list(self.outliers)).to(torch.int64) return torch.Tensor(list(self.outliers)).to(torch.int64)
@deprecated(
"This function is deprecated and will be removed in a future release.",
category=FutureWarning,
)
def get_inverse_transform_indices( def get_inverse_transform_indices(
transform_tile: Callable[[torch.Tensor], torch.Tensor], transform_tile: Callable[[torch.Tensor], torch.Tensor],
tile_size: Tuple[int, int], tile_size: Tuple[int, int],
...@@ -80,6 +84,10 @@ def get_inverse_transform_indices( ...@@ -80,6 +84,10 @@ def get_inverse_transform_indices(
return permuted_tile_indices return permuted_tile_indices
@deprecated(
"This function is deprecated and will be removed in a future release.",
category=FutureWarning,
)
def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor: def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor:
""" """
Undo a tiled permutation such as turing or ampere layout Undo a tiled permutation such as turing or ampere layout
...@@ -98,152 +106,9 @@ def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) - ...@@ -98,152 +106,9 @@ def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -
return outputs.reshape(rows, cols).contiguous() return outputs.reshape(rows, cols).contiguous()
@deprecated(
"MatMul8bit is deprecated and will be removed in a future release. Please use MatMul8bitLt instead.",
category=FutureWarning,
)
class MatMul8bit(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B, out=None, quant_type="vector", precision=None):
if precision is None:
precision = [8, 8, 8]
if precision[0] != 8:
with torch.no_grad():
output = torch.matmul(A, B)
else:
if len(B.shape) == 2:
dim = 0
else:
dim = 1
qA, SA = F.vectorwise_quant(A, dim=-1, quant_type=quant_type)
qB, SB = F.vectorwise_quant(B, dim=dim, quant_type=quant_type)
iout = F.igemm(qA, qB)
output = F.vectorwise_mm_dequant(iout, SA, SB, A.dtype, quant_type)
if A.requires_grad or B.requires_grad:
ctx.save_for_backward(A, B)
ctx.quant_type = quant_type
ctx.precision = precision
return output
@staticmethod
def backward(ctx, grad_output):
A, B = ctx.saved_tensors
quant_type = ctx.quant_type
precision = ctx.precision
grad_A = grad_B = None
if B.requires_grad:
if len(A.shape) == 3:
dims = [0, 1]
# bsi -> ibs
permute_dim = [0, 2, 1]
else:
dims = [0]
# bs -> sb
permute_dim = [1, 0]
if precision[1] != 8:
with torch.no_grad():
grad_B = torch.matmul(A.permute(permute_dim), grad_output)
else:
if len(B.shape) == 2 and len(A.shape) == 3:
grad_output = grad_output.contiguous()
if not grad_output.is_contiguous():
grad_output.contiguous()
qgrad_output, S1 = F.vectorwise_quant(
grad_output.view(-1, grad_output.shape[2]),
dim=0,
quant_type=quant_type,
)
if not A.is_contiguous():
A = A.contiguous()
qA, S2 = F.vectorwise_quant(A.view(-1, A.shape[2]), dim=0, quant_type=quant_type)
igrad_B = F.igemm(qA.t(), qgrad_output)
grad_B = F.vectorwise_mm_dequant(igrad_B, S2.t(), S1, grad_output.dtype, quant_type)
else:
qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type)
igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
grad_B = F.vectorwise_mm_dequant(
igrad_B,
S2.permute(permute_dim),
S1,
grad_output.dtype,
quant_type,
)
if A.requires_grad:
if len(grad_output.shape) == 3:
dims = [2]
else:
dims = [1]
if len(B.shape) == 3:
# bio -> boi
permute_dim = [0, 2, 1]
dim_B = dims
else:
# io -> oi
permute_dim = [1, 0]
dim_B = [1]
if precision[2] != 8:
with torch.no_grad():
grad_A = torch.matmul(grad_output, B.permute(permute_dim))
else:
qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
grad_A = F.vectorwise_mm_dequant(
igrad_A,
S1,
S3.permute(permute_dim),
grad_output.dtype,
quant_type,
)
return grad_A, grad_B, None, None, None
mm_cublas = MatMul8bit.apply
bmm_cublas = MatMul8bit.apply
matmul_cublas = MatMul8bit.apply
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def supports_igemmlt(device: torch.device) -> bool:
"""check if this device supports the optimized int8 kernel"""
if torch.cuda.get_device_capability(device=device) < (7, 5):
return False
device_name = torch.cuda.get_device_name(device=device)
nvidia16_models = ("GTX 1630", "GTX 1650", "GTX 1660") # https://en.wikipedia.org/wiki/GeForce_16_series
if any(model_name in device_name for model_name in nvidia16_models):
return False # these devices are technically cuda 7.5-capable, but they lack tensor cores
return True
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def _get_tile_size(format):
assert format in (
"col_turing",
"col_ampere",
), f"please find this assert and manually enter tile size for {format}"
return (8, 32) if format == "col_turing" else (32, 32)
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def get_tile_inds(format, device):
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=format)[0].to(x.device)
with torch.no_grad():
return get_inverse_transform_indices(transform, _get_tile_size(format)).to(device)
@dataclass @dataclass
class MatmulLtState: class MatmulLtState:
_tile_indices: Optional[torch.Tensor] = None _tile_indices: Optional[torch.Tensor] = None # TODO: remove
force_no_igemmlt: bool = False force_no_igemmlt: bool = False
...@@ -279,9 +144,7 @@ class MatmulLtState: ...@@ -279,9 +144,7 @@ class MatmulLtState:
@property @property
def tile_indices(self): def tile_indices(self):
if self._tile_indices is None: raise ValueError("tile_indices is no longer supported.")
self._tile_indices = get_tile_inds(self.formatB, self.CxB.device)
return self._tile_indices
class MatMul8bitLt(torch.autograd.Function): class MatMul8bitLt(torch.autograd.Function):
...@@ -360,20 +223,12 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -360,20 +223,12 @@ class MatMul8bitLt(torch.autograd.Function):
# we want to divide by 127. It's however more performant to multiply # we want to divide by 127. It's however more performant to multiply
# by the reciprocal. # by the reciprocal.
outliers = state.CB[:, state.idx] outliers = state.CB[:, state.idx]
state.subB = (outliers.t() * state.SCB * 7.874015718698502e-3).to(A.dtype) state.subB = F.int8_vectorwise_dequant(outliers, state.SCB).to(A.dtype).t()
else: else:
subA = None subA = None
# 3. Int8 Matmul # 3. Int8 Matmul + Dequant + Bias
out32 = F.int8_linear_matmul(CA, state.CB) output = torch.ops.bitsandbytes.int8_scaled_mm.default(CA, state.CB, SCA, state.SCB, bias=bias, dtype=A.dtype)
# Dequantize matmul result
if bias is None or bias.dtype == torch.float16:
# we apply the fused bias here
output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=bias).to(A.dtype)
else: # apply bias separately
# TODO: Fused bias for fp32/bf16?
output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=None).to(A.dtype).add_(bias)
# 4. Mixed-precision decomposition matmul # 4. Mixed-precision decomposition matmul
if subA is not None and state.subB is not None: if subA is not None and state.subB is not None:
...@@ -423,8 +278,14 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -423,8 +278,14 @@ class MatMul8bitLt(torch.autograd.Function):
if req_gradB: if req_gradB:
Cgrad, _, _, SCgradt, _ = F.int8_double_quant(grad_output.to(torch.float16)) Cgrad, _, _, SCgradt, _ = F.int8_double_quant(grad_output.to(torch.float16))
gradB32 = F.int8_linear_matmul(Cgrad.t().contiguous(), CAt.t()) grad_B = torch.ops.bitsandbytes.int8_scaled_mm.default(
grad_B = F.int8_mm_dequant(gradB32, SCgradt, SCAt) Cgrad.t().contiguous(),
CAt.t(),
SCgradt,
SCAt,
dtype=torch.float16,
)
if state.threshold > 0.0 and subA is not None: if state.threshold > 0.0 and subA is not None:
grad_B[:, idx] += torch.matmul(grad_output.t(), subA) grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
......
import ctypes as ct
from typing import Optional, Tuple
import torch
from bitsandbytes.functional import get_ptr
from ..._ops import register_kernel
from ...cextension import lib
# torch._int_mm for s8@s8->s32 is supported on CPU from torch 2.4+.
# However, we can overflow if we use this without AVX512_VNNI support.
# This is fixed in torch 2.6+, so we set this as the minimum to be safe.
# For more information: https://github.com/pytorch/pytorch/pull/136942
# TODO(matthewdouglas): aarch64?
if torch.__version__ >= (2, 6):
@register_kernel("bitsandbytes::int8_linear_matmul", "cpu")
def _(A: torch.Tensor, B: torch.Tensor):
return torch._int_mm(
A.reshape(-1, A.shape[-1]),
B.t(),
).reshape(*A.shape[:-1], B.shape[0])
@register_kernel("bitsandbytes::int8_mm_dequant", "cpu")
def _(
A: torch.Tensor,
row_stats: torch.Tensor,
col_stats: torch.Tensor,
dtype=torch.float16,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}")
torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}")
torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}")
A_calc = A.view(-1, A.shape[-1])
row_stats = row_stats.reshape(-1).unsqueeze(-1)
col_stats = col_stats.reshape(-1).unsqueeze(0)
out = A_calc * (row_stats * col_stats) * 6.200124e-05
if bias is not None:
out += bias
return out.to(dtype)
@register_kernel("bitsandbytes::quantize_blockwise", "cpu")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on cpu, got {A.dtype}")
n = A.numel()
blocks = -(n // -blocksize)
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
out = torch.empty_like(A, dtype=torch.uint8)
lib.cquantize_blockwise_cpu_fp32(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_longlong(blocksize),
ct.c_longlong(n),
)
return out, absmax
@register_kernel("bitsandbytes::dequantize_blockwise", "cpu")
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on cpu, got {dtype}")
out = torch.empty_like(A, dtype=dtype)
lib.cdequantize_blockwise_cpu_fp32(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_longlong(blocksize),
ct.c_longlong(A.numel()),
)
return out
_NF4_QUANT_TABLE = torch.tensor(
[
-1.0,
-0.6961928009986877,
-0.5250730514526367,
-0.39491748809814453,
-0.28444138169288635,
-0.18477343022823334,
-0.09105003625154495,
0.0,
0.07958029955625534,
0.16093020141124725,
0.24611230194568634,
0.33791524171829224,
0.44070982933044434,
0.5626170039176941,
0.7229568362236023,
1.0,
],
dtype=torch.float32,
device="cpu",
)
@register_kernel("bitsandbytes::quantize_4bit", "cpu")
def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
n = A.numel()
# TODO: Support when weight matrix is not divisible by blocksize
torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}")
# Divide into blocks and normalize
blocks = A.reshape(-1, blocksize)
absmax = blocks.abs().max(dim=1).values.float()
scaled = blocks / absmax.unsqueeze(-1)
# Quantize with the lookup table
quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - _NF4_QUANT_TABLE), dim=-1, keepdim=True).to(torch.uint8)
# Pack two quantized values per byte
packed = quantized[::2] << 4 | quantized[1::2]
if quant_storage != torch.uint8:
packed = packed.squeeze().view(quant_storage).unsqueeze(1)
return packed, absmax.float()
import ctypes as ct
from math import prod
from typing import Optional, Sequence, Tuple
import torch
from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr
from ..._ops import register_kernel
from ...cextension import lib
@register_kernel("bitsandbytes::int8_linear_matmul", "cuda")
def _(A: torch.Tensor, B: torch.Tensor):
out = torch.empty((*A.shape[:-1], B.shape[0]), device=A.device, dtype=torch.int32)
return _int8_linear_matmul_impl(A, B, out)
@register_kernel("bitsandbytes::int8_linear_matmul.out", "cuda")
def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
_int8_linear_matmul_impl(A, B, out)
def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
A, B = B, A
shapeA = A.shape
shapeB = B.shape
torch._check(A.dtype == torch.int8, lambda: "B must be int8")
torch._check(B.dtype == torch.int8, lambda: "A must be int8")
torch._check(A.ndim == 2, lambda: "Only two dimensional matrices are supported for argument B")
torch._check(B.ndim in [2, 3], lambda: "Only two or three dimensional matrices are supported for argument A")
torch._check(prod(shapeB) > 0, lambda: f"Input tensor dimensions need to be > 0: {shapeB}")
torch._check(out.dtype == torch.int32)
shapeC = (*shapeB[:-1], shapeA[0])
torch._check(out.shape == shapeC, lambda: f"Output shape {out.shape} does not match expected shape {shapeC}")
k, m = shapeA
n = prod(shapeB[:-1])
lda = shapeA[-1] # Weights (outputs, inputs)
ldb = shapeB[-1] # Activations (batch, tokens, inputs)
ldc = shapeC[-1] # Output (batch, tokens, outputs)
torch._check(
lda == ldb,
lambda: f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}",
)
# cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4.
# We'll fall back to a slower fp32 calculation in this circumstance.
# Fortunately, this should not be very common.
if lda % 4 != 0:
result = torch.matmul(B.float(), A.float().t()).to(torch.int32)
return out.copy_(result)
with _cuda_device_of(A):
ctx = CUBLAS_Context.get_instance().get_context(A.device)
ptrA = get_ptr(A)
ptrB = get_ptr(B)
ptrC = get_ptr(out)
ptrRowScale = None
m = ct.c_int32(m)
n = ct.c_int32(n)
k = ct.c_int32(k)
lda = ct.c_int32(lda)
ldb = ct.c_int32(ldb)
ldc = ct.c_int32(ldc)
stream = _get_tensor_stream(A)
has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream)
if has_error:
if has_error == 100:
# `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu`
# TODO: Warn and implement a fallback to fp32 compute?
raise NotImplementedError("int8_linear_matmul not implemented!")
else:
raise RuntimeError(
f"cublasLt ran into an error!\n"
f"\t{shapeA=}, {shapeB=}, {shapeC=}\n"
f"\t{(lda, ldb, ldc)=}\n"
f"\t{(m, n, k)=}"
)
return out
@register_kernel("bitsandbytes::int8_mm_dequant", "cuda")
def _(
A: torch.Tensor,
row_stats: torch.Tensor,
col_stats: torch.Tensor,
dtype=torch.float16,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}")
torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}")
torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}")
# Note: cuda kernel only currently supports fp16 output.
# We'll later cast to desired dtype if needed.
out = torch.empty_like(A, dtype=torch.float16)
ptrA = get_ptr(A)
ptrOut = get_ptr(out)
ptrRowStats = get_ptr(row_stats)
ptrColStats = get_ptr(col_stats)
numRows = ct.c_int32(prod(A.shape[:-1]))
numCols = ct.c_int32(A.shape[-1])
# Note: fused bias in the kernel is only supported for fp16
# TODO(matthewdouglas): Consider supporting bf16 fused bias
ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None
with _cuda_device_of(A):
lib.cdequant_mm_int32_fp16(
ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A)
)
# Add bias separately if not fused in kernel
if bias is not None and bias.dtype != torch.float16:
out.add_(bias)
return out.to(dtype)
@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda")
def _(A: torch.Tensor, threshold=0.0):
torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}")
torch._check(threshold >= 0.0, lambda: "threshold must be non-negative")
rows = prod(A.shape[:-1])
cols = A.shape[-1]
row_stats = torch.empty(rows, device=A.device, dtype=torch.float32)
out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8)
outlier_cols = None
if threshold > 0.0:
# TODO we could improve perf of this
outliers = A.abs() >= threshold
if outliers.any():
outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1)
with _cuda_device_of(A):
lib.cint8_vector_quant(
get_ptr(A),
get_ptr(out_row),
get_ptr(row_stats),
ct.c_float(threshold),
ct.c_int32(rows),
ct.c_int32(cols),
_get_tensor_stream(A),
)
# Zero out values from outlier columns across all rows.
# The kernel will handle this for outliers themselves, so we can optimize for rows=1.
if rows > 1 and outlier_cols is not None:
out_row[:, outlier_cols] = 0
return out_row, row_stats, outlier_cols
@register_kernel("bitsandbytes::int8_double_quant", "cuda")
def _(
A: torch.Tensor,
threshold=0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
# Use CUDA kernel for rowwise and COO tensor
quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default(
A,
threshold=threshold,
)
# PyTorch impl for colwise
col_stats, outlier_mask = _get_col_absmax(A, threshold=threshold)
if threshold > 0.0 and outlier_mask is not None:
A = A.masked_fill(outlier_mask, 0.0)
quant_col = torch.round(A.mul(127.0) / col_stats.unsqueeze(0)).to(torch.int8)
return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols
def _get_col_absmax(
A: torch.Tensor,
threshold=0.0,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
torch._check(A.is_floating_point())
outlier_mask = None
absA = A.abs().view(-1, A.shape[-1])
if threshold > 0.0:
# Filter outliers from stats when enabled
outlier_mask = absA >= threshold
absA.masked_fill_(outlier_mask, 0.0)
# shape [cols]; unsqueeze(0) gives [1,cols]
col_stats = absA.amax(dim=0, keepdim=False).float()
return col_stats, outlier_mask
@register_kernel("bitsandbytes::quantize_blockwise", "cuda")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}")
n = A.numel()
blocks = -(n // -blocksize)
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
out = torch.empty_like(A, dtype=torch.uint8)
with _cuda_device_of(A):
args = (
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int32(blocksize),
ct.c_int(A.numel()),
)
if A.dtype == torch.float16:
lib.cquantize_blockwise_fp16(*args)
elif A.dtype == torch.bfloat16:
lib.cquantize_blockwise_bf16(*args)
elif A.dtype == torch.float32:
lib.cquantize_blockwise_fp32(*args)
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
return out, absmax
@register_kernel("bitsandbytes::dequantize_blockwise", "cuda")
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
out = torch.empty_like(A, dtype=dtype)
_dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out)
return out
@register_kernel("bitsandbytes::dequantize_blockwise.out", "cuda")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}")
_dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out)
def _dequantize_blockwise_impl(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
) -> None:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
torch._check(
dtype in [torch.float16, torch.bfloat16, torch.float32],
lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}",
)
with _cuda_device_of(A):
args = (
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(blocksize),
ct.c_int(A.numel()),
_get_tensor_stream(A),
)
if dtype == torch.float16:
lib.cdequantize_blockwise_fp16(*args)
elif dtype == torch.bfloat16:
lib.cdequantize_blockwise_bf16(*args)
elif dtype == torch.float32:
lib.cdequantize_blockwise_fp32(*args)
@register_kernel("bitsandbytes::quantize_4bit", "cuda")
def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
torch._check(quant_type in ["fp4", "nf4"])
torch._check(
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
)
n = A.numel()
blocks = -(n // -blocksize)
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage)
with _cuda_device_of(A):
args = (
None,
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int32(blocksize),
ct.c_int(n),
)
if A.dtype == torch.bfloat16:
if quant_type == "fp4":
lib.cquantize_blockwise_bf16_fp4(*args)
else:
lib.cquantize_blockwise_bf16_nf4(*args)
elif A.dtype == torch.float16:
if quant_type == "fp4":
lib.cquantize_blockwise_fp16_fp4(*args)
else:
lib.cquantize_blockwise_fp16_nf4(*args)
elif A.dtype == torch.float32:
if quant_type == "fp4":
lib.cquantize_blockwise_fp32_fp4(*args)
else:
lib.cquantize_blockwise_fp32_nf4(*args)
return out, absmax
@register_kernel("bitsandbytes::dequantize_4bit", "cuda")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
out = torch.empty(shape, dtype=dtype, device=A.device)
_dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
return out
@register_kernel("bitsandbytes::dequantize_4bit.out", "cuda")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
_dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
def _dequantize_4bit_impl(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
torch._check(quant_type in ["fp4", "nf4"])
torch._check(
dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
)
with _cuda_device_of(A):
args = (
None,
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(blocksize),
ct.c_int(out.numel()),
_get_tensor_stream(A),
)
if out.dtype == torch.bfloat16:
if quant_type == "fp4":
lib.cdequantize_blockwise_bf16_fp4(*args)
else:
lib.cdequantize_blockwise_bf16_nf4(*args)
elif out.dtype == torch.float16:
if quant_type == "fp4":
lib.cdequantize_blockwise_fp16_fp4(*args)
else:
lib.cdequantize_blockwise_fp16_nf4(*args)
elif out.dtype == torch.float32:
if quant_type == "fp4":
lib.cdequantize_blockwise_fp32_fp4(*args)
else:
lib.cdequantize_blockwise_fp32_nf4(*args)
@register_kernel("bitsandbytes::gemv_4bit", "cuda")
def _(
A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int
) -> torch.Tensor:
shape = (*A.shape[:-1], shapeB[0])
out = torch.empty(shape, device=A.device, dtype=A.dtype)
_gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out)
return out
@register_kernel("bitsandbytes::gemv_4bit.out", "cuda")
def _(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
out: torch.Tensor,
) -> None:
torch._check(
out.shape == (*A.shape[:-1], shapeB[0]),
lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}",
)
torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")
_gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out)
def _gemv_4bit_impl(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
out: torch.Tensor,
) -> None:
torch._check_is_size(blocksize)
torch._check(
A.numel() == A.size(-1),
lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}",
)
torch._check(
A.dtype in [torch.float16, torch.bfloat16, torch.float32],
lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
)
torch._check(
B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32],
lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}",
)
torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}")
torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}")
m = ct.c_int32(shapeB[0])
n = ct.c_int32(1)
k = ct.c_int32(shapeB[1])
lda = m
ldb = ct.c_int32((A.shape[-1] + 1) // 2)
ldc = m
stream = _get_tensor_stream(A)
with _cuda_device_of(A):
if A.dtype == torch.float16:
lib.cgemm_4bit_inference_naive_fp16(
m,
n,
k,
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(code),
get_ptr(out),
lda,
ldb,
ldc,
ct.c_int32(blocksize),
stream,
)
elif A.dtype == torch.bfloat16:
lib.cgemm_4bit_inference_naive_bf16(
m,
n,
k,
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(code),
get_ptr(out),
lda,
ldb,
ldc,
ct.c_int32(blocksize),
stream,
)
elif A.dtype == torch.float32:
lib.cgemm_4bit_inference_naive_fp32(
m,
n,
k,
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(code),
get_ptr(out),
lda,
ldb,
ldc,
ct.c_int32(blocksize),
stream,
)
from typing import Optional
import torch
from ..._ops import register_kernel
@register_kernel("bitsandbytes::int8_scaled_mm", None)
def _(
A: torch.Tensor,
B: torch.Tensor,
row_stats: torch.Tensor,
col_stats: torch.Tensor,
bias: Optional[torch.Tensor] = None,
dtype=torch.float16,
) -> torch.Tensor:
out_i32 = torch.ops.bitsandbytes.int8_linear_matmul.default(A, B)
out = torch.ops.bitsandbytes.int8_mm_dequant.default(out_i32, row_stats, col_stats, dtype=dtype, bias=bias)
return out
@register_kernel("bitsandbytes::int8_linear_matmul", None)
def _(A: torch.Tensor, B: torch.Tensor):
return _int8_linear_matmul_impl(A, B)
@register_kernel("bitsandbytes::int8_linear_matmul.out", None)
def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
torch._check(out.dtype == torch.int32)
_int8_linear_matmul_impl(A, B, out)
def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None):
# Naive implementation: perform matmul in fp32
result = torch.matmul(A.float(), B.float().t()).to(torch.int32)
if out is not None:
result = out.copy_(result)
return result
...@@ -19,7 +19,9 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: ...@@ -19,7 +19,9 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
The library is not guaranteed to exist at the returned path. The library is not guaranteed to exist at the returned path.
""" """
library_name = f"libbitsandbytes_cuda{cuda_specs.cuda_version_string}{DYNAMIC_LIBRARY_SUFFIX}"
prefix = "rocm" if torch.version.hip else "cuda"
library_name = f"libbitsandbytes_{prefix}{cuda_specs.cuda_version_string}{DYNAMIC_LIBRARY_SUFFIX}"
override_value = os.environ.get("BNB_CUDA_VERSION") override_value = os.environ.get("BNB_CUDA_VERSION")
if override_value: if override_value:
...@@ -76,7 +78,7 @@ def get_native_library() -> BNBNativeLibrary: ...@@ -76,7 +78,7 @@ def get_native_library() -> BNBNativeLibrary:
logger.warning( logger.warning(
"The installed version of bitsandbytes was compiled without GPU support. " "The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.", "8-bit optimizers and GPU quantization are unavailable.",
) )
return BNBNativeLibrary(dll) return BNBNativeLibrary(dll)
......
import dataclasses import dataclasses
from functools import lru_cache
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
...@@ -12,22 +13,26 @@ class CUDASpecs: ...@@ -12,22 +13,26 @@ class CUDASpecs:
@property @property
def has_imma(self) -> bool: def has_imma(self) -> bool:
return self.highest_compute_capability >= (7, 5) return torch.version.hip or self.highest_compute_capability >= (7, 5)
def get_compute_capabilities() -> List[Tuple[int, int]]: def get_compute_capabilities() -> List[Tuple[int, int]]:
return sorted(torch.cuda.get_device_capability(torch.cuda.device(i)) for i in range(torch.cuda.device_count())) return sorted(torch.cuda.get_device_capability(torch.cuda.device(i)) for i in range(torch.cuda.device_count()))
@lru_cache(None)
def get_cuda_version_tuple() -> Tuple[int, int]: def get_cuda_version_tuple() -> Tuple[int, int]:
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION if torch.version.cuda:
major, minor = map(int, torch.version.cuda.split(".")) return map(int, torch.version.cuda.split(".")[0:2])
return major, minor elif torch.version.hip:
return map(int, torch.version.hip.split(".")[0:2])
return None
def get_cuda_version_string() -> str: def get_cuda_version_string() -> str:
major, minor = get_cuda_version_tuple() major, minor = get_cuda_version_tuple()
return f"{major}{minor}" return f"{major * 10 + minor}"
def get_cuda_specs() -> Optional[CUDASpecs]: def get_cuda_specs() -> Optional[CUDASpecs]:
......
This diff is collapsed.
...@@ -11,7 +11,6 @@ from torch import Tensor, device, dtype, nn ...@@ -11,7 +11,6 @@ from torch import Tensor, device, dtype, nn
import torch.nn.functional as F import torch.nn.functional as F
import bitsandbytes as bnb import bitsandbytes as bnb
from bitsandbytes.autograd._functions import get_tile_inds, undo_layout
from bitsandbytes.functional import QuantState from bitsandbytes.functional import QuantState
from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import ( from bitsandbytes.utils import (
...@@ -654,8 +653,7 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k ...@@ -654,8 +653,7 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k
weight_format = INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weight_format] weight_format = INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weight_format]
if weight_format != "row": if weight_format != "row":
tile_indices = get_tile_inds(weight_format, weight.device) raise ValueError(f"Only 'row' weight format is supported, got {weight_format}")
state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices)
class Embedding8bit(nn.Embedding): class Embedding8bit(nn.Embedding):
......
This diff is collapsed.
...@@ -121,8 +121,6 @@ template<typename T, int THREADS, int SPARSE_DECOMP> __global__ void kInt8Vector ...@@ -121,8 +121,6 @@ template<typename T, int THREADS, int SPARSE_DECOMP> __global__ void kInt8Vector
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc);
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize);
......
...@@ -374,48 +374,6 @@ template<int ORDER> int get_leading_dim(int dim1, int dim2) ...@@ -374,48 +374,6 @@ template<int ORDER> int get_leading_dim(int dim1, int dim2)
} }
} }
template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2)
{
cublasLtOrder_t orderA = get_order<SRC>();
cublasLtOrder_t orderOut = get_order<TARGET>();
int ldA = get_leading_dim<SRC>(dim1, dim2);
int ldOut = get_leading_dim<TARGET>(dim1, dim2);
cublasLtMatrixLayout_t A_desc = NULL, out_desc = NULL;
cublasLtMatrixTransformDesc_t A2Out_desc = NULL;
cublasOperation_t opTranspose = CUBLAS_OP_T;
float transformAlpha = 1.0f, transformBeta = 0.0f;
if(DTYPE == 8)
{
checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_8I, dim1, dim2, ldA));
checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, CUDA_R_8I, dim1, dim2, ldOut));
}
else if(DTYPE == 32)
{
checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_32I, dim1, dim2, ldA));
checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, CUDA_R_32I, dim1, dim2, ldOut));
}
else
{
printf("ERROR WRONG TYPE FOR TRANSFORM: %i\n", DTYPE);
}
checkCublasStatus(cublasLtMatrixLayoutSetAttribute(A_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderA, sizeof(orderA)));
checkCublasStatus(cublasLtMatrixLayoutSetAttribute(out_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderOut, sizeof(orderOut)));
checkCublasStatus(cublasLtMatrixTransformDescCreate(&A2Out_desc, CUDA_R_32F));
if(transpose){ checkCublasStatus(cublasLtMatrixTransformDescSetAttribute(A2Out_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &opTranspose, sizeof(opTranspose))); }
checkCublasStatus(cublasLtMatrixTransform(ltHandle, A2Out_desc, &transformAlpha, A, A_desc, &transformBeta, NULL, NULL, out, out_desc, 0));
if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc));
if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc));
if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc));
}
template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt( template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(
cublasLtHandle_t ltHandle, cublasLtHandle_t ltHandle,
int m, int n, int k, int m, int n, int k,
...@@ -542,50 +500,6 @@ void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, ...@@ -542,50 +500,6 @@ void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols,
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *out, int rows, int cols)
{
int threads = 256;
int items_per_thread = 8;
// we load 128 column values per warp
int tile_cols = 32*items_per_thread;
int tile_rows = 32;
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
int row_tiles = (tiledRows/tile_rows);
int col_tiles = (tiledCols/tile_cols);
row_tiles = row_tiles > 0 ? row_tiles : 1;
col_tiles = col_tiles > 0 ? col_tiles : 1;
int num_blocks = row_tiles * col_tiles;
int outCols = fill_up_to_nearest_multiple(cols, 32);
int outRows = fill_up_to_nearest_multiple(rows, 32);
if(FORMAT == COL_TURING)
{
if(TRANSPOSE)
outRows = fill_up_to_nearest_multiple(cols, 8);
else
outRows = fill_up_to_nearest_multiple(rows, 8);
}
else if(FORMAT == COL_AMPERE)
{
if(TRANSPOSE)
outRows = fill_up_to_nearest_multiple(cols, 32);
else
outRows = fill_up_to_nearest_multiple(rows, 32);
}
else
{
if(TRANSPOSE)
{
outCols = fill_up_to_nearest_multiple(rows, 32);
outRows = cols;
}
}
kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<<num_blocks, threads>>>(A, out, rows, cols, tiledCols, outRows, outCols);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
{ {
cusparseSpMatDescr_t descA; cusparseSpMatDescr_t descA;
...@@ -643,32 +557,6 @@ template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, ...@@ -643,32 +557,6 @@ template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count,
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols)
{
int threads = 256;
// we load 128 column values per warp
int tiledCols = tiledCols = fill_up_to_nearest_multiple(cols, 32);
int tiledRows = 0;
int num_blocks = idx_size;
if(FORMAT == COL_TURING)
{
tiledRows = fill_up_to_nearest_multiple(rows, 8);
}
else if(FORMAT == COL_AMPERE)
{
tiledRows = fill_up_to_nearest_multiple(rows, 32);
}
kExtractOutliers<FORMAT><<<num_blocks, threads>>>(A, idx, out, idx_size, rows, cols, tiledRows, tiledCols);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits)
{ {
...@@ -722,8 +610,6 @@ template void gemm_4bit_inference_naive<float, 32>(int m, int n, int k, float * ...@@ -722,8 +610,6 @@ template void gemm_4bit_inference_naive<float, 32>(int m, int n, int k, float *
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); //template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits);
template void extractOutliers<COL_TURING>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
template void extractOutliers<COL_AMPERE>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
template void spmm_coo_very_sparse_naive<half, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); template void spmm_coo_very_sparse_naive<half, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
template void spmm_coo_very_sparse_naive<signed char, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); template void spmm_coo_very_sparse_naive<signed char, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
...@@ -732,13 +618,6 @@ template int igemmlt<32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, cons ...@@ -732,13 +618,6 @@ template int igemmlt<32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, cons
template int igemmlt<8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream); template int igemmlt<8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream);
template int igemmlt<8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream); template int igemmlt<8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream);
template void transformRowToFormat<COL32, 0>(char * A, char *out, int rows, int cols);
template void transformRowToFormat<COL32, 1>(char * A, char *out, int rows, int cols);
template void transformRowToFormat<COL_TURING, 0>(char * A, char *out, int rows, int cols);
template void transformRowToFormat<COL_TURING, 1>(char * A, char *out, int rows, int cols);
template void transformRowToFormat<COL_AMPERE, 0>(char * A, char *out, int rows, int cols);
template void transformRowToFormat<COL_AMPERE, 1>(char * A, char *out, int rows, int cols);
template void estimateQuantiles(half *A, float *code, float offset, int n); template void estimateQuantiles(half *A, float *code, float offset, int n);
template void estimateQuantiles(float *A, float *code, float offset, int n); template void estimateQuantiles(float *A, float *code, float offset, int n);
...@@ -840,15 +719,6 @@ MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX); ...@@ -840,15 +719,6 @@ MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX);
template void percentileClipping(float * g, float *gnorm_vec, int step, const int n); template void percentileClipping(float * g, float *gnorm_vec, int step, const int n);
template void percentileClipping(half * g, float *gnorm_vec, int step, const int n); template void percentileClipping(half * g, float *gnorm_vec, int step, const int n);
template void transform<int8_t, ROW, COL, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
template void transform<int8_t, ROW, ROW, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
template void transform<int8_t, ROW, COL32, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
template void transform<int32_t, ROW, COL32, false, 32>(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2);
template void transform<int8_t, ROW, COL_TURING, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
template void transform<int8_t, ROW, COL_AMPERE, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
template void transform<int8_t, COL32, ROW, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
template void transform<int32_t, COL32, ROW, false, 32>(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2);
template int get_leading_dim<ROW>(int dim1, int dim2); template int get_leading_dim<ROW>(int dim1, int dim2);
template int get_leading_dim<COL>(int dim1, int dim2); template int get_leading_dim<COL>(int dim1, int dim2);
template int get_leading_dim<COL32>(int dim1, int dim2); template int get_leading_dim<COL32>(int dim1, int dim2);
...@@ -173,20 +173,15 @@ void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, i ...@@ -173,20 +173,15 @@ void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, i
template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream); template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream);
template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2);
void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc);
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream); void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream);
void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream); void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream);
void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream); void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream);
template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *out, int rows, int cols);
void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B); void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B);
template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols);
void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB); void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB);
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits);
......
...@@ -149,32 +149,6 @@ void dequantizeBlockwise_bf16(float *code, unsigned char *A, float *absmax, __nv ...@@ -149,32 +149,6 @@ void dequantizeBlockwise_bf16(float *code, unsigned char *A, float *absmax, __nv
void dequantizeBlockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, FP4>(NULL, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, FP4>(NULL, A, absmax, out, blocksize, n, stream); }
void dequantizeBlockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, NF4>(NULL, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, NF4>(NULL, A, absmax, out, blocksize, n, stream); }
#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \
{ \
transform<dtype, src, target, transpose, bits>(ltHandle, A, out, dim1, dim2); \
} \
MAKE_FUNC_TRANSFORM(8, row, col, n, int8_t, ROW, COL, false, 8);
MAKE_FUNC_TRANSFORM(8, row, row, n, int8_t, ROW, ROW, false, 8);
MAKE_FUNC_TRANSFORM(8, row, col32, n, int8_t, ROW, COL32, false, 8);
MAKE_FUNC_TRANSFORM(32, row, col32, n, int32_t, ROW, COL32, false, 32);
MAKE_FUNC_TRANSFORM(8, row, col_turing, n, int8_t, ROW, COL_TURING, false, 8);
MAKE_FUNC_TRANSFORM(8, row, col_ampere, n, int8_t, ROW, COL_AMPERE, false, 8);
MAKE_FUNC_TRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8);
MAKE_FUNC_TRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32);
void transform_row2col32(char * A, char *out, int rows, int cols){ transformRowToFormat<COL32, 0>(A, out, rows, cols); }
void transform_row2col32T(char * A, char *out, int rows, int cols){ transformRowToFormat<COL32, 1>(A, out, rows, cols); }
void transform_row2turing(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_TURING, 0>(A, out, rows, cols); }
void transform_row2turingT(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_TURING, 1>(A, out, rows, cols); }
void transform_row2ampere(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_AMPERE, 0>(A, out, rows, cols); }
void transform_row2ampereT(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_AMPERE, 1>(A, out, rows, cols); }
void extractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers<COL_TURING>(A, idx, out, idx_size, rows, cols); }
void extractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers<COL_AMPERE>(A, idx, out, idx_size, rows, cols); }
int igemmlt_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) { int igemmlt_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
return igemmlt<32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); return igemmlt<32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
} }
...@@ -317,22 +291,6 @@ extern "C" ...@@ -317,22 +291,6 @@ extern "C"
int cigemmlt_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) { int cigemmlt_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
return igemmlt_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); return igemmlt_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
} }
#define MAKE_FUNC_CTRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
void ctransform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(Context *context, dtype *A, dtype *out, int dim1, int dim2) \
{ \
transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose((cublasLtHandle_t) context->m_handle, A, out, dim1, dim2); \
} \
MAKE_FUNC_CTRANSFORM(8, row, col, n, int8_t, ROW, COL, false, 8)
MAKE_FUNC_CTRANSFORM(8, row, row, n, int8_t, ROW, ROW, false, 8)
MAKE_FUNC_CTRANSFORM(8, row, col32, n, int8_t, ROW, COL32, false, 8)
MAKE_FUNC_CTRANSFORM(32, row, col32, n, int32_t, ROW, COL32, false, 32)
MAKE_FUNC_CTRANSFORM(8, row, col_turing, n, int8_t, ROW, COL_TURING, false, 8)
MAKE_FUNC_CTRANSFORM(8, row, col_ampere, n, int8_t, ROW, COL_AMPERE, false, 8)
MAKE_FUNC_CTRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8)
MAKE_FUNC_CTRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32)
void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream) void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream)
{ dequant_mm_int32_fp16(A, rowStats, colStats, out, bias, numRows, numCols, stream); } { dequant_mm_int32_fp16(A, rowStats, colStats, out, bias, numRows, numCols, stream); }
void cget_row_stats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) { void cget_row_stats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) {
...@@ -342,24 +300,6 @@ extern "C" ...@@ -342,24 +300,6 @@ extern "C"
int8VectorQuant(A, out, rowStats, threshold, rows, cols, stream); int8VectorQuant(A, out, rowStats, threshold, rows, cols, stream);
} }
void ctransform_row2col32(char * A, char *out, int rows, int cols)
{ transform_row2col32(A, out, rows, cols); }
void ctransform_row2col32T(char * A, char *out, int rows, int cols)
{ transform_row2col32T(A, out, rows, cols); }
void ctransform_row2turing(char * A, char *out, int rows, int cols)
{ transform_row2turing(A, out, rows, cols); }
void ctransform_row2turingT(char * A, char *out, int rows, int cols)
{ transform_row2turingT(A, out, rows, cols); }
void ctransform_row2ampere(char * A, char *out, int rows, int cols)
{ transform_row2ampere(A, out, rows, cols); }
void ctransform_row2ampereT(char * A, char *out, int rows, int cols)
{ transform_row2ampereT(A, out, rows, cols); }
void cspmm_coo(ContextCusparse *context, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) void cspmm_coo(ContextCusparse *context, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
{ spmm_coo((cusparseHandle_t) context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C, transposed_B); } { spmm_coo((cusparseHandle_t) context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C, transposed_B); }
...@@ -369,9 +309,6 @@ extern "C" ...@@ -369,9 +309,6 @@ extern "C"
void cspmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) void cspmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
{ spmm_coo_very_sparse_naive_int8(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } { spmm_coo_very_sparse_naive_int8(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
void cextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_turing(A, idx, out, idx_size, rows, cols); }
void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); }
//void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) //void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc)
//{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); } //{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); }
......
...@@ -174,7 +174,7 @@ export BNB_CUDA_VERSION=126 ...@@ -174,7 +174,7 @@ export BNB_CUDA_VERSION=126
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/YOUR_USERNAME/local/cuda-12.6 export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/YOUR_USERNAME/local/cuda-12.6
``` ```
3. Now when you launch bitsandbytes with these environment variables, the PyTorch CUDA version is overridden by the new CUDA version (in this example, version 11.7) and a different bitsandbytes library is loaded. 3. Now when you launch bitsandbytes with these environment variables, the PyTorch CUDA version is overridden by the new CUDA version (in this example, version 12.6) and a different bitsandbytes library is loaded.
## Multi-backend Support (Alpha Release)[[multi-backend]] ## Multi-backend Support (Alpha Release)[[multi-backend]]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment