Unverified Commit e82f72b3 authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

PyTorch Custom Operator Integration (#1544)



* Sketch out first custom op registration

* Add note

* Initial int8 op registration

* Cleanup some deprecated functions.

* Int8 ops updates; tests

* Implement 4bit quant/dequant ops

* Fix nested quant

* cleanup

* Test improvements

* Clean up and improve tests

* Add higher level custom op for int8 matmul + dequant + bias

* Add gemv 4bit custom op

* Cleanup

* Implement out kwarg overloads for custom ops

* Update PyTorch minimum to 2.1

* Deprecation updates

* Deprecation updates

* Cleanup; rename int8_linear_dequant -> int8_scaled_mm

* Bump min pytorch to 2.2

* cleanup

* Test reorganization

* Remove deprecated supports_igemmlt

* More cleanup

* Cleanup obsolete C++/CUDA code

* Cleanup

* Create 'default' backend for fallback op implementations; initial CPU nf4 work

* Stub out for multi-platform

* Fix serialization tests for torch>=2.6.0

* Add example for torch.compile e2e inference

* Test update

---------
Co-authored-by: default avatarTitus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com>
parent f0735f95
...@@ -3,18 +3,35 @@ ...@@ -3,18 +3,35 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from . import research, utils
import torch
from . import _ops, research, utils
from .autograd._functions import ( from .autograd._functions import (
MatmulLtState, MatmulLtState,
bmm_cublas,
matmul, matmul,
matmul_4bit, matmul_4bit,
matmul_cublas,
mm_cublas,
) )
from .backends.cpu import ops as cpu_ops
from .backends.default import ops as default_ops
from .nn import modules from .nn import modules
from .optim import adam from .optim import adam
# This is a signal for integrations with transformers/diffusers.
# Eventually, we will remove this and check based on release version.
features = {"multi-backend"}
supported_torch_devices = {
"cuda",
"cpu",
# "mps",
# "xpu",
# "hpu",
# "npu",
}
if torch.cuda.is_available():
from .backends.cuda import ops as cuda_ops
__pdoc__ = { __pdoc__ = {
"libbitsandbytes": False, "libbitsandbytes": False,
"optim.optimizer.Optimizer8bit": False, "optim.optimizer.Optimizer8bit": False,
......
from math import prod
from typing import Optional, Sequence, Tuple
import torch
_IS_TORCH_GTE_24 = False
if hasattr(torch.library, "register_fake"):
_IS_TORCH_GTE_24 = True
register_fake = torch.library.register_fake
register_kernel = torch.library.register_kernel
else:
# PyTorch <= 2.3
register_fake = torch.library.impl_abstract
register_kernel = torch.library.impl
# Higher level op: int8 matmul + dequant + bias
torch.library.define(
"bitsandbytes::int8_scaled_mm",
"(Tensor A, Tensor B, Tensor row_stats, Tensor col_stats, Tensor? bias=None, ScalarType dtype=float16) -> Tensor",
)
@register_fake("bitsandbytes::int8_scaled_mm")
def _(
A: torch.Tensor,
B: torch.Tensor,
row_stats: torch.Tensor,
col_stats: torch.Tensor,
bias: Optional[torch.Tensor] = None,
dtype=torch.float16,
) -> torch.Tensor:
shapeC = (*A.shape[:-1], B.shape[0])
return torch.empty(shapeC, device=A.device, dtype=dtype)
torch.library.define(
"bitsandbytes::int8_linear_matmul",
"(Tensor A, Tensor B) -> Tensor",
)
@register_fake("bitsandbytes::int8_linear_matmul")
def _(A: torch.Tensor, B: torch.Tensor):
torch._check(A.dtype == torch.int8, lambda: "A must be int8")
torch._check(B.dtype == torch.int8, lambda: "B must be int8")
shapeC = (*A.shape[:-1], B.shape[0])
return torch.empty(shapeC, device=A.device, dtype=torch.int32)
# More info on `out` overloads:
# https://github.com/pytorch/pytorch/issues/125044
torch.library.define(
"bitsandbytes::int8_linear_matmul.out",
"(Tensor A, Tensor B, Tensor! out) -> ()",
)
@register_fake("bitsandbytes::int8_linear_matmul.out")
def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
shapeC = (*A.shape[:-1], B.shape[0])
torch._check(A.dtype == torch.int8, lambda: "A must be int8")
torch._check(B.dtype == torch.int8, lambda: "B must be int8")
torch._check(out.shape == shapeC, lambda: f"Expected out.shape == {shapeC}, got {out.shape}")
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == torch.int32, lambda: f"Expected out.dtype == int32, got {out.dtype}")
torch.library.define(
"bitsandbytes::int8_vectorwise_quant",
"(Tensor A, float threshold=0.0) -> (Tensor, Tensor, Tensor?)",
)
@register_fake("bitsandbytes::int8_vectorwise_quant")
def _(A: torch.Tensor, threshold=0.0):
out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8)
row_stats = torch.empty(prod(A.shape[:-1]), device=A.device, dtype=torch.float32)
if threshold == 0.0:
return out_row, row_stats, None
outlier_cols = torch.library.get_ctx().new_dynamic_size()
return out_row, row_stats, A.new_empty(outlier_cols, dtype=torch.int64)
torch.library.define("bitsandbytes::int8_vectorwise_dequant", "(Tensor A, Tensor stats) -> Tensor")
@register_fake("bitsandbytes::int8_vectorwise_dequant")
def _(A: torch.Tensor, stats: torch.Tensor) -> torch.Tensor:
torch._check(A.dtype == torch.int8, lambda: "A must be int8")
return torch.empty_like(A, dtype=torch.float32)
# Default PyTorch-native implementation
@register_kernel("bitsandbytes::int8_vectorwise_dequant", None)
def _(A: torch.Tensor, stats: torch.Tensor):
# To dequantize we divide by 127, or multiply by the reciprocal.
return A * stats.view(-1, 1) * 7.874015718698502e-3
torch.library.define(
"bitsandbytes::int8_mm_dequant",
"(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType dtype=float16, Tensor? bias=None) -> Tensor",
)
@register_fake("bitsandbytes::int8_mm_dequant")
def _(
A: torch.Tensor,
row_stats: torch.Tensor,
col_stats: torch.Tensor,
dtype=torch.float16,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
torch._check(A.dtype == torch.int32, lambda: "A must be int32")
return torch.empty_like(A, dtype=dtype)
torch.library.define(
"bitsandbytes::int8_double_quant",
"(Tensor A, float threshold=0.0) -> (Tensor, Tensor, Tensor, Tensor, Tensor?)",
)
@register_fake("bitsandbytes::int8_double_quant")
def _(
A: torch.Tensor,
threshold=0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
out_row = torch.empty_like(A, dtype=torch.int8)
out_col = torch.empty_like(A, dtype=torch.int8)
row_stats = torch.empty(prod(A.shape[:-1]), device=A.device, dtype=torch.float32)
col_stats = torch.empty(A.shape[-1], device=A.device, dtype=torch.float32)
outlier_n = torch.library.get_ctx().new_dynamic_size()
outlier_cols = A.new_empty(outlier_n, dtype=torch.int64)
return out_row, out_col, row_stats, col_stats, outlier_cols
torch.library.define(
"bitsandbytes::dequantize_4bit",
"(Tensor A, Tensor absmax, int blocksize, str quant_type, int[] shape, ScalarType dtype) -> Tensor",
)
@register_fake("bitsandbytes::dequantize_4bit")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
torch._check_is_size(blocksize)
return torch.empty(shape, dtype=dtype, device=A.device)
torch.library.define(
"bitsandbytes::dequantize_4bit.out",
"(Tensor A, Tensor absmax, int blocksize, str quant_type, int[] shape, ScalarType dtype, Tensor! out) -> ()",
)
@register_fake("bitsandbytes::dequantize_4bit.out")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
torch._check_is_size(blocksize)
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
torch.library.define(
"bitsandbytes::quantize_4bit",
"(Tensor A, int blocksize, str quant_type, ScalarType quant_storage) -> (Tensor, Tensor)",
)
@register_fake("bitsandbytes::quantize_4bit")
def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
n = A.numel()
blocks = -(n // -blocksize)
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage)
return out, absmax
torch.library.define(
"bitsandbytes::dequantize_blockwise",
"(Tensor A, Tensor absmax, Tensor code, int blocksize, ScalarType dtype) -> Tensor",
)
@register_fake("bitsandbytes::dequantize_blockwise")
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
return torch.empty_like(A, dtype=dtype)
torch.library.define(
"bitsandbytes::dequantize_blockwise.out",
"(Tensor A, Tensor absmax, Tensor code, int blocksize, ScalarType dtype, Tensor! out) -> ()",
)
@register_fake("bitsandbytes::dequantize_blockwise.out")
def _(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
):
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}")
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
torch.library.define("bitsandbytes::quantize_blockwise", "(Tensor A, Tensor code, int blocksize) -> (Tensor, Tensor)")
@register_fake("bitsandbytes::quantize_blockwise")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
n = A.numel()
blocks = -(n // -blocksize)
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
out = torch.empty_like(A, dtype=torch.uint8)
return out, absmax
torch.library.define(
"bitsandbytes::gemv_4bit",
"(Tensor A, Tensor B, int[] shapeB, Tensor absmax, Tensor code, int blocksize) -> Tensor",
)
@register_fake("bitsandbytes::gemv_4bit")
def _(
A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int
) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}")
torch._check(
A.dtype in [torch.float16, torch.bfloat16, torch.float32],
lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
)
torch._check(
B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32],
lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}",
)
shape = (*A.shape[:-1], shapeB[0])
return torch.empty(shape, device=A.device, dtype=A.dtype)
torch.library.define(
"bitsandbytes::gemv_4bit.out",
"(Tensor A, Tensor B, int[] shapeB, Tensor absmax, Tensor code, int blocksize, Tensor! out) -> ()",
)
@register_fake("bitsandbytes::gemv_4bit.out")
def _(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
out: torch.Tensor,
) -> None:
torch._check_is_size(blocksize)
torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}")
torch._check(
A.dtype in [torch.float16, torch.bfloat16, torch.float32],
lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
)
torch._check(
B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32],
lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}",
)
torch._check(
out.shape == (*A.shape[:-1], shapeB[0]),
lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}",
)
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")
...@@ -49,6 +49,10 @@ class GlobalOutlierPooler: ...@@ -49,6 +49,10 @@ class GlobalOutlierPooler:
return torch.Tensor(list(self.outliers)).to(torch.int64) return torch.Tensor(list(self.outliers)).to(torch.int64)
@deprecated(
"This function is deprecated and will be removed in a future release.",
category=FutureWarning,
)
def get_inverse_transform_indices( def get_inverse_transform_indices(
transform_tile: Callable[[torch.Tensor], torch.Tensor], transform_tile: Callable[[torch.Tensor], torch.Tensor],
tile_size: Tuple[int, int], tile_size: Tuple[int, int],
...@@ -80,6 +84,10 @@ def get_inverse_transform_indices( ...@@ -80,6 +84,10 @@ def get_inverse_transform_indices(
return permuted_tile_indices return permuted_tile_indices
@deprecated(
"This function is deprecated and will be removed in a future release.",
category=FutureWarning,
)
def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor: def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor:
""" """
Undo a tiled permutation such as turing or ampere layout Undo a tiled permutation such as turing or ampere layout
...@@ -98,152 +106,9 @@ def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) - ...@@ -98,152 +106,9 @@ def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -
return outputs.reshape(rows, cols).contiguous() return outputs.reshape(rows, cols).contiguous()
@deprecated(
"MatMul8bit is deprecated and will be removed in a future release. Please use MatMul8bitLt instead.",
category=FutureWarning,
)
class MatMul8bit(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B, out=None, quant_type="vector", precision=None):
if precision is None:
precision = [8, 8, 8]
if precision[0] != 8:
with torch.no_grad():
output = torch.matmul(A, B)
else:
if len(B.shape) == 2:
dim = 0
else:
dim = 1
qA, SA = F.vectorwise_quant(A, dim=-1, quant_type=quant_type)
qB, SB = F.vectorwise_quant(B, dim=dim, quant_type=quant_type)
iout = F.igemm(qA, qB)
output = F.vectorwise_mm_dequant(iout, SA, SB, A.dtype, quant_type)
if A.requires_grad or B.requires_grad:
ctx.save_for_backward(A, B)
ctx.quant_type = quant_type
ctx.precision = precision
return output
@staticmethod
def backward(ctx, grad_output):
A, B = ctx.saved_tensors
quant_type = ctx.quant_type
precision = ctx.precision
grad_A = grad_B = None
if B.requires_grad:
if len(A.shape) == 3:
dims = [0, 1]
# bsi -> ibs
permute_dim = [0, 2, 1]
else:
dims = [0]
# bs -> sb
permute_dim = [1, 0]
if precision[1] != 8:
with torch.no_grad():
grad_B = torch.matmul(A.permute(permute_dim), grad_output)
else:
if len(B.shape) == 2 and len(A.shape) == 3:
grad_output = grad_output.contiguous()
if not grad_output.is_contiguous():
grad_output.contiguous()
qgrad_output, S1 = F.vectorwise_quant(
grad_output.view(-1, grad_output.shape[2]),
dim=0,
quant_type=quant_type,
)
if not A.is_contiguous():
A = A.contiguous()
qA, S2 = F.vectorwise_quant(A.view(-1, A.shape[2]), dim=0, quant_type=quant_type)
igrad_B = F.igemm(qA.t(), qgrad_output)
grad_B = F.vectorwise_mm_dequant(igrad_B, S2.t(), S1, grad_output.dtype, quant_type)
else:
qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type)
igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
grad_B = F.vectorwise_mm_dequant(
igrad_B,
S2.permute(permute_dim),
S1,
grad_output.dtype,
quant_type,
)
if A.requires_grad:
if len(grad_output.shape) == 3:
dims = [2]
else:
dims = [1]
if len(B.shape) == 3:
# bio -> boi
permute_dim = [0, 2, 1]
dim_B = dims
else:
# io -> oi
permute_dim = [1, 0]
dim_B = [1]
if precision[2] != 8:
with torch.no_grad():
grad_A = torch.matmul(grad_output, B.permute(permute_dim))
else:
qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
grad_A = F.vectorwise_mm_dequant(
igrad_A,
S1,
S3.permute(permute_dim),
grad_output.dtype,
quant_type,
)
return grad_A, grad_B, None, None, None
mm_cublas = MatMul8bit.apply
bmm_cublas = MatMul8bit.apply
matmul_cublas = MatMul8bit.apply
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def supports_igemmlt(device: torch.device) -> bool:
"""check if this device supports the optimized int8 kernel"""
if torch.cuda.get_device_capability(device=device) < (7, 5):
return False
device_name = torch.cuda.get_device_name(device=device)
nvidia16_models = ("GTX 1630", "GTX 1650", "GTX 1660") # https://en.wikipedia.org/wiki/GeForce_16_series
if any(model_name in device_name for model_name in nvidia16_models):
return False # these devices are technically cuda 7.5-capable, but they lack tensor cores
return True
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def _get_tile_size(format):
assert format in (
"col_turing",
"col_ampere",
), f"please find this assert and manually enter tile size for {format}"
return (8, 32) if format == "col_turing" else (32, 32)
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def get_tile_inds(format, device):
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=format)[0].to(x.device)
with torch.no_grad():
return get_inverse_transform_indices(transform, _get_tile_size(format)).to(device)
@dataclass @dataclass
class MatmulLtState: class MatmulLtState:
_tile_indices: Optional[torch.Tensor] = None _tile_indices: Optional[torch.Tensor] = None # TODO: remove
force_no_igemmlt: bool = False force_no_igemmlt: bool = False
...@@ -279,9 +144,7 @@ class MatmulLtState: ...@@ -279,9 +144,7 @@ class MatmulLtState:
@property @property
def tile_indices(self): def tile_indices(self):
if self._tile_indices is None: raise ValueError("tile_indices is no longer supported.")
self._tile_indices = get_tile_inds(self.formatB, self.CxB.device)
return self._tile_indices
class MatMul8bitLt(torch.autograd.Function): class MatMul8bitLt(torch.autograd.Function):
...@@ -360,20 +223,12 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -360,20 +223,12 @@ class MatMul8bitLt(torch.autograd.Function):
# we want to divide by 127. It's however more performant to multiply # we want to divide by 127. It's however more performant to multiply
# by the reciprocal. # by the reciprocal.
outliers = state.CB[:, state.idx] outliers = state.CB[:, state.idx]
state.subB = (outliers.t() * state.SCB * 7.874015718698502e-3).to(A.dtype) state.subB = F.int8_vectorwise_dequant(outliers, state.SCB).to(A.dtype).t()
else: else:
subA = None subA = None
# 3. Int8 Matmul # 3. Int8 Matmul + Dequant + Bias
out32 = F.int8_linear_matmul(CA, state.CB) output = torch.ops.bitsandbytes.int8_scaled_mm.default(CA, state.CB, SCA, state.SCB, bias=bias, dtype=A.dtype)
# Dequantize matmul result
if bias is None or bias.dtype == torch.float16:
# we apply the fused bias here
output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=bias).to(A.dtype)
else: # apply bias separately
# TODO: Fused bias for fp32/bf16?
output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=None).to(A.dtype).add_(bias)
# 4. Mixed-precision decomposition matmul # 4. Mixed-precision decomposition matmul
if subA is not None and state.subB is not None: if subA is not None and state.subB is not None:
...@@ -423,8 +278,14 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -423,8 +278,14 @@ class MatMul8bitLt(torch.autograd.Function):
if req_gradB: if req_gradB:
Cgrad, _, _, SCgradt, _ = F.int8_double_quant(grad_output.to(torch.float16)) Cgrad, _, _, SCgradt, _ = F.int8_double_quant(grad_output.to(torch.float16))
gradB32 = F.int8_linear_matmul(Cgrad.t().contiguous(), CAt.t()) grad_B = torch.ops.bitsandbytes.int8_scaled_mm.default(
grad_B = F.int8_mm_dequant(gradB32, SCgradt, SCAt) Cgrad.t().contiguous(),
CAt.t(),
SCgradt,
SCAt,
dtype=torch.float16,
)
if state.threshold > 0.0 and subA is not None: if state.threshold > 0.0 and subA is not None:
grad_B[:, idx] += torch.matmul(grad_output.t(), subA) grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
......
import ctypes as ct
from typing import Optional, Tuple
import torch
from bitsandbytes.functional import get_ptr
from ..._ops import register_kernel
from ...cextension import lib
# torch._int_mm for s8@s8->s32 is supported on CPU from torch 2.4+.
# However, we can overflow if we use this without AVX512_VNNI support.
# This is fixed in torch 2.6+, so we set this as the minimum to be safe.
# For more information: https://github.com/pytorch/pytorch/pull/136942
# TODO(matthewdouglas): aarch64?
if torch.__version__ >= (2, 6):
@register_kernel("bitsandbytes::int8_linear_matmul", "cpu")
def _(A: torch.Tensor, B: torch.Tensor):
return torch._int_mm(
A.reshape(-1, A.shape[-1]),
B.t(),
).reshape(*A.shape[:-1], B.shape[0])
@register_kernel("bitsandbytes::int8_mm_dequant", "cpu")
def _(
A: torch.Tensor,
row_stats: torch.Tensor,
col_stats: torch.Tensor,
dtype=torch.float16,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}")
torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}")
torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}")
A_calc = A.view(-1, A.shape[-1])
row_stats = row_stats.reshape(-1).unsqueeze(-1)
col_stats = col_stats.reshape(-1).unsqueeze(0)
out = A_calc * (row_stats * col_stats) * 6.200124e-05
if bias is not None:
out += bias
return out.to(dtype)
@register_kernel("bitsandbytes::quantize_blockwise", "cpu")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on cpu, got {A.dtype}")
n = A.numel()
blocks = -(n // -blocksize)
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
out = torch.empty_like(A, dtype=torch.uint8)
lib.cquantize_blockwise_cpu_fp32(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_longlong(blocksize),
ct.c_longlong(n),
)
return out, absmax
@register_kernel("bitsandbytes::dequantize_blockwise", "cpu")
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on cpu, got {dtype}")
out = torch.empty_like(A, dtype=dtype)
lib.cdequantize_blockwise_cpu_fp32(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_longlong(blocksize),
ct.c_longlong(A.numel()),
)
return out
_NF4_QUANT_TABLE = torch.tensor(
[
-1.0,
-0.6961928009986877,
-0.5250730514526367,
-0.39491748809814453,
-0.28444138169288635,
-0.18477343022823334,
-0.09105003625154495,
0.0,
0.07958029955625534,
0.16093020141124725,
0.24611230194568634,
0.33791524171829224,
0.44070982933044434,
0.5626170039176941,
0.7229568362236023,
1.0,
],
dtype=torch.float32,
device="cpu",
)
@register_kernel("bitsandbytes::quantize_4bit", "cpu")
def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
n = A.numel()
# TODO: Support when weight matrix is not divisible by blocksize
torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}")
# Divide into blocks and normalize
blocks = A.reshape(-1, blocksize)
absmax = blocks.abs().max(dim=1).values.float()
scaled = blocks / absmax.unsqueeze(-1)
# Quantize with the lookup table
quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - _NF4_QUANT_TABLE), dim=-1, keepdim=True).to(torch.uint8)
# Pack two quantized values per byte
packed = quantized[::2] << 4 | quantized[1::2]
if quant_storage != torch.uint8:
packed = packed.squeeze().view(quant_storage).unsqueeze(1)
return packed, absmax.float()
import ctypes as ct
from math import prod
from typing import Optional, Sequence, Tuple
import torch
from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr
from ..._ops import register_kernel
from ...cextension import lib
@register_kernel("bitsandbytes::int8_linear_matmul", "cuda")
def _(A: torch.Tensor, B: torch.Tensor):
out = torch.empty((*A.shape[:-1], B.shape[0]), device=A.device, dtype=torch.int32)
return _int8_linear_matmul_impl(A, B, out)
@register_kernel("bitsandbytes::int8_linear_matmul.out", "cuda")
def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
_int8_linear_matmul_impl(A, B, out)
def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
A, B = B, A
shapeA = A.shape
shapeB = B.shape
torch._check(A.dtype == torch.int8, lambda: "B must be int8")
torch._check(B.dtype == torch.int8, lambda: "A must be int8")
torch._check(A.ndim == 2, lambda: "Only two dimensional matrices are supported for argument B")
torch._check(B.ndim in [2, 3], lambda: "Only two or three dimensional matrices are supported for argument A")
torch._check(prod(shapeB) > 0, lambda: f"Input tensor dimensions need to be > 0: {shapeB}")
torch._check(out.dtype == torch.int32)
shapeC = (*shapeB[:-1], shapeA[0])
torch._check(out.shape == shapeC, lambda: f"Output shape {out.shape} does not match expected shape {shapeC}")
k, m = shapeA
n = prod(shapeB[:-1])
lda = shapeA[-1] # Weights (outputs, inputs)
ldb = shapeB[-1] # Activations (batch, tokens, inputs)
ldc = shapeC[-1] # Output (batch, tokens, outputs)
torch._check(
lda == ldb,
lambda: f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}",
)
# cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4.
# We'll fall back to a slower fp32 calculation in this circumstance.
# Fortunately, this should not be very common.
if lda % 4 != 0:
result = torch.matmul(B.float(), A.float().t()).to(torch.int32)
return out.copy_(result)
with _cuda_device_of(A):
ctx = CUBLAS_Context.get_instance().get_context(A.device)
ptrA = get_ptr(A)
ptrB = get_ptr(B)
ptrC = get_ptr(out)
ptrRowScale = None
m = ct.c_int32(m)
n = ct.c_int32(n)
k = ct.c_int32(k)
lda = ct.c_int32(lda)
ldb = ct.c_int32(ldb)
ldc = ct.c_int32(ldc)
stream = _get_tensor_stream(A)
has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream)
if has_error:
if has_error == 100:
# `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu`
# TODO: Warn and implement a fallback to fp32 compute?
raise NotImplementedError("int8_linear_matmul not implemented!")
else:
raise RuntimeError(
f"cublasLt ran into an error!\n"
f"\t{shapeA=}, {shapeB=}, {shapeC=}\n"
f"\t{(lda, ldb, ldc)=}\n"
f"\t{(m, n, k)=}"
)
return out
@register_kernel("bitsandbytes::int8_mm_dequant", "cuda")
def _(
A: torch.Tensor,
row_stats: torch.Tensor,
col_stats: torch.Tensor,
dtype=torch.float16,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}")
torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}")
torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}")
# Note: cuda kernel only currently supports fp16 output.
# We'll later cast to desired dtype if needed.
out = torch.empty_like(A, dtype=torch.float16)
ptrA = get_ptr(A)
ptrOut = get_ptr(out)
ptrRowStats = get_ptr(row_stats)
ptrColStats = get_ptr(col_stats)
numRows = ct.c_int32(prod(A.shape[:-1]))
numCols = ct.c_int32(A.shape[-1])
# Note: fused bias in the kernel is only supported for fp16
# TODO(matthewdouglas): Consider supporting bf16 fused bias
ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None
with _cuda_device_of(A):
lib.cdequant_mm_int32_fp16(
ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A)
)
# Add bias separately if not fused in kernel
if bias is not None and bias.dtype != torch.float16:
out.add_(bias)
return out.to(dtype)
@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda")
def _(A: torch.Tensor, threshold=0.0):
torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}")
torch._check(threshold >= 0.0, lambda: "threshold must be non-negative")
rows = prod(A.shape[:-1])
cols = A.shape[-1]
row_stats = torch.empty(rows, device=A.device, dtype=torch.float32)
out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8)
outlier_cols = None
if threshold > 0.0:
# TODO we could improve perf of this
outliers = A.abs() >= threshold
if outliers.any():
outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1)
with _cuda_device_of(A):
lib.cint8_vector_quant(
get_ptr(A),
get_ptr(out_row),
get_ptr(row_stats),
ct.c_float(threshold),
ct.c_int32(rows),
ct.c_int32(cols),
_get_tensor_stream(A),
)
# Zero out values from outlier columns across all rows.
# The kernel will handle this for outliers themselves, so we can optimize for rows=1.
if rows > 1 and outlier_cols is not None:
out_row[:, outlier_cols] = 0
return out_row, row_stats, outlier_cols
@register_kernel("bitsandbytes::int8_double_quant", "cuda")
def _(
A: torch.Tensor,
threshold=0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
# Use CUDA kernel for rowwise and COO tensor
quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default(
A,
threshold=threshold,
)
# PyTorch impl for colwise
col_stats, outlier_mask = _get_col_absmax(A, threshold=threshold)
if threshold > 0.0 and outlier_mask is not None:
A = A.masked_fill(outlier_mask, 0.0)
quant_col = torch.round(A.mul(127.0) / col_stats.unsqueeze(0)).to(torch.int8)
return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols
def _get_col_absmax(
A: torch.Tensor,
threshold=0.0,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
torch._check(A.is_floating_point())
outlier_mask = None
absA = A.abs().view(-1, A.shape[-1])
if threshold > 0.0:
# Filter outliers from stats when enabled
outlier_mask = absA >= threshold
absA.masked_fill_(outlier_mask, 0.0)
# shape [cols]; unsqueeze(0) gives [1,cols]
col_stats = absA.amax(dim=0, keepdim=False).float()
return col_stats, outlier_mask
@register_kernel("bitsandbytes::quantize_blockwise", "cuda")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}")
n = A.numel()
blocks = -(n // -blocksize)
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
out = torch.empty_like(A, dtype=torch.uint8)
with _cuda_device_of(A):
args = (
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int32(blocksize),
ct.c_int(A.numel()),
)
if A.dtype == torch.float16:
lib.cquantize_blockwise_fp16(*args)
elif A.dtype == torch.bfloat16:
lib.cquantize_blockwise_bf16(*args)
elif A.dtype == torch.float32:
lib.cquantize_blockwise_fp32(*args)
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
return out, absmax
@register_kernel("bitsandbytes::dequantize_blockwise", "cuda")
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
out = torch.empty_like(A, dtype=dtype)
_dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out)
return out
@register_kernel("bitsandbytes::dequantize_blockwise.out", "cuda")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}")
_dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out)
def _dequantize_blockwise_impl(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
) -> None:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
torch._check(
dtype in [torch.float16, torch.bfloat16, torch.float32],
lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}",
)
with _cuda_device_of(A):
args = (
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(blocksize),
ct.c_int(A.numel()),
_get_tensor_stream(A),
)
if dtype == torch.float16:
lib.cdequantize_blockwise_fp16(*args)
elif dtype == torch.bfloat16:
lib.cdequantize_blockwise_bf16(*args)
elif dtype == torch.float32:
lib.cdequantize_blockwise_fp32(*args)
@register_kernel("bitsandbytes::quantize_4bit", "cuda")
def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
torch._check(quant_type in ["fp4", "nf4"])
torch._check(
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
)
n = A.numel()
blocks = -(n // -blocksize)
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage)
with _cuda_device_of(A):
args = (
None,
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int32(blocksize),
ct.c_int(n),
)
if A.dtype == torch.bfloat16:
if quant_type == "fp4":
lib.cquantize_blockwise_bf16_fp4(*args)
else:
lib.cquantize_blockwise_bf16_nf4(*args)
elif A.dtype == torch.float16:
if quant_type == "fp4":
lib.cquantize_blockwise_fp16_fp4(*args)
else:
lib.cquantize_blockwise_fp16_nf4(*args)
elif A.dtype == torch.float32:
if quant_type == "fp4":
lib.cquantize_blockwise_fp32_fp4(*args)
else:
lib.cquantize_blockwise_fp32_nf4(*args)
return out, absmax
@register_kernel("bitsandbytes::dequantize_4bit", "cuda")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
out = torch.empty(shape, dtype=dtype, device=A.device)
_dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
return out
@register_kernel("bitsandbytes::dequantize_4bit.out", "cuda")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
_dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
def _dequantize_4bit_impl(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
torch._check(quant_type in ["fp4", "nf4"])
torch._check(
dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
)
with _cuda_device_of(A):
args = (
None,
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(blocksize),
ct.c_int(out.numel()),
_get_tensor_stream(A),
)
if out.dtype == torch.bfloat16:
if quant_type == "fp4":
lib.cdequantize_blockwise_bf16_fp4(*args)
else:
lib.cdequantize_blockwise_bf16_nf4(*args)
elif out.dtype == torch.float16:
if quant_type == "fp4":
lib.cdequantize_blockwise_fp16_fp4(*args)
else:
lib.cdequantize_blockwise_fp16_nf4(*args)
elif out.dtype == torch.float32:
if quant_type == "fp4":
lib.cdequantize_blockwise_fp32_fp4(*args)
else:
lib.cdequantize_blockwise_fp32_nf4(*args)
@register_kernel("bitsandbytes::gemv_4bit", "cuda")
def _(
A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int
) -> torch.Tensor:
shape = (*A.shape[:-1], shapeB[0])
out = torch.empty(shape, device=A.device, dtype=A.dtype)
_gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out)
return out
@register_kernel("bitsandbytes::gemv_4bit.out", "cuda")
def _(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
out: torch.Tensor,
) -> None:
torch._check(
out.shape == (*A.shape[:-1], shapeB[0]),
lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}",
)
torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")
_gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out)
def _gemv_4bit_impl(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
out: torch.Tensor,
) -> None:
torch._check_is_size(blocksize)
torch._check(
A.numel() == A.size(-1),
lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}",
)
torch._check(
A.dtype in [torch.float16, torch.bfloat16, torch.float32],
lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
)
torch._check(
B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32],
lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}",
)
torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}")
torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}")
m = ct.c_int32(shapeB[0])
n = ct.c_int32(1)
k = ct.c_int32(shapeB[1])
lda = m
ldb = ct.c_int32((A.shape[-1] + 1) // 2)
ldc = m
stream = _get_tensor_stream(A)
with _cuda_device_of(A):
if A.dtype == torch.float16:
lib.cgemm_4bit_inference_naive_fp16(
m,
n,
k,
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(code),
get_ptr(out),
lda,
ldb,
ldc,
ct.c_int32(blocksize),
stream,
)
elif A.dtype == torch.bfloat16:
lib.cgemm_4bit_inference_naive_bf16(
m,
n,
k,
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(code),
get_ptr(out),
lda,
ldb,
ldc,
ct.c_int32(blocksize),
stream,
)
elif A.dtype == torch.float32:
lib.cgemm_4bit_inference_naive_fp32(
m,
n,
k,
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(code),
get_ptr(out),
lda,
ldb,
ldc,
ct.c_int32(blocksize),
stream,
)
from typing import Optional
import torch
from ..._ops import register_kernel
@register_kernel("bitsandbytes::int8_scaled_mm", None)
def _(
A: torch.Tensor,
B: torch.Tensor,
row_stats: torch.Tensor,
col_stats: torch.Tensor,
bias: Optional[torch.Tensor] = None,
dtype=torch.float16,
) -> torch.Tensor:
out_i32 = torch.ops.bitsandbytes.int8_linear_matmul.default(A, B)
out = torch.ops.bitsandbytes.int8_mm_dequant.default(out_i32, row_stats, col_stats, dtype=dtype, bias=bias)
return out
@register_kernel("bitsandbytes::int8_linear_matmul", None)
def _(A: torch.Tensor, B: torch.Tensor):
return _int8_linear_matmul_impl(A, B)
@register_kernel("bitsandbytes::int8_linear_matmul.out", None)
def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
torch._check(out.dtype == torch.int32)
_int8_linear_matmul_impl(A, B, out)
def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None):
# Naive implementation: perform matmul in fp32
result = torch.matmul(A.float(), B.float().t()).to(torch.int32)
if out is not None:
result = out.copy_(result)
return result
...@@ -19,7 +19,9 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: ...@@ -19,7 +19,9 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
The library is not guaranteed to exist at the returned path. The library is not guaranteed to exist at the returned path.
""" """
library_name = f"libbitsandbytes_cuda{cuda_specs.cuda_version_string}{DYNAMIC_LIBRARY_SUFFIX}"
prefix = "rocm" if torch.version.hip else "cuda"
library_name = f"libbitsandbytes_{prefix}{cuda_specs.cuda_version_string}{DYNAMIC_LIBRARY_SUFFIX}"
override_value = os.environ.get("BNB_CUDA_VERSION") override_value = os.environ.get("BNB_CUDA_VERSION")
if override_value: if override_value:
...@@ -76,7 +78,7 @@ def get_native_library() -> BNBNativeLibrary: ...@@ -76,7 +78,7 @@ def get_native_library() -> BNBNativeLibrary:
logger.warning( logger.warning(
"The installed version of bitsandbytes was compiled without GPU support. " "The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.", "8-bit optimizers and GPU quantization are unavailable.",
) )
return BNBNativeLibrary(dll) return BNBNativeLibrary(dll)
......
import dataclasses import dataclasses
from functools import lru_cache
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
...@@ -12,22 +13,26 @@ class CUDASpecs: ...@@ -12,22 +13,26 @@ class CUDASpecs:
@property @property
def has_imma(self) -> bool: def has_imma(self) -> bool:
return self.highest_compute_capability >= (7, 5) return torch.version.hip or self.highest_compute_capability >= (7, 5)
def get_compute_capabilities() -> List[Tuple[int, int]]: def get_compute_capabilities() -> List[Tuple[int, int]]:
return sorted(torch.cuda.get_device_capability(torch.cuda.device(i)) for i in range(torch.cuda.device_count())) return sorted(torch.cuda.get_device_capability(torch.cuda.device(i)) for i in range(torch.cuda.device_count()))
@lru_cache(None)
def get_cuda_version_tuple() -> Tuple[int, int]: def get_cuda_version_tuple() -> Tuple[int, int]:
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION if torch.version.cuda:
major, minor = map(int, torch.version.cuda.split(".")) return map(int, torch.version.cuda.split(".")[0:2])
return major, minor elif torch.version.hip:
return map(int, torch.version.hip.split(".")[0:2])
return None
def get_cuda_version_string() -> str: def get_cuda_version_string() -> str:
major, minor = get_cuda_version_tuple() major, minor = get_cuda_version_tuple()
return f"{major}{minor}" return f"{major * 10 + minor}"
def get_cuda_specs() -> Optional[CUDASpecs]: def get_cuda_specs() -> Optional[CUDASpecs]:
......
...@@ -182,13 +182,6 @@ class Cusparse_Context: ...@@ -182,13 +182,6 @@ class Cusparse_Context:
return cls._instance return cls._instance
dtype2bytes = {}
dtype2bytes[torch.float32] = 4
dtype2bytes[torch.float16] = 2
dtype2bytes[torch.bfloat16] = 2
dtype2bytes[torch.uint8] = 1
dtype2bytes[torch.int8] = 1
FIRST_CUDA_DEVICE = torch.device("cuda", index=0) FIRST_CUDA_DEVICE = torch.device("cuda", index=0)
# When multiple GPUs are present, we use a context manager to # When multiple GPUs are present, we use a context manager to
...@@ -207,7 +200,7 @@ else: ...@@ -207,7 +200,7 @@ else:
def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE): def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE):
num_bytes = dtype2bytes[dtype] * prod(shape) num_bytes = dtype.itemsize * prod(shape)
cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes)) cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes))
c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int)) c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int))
new_array = np.ctypeslib.as_array(c_ptr, shape=shape) new_array = np.ctypeslib.as_array(c_ptr, shape=shape)
...@@ -217,15 +210,14 @@ def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE): ...@@ -217,15 +210,14 @@ def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE):
return out return out
def prefetch_tensor(A, to_cpu=False): def prefetch_tensor(A: torch.Tensor, to_cpu=False):
assert A.is_paged, "Only paged tensors can be prefetched!" assert A.is_paged, "Only paged tensors can be prefetched!"
if to_cpu: if to_cpu:
deviceid = -1 deviceid = -1
else: else:
deviceid = A.page_deviceid deviceid = A.page_deviceid
num_bytes = dtype2bytes[A.dtype] * A.numel() lib.cprefetch(get_ptr(A), ct.c_size_t(A.nbytes), ct.c_int32(deviceid))
lib.cprefetch(get_ptr(A), ct.c_size_t(num_bytes), ct.c_int32(deviceid))
def elementwise_func(func_name, A, B, value, prefetch=True): def elementwise_func(func_name, A, B, value, prefetch=True):
...@@ -431,11 +423,6 @@ def create_quantile_map(A, total_bits=8): ...@@ -431,11 +423,6 @@ def create_quantile_map(A, total_bits=8):
return q return q
@deprecated("This function is deprecated and will be removed in a future version.", category=FutureWarning)
def get_special_format_str():
return "row"
def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]): def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]):
"""Verifies that the input tensors are all on the same device. """Verifies that the input tensors are all on the same device.
...@@ -472,11 +459,6 @@ def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]): ...@@ -472,11 +459,6 @@ def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]):
return on_gpu return on_gpu
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def get_tensor_stream(tensor: Tensor) -> torch.cuda.Stream:
return torch.cuda.current_stream(tensor.device)
def _get_tensor_stream(tensor: Tensor) -> ct.c_void_p: def _get_tensor_stream(tensor: Tensor) -> ct.c_void_p:
# We use the raw stream for performance reasons. # We use the raw stream for performance reasons.
return ct.c_void_p(torch._C._cuda_getCurrentRawStream(tensor.device.index)) return ct.c_void_p(torch._C._cuda_getCurrentRawStream(tensor.device.index))
...@@ -509,106 +491,6 @@ def post_call(prev_device): ...@@ -509,106 +491,6 @@ def post_call(prev_device):
torch.cuda.set_device(prev_device) torch.cuda.set_device(prev_device)
@deprecated(
"The layout transformation operations will be removed in a future release. Please use row-major layout only.",
category=FutureWarning,
)
def get_transform_func(dtype, orderA, orderOut, transpose=False):
name = f'ctransform_{(8 if dtype == torch.int8 else 32)}_{orderA}_to_{orderOut}_{"t" if transpose else "n"}'
if not hasattr(lib, name):
print(name)
raise ValueError(
f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}",
)
else:
return getattr(lib, name)
@deprecated(
"The layout transformation operations will be removed in a future release. Please use row-major layout only.",
category=FutureWarning,
)
def get_transform_buffer(shape, dtype, device, to_order, from_order="row", transpose=False):
# init_func = torch.empty
init_func = torch.zeros
dims = len(shape)
if dims == 2:
rows = shape[0]
elif dims == 3:
rows = shape[0] * shape[1]
cols = shape[-1]
state = (shape, to_order)
if transpose:
# swap dims
tmp = rows
rows = cols
cols = tmp
state = (shape[::-1], to_order)
if to_order == "row" or to_order == "col":
return init_func(shape, dtype=dtype, device=device), state
elif to_order == "col32":
# blocks of 32 columns (padded)
cols = 32 * ((cols + 31) // 32)
return init_func((rows, cols), dtype=dtype, device=device), state
elif to_order == "col_turing":
# blocks of 32 columns and 8 rows
cols = 32 * ((cols + 31) // 32)
rows = 8 * ((rows + 7) // 8)
return init_func((rows, cols), dtype=dtype, device=device), state
elif to_order == "col_ampere":
# blocks of 32 columns and 32 rows
cols = 32 * ((cols + 31) // 32)
rows = 32 * ((rows + 31) // 32)
return init_func((rows, cols), dtype=dtype, device=device), state
else:
raise NotImplementedError(f"To_order not supported: {to_order}")
@deprecated(
"The layout transformation operations will be removed in a future release. Please use row-major layout only.",
category=FutureWarning,
)
def nvidia_transform(
A,
to_order,
from_order="row",
out=None,
transpose=False,
state=None,
ld=None,
):
if state is None:
state = (A.shape, from_order)
else:
from_order = state[1]
if out is None:
out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1])
else:
new_state = (state[1], to_order)
func = get_transform_func(A.dtype, from_order, to_order, transpose)
shape = state[0]
if len(shape) == 2:
dim1 = ct.c_int32(shape[0])
dim2 = ct.c_int32(shape[1])
elif ld is not None:
n = prod(shape)
dim1 = prod([shape[i] for i in ld])
dim2 = ct.c_int32(n // dim1)
dim1 = ct.c_int32(dim1)
else:
dim1 = ct.c_int32(shape[0] * shape[1])
dim2 = ct.c_int32(shape[2])
ptr = CUBLAS_Context.get_instance().get_context(A.device)
func(ptr, get_ptr(A), get_ptr(out), dim1, dim2)
return out, new_state
def estimate_quantiles( def estimate_quantiles(
A: Tensor, A: Tensor,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
...@@ -892,56 +774,16 @@ def quantize_blockwise( ...@@ -892,56 +774,16 @@ def quantize_blockwise(
name2qmap["dynamic"] = create_dynamic_map().to(A.device) name2qmap["dynamic"] = create_dynamic_map().to(A.device)
code = name2qmap["dynamic"] code = name2qmap["dynamic"]
if absmax is None: _out, _absmax = torch.ops.bitsandbytes.quantize_blockwise.default(
n = A.numel() A,
blocks = -(n // -blocksize) code.to(A.device),
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) blocksize,
if out is None:
out = torch.zeros_like(A, dtype=torch.uint8)
if A.device.type != "cpu":
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
code = code.to(A.device)
is_on_gpu([A, out, absmax])
with _cuda_device_of(A):
args = (
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int32(blocksize),
ct.c_int(A.numel()),
)
if A.dtype == torch.float16:
lib.cquantize_blockwise_fp16(*args)
elif A.dtype == torch.bfloat16:
lib.cquantize_blockwise_bf16(*args)
elif A.dtype == torch.float32:
lib.cquantize_blockwise_fp32(*args)
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
else:
# cpu
code = code.cpu()
lib.cquantize_blockwise_cpu_fp32(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_longlong(blocksize),
ct.c_longlong(A.numel()),
) )
if nested: if nested:
offset = absmax.mean() offset = _absmax.mean()
absmax -= offset _absmax -= offset
qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False) qabsmax, state2 = quantize_blockwise(_absmax, blocksize=blocksize, nested=False)
quant_state = QuantState( quant_state = QuantState(
absmax=qabsmax, absmax=qabsmax,
code=code, code=code,
...@@ -951,7 +793,14 @@ def quantize_blockwise( ...@@ -951,7 +793,14 @@ def quantize_blockwise(
state2=state2, state2=state2,
) )
else: else:
quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=A.dtype) quant_state = QuantState(absmax=_absmax, code=code, blocksize=blocksize, dtype=A.dtype)
# TODO(matthewdouglas): Deprecate out kwarg
out = out.copy_(_out) if out is not None else _out
# TODO(matthewdouglas): Deprecate absmax kwarg
if absmax is not None:
quant_state.absmax = absmax.copy_(quant_state.absmax)
return out, quant_state return out, quant_state
...@@ -1013,50 +862,25 @@ def dequantize_blockwise( ...@@ -1013,50 +862,25 @@ def dequantize_blockwise(
if absmax.dtype != torch.float32: if absmax.dtype != torch.float32:
absmax = absmax.float() absmax = absmax.float()
if out is None: if out is not None:
out = torch.empty(A.shape, dtype=quant_state.dtype, device=A.device) torch.ops.bitsandbytes.dequantize_blockwise.out(
A,
if A.device.type != "cpu": absmax,
code = quant_state.code.to(A.device) code.to(A.device),
if quant_state.blocksize not in [4096, 2048, 1024, 512, 256, 128, 64]: blocksize,
raise ValueError( quant_state.dtype,
f"The blocksize of {quant_state.blocksize} is not supported. Supported values: [4096, 2048, 1024, 512, 256, 128, 64]", out=out,
)
is_on_gpu([A, absmax, out])
with _cuda_device_of(A):
args = (
get_ptr(quant_state.code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(A.numel()),
_get_tensor_stream(A),
) )
return out
if out.dtype == torch.float16: return torch.ops.bitsandbytes.dequantize_blockwise.default(
lib.cdequantize_blockwise_fp16(*args) A,
elif out.dtype == torch.bfloat16: absmax,
lib.cdequantize_blockwise_bf16(*args) quant_state.code.to(A.device),
elif out.dtype == torch.float32: quant_state.blocksize,
lib.cdequantize_blockwise_fp32(*args) quant_state.dtype,
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}")
else:
code = quant_state.code.cpu()
lib.cdequantize_blockwise_cpu_fp32(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_longlong(quant_state.blocksize),
ct.c_longlong(A.numel()),
) )
return out
def get_4bit_type(typename, device=None, blocksize=64): def get_4bit_type(typename, device=None, blocksize=64):
if device is None: if device is None:
...@@ -1194,62 +1018,21 @@ def quantize_4bit( ...@@ -1194,62 +1018,21 @@ def quantize_4bit(
- `torch.Tensor`: The quantized tensor with packed 4-bit values. - `torch.Tensor`: The quantized tensor with packed 4-bit values.
- [`QuantState`]: The state object used to undo the quantization. - [`QuantState`]: The state object used to undo the quantization.
""" """
if A.device.type != "cuda":
raise NotImplementedError(f"Device type not supported for FP4 quantization: {A.device.type}")
if quant_type not in ["fp4", "nf4"]:
raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.")
n = A.numel()
input_shape = A.shape input_shape = A.shape
if absmax is None: _out, _absmax = torch.ops.bitsandbytes.quantize_4bit.default(
blocks = -(n // -blocksize) A,
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) blocksize,
quant_type,
if out is None: quant_storage,
mod = dtype2bytes[quant_storage] * 2
out = torch.zeros(((n + 1) // mod, 1), dtype=quant_storage, device=A.device)
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
is_on_gpu([A, out, absmax])
with _cuda_device_of(A):
args = (
None,
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int32(blocksize),
ct.c_int(n),
) )
if A.dtype == torch.bfloat16:
if quant_type == "fp4":
lib.cquantize_blockwise_bf16_fp4(*args)
else:
lib.cquantize_blockwise_bf16_nf4(*args)
elif A.dtype == torch.float16:
if quant_type == "fp4":
lib.cquantize_blockwise_fp16_fp4(*args)
else:
lib.cquantize_blockwise_fp16_nf4(*args)
elif A.dtype == torch.float32:
if quant_type == "fp4":
lib.cquantize_blockwise_fp32_fp4(*args)
else:
lib.cquantize_blockwise_fp32_nf4(*args)
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
code = get_4bit_type(quant_type, device=A.device) code = get_4bit_type(quant_type, device=A.device)
if compress_statistics: if compress_statistics:
offset = absmax.mean() offset = _absmax.mean()
absmax -= offset qabsmax, state2 = quantize_blockwise(_absmax - offset, blocksize=256)
qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) del _absmax
del absmax
state = QuantState( state = QuantState(
absmax=qabsmax, absmax=qabsmax,
shape=input_shape, shape=input_shape,
...@@ -1262,7 +1045,7 @@ def quantize_4bit( ...@@ -1262,7 +1045,7 @@ def quantize_4bit(
) )
else: else:
state = QuantState( state = QuantState(
absmax=absmax, absmax=_absmax,
shape=input_shape, shape=input_shape,
dtype=A.dtype, dtype=A.dtype,
blocksize=blocksize, blocksize=blocksize,
...@@ -1270,6 +1053,13 @@ def quantize_4bit( ...@@ -1270,6 +1053,13 @@ def quantize_4bit(
quant_type=quant_type, quant_type=quant_type,
) )
# TODO(matthewdouglas): Deprecate out kwarg
out = out.copy_(_out) if out is not None else _out
# TODO(matthewdouglas): Deprecate absmax kwarg
if absmax is not None:
state.absmax = absmax.copy_(state.absmax)
return out, state return out, state
...@@ -1327,14 +1117,6 @@ def dequantize_4bit( ...@@ -1327,14 +1117,6 @@ def dequantize_4bit(
Returns: Returns:
`torch.Tensor`: The dequantized tensor. `torch.Tensor`: The dequantized tensor.
""" """
if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
raise ValueError(
f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]",
)
if quant_type not in ["fp4", "nf4"]:
raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.")
if quant_state is None: if quant_state is None:
assert absmax is not None and out is not None assert absmax is not None and out is not None
...@@ -1355,42 +1137,19 @@ def dequantize_4bit( ...@@ -1355,42 +1137,19 @@ def dequantize_4bit(
if absmax.dtype != torch.float32: if absmax.dtype != torch.float32:
absmax = absmax.float() absmax = absmax.float()
if out is None: if out is not None:
out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device) torch.ops.bitsandbytes.dequantize_4bit.out(
A, absmax, quant_state.blocksize, quant_state.quant_type, quant_state.shape, quant_state.dtype, out=out
n = out.numel()
is_on_gpu([A, absmax, out])
stream = _get_tensor_stream(A)
with _cuda_device_of(A):
args = (
None,
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
stream,
) )
if out.dtype == torch.bfloat16:
if quant_state.quant_type == "fp4":
lib.cdequantize_blockwise_bf16_fp4(*args)
else:
lib.cdequantize_blockwise_bf16_nf4(*args)
elif out.dtype == torch.float16:
if quant_state.quant_type == "fp4":
lib.cdequantize_blockwise_fp16_fp4(*args)
else:
lib.cdequantize_blockwise_fp16_nf4(*args)
elif out.dtype == torch.float32:
if quant_state.quant_type == "fp4":
lib.cdequantize_blockwise_fp32_fp4(*args)
else:
lib.cdequantize_blockwise_fp32_nf4(*args)
else: else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") out = torch.ops.bitsandbytes.dequantize_4bit.default(
A,
absmax,
quant_state.blocksize,
quant_state.quant_type,
quant_state.shape,
quant_state.dtype,
)
if A.shape[0] == 1: # is transposed, transpose back if A.shape[0] == 1: # is transposed, transpose back
return out.t() return out.t()
...@@ -1849,6 +1608,7 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: ...@@ -1849,6 +1608,7 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile:
return current_gnorm, clip_value, gnorm_scale return current_gnorm, clip_value, gnorm_scale
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor): def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor):
assert len(histogram.shape) == 2 assert len(histogram.shape) == 2
assert histogram.dtype == torch.float32 assert histogram.dtype == torch.float32
...@@ -1959,101 +1719,34 @@ def gemv_4bit( ...@@ -1959,101 +1719,34 @@ def gemv_4bit(
transposed_B=False, transposed_B=False,
state=None, state=None,
): ):
# sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
if state is None: if state is None:
raise ValueError("state cannot be None. gemv_4bit() requires the state from quantize_4bit()") raise ValueError("state cannot be None. gemv_4bit() requires the state from quantize_4bit()")
if A.numel() != A.shape[-1]:
raise ValueError(
'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]',
)
Bshape = state.shape
bout = Bshape[0]
absmax = state.absmax absmax = state.absmax
if state.nested: if state.nested:
absmax = dequantize_blockwise(state.absmax, state.state2) absmax = dequantize_blockwise(absmax, state.state2) + state.offset
absmax += state.offset
if out is None: if out is not None:
if len(A.shape) == 3: torch.ops.bitsandbytes.gemv_4bit.out(
out = torch.empty(size=(A.shape[0], A.shape[1], bout), dtype=A.dtype, device=A.device) A,
else: B,
out = torch.empty(size=(A.shape[0], bout), dtype=A.dtype, device=A.device) state.shape,
absmax,
n = 1 state.code,
m = Bshape[0] state.blocksize,
k = Bshape[1] out=out,
lda = Bshape[0]
ldc = Bshape[0]
ldb = (A.shape[-1] + 1) // 2
is_on_gpu([B, A, out, absmax, state.code])
m = ct.c_int32(m)
n = ct.c_int32(n)
k = ct.c_int32(k)
lda = ct.c_int32(lda)
ldb = ct.c_int32(ldb)
ldc = ct.c_int32(ldc)
stream = _get_tensor_stream(A)
with _cuda_device_of(A):
if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]:
if A.dtype == torch.float16:
lib.cgemm_4bit_inference_naive_fp16(
m,
n,
k,
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(state.code),
get_ptr(out),
lda,
ldb,
ldc,
ct.c_int32(state.blocksize),
stream,
)
elif A.dtype == torch.bfloat16:
lib.cgemm_4bit_inference_naive_bf16(
m,
n,
k,
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(state.code),
get_ptr(out),
lda,
ldb,
ldc,
ct.c_int32(state.blocksize),
stream,
)
elif A.dtype == torch.float32:
lib.cgemm_4bit_inference_naive_fp32(
m,
n,
k,
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(state.code),
get_ptr(out),
lda,
ldb,
ldc,
ct.c_int32(state.blocksize),
stream,
) )
else:
raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}")
else:
raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}")
return out return out
return torch.ops.bitsandbytes.gemv_4bit.default(
A,
B,
state.shape,
absmax,
state.code,
state.blocksize,
)
def igemm( def igemm(
A: Tensor, A: Tensor,
...@@ -2252,27 +1945,6 @@ def batched_igemm( ...@@ -2252,27 +1945,6 @@ def batched_igemm(
return out return out
@deprecated(
"igemmlt is deprecated and will be removed in a future release. Please use int8_linear_matmul instead.",
category=FutureWarning,
)
def igemmlt(
A: torch.Tensor,
B: torch.Tensor,
SA: Tuple[torch.Size, str],
SB: Tuple[torch.Size, str],
out: Optional[torch.Tensor] = None,
Sout: Optional[Tuple[torch.Size, str]] = None,
dtype=torch.int32,
):
if SA is not None and SA[1] != "row":
raise NotImplementedError(f"Only row-major format inputs are supported, but got format `{SA[1]}`")
if SB is not None and SB[1] != "row":
raise NotImplementedError(f"Only row-major format is supported for matrix B, but got format `{SB[1]}`")
result = int8_linear_matmul(A, B, out=out, dtype=dtype)
return result, (result.shape, "row")
def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtype=torch.int32): def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtype=torch.int32):
"""Performs an 8-bit integer matrix multiplication. """Performs an 8-bit integer matrix multiplication.
...@@ -2292,89 +1964,12 @@ def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Ten ...@@ -2292,89 +1964,12 @@ def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Ten
Returns: Returns:
`torch.Tensor`: The result of the operation. `torch.Tensor`: The result of the operation.
""" """
#
# To use the IMMA tensor core kernels without special Turing/Ampere layouts,
# cublasLt has some rules, namely: A must be transposed, B must not be transposed.
# The C++ API will calculate `C = A.T @ B` in with A, B, C in col-major.
# This will typically be used with row-major tensors to efficiently
# calculate the linear layer with `C = B @ A.T` without any transformations.
# We will swap A and B in the API invocation, so that we get `C = A @ B.T`.
#
# Quick explanation:
# With row-major A and B tensors, `C = A.T.T @ B.T = A @ B.T`.
# To get row-major output, `C.T = (A @ B.T).T = B @ A.T`.
#
A, B = B, A
shapeA = A.shape
shapeB = B.shape
assert A.dtype == torch.int8
assert B.dtype == torch.int8
assert A.ndim == 2, "Only two dimensional matrices are supported for argument B"
assert B.ndim in [2, 3], "Only two or three dimensional matrices are supported for argument A"
assert prod(shapeB) > 0, f"Input tensor dimensions need to be > 0: {shapeB}"
assert out is None or out.dtype == dtype
shapeC = (*shapeB[:-1], shapeA[0])
k, m = shapeA
n = prod(shapeB[:-1])
lda = shapeA[-1] # Weights (outputs, inputs)
ldb = shapeB[-1] # Activations (batch, tokens, inputs)
ldc = shapeC[-1] # Output (batch, tokens, outputs)
assert (
lda == ldb
), f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}"
# cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4.
# We'll fall back to a slower fp32 calculation in this circumstance.
# Fortunately, this should not be very common.
if lda % 4 != 0:
result = torch.matmul(B.float(), A.float().t()).to(torch.int32)
if out is not None: if out is not None:
result = out.copy_(result) torch.ops.bitsandbytes.int8_linear_matmul.out(A, B, out)
return result
if out is None:
out = torch.empty(shapeC, device=A.device, dtype=dtype)
is_on_gpu([A, B, out])
with _cuda_device_of(A):
ctx = CUBLAS_Context.get_instance().get_context(A.device)
ptrA = get_ptr(A)
ptrB = get_ptr(B)
ptrC = get_ptr(out)
ptrRowScale = None
m = ct.c_int32(m)
n = ct.c_int32(n)
k = ct.c_int32(k)
lda = ct.c_int32(lda)
ldb = ct.c_int32(ldb)
ldc = ct.c_int32(ldc)
stream = _get_tensor_stream(A)
if dtype == torch.int32:
has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream)
else:
has_error = lib.cigemmlt_8(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream)
if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu`
raise NotImplementedError("int8_linear_matmul not implemented!")
if has_error:
raise RuntimeError(
f"cublasLt ran into an error!\n"
f"\t{shapeA=}, {shapeB=}, {shapeC=}\n"
f"\t{(lda, ldb, ldc)=}\n"
f"\t{(m, n, k)=}"
)
return out return out
return torch.ops.bitsandbytes.int8_linear_matmul.default(A, B)
def int8_mm_dequant( def int8_mm_dequant(
A: torch.Tensor, A: torch.Tensor,
...@@ -2395,47 +1990,16 @@ def int8_mm_dequant( ...@@ -2395,47 +1990,16 @@ def int8_mm_dequant(
Returns: Returns:
`torch.Tensor`: The dequantized result with an optional bias, with dtype `torch.float16`. `torch.Tensor`: The dequantized result with an optional bias, with dtype `torch.float16`.
""" """
result = torch.ops.bitsandbytes.int8_mm_dequant.default(A, row_stats, col_stats, dtype=torch.float16, bias=bias)
assert A.dtype == torch.int32 # TODO(matthewdouglas): Deprecate out kwarg
if out is not None:
if bias is not None: return out.copy_(result)
assert bias.dtype == torch.float16
if out is None:
out = torch.empty_like(A, dtype=torch.float16)
ptrA = get_ptr(A)
ptrOut = get_ptr(out)
ptrRowStats = get_ptr(row_stats)
ptrColStats = get_ptr(col_stats)
ptrBias = get_ptr(bias)
numRows = ct.c_int32(prod(A.shape[:-1]))
numCols = ct.c_int32(A.shape[-1])
is_on_gpu([A, row_stats, col_stats, out, bias])
with _cuda_device_of(A):
lib.cdequant_mm_int32_fp16(
ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A)
)
return out
@deprecated("mm_dequant is deprecated. Please use int8_mm_dequant() instead.", category=FutureWarning) return result
def mm_dequant(
A: torch.Tensor,
quant_state: Optional[Tuple[torch.Size, str]], # Not used
row_stats: torch.Tensor,
col_stats: torch.Tensor,
out: Optional[torch.Tensor] = None,
new_row_stats=None, # Not used
new_col_stats=None, # Not used
bias: Optional[torch.Tensor] = None,
):
return int8_mm_dequant(A, row_stats, col_stats, out, bias)
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def get_colrow_absmax( def get_colrow_absmax(
A: torch.Tensor, A: torch.Tensor,
row_stats: Optional[torch.Tensor] = None, row_stats: Optional[torch.Tensor] = None,
...@@ -2493,6 +2057,7 @@ def get_colrow_absmax( ...@@ -2493,6 +2057,7 @@ def get_colrow_absmax(
return row_stats, col_stats, outlier_mask return row_stats, col_stats, outlier_mask
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def get_row_absmax(A: torch.Tensor, threshold=0.0): def get_row_absmax(A: torch.Tensor, threshold=0.0):
"""Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm. """Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.
...@@ -2611,72 +2176,6 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half): ...@@ -2611,72 +2176,6 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values)
@deprecated("This function is deprecated. Please use `int8_double_quant` instead.", category=FutureWarning)
def double_quant(
A: torch.Tensor,
col_stats: Optional[torch.Tensor] = None,
row_stats: Optional[torch.Tensor] = None,
out_col: Optional[torch.Tensor] = None,
out_row: Optional[torch.Tensor] = None,
threshold=0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[COOSparseTensor]]:
"""Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.
The statistics are determined both row-wise and column-wise (transposed).
For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339).
<Tip warning={true}>
This function exists for backwards compatibility only. It is advised to use [`int8_double_quant`] instead.
The difference is that this function will return a [`COOSparseTensor`] for outliers instead of a column index.
</Tip>
Args:
A (`torch.Tensor` with dtype `torch.float16`): The input matrix.
col_stats (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the column-wise quantization scales.
row_stats (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the row-wise quantization scales.
out_col (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the column-wise quantized data.
out_row (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the row-wise quantized data.
threshold (`float`, *optional*):
An optional threshold for sparse decomposition of outlier features.
No outliers are held back when 0.0. Defaults to 0.0.
Returns:
`Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple containing the quantized tensor and relevant statistics.
- `torch.Tensor` with dtype `torch.int8`: The row-wise quantized data.
- `torch.Tensor` with dtype `torch.int8`: The column-wise quantized data.
- `torch.Tensor` with dtype `torch.float32`: The row-wise quantization scales.
- `torch.Tensor` with dtype `torch.float32`: The column-wise quantization scales.
- `COOSparseTensor`, *optional*: A structure representing the outlier values from the input tensor.
"""
coo_tensor = None
quant_row, quant_col, row_stats, col_stats, outlier_cols = int8_double_quant(
A,
col_stats,
row_stats,
out_col,
out_row,
threshold=threshold,
)
if threshold > 0.0 and outlier_cols is not None:
# Build a COO tensor including all of the outlier columns.
outlier_rows = torch.arange(0, A.shape[0], device=A.device, dtype=torch.int32)
outliers = A[:, outlier_cols]
coo_tensor = COOSparseTensor(
A.shape[0],
A.shape[1],
outliers.numel(),
outlier_rows.repeat_interleave(outliers.size(1)),
outlier_cols.repeat(outliers.size(0)).int(),
outliers,
)
return quant_row, quant_col, row_stats, col_stats.flatten().float(), coo_tensor
def int8_double_quant( def int8_double_quant(
A: torch.Tensor, A: torch.Tensor,
col_stats: Optional[torch.Tensor] = None, col_stats: Optional[torch.Tensor] = None,
...@@ -2716,23 +2215,16 @@ def int8_double_quant( ...@@ -2716,23 +2215,16 @@ def int8_double_quant(
- `torch.Tensor` with dtype `torch.int32`, *optional*: A list of column indices which contain outlier features. - `torch.Tensor` with dtype `torch.int32`, *optional*: A list of column indices which contain outlier features.
""" """
# TODO: Optimize/write CUDA kernel for this? if row_stats is not None:
raise ValueError("row_stats must be None. int8_double_quant() does not support pre-allocated row_stats.")
# Use CUDA kernel for rowwise and COO tensor if col_stats is not None:
quant_row, row_stats, outlier_cols = int8_vectorwise_quant(A, threshold=threshold) raise ValueError("col_stats must be None. int8_double_quant() does not support pre-allocated col_stats.")
# PyTorch impl for colwise
_, col_stats, outlier_mask = get_colrow_absmax(A, threshold=threshold)
if threshold > 0.0 and outlier_mask is not None:
A = A.masked_fill(outlier_mask, 0.0)
quant_col = torch.round(A.mul(C) / col_stats.unsqueeze(0)).to(torch.int8)
if out_row is not None:
quant_row = out_row.copy_(quant_row)
if out_col is not None: if out_col is not None:
quant_col = out_col.copy_(quant_col) raise ValueError("out_col must be None. int8_double_quant() does not support pre-allocated out_col.")
if out_row is not None:
raise ValueError("out_row must be None. int8_double_quant() does not support pre-allocated out_row.")
return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols return torch.ops.bitsandbytes.int8_double_quant.default(A, threshold=threshold)
def int8_vectorwise_dequant(A: torch.Tensor, stats: torch.Tensor): def int8_vectorwise_dequant(A: torch.Tensor, stats: torch.Tensor):
...@@ -2746,7 +2238,7 @@ def int8_vectorwise_dequant(A: torch.Tensor, stats: torch.Tensor): ...@@ -2746,7 +2238,7 @@ def int8_vectorwise_dequant(A: torch.Tensor, stats: torch.Tensor):
`torch.Tensor` with dtype `torch.float32`: The dequantized tensor. `torch.Tensor` with dtype `torch.float32`: The dequantized tensor.
""" """
# To dequantize we divide by 127, or multiply by the reciprocal. # To dequantize we divide by 127, or multiply by the reciprocal.
return A * stats.view(-1, 1) * 7.874015718698502e-3 return torch.ops.bitsandbytes.int8_vectorwise_dequant.default(A, stats)
def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0): def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
...@@ -2767,94 +2259,7 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0): ...@@ -2767,94 +2259,7 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
- `torch.Tensor` with dtype `torch.float32`: The quantization scales. - `torch.Tensor` with dtype `torch.float32`: The quantization scales.
- `torch.Tensor` with dtype `torch.int32`, *optional*: A list of column indices which contain outlier features. - `torch.Tensor` with dtype `torch.int32`, *optional*: A list of column indices which contain outlier features.
""" """
return torch.ops.bitsandbytes.int8_vectorwise_quant.default(A, threshold)
assert A.dtype == torch.half
is_on_gpu([A])
rows = prod(A.shape[:-1])
cols = A.shape[-1]
row_stats = torch.empty(rows, device=A.device, dtype=torch.float32)
out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8)
outlier_cols = None
if threshold > 0.0:
# TODO we could improve perf of this
outliers = A.abs() >= threshold
if outliers.any():
outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1)
with _cuda_device_of(A):
lib.cint8_vector_quant(
get_ptr(A),
get_ptr(out_row),
get_ptr(row_stats),
ct.c_float(threshold),
ct.c_int32(rows),
ct.c_int32(cols),
_get_tensor_stream(A),
)
# Zero out values from outlier columns across all rows.
# The kernel will handle this for outliers themselves, so we can optimize for rows=1.
if rows > 1 and outlier_cols is not None:
out_row[:, outlier_cols] = 0
return out_row, row_stats, outlier_cols
@deprecated(
"The layout transformation operations will be removed in a future release. Please use row-major layout only.",
category=FutureWarning,
)
def transform(A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None):
prev_device = pre_call(A.device)
if state is None:
state = (A.shape, from_order)
else:
from_order = state[1]
if out is None:
out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose)
else:
new_state = (state[0], to_order) # (shape, order)
shape = state[0]
if len(shape) == 2:
dim1 = ct.c_int32(shape[0])
dim2 = ct.c_int32(shape[1])
else:
dim1 = ct.c_int32(shape[0] * shape[1])
dim2 = ct.c_int32(shape[2])
is_on_gpu([A, out])
if to_order == "col32":
if transpose:
lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2)
else:
lib.ctransform_row2col32(get_ptr(A), get_ptr(out), dim1, dim2)
elif to_order == "col_turing":
if transpose:
lib.ctransform_row2turingT(get_ptr(A), get_ptr(out), dim1, dim2)
else:
lib.ctransform_row2turing(get_ptr(A), get_ptr(out), dim1, dim2)
elif to_order == "col_ampere":
if transpose:
lib.ctransform_row2ampereT(get_ptr(A), get_ptr(out), dim1, dim2)
else:
lib.ctransform_row2ampere(get_ptr(A), get_ptr(out), dim1, dim2)
elif to_order == "row":
if from_order == "col_turing":
lib.ctransform_turing2row(get_ptr(A), get_ptr(out), dim1, dim2)
elif from_order == "col_ampere":
lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2)
else:
raise NotImplementedError(f"Transform function not implemented: From {from_order} to {to_order}")
post_call(prev_device)
return out, new_state
def spmm_coo( def spmm_coo(
...@@ -3059,7 +2464,7 @@ def vectorwise_quant(x, dim=1, quant_type="vector"): ...@@ -3059,7 +2464,7 @@ def vectorwise_quant(x, dim=1, quant_type="vector"):
@deprecated( @deprecated(
"This function is deprecated and will be removed in a future release. Consider using `int8_vectorwise_dequant` instead.", "This function is deprecated and will be removed in a future release.",
category=FutureWarning, category=FutureWarning,
) )
def vectorwise_dequant(xq, max1, quant_type="vector"): def vectorwise_dequant(xq, max1, quant_type="vector"):
...@@ -3071,7 +2476,7 @@ def vectorwise_dequant(xq, max1, quant_type="vector"): ...@@ -3071,7 +2476,7 @@ def vectorwise_dequant(xq, max1, quant_type="vector"):
@deprecated( @deprecated(
"This function is deprecated and will be removed in a future release. Consider using `int8_mm_dequant` instead.", "This function is deprecated and will be removed in a future release.",
category=FutureWarning, category=FutureWarning,
) )
def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"): def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"):
...@@ -3131,51 +2536,3 @@ def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"): ...@@ -3131,51 +2536,3 @@ def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"):
return x.to(dtype) return x.to(dtype)
else: else:
return None return None
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half):
offset = B.float().t().sum(0) * (SA[0] + SA[1])
x = xq.float()
if len(xq.shape) == 2 and len(SB.shape) == 3:
SB = SB.squeeze(0)
if len(SB.shape) == 2:
x *= SB.t() / 127
else:
x *= SB / 127
x *= SA[1] / 127
x += offset
return x.to(dtype)
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def extract_outliers(A, SA, idx):
shapeA = SA[0]
formatA = SA[1]
assert formatA in ["col_turing", "col_ampere"]
assert A.device.type == "cuda"
out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device)
idx_size = ct.c_int32(idx.numel())
rows = ct.c_int32(shapeA[0])
cols = ct.c_int32(shapeA[1])
ptrA = get_ptr(A)
ptrIdx = get_ptr(idx)
ptrOut = get_ptr(out)
prev_device = pre_call(A.device)
if formatA == "col_turing":
lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
elif formatA == "col_ampere":
lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
post_call(prev_device)
return out
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def pipeline_test(A, batch_size):
out = torch.zeros_like(A)
lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size))
return out
...@@ -11,7 +11,6 @@ from torch import Tensor, device, dtype, nn ...@@ -11,7 +11,6 @@ from torch import Tensor, device, dtype, nn
import torch.nn.functional as F import torch.nn.functional as F
import bitsandbytes as bnb import bitsandbytes as bnb
from bitsandbytes.autograd._functions import get_tile_inds, undo_layout
from bitsandbytes.functional import QuantState from bitsandbytes.functional import QuantState
from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import ( from bitsandbytes.utils import (
...@@ -654,8 +653,7 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k ...@@ -654,8 +653,7 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k
weight_format = INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weight_format] weight_format = INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weight_format]
if weight_format != "row": if weight_format != "row":
tile_indices = get_tile_inds(weight_format, weight.device) raise ValueError(f"Only 'row' weight format is supported, got {weight_format}")
state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices)
class Embedding8bit(nn.Embedding): class Embedding8bit(nn.Embedding):
......
...@@ -2205,333 +2205,6 @@ __global__ void kdequant_mm_int32_fp16( ...@@ -2205,333 +2205,6 @@ __global__ void kdequant_mm_int32_fp16(
} }
} }
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols)
{
// 0. Load data into 32*32 shared memory tiles
// 1. transpose / reorder in shared memory
// 2. store
// COL32 FORMAT:
// rows*32 tiles
// TURING FORMAT:
// 8*32 tiles with 4*4 subtiles
// the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements)
// the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
// the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
// the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
// index increases by 32
// AMPERE FORMAT:
// 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows:
// row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
// the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32]
// To have efficient loads and stores if we transpose we need 128 consequitive bytes which at 1 byte are 128 values
// As such we need:
// at least 32*4 shared memory tiles for col32; preferably 32*32
// at least 32*6 shared memory tiles for col32_ampere: preferably 32*32
// at least 32*8 shared memory tiles for col4_turing: preferably 32*32
// for efficient loading of row major we need to load 128 elements and repeat this 32 items
// this would imply a 32x128 shared memory tile -> 4kb
// It is more efficient to have more than 1 warp, so with 64 threads we need 32x128 -> 8 kb
// we have 64k sharded mem per SM in Turing which is 8 blocks per SM which is 2*8 = 32 warps = 100% occupancy
// for turing and 50% for A100 and 75% for RTX 30s / A40 which is probably good enough
// register pressure should be low with: 8 registers from local memoryh per block and 64 registers per SM
//
// to make the shared memory work with that occupancy we might need to union the block loads/stores
// each block loads TILE_COLs columns and TILE_ROW rows
// after reading a tile the row counter increase by TILE_ROWS
// the col counter reset after reading TILE_COL elements
const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS;
// col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
const int base_col = (blockIdx.x*TILE_COLS) % tiledCols;
const int base_idx = (base_row*cols) + base_col;
// we load 128 bytes per warp with
// 32 rows for transposes that fill col32 types
// so that we can have contiguous stores
__shared__ char smem_data[32*33*ITEMS_PER_THREAD];
char local_data[ITEMS_PER_THREAD];
typedef cub::BlockExchange<char, THREADS, ITEMS_PER_THREAD> BlockExchange;
// we load row after row from the base_position
// Load data row by row
int warps = blockDim.x/32;
int warp_id = threadIdx.x/32;
int warp_lane = threadIdx.x % 32;
int offset = 0;
int smem_row = 0;
// each warp loads one row of 128 bytes
for(int row = warp_id; row < TILE_ROWS; row+=warps)
{
int i = base_idx + (row*cols);
// we load up to 128 bytes/items per load
int valid_items = cols - base_col > 32*ITEMS_PER_THREAD ? 32*ITEMS_PER_THREAD : cols - base_col;
// 0. Load data into 32*32 shared memory tiles
if(base_row + row < rows)
{
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
{
int col_idx = warp_lane+(j*32);
if(col_idx < valid_items)
local_data[j] = A[i+col_idx];
else
local_data[j] = 0;
}
}
else
{
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
local_data[j] = 0;
}
if(TRANSPOSE)
{
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
{
int local_col = (32*j)+warp_lane;
//int local_row = row;
// store as 256x32
smem_data[(local_col*33) + row] = local_data[j];
}
}
else
{
// treat smem as 32x256, that is 32 rows and 256 columns
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
smem_data[row*32*ITEMS_PER_THREAD + (warp_lane) + (j*32)] = local_data[j];
}
smem_row += warps;
// 1. transpose / reorder in shared memory
if(smem_row % 32 == 0)
{
smem_row = 0;
__syncthreads();
for(int subrow = warp_id; subrow < 32; subrow+=warps)
{
for(int j = 0; j < ITEMS_PER_THREAD; j++)
{
switch(FORMAT)
{
case COL32:
if(TRANSPOSE)
{
// data lies in shared memory in the following way:
// row0 [col0 col1 ... col31]
// row1 [col0 col1 ... col31]
// ...
//
// As such we read consecutive entries with 256 threads (8rows x 32 columns)
// as j increase, the row increase by a factor of 8
// We load 8 rows per subrow loop, and subrow increase by 8 per loop
// so we have an offset of 8 rows every loop or (subrow/warps)*8 = (subrow/8)*8
const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
//const int local_row = warp_id; // each warp_id is one row
//const int block_row = base_col; // block offset for row
//const int local_col = warp_lane
//const int global_col = base_row; // block offset for col
if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
{
// each row has 32 columns and is offset by 1 to prevent bank conflict during storage into smem
char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];
// each 32 columns we have new tile
// each tile has size outRows*32 and base_row is done in increments of 32
offset = base_row*outRows;
out[offset + (base_col + jrow + subrow_loop_row)*32 + threadIdx.x] = data;
}
}
else
{
if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols))
{
offset = (base_col/32)*(32*rows);
char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane];
out[offset+(base_row+subrow)*32 + ((j)*rows*32)+warp_lane] = data;
}
}
break;
case COL_TURING:
// TURING FORMAT:
// 8*32 tiles with 4*4 subtiles
// the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements)
// the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
// the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
// the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
// index increases by 32
//
// [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...]
if(TRANSPOSE)
{
const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
//const int local_row = warp_id; // each warp_id is one row
//const int block_row = base_col; // block offset for row
//const int local_col = warp_lane
//const int global_col = base_row; // block offset for col
if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
{
// each row has 32 columns and is offset by 1 to prevent bank conflict during storage into smem
char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];
// each 32 columns we have new tile
// each tile has size 8*32 = 256 elements offset
// for each row offset of 8 we increaes the tile first
// after all rows are exhausted, we increase the col
int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/8)*256; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows
// we increase by row_tile_column every 32 columns
// base_row increase in increments of 32
//int row_tile_column = 256*outRows/8; // there are outRows/8 row tiles, and each tile is 256 elements
//int col_offset = (base_row/32)*row_tile_column;
// -> we can remove the divisions to speed up compute since outRows is always a multiple of 8
// 256*outRows/8*base_row/32 = outRows*base_row
int col_offset = outRows*base_row;
offset = row_offset+col_offset;
// since we process even number of rows with each j (8) and with each subrow (8j) we can determine
// odd or even rows with the warp_id (each warp processes one row)
// the col is warp_lane (max 32 columns per row) and the row warp_id
if(warp_id % 2 == 1)
// odd
offset += 128 + (warp_lane/4)*16 + (warp_lane%4) + (((warp_id%8)-1)*2);
else
// even
offset += 0 + (warp_lane/4)*16 + (warp_lane%4) + ((warp_id%8)*2);
out[offset] = data;
}
}
else
{
if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols))
{
char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane];
// set offset designates the tile offset among the 8*32 tiles
// we first increase rows and then columns. Since we load 128 columns at once
// we increase the offset by outRows*32 every 32 columns
// additionally, we increase the offset by 8*32=256 every 8 rows
offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/8)*256); // global offset (8x32 tile)
// first 4 rows are reserved for even rows, [0, 2, 4, 6], the next 4 for odd
// each of these has 32 values in total for 32*4 = 128 as offset if odd
// every set of 4 columns increases the total offset by 16
// each even row increase the offset by 4, for example row 2 is offset by 4, 4 by 6 etc so: subrow/2*4 = subrow*2
// this happens every 8 rows anew (subrow % 8)
// one writes 4 columns at once that is (col % 4) for the particular index in the subtile
int subcol = warp_lane;
// add local offset (4x4 sub-tile)
if(subrow % 2 == 1)
// odd
offset += 128 + (subcol/4)*16 + (subcol%4) + (((subrow%8)-1)*2);
else
// even
offset += 0 + (subcol/4)*16 + (subcol%4) + ((subrow%8)*2);
out[offset] = data;
}
}
break;
case COL_AMPERE:
// AMPERE FORMAT:
// 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows:
// row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
// the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32]
if(TRANSPOSE)
{
const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
//const int local_row = warp_id; // each warp_id is one row
//const int block_row = base_col; // block offset for row
//const int local_col = warp_lane
//const int global_col = base_row; // block offset for col
if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
{
// each row has 32 columns and is offset by 1 to prevent bank conflict during storage into smem
char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];
// each 32 columns we have new tile
// each tile has size 32*32 = 1024 elements offset
// for each row offset of 32 we increaes the tile first
// after all rows are exhausted, we increase the col
int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/32)*1024; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows
// we increase by row_tile_column every 32 columns
// base_row increase in increments of 32
//int row_tile_column = 1024*outRows/32; // there are outRows/32 row tiles, and each tile is 1024 elements
//int col_offset = (base_row/32)*row_tile_column;
// -> we can remove the divisions to speed up compute since outRows is always a multiple of 8
// 1024*outRows/32*base_row/32 = outRows*base_row
int col_offset = outRows*base_row;
offset = row_offset+col_offset;
// same as in the non-transpose case (see below)
// the difference is that now rows = cols
// in this case warp_id = subrow
// [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
// subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc
// subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row
// every 2 rows, the offset increases by two [0, 1, 8, 9...]
// every 2 rows, the row index increase by 8 [0, 1, 8, 9...]
int local_row = (jrow + warp_id) % 32; // offset for row > 32 is already calculated into row_offset
int ampere_row = ((local_row % 8)/2)*8 + (local_row/8)*2 + (local_row % 2);
// global offset + row with 32 cols each + 32 cols per j + col_idx=warp_lane
out[offset + (ampere_row*32) + warp_lane] = data;
}
}
else
{
if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols))
{
char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane];
// set offset designates the tile offset among the 32*32 tiles
// we first increase rows and then columns. Since we load 128 columns at once
// we increase the offset by outRows*32 every 32 columns
// additionally, we increase the offset by 32*32=1024 every 32 rows
offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/32)*1024); // global offset (32x32 tile)
// [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
// subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc
// subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row
// every 2 rows, the offset increases by two [0, 1, 8, 9...]
// every 2 rows, the row index increase by 8 [0, 1, 8, 9...]
int local_row = ((subrow % 8)/2)*8 + (subrow/8)*2 + (subrow % 2);
// global offset + row with 32 cols each + 32 cols per j + col_idx
out[offset + (local_row*32) + warp_lane] = data;
}
}
break;
}
}
}
}
}
}
#define DENORM 1.0f/127.0f #define DENORM 1.0f/127.0f
#define MAX_SPARSE_COUNT 32 #define MAX_SPARSE_COUNT 32
#define SMEM_SIZE 8*256 #define SMEM_SIZE 8*256
...@@ -2679,69 +2352,6 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o ...@@ -2679,69 +2352,6 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
} }
} }
template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA)
{
int local_colidx = idx[blockIdx.x];
if(FORMAT==COL_TURING)
{
// TURING FORMAT:
// 8*32 tiles with 4*4 subtiles
// the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*8 = 128 elements)
// the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
// the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
// the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
// index increases by 32
// columns are grouped in increments of 4, meaning that one has the following rows and columns
// rows: [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...]
// cols: [0 1 2 3, 0 1 2 4, 0 1 2 3, 0 1 2 3, 4 5 6 7 ...]
// each thread reads 1 element = 1 row
for(int row = threadIdx.x; row < rowsA; row+= blockDim.x)
{
int offset_per_col_tile = ((rowsA+7)/8)*32*8;
int tile_offset_rows = (row/8)*32*8;
int tile_offset_cols = (local_colidx/32)*offset_per_col_tile;
int offset = 0;
int subtile_col_idx = local_colidx%32;
int subtile_row_idx = row % 8;
if(row % 2 == 1)
offset += 128 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + ((subtile_row_idx-1)*2);
else
// even
offset += 0 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + (subtile_row_idx*2);
offset += tile_offset_rows + tile_offset_cols;
char val = A[offset];
int out_idx = (row*idx_size) + blockIdx.x;
out[out_idx] = val;
}
}
else if(FORMAT == COL_AMPERE)
{
for(int row = threadIdx.x; row < rowsA; row+= blockDim.x)
{
// we got 32x32 tiles and we use the magic equation from the cublasLt doc to get the element
// within each tile.
int offset_per_col_tile = ((rowsA+31)/32)*32*32;
int tile_offset_rows = (row/32)*32*32;
int tile_offset_cols = (local_colidx/32)*offset_per_col_tile;
int subtile_col_idx = local_colidx%32;
int subtile_row_idx = row % 32;
// this magic is taken from the cublasLt doc (search for COL32)
int offset = (((subtile_row_idx%8)/2*4+subtile_row_idx/8)*2+subtile_row_idx%2)*32+subtile_col_idx;
offset += tile_offset_cols + tile_offset_rows;
char val = A[offset];
int out_idx = (row*idx_size) + blockIdx.x;
out[out_idx] = val;
}
}
}
#define WARPS 3 #define WARPS 3
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc)
{ {
...@@ -3376,9 +2986,6 @@ template __global__ void kgemm_4bit_inference_naive<half, 128, 16>(int M, int N, ...@@ -3376,9 +2986,6 @@ template __global__ void kgemm_4bit_inference_naive<half, 128, 16>(int M, int N,
template __global__ void kgemm_4bit_inference_naive<__nv_bfloat16, 128, 16>(int M, int N, int K, __nv_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference_naive<__nv_bfloat16, 128, 16>(int M, int N, int K, __nv_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference_naive<float, 128, 32>(int M, int N, int K, float * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, float * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference_naive<float, 128, 32>(int M, int N, int K, float * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, float * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
template __global__ void kExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
...@@ -3386,13 +2993,6 @@ template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max ...@@ -3386,13 +2993,6 @@ template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max
template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kdequant_mm_int32_fp16<4, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n); template __global__ void kdequant_mm_int32_fp16<4, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n);
template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x); template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x);
......
...@@ -121,8 +121,6 @@ template<typename T, int THREADS, int SPARSE_DECOMP> __global__ void kInt8Vector ...@@ -121,8 +121,6 @@ template<typename T, int THREADS, int SPARSE_DECOMP> __global__ void kInt8Vector
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc);
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize);
......
...@@ -374,48 +374,6 @@ template<int ORDER> int get_leading_dim(int dim1, int dim2) ...@@ -374,48 +374,6 @@ template<int ORDER> int get_leading_dim(int dim1, int dim2)
} }
} }
template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2)
{
cublasLtOrder_t orderA = get_order<SRC>();
cublasLtOrder_t orderOut = get_order<TARGET>();
int ldA = get_leading_dim<SRC>(dim1, dim2);
int ldOut = get_leading_dim<TARGET>(dim1, dim2);
cublasLtMatrixLayout_t A_desc = NULL, out_desc = NULL;
cublasLtMatrixTransformDesc_t A2Out_desc = NULL;
cublasOperation_t opTranspose = CUBLAS_OP_T;
float transformAlpha = 1.0f, transformBeta = 0.0f;
if(DTYPE == 8)
{
checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_8I, dim1, dim2, ldA));
checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, CUDA_R_8I, dim1, dim2, ldOut));
}
else if(DTYPE == 32)
{
checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_32I, dim1, dim2, ldA));
checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, CUDA_R_32I, dim1, dim2, ldOut));
}
else
{
printf("ERROR WRONG TYPE FOR TRANSFORM: %i\n", DTYPE);
}
checkCublasStatus(cublasLtMatrixLayoutSetAttribute(A_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderA, sizeof(orderA)));
checkCublasStatus(cublasLtMatrixLayoutSetAttribute(out_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderOut, sizeof(orderOut)));
checkCublasStatus(cublasLtMatrixTransformDescCreate(&A2Out_desc, CUDA_R_32F));
if(transpose){ checkCublasStatus(cublasLtMatrixTransformDescSetAttribute(A2Out_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &opTranspose, sizeof(opTranspose))); }
checkCublasStatus(cublasLtMatrixTransform(ltHandle, A2Out_desc, &transformAlpha, A, A_desc, &transformBeta, NULL, NULL, out, out_desc, 0));
if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc));
if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc));
if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc));
}
template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt( template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(
cublasLtHandle_t ltHandle, cublasLtHandle_t ltHandle,
int m, int n, int k, int m, int n, int k,
...@@ -542,50 +500,6 @@ void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, ...@@ -542,50 +500,6 @@ void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols,
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *out, int rows, int cols)
{
int threads = 256;
int items_per_thread = 8;
// we load 128 column values per warp
int tile_cols = 32*items_per_thread;
int tile_rows = 32;
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
int row_tiles = (tiledRows/tile_rows);
int col_tiles = (tiledCols/tile_cols);
row_tiles = row_tiles > 0 ? row_tiles : 1;
col_tiles = col_tiles > 0 ? col_tiles : 1;
int num_blocks = row_tiles * col_tiles;
int outCols = fill_up_to_nearest_multiple(cols, 32);
int outRows = fill_up_to_nearest_multiple(rows, 32);
if(FORMAT == COL_TURING)
{
if(TRANSPOSE)
outRows = fill_up_to_nearest_multiple(cols, 8);
else
outRows = fill_up_to_nearest_multiple(rows, 8);
}
else if(FORMAT == COL_AMPERE)
{
if(TRANSPOSE)
outRows = fill_up_to_nearest_multiple(cols, 32);
else
outRows = fill_up_to_nearest_multiple(rows, 32);
}
else
{
if(TRANSPOSE)
{
outCols = fill_up_to_nearest_multiple(rows, 32);
outRows = cols;
}
}
kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<<num_blocks, threads>>>(A, out, rows, cols, tiledCols, outRows, outCols);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
{ {
cusparseSpMatDescr_t descA; cusparseSpMatDescr_t descA;
...@@ -643,32 +557,6 @@ template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, ...@@ -643,32 +557,6 @@ template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count,
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols)
{
int threads = 256;
// we load 128 column values per warp
int tiledCols = tiledCols = fill_up_to_nearest_multiple(cols, 32);
int tiledRows = 0;
int num_blocks = idx_size;
if(FORMAT == COL_TURING)
{
tiledRows = fill_up_to_nearest_multiple(rows, 8);
}
else if(FORMAT == COL_AMPERE)
{
tiledRows = fill_up_to_nearest_multiple(rows, 32);
}
kExtractOutliers<FORMAT><<<num_blocks, threads>>>(A, idx, out, idx_size, rows, cols, tiledRows, tiledCols);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits)
{ {
...@@ -722,8 +610,6 @@ template void gemm_4bit_inference_naive<float, 32>(int m, int n, int k, float * ...@@ -722,8 +610,6 @@ template void gemm_4bit_inference_naive<float, 32>(int m, int n, int k, float *
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); //template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits);
template void extractOutliers<COL_TURING>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
template void extractOutliers<COL_AMPERE>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
template void spmm_coo_very_sparse_naive<half, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); template void spmm_coo_very_sparse_naive<half, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
template void spmm_coo_very_sparse_naive<signed char, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); template void spmm_coo_very_sparse_naive<signed char, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
...@@ -732,13 +618,6 @@ template int igemmlt<32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, cons ...@@ -732,13 +618,6 @@ template int igemmlt<32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, cons
template int igemmlt<8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream); template int igemmlt<8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream);
template int igemmlt<8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream); template int igemmlt<8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream);
template void transformRowToFormat<COL32, 0>(char * A, char *out, int rows, int cols);
template void transformRowToFormat<COL32, 1>(char * A, char *out, int rows, int cols);
template void transformRowToFormat<COL_TURING, 0>(char * A, char *out, int rows, int cols);
template void transformRowToFormat<COL_TURING, 1>(char * A, char *out, int rows, int cols);
template void transformRowToFormat<COL_AMPERE, 0>(char * A, char *out, int rows, int cols);
template void transformRowToFormat<COL_AMPERE, 1>(char * A, char *out, int rows, int cols);
template void estimateQuantiles(half *A, float *code, float offset, int n); template void estimateQuantiles(half *A, float *code, float offset, int n);
template void estimateQuantiles(float *A, float *code, float offset, int n); template void estimateQuantiles(float *A, float *code, float offset, int n);
...@@ -840,15 +719,6 @@ MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX); ...@@ -840,15 +719,6 @@ MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX);
template void percentileClipping(float * g, float *gnorm_vec, int step, const int n); template void percentileClipping(float * g, float *gnorm_vec, int step, const int n);
template void percentileClipping(half * g, float *gnorm_vec, int step, const int n); template void percentileClipping(half * g, float *gnorm_vec, int step, const int n);
template void transform<int8_t, ROW, COL, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
template void transform<int8_t, ROW, ROW, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
template void transform<int8_t, ROW, COL32, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
template void transform<int32_t, ROW, COL32, false, 32>(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2);
template void transform<int8_t, ROW, COL_TURING, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
template void transform<int8_t, ROW, COL_AMPERE, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
template void transform<int8_t, COL32, ROW, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
template void transform<int32_t, COL32, ROW, false, 32>(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2);
template int get_leading_dim<ROW>(int dim1, int dim2); template int get_leading_dim<ROW>(int dim1, int dim2);
template int get_leading_dim<COL>(int dim1, int dim2); template int get_leading_dim<COL>(int dim1, int dim2);
template int get_leading_dim<COL32>(int dim1, int dim2); template int get_leading_dim<COL32>(int dim1, int dim2);
...@@ -173,20 +173,15 @@ void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, i ...@@ -173,20 +173,15 @@ void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, i
template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream); template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream);
template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2);
void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc);
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream); void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream);
void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream); void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream);
void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream); void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream);
template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *out, int rows, int cols);
void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B); void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B);
template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols);
void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB); void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB);
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits);
......
...@@ -149,32 +149,6 @@ void dequantizeBlockwise_bf16(float *code, unsigned char *A, float *absmax, __nv ...@@ -149,32 +149,6 @@ void dequantizeBlockwise_bf16(float *code, unsigned char *A, float *absmax, __nv
void dequantizeBlockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, FP4>(NULL, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, FP4>(NULL, A, absmax, out, blocksize, n, stream); }
void dequantizeBlockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, NF4>(NULL, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, NF4>(NULL, A, absmax, out, blocksize, n, stream); }
#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \
{ \
transform<dtype, src, target, transpose, bits>(ltHandle, A, out, dim1, dim2); \
} \
MAKE_FUNC_TRANSFORM(8, row, col, n, int8_t, ROW, COL, false, 8);
MAKE_FUNC_TRANSFORM(8, row, row, n, int8_t, ROW, ROW, false, 8);
MAKE_FUNC_TRANSFORM(8, row, col32, n, int8_t, ROW, COL32, false, 8);
MAKE_FUNC_TRANSFORM(32, row, col32, n, int32_t, ROW, COL32, false, 32);
MAKE_FUNC_TRANSFORM(8, row, col_turing, n, int8_t, ROW, COL_TURING, false, 8);
MAKE_FUNC_TRANSFORM(8, row, col_ampere, n, int8_t, ROW, COL_AMPERE, false, 8);
MAKE_FUNC_TRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8);
MAKE_FUNC_TRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32);
void transform_row2col32(char * A, char *out, int rows, int cols){ transformRowToFormat<COL32, 0>(A, out, rows, cols); }
void transform_row2col32T(char * A, char *out, int rows, int cols){ transformRowToFormat<COL32, 1>(A, out, rows, cols); }
void transform_row2turing(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_TURING, 0>(A, out, rows, cols); }
void transform_row2turingT(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_TURING, 1>(A, out, rows, cols); }
void transform_row2ampere(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_AMPERE, 0>(A, out, rows, cols); }
void transform_row2ampereT(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_AMPERE, 1>(A, out, rows, cols); }
void extractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers<COL_TURING>(A, idx, out, idx_size, rows, cols); }
void extractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers<COL_AMPERE>(A, idx, out, idx_size, rows, cols); }
int igemmlt_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) { int igemmlt_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
return igemmlt<32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); return igemmlt<32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
} }
...@@ -317,22 +291,6 @@ extern "C" ...@@ -317,22 +291,6 @@ extern "C"
int cigemmlt_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) { int cigemmlt_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
return igemmlt_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); return igemmlt_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
} }
#define MAKE_FUNC_CTRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
void ctransform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(Context *context, dtype *A, dtype *out, int dim1, int dim2) \
{ \
transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose((cublasLtHandle_t) context->m_handle, A, out, dim1, dim2); \
} \
MAKE_FUNC_CTRANSFORM(8, row, col, n, int8_t, ROW, COL, false, 8)
MAKE_FUNC_CTRANSFORM(8, row, row, n, int8_t, ROW, ROW, false, 8)
MAKE_FUNC_CTRANSFORM(8, row, col32, n, int8_t, ROW, COL32, false, 8)
MAKE_FUNC_CTRANSFORM(32, row, col32, n, int32_t, ROW, COL32, false, 32)
MAKE_FUNC_CTRANSFORM(8, row, col_turing, n, int8_t, ROW, COL_TURING, false, 8)
MAKE_FUNC_CTRANSFORM(8, row, col_ampere, n, int8_t, ROW, COL_AMPERE, false, 8)
MAKE_FUNC_CTRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8)
MAKE_FUNC_CTRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32)
void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream) void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream)
{ dequant_mm_int32_fp16(A, rowStats, colStats, out, bias, numRows, numCols, stream); } { dequant_mm_int32_fp16(A, rowStats, colStats, out, bias, numRows, numCols, stream); }
void cget_row_stats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) { void cget_row_stats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) {
...@@ -342,24 +300,6 @@ extern "C" ...@@ -342,24 +300,6 @@ extern "C"
int8VectorQuant(A, out, rowStats, threshold, rows, cols, stream); int8VectorQuant(A, out, rowStats, threshold, rows, cols, stream);
} }
void ctransform_row2col32(char * A, char *out, int rows, int cols)
{ transform_row2col32(A, out, rows, cols); }
void ctransform_row2col32T(char * A, char *out, int rows, int cols)
{ transform_row2col32T(A, out, rows, cols); }
void ctransform_row2turing(char * A, char *out, int rows, int cols)
{ transform_row2turing(A, out, rows, cols); }
void ctransform_row2turingT(char * A, char *out, int rows, int cols)
{ transform_row2turingT(A, out, rows, cols); }
void ctransform_row2ampere(char * A, char *out, int rows, int cols)
{ transform_row2ampere(A, out, rows, cols); }
void ctransform_row2ampereT(char * A, char *out, int rows, int cols)
{ transform_row2ampereT(A, out, rows, cols); }
void cspmm_coo(ContextCusparse *context, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) void cspmm_coo(ContextCusparse *context, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
{ spmm_coo((cusparseHandle_t) context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C, transposed_B); } { spmm_coo((cusparseHandle_t) context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C, transposed_B); }
...@@ -369,9 +309,6 @@ extern "C" ...@@ -369,9 +309,6 @@ extern "C"
void cspmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) void cspmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
{ spmm_coo_very_sparse_naive_int8(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } { spmm_coo_very_sparse_naive_int8(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
void cextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_turing(A, idx, out, idx_size, rows, cols); }
void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); }
//void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) //void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc)
//{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); } //{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); }
......
...@@ -174,7 +174,7 @@ export BNB_CUDA_VERSION=126 ...@@ -174,7 +174,7 @@ export BNB_CUDA_VERSION=126
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/YOUR_USERNAME/local/cuda-12.6 export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/YOUR_USERNAME/local/cuda-12.6
``` ```
3. Now when you launch bitsandbytes with these environment variables, the PyTorch CUDA version is overridden by the new CUDA version (in this example, version 11.7) and a different bitsandbytes library is loaded. 3. Now when you launch bitsandbytes with these environment variables, the PyTorch CUDA version is overridden by the new CUDA version (in this example, version 12.6) and a different bitsandbytes library is loaded.
## Multi-backend Support (Alpha Release)[[multi-backend]] ## Multi-backend Support (Alpha Release)[[multi-backend]]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment